Index: /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/GradientBoostedTreesModelView.Designer.cs
===================================================================
--- /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/GradientBoostedTreesModelView.Designer.cs (revision 12372)
+++ /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/GradientBoostedTreesModelView.Designer.cs (revision 12372)
@@ -0,0 +1,114 @@
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using HeuristicLab.Optimization.Views;
+
+namespace HeuristicLab.Algorithms.DataAnalysis.Views {
+ partial class GradientBoostedTreesModelView {
+ ///
+ /// Required designer variable.
+ ///
+ private System.ComponentModel.IContainer components = null;
+
+ ///
+ /// Clean up any resources being used.
+ ///
+ /// true if managed resources should be disposed; otherwise, false.
+ protected override void Dispose(bool disposing) {
+ if (disposing && (components != null)) {
+ components.Dispose();
+ }
+ base.Dispose(disposing);
+ }
+
+ #region Component Designer generated code
+
+ ///
+ /// Required method for Designer support - do not modify
+ /// the contents of this method with the code editor.
+ ///
+ private void InitializeComponent() {
+ this.listBox = new System.Windows.Forms.ListBox();
+ this.modelsGroupBox = new System.Windows.Forms.GroupBox();
+ this.viewHost = new HeuristicLab.MainForm.WindowsForms.ViewHost();
+ this.modelsGroupBox.SuspendLayout();
+ this.SuspendLayout();
+ //
+ // listBox
+ //
+ this.listBox.Dock = System.Windows.Forms.DockStyle.Fill;
+ this.listBox.FormattingEnabled = true;
+ this.listBox.Location = new System.Drawing.Point(3, 16);
+ this.listBox.Name = "listBox";
+ this.listBox.Size = new System.Drawing.Size(153, 387);
+ this.listBox.TabIndex = 0;
+ this.listBox.SelectedIndexChanged += new System.EventHandler(this.listBox_SelectedIndexChanged);
+ //
+ // modelsGroupBox
+ //
+ this.modelsGroupBox.Anchor = ((System.Windows.Forms.AnchorStyles)(((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Bottom)
+ | System.Windows.Forms.AnchorStyles.Left)));
+ this.modelsGroupBox.Controls.Add(this.listBox);
+ this.modelsGroupBox.Location = new System.Drawing.Point(3, 3);
+ this.modelsGroupBox.Name = "modelsGroupBox";
+ this.modelsGroupBox.Size = new System.Drawing.Size(159, 406);
+ this.modelsGroupBox.TabIndex = 1;
+ this.modelsGroupBox.TabStop = false;
+ this.modelsGroupBox.Text = "Trees";
+ //
+ // viewHost
+ //
+ this.viewHost.Anchor = ((System.Windows.Forms.AnchorStyles)((((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Bottom)
+ | System.Windows.Forms.AnchorStyles.Left)
+ | System.Windows.Forms.AnchorStyles.Right)));
+ this.viewHost.Caption = "View";
+ this.viewHost.Content = null;
+ this.viewHost.Enabled = false;
+ this.viewHost.Location = new System.Drawing.Point(168, 3);
+ this.viewHost.Name = "viewHost";
+ this.viewHost.ReadOnly = false;
+ this.viewHost.Size = new System.Drawing.Size(138, 406);
+ this.viewHost.TabIndex = 2;
+ this.viewHost.ViewsLabelVisible = true;
+ this.viewHost.ViewType = null;
+ //
+ // GradientBoostedTreesModelView
+ //
+ this.AllowDrop = true;
+ this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Inherit;
+ this.Controls.Add(this.viewHost);
+ this.Controls.Add(this.modelsGroupBox);
+ this.Name = "GradientBoostedTreesModelView";
+ this.Size = new System.Drawing.Size(309, 412);
+ this.modelsGroupBox.ResumeLayout(false);
+ this.ResumeLayout(false);
+
+ }
+
+ #endregion
+
+ private System.Windows.Forms.ListBox listBox;
+ private System.Windows.Forms.GroupBox modelsGroupBox;
+ private MainForm.WindowsForms.ViewHost viewHost;
+
+
+ }
+}
Index: /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/GradientBoostedTreesModelView.cs
===================================================================
--- /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/GradientBoostedTreesModelView.cs (revision 12372)
+++ /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/GradientBoostedTreesModelView.cs (revision 12372)
@@ -0,0 +1,71 @@
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using System.Linq;
+using System.Windows.Forms;
+using GradientBoostedTrees;
+using HeuristicLab.MainForm;
+using HeuristicLab.MainForm.WindowsForms;
+using HeuristicLab.Problems.DataAnalysis;
+
+namespace HeuristicLab.Algorithms.DataAnalysis.Views {
+ [View("Gradient boosted trees model")]
+ [Content(typeof(GradientBoostedTreesModel), true)]
+ public partial class GradientBoostedTreesModelView : AsynchronousContentView {
+ public new GradientBoostedTreesModel Content {
+ get { return (GradientBoostedTreesModel)base.Content; }
+ set {
+ base.Content = value;
+ }
+ }
+
+ public GradientBoostedTreesModelView()
+ : base() {
+ InitializeComponent();
+ viewHost.ViewsLabelVisible = false;
+ }
+
+ protected override void OnContentChanged() {
+ base.OnContentChanged();
+ if (Content == null) {
+ // clear
+ viewHost.Content = null;
+ listBox.Items.Clear();
+ } else {
+ PopulateModelsList();
+ }
+ }
+
+ private void PopulateModelsList() {
+ listBox.BeginUpdate();
+ listBox.Items.Clear();
+ listBox.Items.AddRange(Content.Models.ToArray());
+ listBox.EndUpdate();
+ }
+
+ private void listBox_SelectedIndexChanged(object sender, System.EventArgs e) {
+ var idx = listBox.SelectedIndex;
+ viewHost.Content = null;
+ if (idx < 0) return;
+ viewHost.Content = (IRegressionModel)listBox.Items[idx];
+ }
+ }
+}
Index: /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/HeuristicLab.Algorithms.DataAnalysis.Views-3.4.csproj
===================================================================
--- /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/HeuristicLab.Algorithms.DataAnalysis.Views-3.4.csproj (revision 12371)
+++ /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/HeuristicLab.Algorithms.DataAnalysis.Views-3.4.csproj (revision 12372)
@@ -125,4 +125,16 @@
+
+ UserControl
+
+
+ RegressionTreeModelView.cs
+
+
+ UserControl
+
+
+ GradientBoostedTreesModelView.cs
+
UserControl
@@ -244,4 +256,12 @@
HeuristicLab.Data-3.3
False
+
+
+ {423bd94f-963a-438e-ba45-3bb3d61cd03b}
+ HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Views-3.4
+
+
+ {06D4A186-9319-48A0-BADE-A2058D462EEA}
+ HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4
Index: /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/RegressionTreeModelView.Designer.cs
===================================================================
--- /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/RegressionTreeModelView.Designer.cs (revision 12372)
+++ /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/RegressionTreeModelView.Designer.cs (revision 12372)
@@ -0,0 +1,85 @@
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using HeuristicLab.Optimization.Views;
+
+namespace HeuristicLab.Algorithms.DataAnalysis.Views {
+ partial class RegressionTreeModelView {
+ ///
+ /// Required designer variable.
+ ///
+ private System.ComponentModel.IContainer components = null;
+
+ ///
+ /// Clean up any resources being used.
+ ///
+ /// true if managed resources should be disposed; otherwise, false.
+ protected override void Dispose(bool disposing) {
+ if (disposing && (components != null)) {
+ components.Dispose();
+ }
+ base.Dispose(disposing);
+ }
+
+ #region Component Designer generated code
+
+ ///
+ /// Required method for Designer support - do not modify
+ /// the contents of this method with the code editor.
+ ///
+ private void InitializeComponent() {
+ this.components = new System.ComponentModel.Container();
+ this.symbolicExpressionTreeChart = new HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Views.SymbolicExpressionTreeChart();
+ this.SuspendLayout();
+ //
+ // symbolicExpressionTreeChart
+ //
+ this.symbolicExpressionTreeChart.BackgroundColor = System.Drawing.Color.White;
+ this.symbolicExpressionTreeChart.Dock = System.Windows.Forms.DockStyle.Fill;
+ this.symbolicExpressionTreeChart.LineColor = System.Drawing.Color.Black;
+ this.symbolicExpressionTreeChart.Location = new System.Drawing.Point(0, 0);
+ this.symbolicExpressionTreeChart.Name = "symbolicExpressionTreeChart";
+ this.symbolicExpressionTreeChart.Size = new System.Drawing.Size(246, 242);
+ this.symbolicExpressionTreeChart.Spacing = 5;
+ this.symbolicExpressionTreeChart.SuspendRepaint = false;
+ this.symbolicExpressionTreeChart.TabIndex = 0;
+ this.symbolicExpressionTreeChart.TextFont = new System.Drawing.Font("Microsoft Sans Serif", 12F);
+ this.symbolicExpressionTreeChart.Tree = null;
+ //
+ // RegressionTreeModelView
+ //
+ this.AllowDrop = true;
+ this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Inherit;
+ this.Controls.Add(this.symbolicExpressionTreeChart);
+ this.Name = "RegressionTreeModelView";
+ this.Size = new System.Drawing.Size(246, 242);
+ this.ResumeLayout(false);
+
+ }
+
+ #endregion
+
+ private Encodings.SymbolicExpressionTreeEncoding.Views.SymbolicExpressionTreeChart symbolicExpressionTreeChart;
+
+
+
+ }
+}
Index: /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/RegressionTreeModelView.cs
===================================================================
--- /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/RegressionTreeModelView.cs (revision 12372)
+++ /branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/RegressionTreeModelView.cs (revision 12372)
@@ -0,0 +1,184 @@
+#region License Information
+/* HeuristicLab
+ * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
+ *
+ * This file is part of HeuristicLab.
+ *
+ * HeuristicLab is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * HeuristicLab is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with HeuristicLab. If not, see .
+ */
+#endregion
+
+using System.Drawing;
+using System.Windows.Forms;
+using GradientBoostedTrees;
+using HeuristicLab.Common;
+using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
+using HeuristicLab.MainForm;
+using HeuristicLab.MainForm.WindowsForms;
+using HeuristicLab.PluginInfrastructure;
+
+namespace HeuristicLab.Algorithms.DataAnalysis.Views {
+ [View("Regression tree model")]
+ [Content(typeof(RegressionTreeModel), true)]
+ public partial class RegressionTreeModelView : AsynchronousContentView {
+ public new RegressionTreeModel Content {
+ get { return (RegressionTreeModel)base.Content; }
+ set {
+ base.Content = value;
+ }
+ }
+
+ public RegressionTreeModelView()
+ : base() {
+ InitializeComponent();
+ symbolicExpressionTreeChart.TextFont = new System.Drawing.Font(FontFamily.GenericSerif, 8F);
+ }
+
+ protected override void OnContentChanged() {
+ base.OnContentChanged();
+ if (Content == null) {
+ // clear
+ symbolicExpressionTreeChart.Tree = null;
+ } else {
+ symbolicExpressionTreeChart.Tree = CreateSymbTree(Content);
+ }
+ }
+
+ private ISymbolicExpressionTree CreateSymbTree(RegressionTreeModel regressionTreeModel) {
+ int treeNodeIdx = 0;
+ return new SymbolicExpressionTree(CreateSymbTree(regressionTreeModel.tree, ref treeNodeIdx));
+ }
+
+ private ISymbolicExpressionTreeNode CreateSymbTree(RegressionTreeModel.TreeNode[] tree, ref int treeNodeIdx) {
+ var node = tree[treeNodeIdx];
+ treeNodeIdx++;
+
+ if (node.varName != RegressionTreeModel.TreeNode.NO_VARIABLE) {
+ var treeNode = (DecisionTreeNode)decisionSy.CreateTreeNode();
+ treeNode.VariableName = node.varName;
+ treeNode.Threshold = node.val;
+ treeNode.AddSubtree(CreateSymbTree(tree, ref treeNodeIdx));
+ treeNode.AddSubtree(CreateSymbTree(tree, ref treeNodeIdx));
+ return treeNode;
+ } else {
+ var treeNode = (ConstantTreeNode)constantSy.CreateTreeNode();
+ treeNode.Value = node.val;
+ return treeNode;
+ }
+ }
+
+ #region helper types for symbols
+ private static readonly Decision decisionSy = new Decision();
+ private static readonly Constant constantSy = new Constant();
+
+
+ // TODO use simple symbols
+ [NonDiscoverableType]
+ public sealed class Decision : Symbol {
+ private const int minimumArity = 2;
+ private const int maximumArity = 2;
+
+ public override int MinimumArity {
+ get { return minimumArity; }
+ }
+ public override int MaximumArity {
+ get { return maximumArity; }
+ }
+
+ public Decision() : base("Decision", "") { }
+ public override IDeepCloneable Clone(Cloner cloner) {
+ throw new System.NotImplementedException();
+ }
+ public override ISymbolicExpressionTreeNode CreateTreeNode() {
+ return new DecisionTreeNode(this);
+ }
+ }
+ [NonDiscoverableType]
+ public sealed class DecisionTreeNode : SymbolicExpressionTreeNode {
+ public new Decision Symbol {
+ get { return (Decision)base.Symbol; }
+ }
+
+ public double Threshold { get; set; }
+ public string VariableName { get; set; }
+
+ public DecisionTreeNode(Decision symbol) : base(symbol) { }
+
+ public override bool HasLocalParameters {
+ get {
+ return true;
+ }
+ }
+ public override IDeepCloneable Clone(Cloner cloner) {
+ throw new System.NotImplementedException();
+ }
+
+ public override string ToString() {
+ return string.Format("{0} <= {1:E4}", VariableName, Threshold);
+ }
+ }
+
+ [NonDiscoverableType]
+ public sealed class Constant : Symbol {
+ private const int minimumArity = 0;
+ private const int maximumArity = 0;
+
+ public override int MinimumArity {
+ get { return minimumArity; }
+ }
+ public override int MaximumArity {
+ get { return maximumArity; }
+ }
+
+ public Constant() : base("Constant", "") { }
+ public override IDeepCloneable Clone(Cloner cloner) {
+ throw new System.NotImplementedException();
+ }
+ public override ISymbolicExpressionTreeNode CreateTreeNode() {
+ return new ConstantTreeNode(this);
+ }
+ }
+
+ [NonDiscoverableType]
+ public sealed class ConstantTreeNode : SymbolicExpressionTreeTerminalNode {
+ public new Constant Symbol {
+ get { return (Constant)base.Symbol; }
+ }
+
+ private double constantValue;
+ public double Value {
+ get { return constantValue; }
+ set { constantValue = value; }
+ }
+
+ private ConstantTreeNode() : base() { }
+ public ConstantTreeNode(Constant constantSymbol) : base(constantSymbol) { }
+
+ public override bool HasLocalParameters {
+ get {
+ return true;
+ }
+ }
+ public override IDeepCloneable Clone(Cloner cloner) {
+ throw new System.NotImplementedException();
+ }
+
+ public override string ToString() {
+ return constantValue.ToString("E4");
+ }
+ }
+
+ #endregion
+ }
+}
Index: /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs
===================================================================
--- /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs (revision 12371)
+++ /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs (revision 12372)
@@ -8,12 +8,15 @@
namespace GradientBoostedTrees {
- [Item("GradientBoostedTreesSolution", "")]
[StorableClass]
+ [Item("Gradient boosted tree model", "")]
public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {
[Storable]
private readonly IList models;
+ public IEnumerable Models { get { return models; } }
+
[Storable]
private readonly IList weights;
+ public IEnumerable Weights { get { return weights; } }
[StorableConstructor]
@@ -25,5 +28,5 @@
}
public GradientBoostedTreesModel(IEnumerable models, IEnumerable weights)
- : base() {
+ : base("Gradient boosted tree model", string.Empty) {
this.models = new List(models);
this.weights = new List(weights);
Index: /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
===================================================================
--- /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs (revision 12371)
+++ /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs (revision 12372)
@@ -33,4 +33,8 @@
private readonly double[] outx;
private readonly int[] outSortedIdx;
+
+ private RegressionTreeModel.TreeNode[] tree; // tree is represented as a flat array of nodes
+ private int curTreeNodeIdx; // the index where the next tree node is stored
+
private readonly IList nodeQueue; //TODO
@@ -128,22 +132,29 @@
}
}
+
+ // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes)
+ int numNodes = (int)Math.Pow(2, maxDepth) - 1;
+ //this.tree = new RegressionTreeModel.TreeNode[numNodes];
+ this.tree = Enumerable.Range(0, numNodes).Select(_=>new RegressionTreeModel.TreeNode()).ToArray();
+ this.curTreeNodeIdx = 0;
+
// start and end idx are inclusive
- var tree = CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);
+ CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);
return new RegressionTreeModel(tree);
}
// startIdx and endIdx are inclusive
- private RegressionTreeModel.TreeNode CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {
+ private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {
Contract.Assert(endIdx - startIdx >= 0);
Contract.Assert(startIdx >= 0);
Contract.Assert(endIdx < internalIdx.Length);
- RegressionTreeModel.TreeNode t;
// TODO: stop when y is constant
// TODO: use priority queue of nodes to be expanded (sorted by improvement) instead of the recursion to maximum depth
if (maxDepth <= 1 || endIdx - startIdx == 0) {
- // max depth reached or only one element
- t = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx));
- return t;
+ // max depth reached or only one element
+ tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
+ tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
+ curTreeNodeIdx++;
} else {
int i, j;
@@ -154,5 +165,8 @@
// if bestVariableName is NO_VARIABLE then no split was possible anymore
if (bestVariableName == RegressionTreeModel.TreeNode.NO_VARIABLE) {
- return new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx));
+ // max depth reached or only one element
+ tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
+ tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
+ curTreeNodeIdx++;
} else {
@@ -214,10 +228,16 @@
Debug.Assert(j <= endIdx);
- t = new RegressionTreeModel.TreeNode(bestVariableName,
- threshold,
- CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch),
- CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch));
-
- return t;
+ var parentIdx = curTreeNodeIdx;
+ tree[parentIdx].varName = bestVariableName;
+ tree[parentIdx].val = threshold;
+ curTreeNodeIdx++;
+
+ // create left subtree
+ tree[parentIdx].leftIdx = curTreeNodeIdx;
+ CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch);
+
+ // create right subtree
+ tree[parentIdx].rightIdx = curTreeNodeIdx;
+ CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch);
}
}
@@ -272,4 +292,5 @@
// assumption is that the Average(y) = 0
private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {
+ if (string.IsNullOrEmpty(bestVar)) return;
// update variable relevance
double err = sumY * sumY / rows;
Index: /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
===================================================================
--- /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs (revision 12371)
+++ /branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs (revision 12372)
@@ -1,3 +1,4 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
@@ -12,29 +13,29 @@
public class RegressionTreeModel : NamedItem, IRegressionModel {
+ // trees are represented as a flat array
+ // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
[StorableClass]
public class TreeNode {
public readonly static string NO_VARIABLE = string.Empty;
[Storable]
- public readonly string varName; // name of the variable for splitting or -1 if terminal node
+ public string varName; // name of the variable for splitting or -1 if terminal node
[Storable]
- public readonly double val; // threshold
+ public double val; // threshold
[Storable]
- public readonly TreeNode left;
+ public int leftIdx;
[Storable]
- public readonly TreeNode right;
+ public int rightIdx;
+ public TreeNode() {
+ varName = NO_VARIABLE;
+ leftIdx = -1;
+ rightIdx = -1;
+ }
[StorableConstructor]
private TreeNode(bool deserializing) { }
-
- public TreeNode(string varName, double value, TreeNode left = null, TreeNode right = null) {
- this.varName = varName;
- this.val = value;
- this.left = left;
- this.right = right;
- }
}
[Storable]
- public readonly TreeNode tree;
+ public readonly TreeNode[] tree;
[StorableConstructor]
@@ -43,22 +44,20 @@
public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
: base(original, cloner) {
- this.tree = original.tree;
+ this.tree = original.tree; // shallow clone, tree must be readonly
}
- public RegressionTreeModel(TreeNode tree)
- : base() {
- this.name = ItemName;
- this.description = ItemDescription;
-
+ public RegressionTreeModel(TreeNode[] tree)
+ : base("RegressionTreeModel", "Represents a decision tree for regression.") {
this.tree = tree;
}
- private static double GetPredictionForRow(TreeNode t, Dataset ds, int row) {
- if (t.varName == TreeNode.NO_VARIABLE)
- return t.val;
- else if (ds.GetDoubleValue(t.varName, row) <= t.val)
- return GetPredictionForRow(t.left, ds, row);
+ private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, Dataset ds, int row) {
+ var node = t[nodeIdx];
+ if (node.varName == TreeNode.NO_VARIABLE)
+ return node.val;
+ else if (ds.GetDoubleValue(node.varName, row) <= node.val)
+ return GetPredictionForRow(t, node.leftIdx, ds, row);
else
- return GetPredictionForRow(t.right, ds, row);
+ return GetPredictionForRow(t, node.rightIdx, ds, row);
}
@@ -68,5 +67,5 @@
public IEnumerable GetEstimatedValues(Dataset ds, IEnumerable rows) {
- return rows.Select(r => GetPredictionForRow(tree, ds, r));
+ return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
}