source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 14345

Last change on this file since 14345 was 14345, checked in by gkronber, 4 years ago

#2690: implemented methods to generate symbolic expression tree solutions for decision tree models (random forest and gradient boosted) as well as views which make it possible to inspect each of the individual trees in a GBT and RF solution

File size: 18.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  /// <summary>
34  /// Represents a random forest model for regression and classification
35  /// </summary>
36  [StorableClass]
37  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
38  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
39    // not persisted
40    private alglib.decisionforest randomForest;
41    private alglib.decisionforest RandomForest {
42      get {
43        // recalculate lazily
44        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
45        return randomForest;
46      }
47    }
48
49    public override IEnumerable<string> VariablesUsedForPrediction {
50      get { return originalTrainingData.AllowedInputVariables; }
51    }
52
53    public int NumberOfTrees {
54      get { return nTrees; }
55    }
56
57    // instead of storing the data of the model itself
58    // we instead only store data necessary to recalculate the same model lazily on demand
59    [Storable]
60    private int seed;
61    [Storable]
62    private IDataAnalysisProblemData originalTrainingData;
63    [Storable]
64    private double[] classValues;
65    [Storable]
66    private int nTrees;
67    [Storable]
68    private double r;
69    [Storable]
70    private double m;
71
72    [StorableConstructor]
73    private RandomForestModel(bool deserializing)
74      : base(deserializing) {
75      // for backwards compatibility (loading old solutions)
76      randomForest = new alglib.decisionforest();
77    }
78    private RandomForestModel(RandomForestModel original, Cloner cloner)
79      : base(original, cloner) {
80      randomForest = new alglib.decisionforest();
81      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
82      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
83      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
84      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
85      // we assume that the trees array (double[]) is immutable in alglib
86      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
87
88      // allowedInputVariables is immutable so we don't need to clone
89      allowedInputVariables = original.allowedInputVariables;
90
91      // clone data which is necessary to rebuild the model
92      this.seed = original.seed;
93      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
94      // classvalues is immutable so we don't need to clone
95      this.classValues = original.classValues;
96      this.nTrees = original.nTrees;
97      this.r = original.r;
98      this.m = original.m;
99    }
100
101    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
102    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
103      int seed, IDataAnalysisProblemData originalTrainingData,
104      int nTrees, double r, double m, double[] classValues = null)
105      : base(targetVariable) {
106      this.name = ItemName;
107      this.description = ItemDescription;
108      // the model itself
109      this.randomForest = randomForest;
110      // data which is necessary for recalculation of the model
111      this.seed = seed;
112      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
113      this.classValues = classValues;
114      this.nTrees = nTrees;
115      this.r = r;
116      this.m = m;
117    }
118
119    public override IDeepCloneable Clone(Cloner cloner) {
120      return new RandomForestModel(this, cloner);
121    }
122
123    private void RecalculateModel() {
124      double rmsError, oobRmsError, relClassError, oobRelClassError;
125      var regressionProblemData = originalTrainingData as IRegressionProblemData;
126      var classificationProblemData = originalTrainingData as IClassificationProblemData;
127      if (regressionProblemData != null) {
128        var model = CreateRegressionModel(regressionProblemData,
129                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
130                                              out relClassError, out oobRelClassError);
131        randomForest = model.randomForest;
132      } else if (classificationProblemData != null) {
133        var model = CreateClassificationModel(classificationProblemData,
134                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
135                                              out relClassError, out oobRelClassError);
136        randomForest = model.randomForest;
137      }
138    }
139
140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
141      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
142      AssertInputMatrix(inputData);
143
144      int n = inputData.GetLength(0);
145      int columns = inputData.GetLength(1);
146      double[] x = new double[columns];
147      double[] y = new double[1];
148
149      for (int row = 0; row < n; row++) {
150        for (int column = 0; column < columns; column++) {
151          x[column] = inputData[row, column];
152        }
153        alglib.dfprocess(RandomForest, x, ref y);
154        yield return y[0];
155      }
156    }
157
158    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
159      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
160      AssertInputMatrix(inputData);
161
162      int n = inputData.GetLength(0);
163      int columns = inputData.GetLength(1);
164      double[] x = new double[columns];
165      double[] ys = new double[this.RandomForest.innerobj.ntrees];
166
167      for (int row = 0; row < n; row++) {
168        for (int column = 0; column < columns; column++) {
169          x[column] = inputData[row, column];
170        }
171        alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
172        yield return ys.VariancePop();
173      }
174    }
175
176    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
177      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
178      AssertInputMatrix(inputData);
179
180      int n = inputData.GetLength(0);
181      int columns = inputData.GetLength(1);
182      double[] x = new double[columns];
183      double[] y = new double[RandomForest.innerobj.nclasses];
184
185      for (int row = 0; row < n; row++) {
186        for (int column = 0; column < columns; column++) {
187          x[column] = inputData[row, column];
188        }
189        alglib.dfprocess(randomForest, x, ref y);
190        // find class for with the largest probability value
191        int maxProbClassIndex = 0;
192        double maxProb = y[0];
193        for (int i = 1; i < y.Length; i++) {
194          if (maxProb < y[i]) {
195            maxProb = y[i];
196            maxProbClassIndex = i;
197          }
198        }
199        yield return classValues[maxProbClassIndex];
200      }
201    }
202
203    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
204      // hoping that the internal representation of alglib is stable
205
206      // TREE FORMAT
207      // W[Offs]      -   size of sub-array (for the tree)
208      //     node info:
209      // W[K+0]       -   variable number        (-1 for leaf mode)
210      // W[K+1]       -   threshold              (class/value for leaf node)
211      // W[K+2]       -   ">=" branch index      (absent for leaf node)
212
213      // skip irrelevant trees
214      int offset = 0;
215      for (int i = 0; i < treeIdx - 1; i++) {
216        offset = offset + (int)Math.Round(randomForest.innerobj.trees[offset]);
217      }
218
219      var constSy = new Constant();
220      var varCondSy = new VariableCondition() { IgnoreSlope = true };
221
222      var node = CreateRegressionTreeRec(randomForest.innerobj.trees, offset, offset + 1, constSy, varCondSy);
223
224      var startNode = new StartSymbol().CreateTreeNode();
225      startNode.AddSubtree(node);
226      var root = new ProgramRootSymbol().CreateTreeNode();
227      root.AddSubtree(startNode);
228      return new SymbolicExpressionTree(root);
229    }
230
231    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
232
233      // alglib source for evaluation of one tree (dfprocessinternal)
234      // offs = 0
235      //
236      // Set pointer to the root
237      //
238      // k = offs + 1;
239      //
240      // //
241      // // Navigate through the tree
242      // //
243      // while (true) {
244      //   if ((double)(df.trees[k]) == (double)(-1)) {
245      //     if (df.nclasses == 1) {
246      //       y[0] = y[0] + df.trees[k + 1];
247      //     } else {
248      //       idx = (int)Math.Round(df.trees[k + 1]);
249      //       y[idx] = y[idx] + 1;
250      //     }
251      //     break;
252      //   }
253      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
254      //     k = k + innernodewidth;
255      //   } else {
256      //     k = offs + (int)Math.Round(df.trees[k + 2]);
257      //   }
258      // }
259
260      if ((double)(trees[k]) == (double)(-1)) {
261        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
262        constNode.Value = trees[k + 1];
263        return constNode;
264      } else {
265        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
266        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
267        condNode.Threshold = trees[k + 1];
268        condNode.Slope = double.PositiveInfinity;
269
270        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
271        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
272
273        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
274        condNode.AddSubtree(right);
275        return condNode;
276      }
277    }
278
279
280    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
281      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
282    }
283    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
284      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
285    }
286
287    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
288      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
289      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
290    }
291
292    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
293      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
294      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
295      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
296
297      alglib.dfreport rep;
298      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
299
300      rmsError = rep.rmserror;
301      avgRelError = rep.avgrelerror;
302      outOfBagAvgRelError = rep.oobavgrelerror;
303      outOfBagRmsError = rep.oobrmserror;
304
305      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
306    }
307
308    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
309      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
310      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
311    }
312
313    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
314      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
315
316      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
317      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
318
319      var classValues = problemData.ClassValues.ToArray();
320      int nClasses = classValues.Length;
321
322      // map original class values to values [0..nClasses-1]
323      var classIndices = new Dictionary<double, double>();
324      for (int i = 0; i < nClasses; i++) {
325        classIndices[classValues[i]] = i;
326      }
327
328      int nRows = inputMatrix.GetLength(0);
329      int nColumns = inputMatrix.GetLength(1);
330      for (int row = 0; row < nRows; row++) {
331        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
332      }
333
334      alglib.dfreport rep;
335      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
336
337      rmsError = rep.rmserror;
338      outOfBagRmsError = rep.oobrmserror;
339      relClassificationError = rep.relclserror;
340      outOfBagRelClassificationError = rep.oobrelclserror;
341
342      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
343    }
344
345    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
346      AssertParameters(r, m);
347      AssertInputMatrix(inputMatrix);
348
349      int info = 0;
350      alglib.math.rndobject = new System.Random(seed);
351      var dForest = new alglib.decisionforest();
352      rep = new alglib.dfreport();
353      int nRows = inputMatrix.GetLength(0);
354      int nColumns = inputMatrix.GetLength(1);
355      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
356      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
357
358      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
359      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
360      return dForest;
361    }
362
363    private static void AssertParameters(double r, double m) {
364      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
365      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
366    }
367
368    private static void AssertInputMatrix(double[,] inputMatrix) {
369      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
370        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
371    }
372
373    #region persistence for backwards compatibility
374    // when the originalTrainingData is null this means the model was loaded from an old file
375    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
376    // in such cases we still store the compete model
377    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
378
379    private string[] allowedInputVariables;
380    [Storable(Name = "allowedInputVariables")]
381    private string[] AllowedInputVariables {
382      get {
383        if (IsCompatibilityLoaded) return allowedInputVariables;
384        else return originalTrainingData.AllowedInputVariables.ToArray();
385      }
386      set { allowedInputVariables = value; }
387    }
388    [Storable]
389    private int RandomForestBufSize {
390      get {
391        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
392        else return 0;
393      }
394      set {
395        randomForest.innerobj.bufsize = value;
396      }
397    }
398    [Storable]
399    private int RandomForestNClasses {
400      get {
401        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
402        else return 0;
403      }
404      set {
405        randomForest.innerobj.nclasses = value;
406      }
407    }
408    [Storable]
409    private int RandomForestNTrees {
410      get {
411        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
412        else return 0;
413      }
414      set {
415        randomForest.innerobj.ntrees = value;
416      }
417    }
418    [Storable]
419    private int RandomForestNVars {
420      get {
421        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
422        else return 0;
423      }
424      set {
425        randomForest.innerobj.nvars = value;
426      }
427    }
428    [Storable]
429    private double[] RandomForestTrees {
430      get {
431        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
432        else return new double[] { };
433      }
434      set {
435        randomForest.innerobj.trees = value;
436      }
437    }
438    #endregion
439  }
440}
Note: See TracBrowser for help on using the repository browser.