Ignore:
Timestamp:
06/29/16 10:36:52 (5 years ago)
Author:
pfleck
Message:

#2597

  • Merged recent trunk changes.
  • Adapted VariablesUsedForPrediction property for RegressionSolutionTargetResponseGradientView.
  • Fixed a reference (.dll to project ref).
Location:
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis

  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r13030 r13948  
    3434  [StorableClass]
    3535  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
    36   public sealed class RegressionTreeModel : NamedItem, IRegressionModel {
     36  public sealed class RegressionTreeModel : RegressionModel {
     37    public override IEnumerable<string> VariablesUsedForPrediction {
     38      get { return tree.Select(t => t.VarName).Where(v => v != TreeNode.NO_VARIABLE); }
     39    }
    3740
    3841    // trees are represented as a flat array   
     
    4043      public readonly static string NO_VARIABLE = null;
    4144
    42       public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1)
     45      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0)
    4346        : this() {
    4447        VarName = varName;
     
    4649        LeftIdx = leftIdx;
    4750        RightIdx = rightIdx;
    48       }
    49 
    50       public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node
    51       public double Val { get; private set; } // threshold
    52       public int LeftIdx { get; private set; }
    53       public int RightIdx { get; private set; }
     51        WeightLeft = weightLeft;
     52      }
     53
     54      public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node
     55      public double Val { get; internal set; } // threshold
     56      public int LeftIdx { get; internal set; }
     57      public int RightIdx { get; internal set; }
     58      public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree
     59
    5460
    5561      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
     
    6470            LeftIdx.Equals(other.LeftIdx) &&
    6571            RightIdx.Equals(other.RightIdx) &&
     72            WeightLeft.Equals(other.WeightLeft) &&
    6673            EqualStrings(VarName, other.VarName);
    6774        } else {
     
    7986    private TreeNode[] tree;
    8087
    81     [Storable]
     88    #region old storable format
     89    // remove with HL 3.4
     90    [Storable(AllowOneWay = true)]
    8291    // to prevent storing the references to data caches in nodes
    83     // TODO seemingly it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
     92    // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
    8493    private Tuple<string, double, int, int>[] SerializedTree {
    85       get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
    86       set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); }
    87     }
     94      // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
     95      set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models
     96    }
     97    #endregion
     98    #region new storable format
     99    [Storable]
     100    private string[] SerializedTreeVarNames {
     101      get { return tree.Select(t => t.VarName).ToArray(); }
     102      set {
     103        if (tree == null) tree = new TreeNode[value.Length];
     104        for (int i = 0; i < value.Length; i++) {
     105          tree[i].VarName = value[i];
     106        }
     107      }
     108    }
     109    [Storable]
     110    private double[] SerializedTreeValues {
     111      get { return tree.Select(t => t.Val).ToArray(); }
     112      set {
     113        if (tree == null) tree = new TreeNode[value.Length];
     114        for (int i = 0; i < value.Length; i++) {
     115          tree[i].Val = value[i];
     116        }
     117      }
     118    }
     119    [Storable]
     120    private int[] SerializedTreeLeftIdx {
     121      get { return tree.Select(t => t.LeftIdx).ToArray(); }
     122      set {
     123        if (tree == null) tree = new TreeNode[value.Length];
     124        for (int i = 0; i < value.Length; i++) {
     125          tree[i].LeftIdx = value[i];
     126        }
     127      }
     128    }
     129    [Storable]
     130    private int[] SerializedTreeRightIdx {
     131      get { return tree.Select(t => t.RightIdx).ToArray(); }
     132      set {
     133        if (tree == null) tree = new TreeNode[value.Length];
     134        for (int i = 0; i < value.Length; i++) {
     135          tree[i].RightIdx = value[i];
     136        }
     137      }
     138    }
     139    [Storable]
     140    private double[] SerializedTreeWeightLeft {
     141      get { return tree.Select(t => t.WeightLeft).ToArray(); }
     142      set {
     143        if (tree == null) tree = new TreeNode[value.Length];
     144        for (int i = 0; i < value.Length; i++) {
     145          tree[i].WeightLeft = value[i];
     146        }
     147      }
     148    }
     149    #endregion
     150
     151
     152
    88153
    89154    [StorableConstructor]
     
    98163    }
    99164
    100     internal RegressionTreeModel(TreeNode[] tree)
    101       : base("RegressionTreeModel", "Represents a decision tree for regression.") {
     165    internal RegressionTreeModel(TreeNode[] tree, string target = "Target")
     166      : base(target, "RegressionTreeModel", "Represents a decision tree for regression.") {
    102167      this.tree = tree;
    103168    }
     
    108173        if (node.VarName == TreeNode.NO_VARIABLE)
    109174          return node.Val;
    110 
    111         if (columnCache[nodeIdx][row] <= node.Val)
     175        if (columnCache[nodeIdx] == null) {
     176          if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab.");
     177          // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees)
     178          return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) +
     179                 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row);
     180        } else if (columnCache[nodeIdx][row] <= node.Val)
    112181          nodeIdx = node.LeftIdx;
    113182        else
     
    121190    }
    122191
    123     public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
     192    public override IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
    124193      // lookup columns for variableNames in one pass over the tree to speed up evaluation later on
    125194      ReadOnlyCollection<double>[] columnCache = new ReadOnlyCollection<double>[tree.Length];
     
    127196      for (int i = 0; i < tree.Length; i++) {
    128197        if (tree[i].VarName != TreeNode.NO_VARIABLE) {
    129           columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
     198          // tree models also support calculating estimations if not all variables used for training are available in the dataset
     199          if (ds.ColumnNames.Contains(tree[i].VarName))
     200            columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
    130201        }
    131202      }
     
    133204    }
    134205
    135     public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     206    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    136207      return new RegressionSolution(this, new RegressionProblemData(problemData));
    137208    }
     
    148219      } else {
    149220        return
    150           TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val))
    151         + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val));
    152       }
    153     }
     221          TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft))
     222        + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2}  >  {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft));
     223      }
     224    }
     225
    154226  }
    155227}
Note: See TracChangeset for help on using the changeset viewer.