Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/29/16 10:36:52 (8 years ago)
Author:
pfleck
Message:

#2597

  • Merged recent trunk changes.
  • Adapted VariablesUsedForPrediction property for RegressionSolutionTargetResponseGradientView.
  • Fixed a reference (.dll to project ref).
Location:
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis

  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs

    r13157 r13948  
    3333  [Item("Gradient boosted tree model", "")]
    3434  // this is essentially a collection of weighted regression models
    35   public sealed class GradientBoostedTreesModel : NamedItem, IGradientBoostedTreesModel {
     35  public sealed class GradientBoostedTreesModel : RegressionModel, IGradientBoostedTreesModel {
    3636    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
    3737    #region Backwards compatible code, remove with 3.5
     
    5858    #endregion
    5959
     60    public override IEnumerable<string> VariablesUsedForPrediction {
     61      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
     62    }
     63
    6064    private readonly IList<IRegressionModel> models;
    6165    public IEnumerable<IRegressionModel> Models { get { return models; } }
     
    7781    }
    7882    [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")]
    79     public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
    80       : base("Gradient boosted tree model", string.Empty) {
     83    internal GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
     84      : base(string.Empty, "Gradient boosted tree model", string.Empty) {
    8185      this.models = new List<IRegressionModel>(models);
    8286      this.weights = new List<double>(weights);
     
    8993    }
    9094
    91     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     95    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    9296      // allocate target array go over all models and add up weighted estimation for each row
    9397      if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable.
     
    105109    }
    106110
    107     public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     111    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    108112      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
    109113    }
     114
    110115  }
    111116}
  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs

    r13157 r13948  
    2222
    2323using System.Collections.Generic;
     24using System.Linq;
    2425using HeuristicLab.Common;
    2526using HeuristicLab.Core;
     
    3334  // recalculate the actual GBT model on demand
    3435  [Item("Gradient boosted tree model", "")]
    35   public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IGradientBoostedTreesModel {
     36  public sealed class GradientBoostedTreesModelSurrogate : RegressionModel, IGradientBoostedTreesModel {
    3637    // don't store the actual model!
    3738    private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary
     
    5556
    5657
     58    public override IEnumerable<string> VariablesUsedForPrediction {
     59      get { return actualModel.Models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
     60    }
     61
    5762    [StorableConstructor]
    5863    private GradientBoostedTreesModelSurrogate(bool deserializing) : base(deserializing) { }
     
    7378
    7479    // create only the surrogate model without an actual model
    75     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)
    76       : base("Gradient boosted tree model", string.Empty) {
     80    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed,
     81      ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)
     82      : base(trainingProblemData.TargetVariable, "Gradient boosted tree model", string.Empty) {
    7783      this.trainingProblemData = trainingProblemData;
    7884      this.seed = seed;
     
    8692
    8793    // wrap an actual model in a surrograte
    88     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IGradientBoostedTreesModel model)
     94    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed,
     95      ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu,
     96      IGradientBoostedTreesModel model)
    8997      : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) {
    9098      this.actualModel = model;
     
    96104
    97105    // forward message to actual model (recalculate model first if necessary)
    98     public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
     106    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    99107      if (actualModel == null) actualModel = RecalculateModel();
    100108      return actualModel.GetEstimatedValues(dataset, rows);
    101109    }
    102110
    103     public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     111    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    104112      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
    105113    }
    106 
    107114
    108115    private IGradientBoostedTreesModel RecalculateModel() {
  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r13065 r13948  
    180180
    181181
    182     // processes potential splits from the queue as long as splits are left and the maximum size of the tree is not reached
     182    // processes potential splits from the queue as long as splits are remaining and the maximum size of the tree is not reached
    183183    private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) {
    184184      while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop
     
    204204
    205205        // overwrite existing leaf node with an internal node
    206         tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx);
     206        tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx, weightLeft: (splitIdx - startIdx + 1) / (double)(endIdx - startIdx + 1));
    207207      }
    208208    }
  • branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r13030 r13948  
    3434  [StorableClass]
    3535  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
    36   public sealed class RegressionTreeModel : NamedItem, IRegressionModel {
     36  public sealed class RegressionTreeModel : RegressionModel {
     37    public override IEnumerable<string> VariablesUsedForPrediction {
     38      get { return tree.Select(t => t.VarName).Where(v => v != TreeNode.NO_VARIABLE); }
     39    }
    3740
    3841    // trees are represented as a flat array   
     
    4043      public readonly static string NO_VARIABLE = null;
    4144
    42       public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1)
     45      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0)
    4346        : this() {
    4447        VarName = varName;
     
    4649        LeftIdx = leftIdx;
    4750        RightIdx = rightIdx;
    48       }
    49 
    50       public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node
    51       public double Val { get; private set; } // threshold
    52       public int LeftIdx { get; private set; }
    53       public int RightIdx { get; private set; }
     51        WeightLeft = weightLeft;
     52      }
     53
     54      public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node
     55      public double Val { get; internal set; } // threshold
     56      public int LeftIdx { get; internal set; }
     57      public int RightIdx { get; internal set; }
     58      public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree
     59
    5460
    5561      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
     
    6470            LeftIdx.Equals(other.LeftIdx) &&
    6571            RightIdx.Equals(other.RightIdx) &&
     72            WeightLeft.Equals(other.WeightLeft) &&
    6673            EqualStrings(VarName, other.VarName);
    6774        } else {
     
    7986    private TreeNode[] tree;
    8087
    81     [Storable]
     88    #region old storable format
     89    // remove with HL 3.4
     90    [Storable(AllowOneWay = true)]
    8291    // to prevent storing the references to data caches in nodes
    83     // TODO seemingly it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
     92    // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
    8493    private Tuple<string, double, int, int>[] SerializedTree {
    85       get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
    86       set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); }
    87     }
     94      // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
     95      set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models
     96    }
     97    #endregion
     98    #region new storable format
     99    [Storable]
     100    private string[] SerializedTreeVarNames {
     101      get { return tree.Select(t => t.VarName).ToArray(); }
     102      set {
     103        if (tree == null) tree = new TreeNode[value.Length];
     104        for (int i = 0; i < value.Length; i++) {
     105          tree[i].VarName = value[i];
     106        }
     107      }
     108    }
     109    [Storable]
     110    private double[] SerializedTreeValues {
     111      get { return tree.Select(t => t.Val).ToArray(); }
     112      set {
     113        if (tree == null) tree = new TreeNode[value.Length];
     114        for (int i = 0; i < value.Length; i++) {
     115          tree[i].Val = value[i];
     116        }
     117      }
     118    }
     119    [Storable]
     120    private int[] SerializedTreeLeftIdx {
     121      get { return tree.Select(t => t.LeftIdx).ToArray(); }
     122      set {
     123        if (tree == null) tree = new TreeNode[value.Length];
     124        for (int i = 0; i < value.Length; i++) {
     125          tree[i].LeftIdx = value[i];
     126        }
     127      }
     128    }
     129    [Storable]
     130    private int[] SerializedTreeRightIdx {
     131      get { return tree.Select(t => t.RightIdx).ToArray(); }
     132      set {
     133        if (tree == null) tree = new TreeNode[value.Length];
     134        for (int i = 0; i < value.Length; i++) {
     135          tree[i].RightIdx = value[i];
     136        }
     137      }
     138    }
     139    [Storable]
     140    private double[] SerializedTreeWeightLeft {
     141      get { return tree.Select(t => t.WeightLeft).ToArray(); }
     142      set {
     143        if (tree == null) tree = new TreeNode[value.Length];
     144        for (int i = 0; i < value.Length; i++) {
     145          tree[i].WeightLeft = value[i];
     146        }
     147      }
     148    }
     149    #endregion
     150
     151
     152
    88153
    89154    [StorableConstructor]
     
    98163    }
    99164
    100     internal RegressionTreeModel(TreeNode[] tree)
    101       : base("RegressionTreeModel", "Represents a decision tree for regression.") {
     165    internal RegressionTreeModel(TreeNode[] tree, string target = "Target")
     166      : base(target, "RegressionTreeModel", "Represents a decision tree for regression.") {
    102167      this.tree = tree;
    103168    }
     
    108173        if (node.VarName == TreeNode.NO_VARIABLE)
    109174          return node.Val;
    110 
    111         if (columnCache[nodeIdx][row] <= node.Val)
     175        if (columnCache[nodeIdx] == null) {
     176          if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab.");
     177          // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees)
     178          return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) +
     179                 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row);
     180        } else if (columnCache[nodeIdx][row] <= node.Val)
    112181          nodeIdx = node.LeftIdx;
    113182        else
     
    121190    }
    122191
    123     public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
     192    public override IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
    124193      // lookup columns for variableNames in one pass over the tree to speed up evaluation later on
    125194      ReadOnlyCollection<double>[] columnCache = new ReadOnlyCollection<double>[tree.Length];
     
    127196      for (int i = 0; i < tree.Length; i++) {
    128197        if (tree[i].VarName != TreeNode.NO_VARIABLE) {
    129           columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
     198          // tree models also support calculating estimations if not all variables used for training are available in the dataset
     199          if (ds.ColumnNames.Contains(tree[i].VarName))
     200            columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
    130201        }
    131202      }
     
    133204    }
    134205
    135     public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
     206    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    136207      return new RegressionSolution(this, new RegressionProblemData(problemData));
    137208    }
     
    148219      } else {
    149220        return
    150           TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val))
    151         + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val));
    152       }
    153     }
     221          TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft))
     222        + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2}  >  {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft));
     223      }
     224    }
     225
    154226  }
    155227}
Note: See TracChangeset for help on using the changeset viewer.