Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/07/15 17:42:24 (9 years ago)
Author:
gkronber
Message:

#2261: made TreeNode immutable and prevent change of TreeNode[] tree in RegressionTreeModel
ToString() override to make debugging easier and to enable inspection in unit test

Location:
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12632 r12658  
    207207        Debug.Assert(endIdx < internalIdx.Length);
    208208
    209         // transform the leaf node into an internal node
    210         tree[f.ParentNodeIdx].VarName = f.SplittingVariable;
    211         tree[f.ParentNodeIdx].Val = f.SplittingThreshold;
    212 
    213209        // split partition into left and right
    214210        int splitIdx;
     
    218214
    219215        // create two leaf nodes (and enqueue best splits for both)
    220         tree[f.ParentNodeIdx].LeftIdx = CreateLeafNode(startIdx, splitIdx, lineSearch);
    221         tree[f.ParentNodeIdx].RightIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch);
     216        var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, lineSearch);
     217        var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch);
     218
     219        // overwrite existing leaf node with an internal node
     220        tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx);
    222221      }
    223222    }
     
    226225    // returns the index of the newly created tree node
    227226    private int CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {
    228       tree[curTreeNodeIdx].VarName = RegressionTreeModel.TreeNode.NO_VARIABLE;
    229       tree[curTreeNodeIdx].Val = lineSearch(internalIdx, startIdx, endIdx);
     227      // write a leaf node
     228      var val = lineSearch(internalIdx, startIdx, endIdx);
     229      tree[curTreeNodeIdx] = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, val);
    230230
    231231      EnqueuePartitionSplit(curTreeNodeIdx, startIdx, endIdx);
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r12635 r12658  
    2121#endregion
    2222
     23using System;
    2324using System.Collections.Generic;
     25using System.Globalization;
    2426using System.Linq;
     27using System.Text;
    2528using HeuristicLab.Common;
    2629using HeuristicLab.Core;
     
    3134  [StorableClass]
    3235  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
    33   public class RegressionTreeModel : NamedItem, IRegressionModel {
     36  public sealed class RegressionTreeModel : NamedItem, IRegressionModel {
    3437
    3538    // trees are represented as a flat array
    36     public struct TreeNode {
     39    internal struct TreeNode {
    3740      public readonly static string NO_VARIABLE = string.Empty;
    38       public string VarName { get; set; } // name of the variable for splitting or NO_VARIABLE if terminal node
    39       public double Val { get; set; } // threshold
    40       public int LeftIdx { get; set; }
    41       public int RightIdx { get; set; }
    4241
     42      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1)
     43        : this() {
     44        VarName = varName;
     45        Val = val;
     46        LeftIdx = leftIdx;
     47        RightIdx = rightIdx;
     48      }
     49
     50      public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node
     51      public double Val { get; private set; } // threshold
     52      public int LeftIdx { get; private set; }
     53      public int RightIdx { get; private set; }
     54
     55      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
    4356      public override int GetHashCode() {
    4457        return LeftIdx ^ RightIdx ^ Val.GetHashCode();
     58      }
     59      // necessary because of GetHashCode override
     60      public override bool Equals(object obj) {
     61        if (obj is TreeNode) {
     62          var other = (TreeNode)obj;
     63          return Val.Equals(other.Val) &&
     64            VarName.Equals(other.VarName) &&
     65            LeftIdx.Equals(other.LeftIdx) &&
     66            RightIdx.Equals(other.RightIdx);
     67        } else {
     68          return false;
     69        }
    4570      }
    4671    }
    4772
    4873    [Storable]
    49     public readonly TreeNode[] tree;
     74    private readonly TreeNode[] tree;
    5075
    5176    [StorableConstructor]
    5277    private RegressionTreeModel(bool serializing) : base(serializing) { }
    5378    // cloning ctor
    54     public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
     79    private RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
    5580      : base(original, cloner) {
    5681      this.tree = original.tree; // shallow clone, tree must be readonly
    5782    }
    5883
    59     public RegressionTreeModel(TreeNode[] tree)
     84    internal RegressionTreeModel(TreeNode[] tree)
    6085      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
    6186      this.tree = tree;
     
    84109      return new RegressionSolution(this, new RegressionProblemData(problemData));
    85110    }
     111
     112    // mainly for debugging
     113    public override string ToString() {
     114      return TreeToString(0, "");
     115    }
     116
     117    private string TreeToString(int idx, string part) {
     118      var n = tree[idx];
     119      if (n.VarName == TreeNode.NO_VARIABLE) {
     120        return string.Format(CultureInfo.InvariantCulture, "{0} -> {1:F}{2}", part, n.Val, Environment.NewLine);
     121      } else {
     122        return
     123          TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val))
     124        + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val));
     125      }
     126    }
     127
    86128  }
    87129
Note: See TracChangeset for help on using the changeset viewer.