Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs @ 12632

Last change on this file since 12632 was 12632, checked in by gkronber, 9 years ago

#2261 implemented node expansion using a priority queue (and changed parameter MaxDepth to MaxSize). Moved unit tests to a separate project.

File size: 3.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  [StorableClass]
32  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
33  public class RegressionTreeModel : NamedItem, IRegressionModel {
34
35    // trees are represented as a flat array
36    public struct TreeNode {
37      public readonly static string NO_VARIABLE = string.Empty;
38      public string VarName { get; set; } // name of the variable for splitting or -1 if terminal node
39      public double Val { get; set; } // threshold
40      public int LeftIdx { get; set; }
41      public int RightIdx { get; set; }
42
43      public override int GetHashCode() {
44        return LeftIdx ^ RightIdx ^ Val.GetHashCode();
45      }
46    }
47
48    [Storable]
49    public readonly TreeNode[] tree;
50
51    [StorableConstructor]
52    private RegressionTreeModel(bool serializing) : base(serializing) { }
53    // cloning ctor
54    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
55      : base(original, cloner) {
56      this.tree = original.tree; // shallow clone, tree must be readonly
57    }
58
59    public RegressionTreeModel(TreeNode[] tree)
60      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
61      this.tree = tree;
62    }
63
64    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) {
65      var node = t[nodeIdx];
66      if (node.VarName == TreeNode.NO_VARIABLE)
67        return node.Val;
68      else if (ds.GetDoubleValue(node.VarName, row) <= node.Val)
69        return GetPredictionForRow(t, node.LeftIdx, ds, row);
70      else
71        return GetPredictionForRow(t, node.RightIdx, ds, row);
72    }
73
74    public override IDeepCloneable Clone(Cloner cloner) {
75      return new RegressionTreeModel(this, cloner);
76    }
77
78    public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
79      return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
80    }
81
82    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
83      return new RegressionSolution(this, new RegressionProblemData(problemData));
84    }
85  }
86
87}
Note: See TracBrowser for help on using the repository browser.