Free cookie consent management tool by TermsFeed Policy Generator

Changeset 12658


Ignore:
Timestamp:
07/07/15 17:42:24 (9 years ago)
Author:
gkronber
Message:

#2261: made TreeNode immutable and prevent change of TreeNode[] tree in RegressionTreeModel
ToString() override to make debugging easier and to enable inspection in unit test

Location:
branches/GBT-trunkintegration
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12632 r12658  
    207207        Debug.Assert(endIdx < internalIdx.Length);
    208208
    209         // transform the leaf node into an internal node
    210         tree[f.ParentNodeIdx].VarName = f.SplittingVariable;
    211         tree[f.ParentNodeIdx].Val = f.SplittingThreshold;
    212 
    213209        // split partition into left and right
    214210        int splitIdx;
     
    218214
    219215        // create two leaf nodes (and enqueue best splits for both)
    220         tree[f.ParentNodeIdx].LeftIdx = CreateLeafNode(startIdx, splitIdx, lineSearch);
    221         tree[f.ParentNodeIdx].RightIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch);
     216        var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, lineSearch);
     217        var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch);
     218
     219        // overwrite existing leaf node with an internal node
     220        tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx);
    222221      }
    223222    }
     
    226225    // returns the index of the newly created tree node
    227226    private int CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {
    228       tree[curTreeNodeIdx].VarName = RegressionTreeModel.TreeNode.NO_VARIABLE;
    229       tree[curTreeNodeIdx].Val = lineSearch(internalIdx, startIdx, endIdx);
     227      // write a leaf node
     228      var val = lineSearch(internalIdx, startIdx, endIdx);
     229      tree[curTreeNodeIdx] = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, val);
    230230
    231231      EnqueuePartitionSplit(curTreeNodeIdx, startIdx, endIdx);
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r12635 r12658  
    2121#endregion
    2222
     23using System;
    2324using System.Collections.Generic;
     25using System.Globalization;
    2426using System.Linq;
     27using System.Text;
    2528using HeuristicLab.Common;
    2629using HeuristicLab.Core;
     
    3134  [StorableClass]
    3235  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
    33   public class RegressionTreeModel : NamedItem, IRegressionModel {
     36  public sealed class RegressionTreeModel : NamedItem, IRegressionModel {
    3437
    3538    // trees are represented as a flat array
    36     public struct TreeNode {
     39    internal struct TreeNode {
    3740      public readonly static string NO_VARIABLE = string.Empty;
    38       public string VarName { get; set; } // name of the variable for splitting or NO_VARIABLE if terminal node
    39       public double Val { get; set; } // threshold
    40       public int LeftIdx { get; set; }
    41       public int RightIdx { get; set; }
    4241
     42      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1)
     43        : this() {
     44        VarName = varName;
     45        Val = val;
     46        LeftIdx = leftIdx;
     47        RightIdx = rightIdx;
     48      }
     49
     50      public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node
     51      public double Val { get; private set; } // threshold
     52      public int LeftIdx { get; private set; }
     53      public int RightIdx { get; private set; }
     54
     55      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
    4356      public override int GetHashCode() {
    4457        return LeftIdx ^ RightIdx ^ Val.GetHashCode();
     58      }
     59      // necessary because of GetHashCode override
     60      public override bool Equals(object obj) {
     61        if (obj is TreeNode) {
     62          var other = (TreeNode)obj;
     63          return Val.Equals(other.Val) &&
     64            VarName.Equals(other.VarName) &&
     65            LeftIdx.Equals(other.LeftIdx) &&
     66            RightIdx.Equals(other.RightIdx);
     67        } else {
     68          return false;
     69        }
    4570      }
    4671    }
    4772
    4873    [Storable]
    49     public readonly TreeNode[] tree;
     74    private readonly TreeNode[] tree;
    5075
    5176    [StorableConstructor]
    5277    private RegressionTreeModel(bool serializing) : base(serializing) { }
    5378    // cloning ctor
    54     public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
     79    private RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
    5580      : base(original, cloner) {
    5681      this.tree = original.tree; // shallow clone, tree must be readonly
    5782    }
    5883
    59     public RegressionTreeModel(TreeNode[] tree)
     84    internal RegressionTreeModel(TreeNode[] tree)
    6085      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
    6186      this.tree = tree;
     
    84109      return new RegressionSolution(this, new RegressionProblemData(problemData));
    85110    }
     111
     112    // mainly for debugging
     113    public override string ToString() {
     114      return TreeToString(0, "");
     115    }
     116
     117    private string TreeToString(int idx, string part) {
     118      var n = tree[idx];
     119      if (n.VarName == TreeNode.NO_VARIABLE) {
     120        return string.Format(CultureInfo.InvariantCulture, "{0} -> {1:F}{2}", part, n.Val, Environment.NewLine);
     121      } else {
     122        return
     123          TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val))
     124        + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val));
     125      }
     126    }
     127
    86128  }
    87129
  • branches/GBT-trunkintegration/Tests/Test.cs

    r12632 r12658  
    1111using Microsoft.VisualStudio.TestTools.UnitTesting;
    1212
    13 namespace HeuristicLab.Algorithms.DataAnalysis.GradientBoostedTrees {
     13namespace HeuristicLab.Algorithms.DataAnalysis {
    1414  [TestClass()]
    1515  public class Test {
     
    2121        var xy = new double[,]
    2222        {
    23           {1, 20, 0},
    24           {1, 20, 0},
    25           {2, 10, 0},
    26           {2, 10, 0},
    27         };
    28         var allVariables = new string[] { "y", "x1", "x2" };
    29 
    30         // x1 <= 15 -> 2
    31         // x1 >  15 -> 1
    32         BuildTree(xy, allVariables, 10);
    33       }
    34 
    35 
    36       {
    37         var xy = new double[,]
    38         {
    39           {1, 20,  1},
    40           {1, 20, -1},
    41           {2, 10, -1},
    42           {2, 10, 1},
     23          {-1, 20, 0},
     24          {-1, 20, 0},
     25          { 1, 10, 0},
     26          { 1, 10, 0},
     27        };
     28        var allVariables = new string[] { "y", "x1", "x2" };
     29
     30        // x1 <= 15 -> 1
     31        // x1 >  15 -> -1
     32        BuildTree(xy, allVariables, 10);
     33      }
     34
     35
     36      {
     37        var xy = new double[,]
     38        {
     39          {-1, 20,  1},
     40          {-1, 20, -1},
     41          { 1, 10, -1},
     42          { 1, 10, 1},
    4343        };
    4444        var allVariables = new string[] { "y", "x1", "x2" };
    4545
    4646        // ignore irrelevant variables
    47         // x1 <= 15 -> 2
    48         // x1 >  15 -> 1
     47        // x1 <= 15 -> 1
     48        // x1 >  15 -> -1
    4949        BuildTree(xy, allVariables, 10);
    5050      }
     
    5454        var xy = new double[,]
    5555        {
    56           {1, 20,  1},
    57           {2, 20, -1},
    58           {3, 10, -1},
    59           {4, 10, 1},
    60         };
    61 
    62         var allVariables = new string[] { "y", "x1", "x2" };
    63 
    64         // x1 <= 15 AND x2 <= 0 -> 3
    65         // x1 <= 15 AND x2 >  0 -> 4
    66         // x1 >  15 AND x2 <= 0 -> 2
    67         // x1 >  15 AND x2 >  0 -> 1
     56          {-2, 20,  1},
     57          {-1, 20, -1},
     58          { 1, 10, -1},
     59          { 2, 10, 1},
     60        };
     61
     62        var allVariables = new string[] { "y", "x1", "x2" };
     63
     64        // x1 <= 15 AND x2 <= 0 -> 1
     65        // x1 <= 15 AND x2 >  0 -> 2
     66        // x1 >  15 AND x2 <= 0 -> -1
     67        // x1 >  15 AND x2 >  0 -> -2
    6868        BuildTree(xy, allVariables, 10);
    6969      }
     
    7373        var xy = new double[,]
    7474        {
    75           {0.5, 20,  1},
    76           {1.5, 20,  1},
    77           {1.5, 20, -1},
    78           {2.5, 20, -1},
    79           {2.5, 10, -1},
    80           {3.5, 10, -1},
    81           {3.5, 10, 1},
    82           {4.5, 10, 1},
    83         };
    84 
    85         var allVariables = new string[] { "y", "x1", "x2" };
    86 
    87         // x1 <= 15 AND x2 <= 0 -> 3
    88         // x1 <= 15 AND x2 >  0 -> 4
    89         // x1 >  15 AND x2 <= 0 -> 2
    90         // x1 >  15 AND x2 >  0 -> 1
     75          {-2.5, 20,  1},
     76          {-1.5, 20,  1},
     77          {-1.5, 20, -1},
     78          {0.5, 20, -1},
     79          {0.5, 10, -1},
     80          {1.5, 10, -1},
     81          {1.5, 10, 1},
     82          {2.5, 10, 1},
     83        };
     84
     85        var allVariables = new string[] { "y", "x1", "x2" };
     86
     87        // x1 <= 15 AND x2 <= 0 -> 1
     88        // x1 <= 15 AND x2 >  0 -> 2
     89        // x1 >  15 AND x2 <= 0 -> -1
     90        // x1 >  15 AND x2 >  0 -> -2
    9191        BuildTree(xy, allVariables, 10);
    9292      }
     
    9797        var xy = new double[,]
    9898        {
    99           {10, 1, 1},
    100           {1, 1, 2},
    101           {1, 2, 1},
    102           {10, 2, 2},
     99          { 1, 1, 1},
     100          {-1, 1, 2},
     101          {-1, 2, 1},
     102          { 1, 2, 2},
    103103        };
    104104
     
    106106
    107107        // split cannot be found
    108         // -> 5.50
     108        // -> 0.0
    109109        BuildTree(xy, allVariables, 3);
    110110      }
     
    113113        var xy = new double[,]
    114114        {
    115           {10, 1, 1},
    116           {1, 1, 2},
    117           {1, 2, 1},
    118           {10.1, 2, 2},
     115          { 1, 1, 1},
     116          {-1, 1, 2},
     117          {-1, 2, 1},
     118          { 1.0001, 2, 2},
    119119        };
    120120
    121121        var allVariables = new string[] { "y", "x1", "x2" };
    122122        // (two possible solutions)
    123         // x2 <= 1.5 -> 5.50
    124         // x2 >  1.5 -> 5.55
     123        // x2 <= 1.5 -> 0
     124        // x2 >  1.5 -> 0 (not quite)
    125125        BuildTree(xy, allVariables, 3);
    126126
    127         // x1 <= 1.5 AND x2 <= 1.5 -> 10
    128         // x1 <= 1.5 AND x2 >  1.5 -> 1
    129         // x1 >  1.5 AND x2 <= 1.5 -> 1
    130         // x1 >  1.5 AND x2 >  1.5 -> 10.1
     127        // x1 <= 1.5 AND x2 <= 1.5 -> 1
     128        // x1 <= 1.5 AND x2 >  1.5 -> -1
     129        // x1 >  1.5 AND x2 <= 1.5 -> -1
     130        // x1 >  1.5 AND x2 >  1.5 -> 1 (not quite)
    131131        BuildTree(xy, allVariables, 7);
    132132      }
     
    155155          {-1, 1, 2},
    156156          {-1, 2, 1},
    157           { 1, 2, 2},
     157          { 3, 2, 2},
    158158        };
    159159
     
    162162        // x2 <= 1.5 -> -1.0
    163163        // x2 >  1.5 AND x1 <= 1.5 -> -1.0
    164         // x2 >  1.5 AND x1 >  1.5 ->  1.0
     164        // x2 >  1.5 AND x1 >  1.5 ->  3.0
    165165        BuildTree(xy, allVariables, 10);
    166166      }
     
    269269      var builder = new RegressionTreeBuilder(problemData, rand);
    270270      var model = (GradientBoostedTreesModel)builder.CreateRegressionTree(maxDepth, 1, 1); // maximal depth and use all rows and cols
    271       var constM = model.Models.First() as ConstantRegressionModel;
    272271      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
    273       WriteTree(treeM.tree, 0, "", constM.Constant);
     272
     273      Console.WriteLine(treeM.ToString());
    274274      Console.WriteLine();
    275     }
    276 
    277     private void WriteTree(RegressionTreeModel.TreeNode[] tree, int idx, string partialRule, double offset) {
    278       var n = tree[idx];
    279       if (n.VarName == RegressionTreeModel.TreeNode.NO_VARIABLE) {
    280         Console.WriteLine("{0} -> {1:F}", partialRule, n.Val + offset);
    281       } else {
    282         WriteTree(tree, n.LeftIdx,
    283           string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}",
    284           partialRule,
    285           string.IsNullOrEmpty(partialRule) ? "" : " and ",
    286           n.VarName,
    287           n.Val), offset);
    288         WriteTree(tree, n.RightIdx,
    289           string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}",
    290           partialRule,
    291           string.IsNullOrEmpty(partialRule) ? "" : " and ",
    292           n.VarName,
    293           n.Val), offset);
    294       }
    295275    }
    296276    #endregion
Note: See TracChangeset for help on using the changeset viewer.