Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs @ 12597

Last change on this file since 12597 was 12590, checked in by gkronber, 9 years ago

#2261: preparations for trunk integration (adapt to current trunk version, add license headers, add comments, improve code quality)

File size: 3.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  [StorableClass]
32  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
33  public class RegressionTreeModel : NamedItem, IRegressionModel {
34
35    // trees are represented as a flat array
36    public struct TreeNode {
37      public readonly static string NO_VARIABLE = string.Empty;
38      public string varName; // name of the variable for splitting or -1 if terminal node
39      public double val; // threshold
40      public int leftIdx;
41      public int rightIdx;
42
43      public override int GetHashCode() {
44        return leftIdx ^ rightIdx ^ val.GetHashCode();
45      }
46    }
47
48    [Storable]
49    public readonly TreeNode[] tree;
50
51    [StorableConstructor]
52    private RegressionTreeModel(bool serializing) : base(serializing) { }
53    // cloning ctor
54    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
55      : base(original, cloner) {
56      this.tree = original.tree; // shallow clone, tree must be readonly
57    }
58
59    public RegressionTreeModel(TreeNode[] tree)
60      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
61      this.tree = tree;
62    }
63
64    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) {
65      var node = t[nodeIdx];
66      if (node.varName == TreeNode.NO_VARIABLE)
67        return node.val;
68      else if (ds.GetDoubleValue(node.varName, row) <= node.val)
69        return GetPredictionForRow(t, node.leftIdx, ds, row);
70      else
71        return GetPredictionForRow(t, node.rightIdx, ds, row);
72    }
73
74    public override IDeepCloneable Clone(Cloner cloner) {
75      return new RegressionTreeModel(this, cloner);
76    }
77
78    public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
79      return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
80    }
81
82    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
83      return new RegressionSolution(this, new RegressionProblemData(problemData));
84    }
85  }
86
87}
Note: See TracBrowser for help on using the repository browser.