Free cookie consent management tool by TermsFeed Policy Generator

source: branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 14853

Last change on this file since 14853 was 14368, checked in by gkronber, 8 years ago

#2690: recalculate RF if required before extracting a specific tree

File size: 18.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  /// <summary>
34  /// Represents a random forest model for regression and classification
35  /// </summary>
36  [StorableClass]
37  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
38  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
39    // not persisted
40    private alglib.decisionforest randomForest;
41    private alglib.decisionforest RandomForest {
42      get {
43        // recalculate lazily
44        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
45        return randomForest;
46      }
47    }
48
49    public override IEnumerable<string> VariablesUsedForPrediction {
50      get { return originalTrainingData.AllowedInputVariables; }
51    }
52
53    public int NumberOfTrees {
54      get { return nTrees; }
55    }
56
57    // instead of storing the data of the model itself
58    // we instead only store data necessary to recalculate the same model lazily on demand
59    [Storable]
60    private int seed;
61    [Storable]
62    private IDataAnalysisProblemData originalTrainingData;
63    [Storable]
64    private double[] classValues;
65    [Storable]
66    private int nTrees;
67    [Storable]
68    private double r;
69    [Storable]
70    private double m;
71
72    [StorableConstructor]
73    private RandomForestModel(bool deserializing)
74      : base(deserializing) {
75      // for backwards compatibility (loading old solutions)
76      randomForest = new alglib.decisionforest();
77    }
78    private RandomForestModel(RandomForestModel original, Cloner cloner)
79      : base(original, cloner) {
80      randomForest = new alglib.decisionforest();
81      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
82      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
83      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
84      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
85      // we assume that the trees array (double[]) is immutable in alglib
86      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
87
88      // allowedInputVariables is immutable so we don't need to clone
89      allowedInputVariables = original.allowedInputVariables;
90
91      // clone data which is necessary to rebuild the model
92      this.seed = original.seed;
93      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
94      // classvalues is immutable so we don't need to clone
95      this.classValues = original.classValues;
96      this.nTrees = original.nTrees;
97      this.r = original.r;
98      this.m = original.m;
99    }
100
101    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
102    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
103      int seed, IDataAnalysisProblemData originalTrainingData,
104      int nTrees, double r, double m, double[] classValues = null)
105      : base(targetVariable) {
106      this.name = ItemName;
107      this.description = ItemDescription;
108      // the model itself
109      this.randomForest = randomForest;
110      // data which is necessary for recalculation of the model
111      this.seed = seed;
112      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
113      this.classValues = classValues;
114      this.nTrees = nTrees;
115      this.r = r;
116      this.m = m;
117    }
118
119    public override IDeepCloneable Clone(Cloner cloner) {
120      return new RandomForestModel(this, cloner);
121    }
122
123    private void RecalculateModel() {
124      double rmsError, oobRmsError, relClassError, oobRelClassError;
125      var regressionProblemData = originalTrainingData as IRegressionProblemData;
126      var classificationProblemData = originalTrainingData as IClassificationProblemData;
127      if (regressionProblemData != null) {
128        var model = CreateRegressionModel(regressionProblemData,
129                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
130                                              out relClassError, out oobRelClassError);
131        randomForest = model.randomForest;
132      } else if (classificationProblemData != null) {
133        var model = CreateClassificationModel(classificationProblemData,
134                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
135                                              out relClassError, out oobRelClassError);
136        randomForest = model.randomForest;
137      }
138    }
139
140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
141      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
142      AssertInputMatrix(inputData);
143
144      int n = inputData.GetLength(0);
145      int columns = inputData.GetLength(1);
146      double[] x = new double[columns];
147      double[] y = new double[1];
148
149      for (int row = 0; row < n; row++) {
150        for (int column = 0; column < columns; column++) {
151          x[column] = inputData[row, column];
152        }
153        alglib.dfprocess(RandomForest, x, ref y);
154        yield return y[0];
155      }
156    }
157
158    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
159      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
160      AssertInputMatrix(inputData);
161
162      int n = inputData.GetLength(0);
163      int columns = inputData.GetLength(1);
164      double[] x = new double[columns];
165      double[] ys = new double[this.RandomForest.innerobj.ntrees];
166
167      for (int row = 0; row < n; row++) {
168        for (int column = 0; column < columns; column++) {
169          x[column] = inputData[row, column];
170        }
171        alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
172        yield return ys.VariancePop();
173      }
174    }
175
176    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
177      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
178      AssertInputMatrix(inputData);
179
180      int n = inputData.GetLength(0);
181      int columns = inputData.GetLength(1);
182      double[] x = new double[columns];
183      double[] y = new double[RandomForest.innerobj.nclasses];
184
185      for (int row = 0; row < n; row++) {
186        for (int column = 0; column < columns; column++) {
187          x[column] = inputData[row, column];
188        }
189        alglib.dfprocess(randomForest, x, ref y);
190        // find class for with the largest probability value
191        int maxProbClassIndex = 0;
192        double maxProb = y[0];
193        for (int i = 1; i < y.Length; i++) {
194          if (maxProb < y[i]) {
195            maxProb = y[i];
196            maxProbClassIndex = i;
197          }
198        }
199        yield return classValues[maxProbClassIndex];
200      }
201    }
202
203    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
204      var rf = RandomForest;
205      // hoping that the internal representation of alglib is stable
206
207      // TREE FORMAT
208      // W[Offs]      -   size of sub-array (for the tree)
209      //     node info:
210      // W[K+0]       -   variable number        (-1 for leaf mode)
211      // W[K+1]       -   threshold              (class/value for leaf node)
212      // W[K+2]       -   ">=" branch index      (absent for leaf node)
213
214      // skip irrelevant trees
215      int offset = 0;
216      for (int i = 0; i < treeIdx - 1; i++) {
217        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
218      }
219
220      var constSy = new Constant();
221      var varCondSy = new VariableCondition() { IgnoreSlope = true };
222
223      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy);
224
225      var startNode = new StartSymbol().CreateTreeNode();
226      startNode.AddSubtree(node);
227      var root = new ProgramRootSymbol().CreateTreeNode();
228      root.AddSubtree(startNode);
229      return new SymbolicExpressionTree(root);
230    }
231
232    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
233
234      // alglib source for evaluation of one tree (dfprocessinternal)
235      // offs = 0
236      //
237      // Set pointer to the root
238      //
239      // k = offs + 1;
240      //
241      // //
242      // // Navigate through the tree
243      // //
244      // while (true) {
245      //   if ((double)(df.trees[k]) == (double)(-1)) {
246      //     if (df.nclasses == 1) {
247      //       y[0] = y[0] + df.trees[k + 1];
248      //     } else {
249      //       idx = (int)Math.Round(df.trees[k + 1]);
250      //       y[idx] = y[idx] + 1;
251      //     }
252      //     break;
253      //   }
254      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
255      //     k = k + innernodewidth;
256      //   } else {
257      //     k = offs + (int)Math.Round(df.trees[k + 2]);
258      //   }
259      // }
260
261      if ((double)(trees[k]) == (double)(-1)) {
262        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
263        constNode.Value = trees[k + 1];
264        return constNode;
265      } else {
266        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
267        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
268        condNode.Threshold = trees[k + 1];
269        condNode.Slope = double.PositiveInfinity;
270
271        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
272        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
273
274        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
275        condNode.AddSubtree(right);
276        return condNode;
277      }
278    }
279
280
281    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
282      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
283    }
284    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
285      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
286    }
287
288    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
289      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
290      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
291    }
292
293    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
294      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
295      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
296      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
297
298      alglib.dfreport rep;
299      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
300
301      rmsError = rep.rmserror;
302      avgRelError = rep.avgrelerror;
303      outOfBagAvgRelError = rep.oobavgrelerror;
304      outOfBagRmsError = rep.oobrmserror;
305
306      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
307    }
308
309    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
310      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
311      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
312    }
313
314    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
315      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
316
317      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
318      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
319
320      var classValues = problemData.ClassValues.ToArray();
321      int nClasses = classValues.Length;
322
323      // map original class values to values [0..nClasses-1]
324      var classIndices = new Dictionary<double, double>();
325      for (int i = 0; i < nClasses; i++) {
326        classIndices[classValues[i]] = i;
327      }
328
329      int nRows = inputMatrix.GetLength(0);
330      int nColumns = inputMatrix.GetLength(1);
331      for (int row = 0; row < nRows; row++) {
332        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
333      }
334
335      alglib.dfreport rep;
336      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
337
338      rmsError = rep.rmserror;
339      outOfBagRmsError = rep.oobrmserror;
340      relClassificationError = rep.relclserror;
341      outOfBagRelClassificationError = rep.oobrelclserror;
342
343      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
344    }
345
346    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
347      AssertParameters(r, m);
348      AssertInputMatrix(inputMatrix);
349
350      int info = 0;
351      alglib.math.rndobject = new System.Random(seed);
352      var dForest = new alglib.decisionforest();
353      rep = new alglib.dfreport();
354      int nRows = inputMatrix.GetLength(0);
355      int nColumns = inputMatrix.GetLength(1);
356      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
357      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
358
359      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
360      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
361      return dForest;
362    }
363
364    private static void AssertParameters(double r, double m) {
365      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
366      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
367    }
368
369    private static void AssertInputMatrix(double[,] inputMatrix) {
370      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
371        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
372    }
373
374    #region persistence for backwards compatibility
375    // when the originalTrainingData is null this means the model was loaded from an old file
376    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
377    // in such cases we still store the compete model
378    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
379
380    private string[] allowedInputVariables;
381    [Storable(Name = "allowedInputVariables")]
382    private string[] AllowedInputVariables {
383      get {
384        if (IsCompatibilityLoaded) return allowedInputVariables;
385        else return originalTrainingData.AllowedInputVariables.ToArray();
386      }
387      set { allowedInputVariables = value; }
388    }
389    [Storable]
390    private int RandomForestBufSize {
391      get {
392        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
393        else return 0;
394      }
395      set {
396        randomForest.innerobj.bufsize = value;
397      }
398    }
399    [Storable]
400    private int RandomForestNClasses {
401      get {
402        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
403        else return 0;
404      }
405      set {
406        randomForest.innerobj.nclasses = value;
407      }
408    }
409    [Storable]
410    private int RandomForestNTrees {
411      get {
412        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
413        else return 0;
414      }
415      set {
416        randomForest.innerobj.ntrees = value;
417      }
418    }
419    [Storable]
420    private int RandomForestNVars {
421      get {
422        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
423        else return 0;
424      }
425      set {
426        randomForest.innerobj.nvars = value;
427      }
428    }
429    [Storable]
430    private double[] RandomForestTrees {
431      get {
432        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
433        else return new double[] { };
434      }
435      set {
436        randomForest.innerobj.trees = value;
437      }
438    }
439    #endregion
440  }
441}
Note: See TracBrowser for help on using the repository browser.