Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 16811

Last change on this file since 16811 was 15788, checked in by gkronber, 7 years ago

#2902 merged r15783 and r15786 from trunk to stable

File size: 18.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  /// <summary>
34  /// Represents a random forest model for regression and classification
35  /// </summary>
36  [StorableClass]
37  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
38  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
39    // not persisted
40    private alglib.decisionforest randomForest;
41    private alglib.decisionforest RandomForest {
42      get {
43        // recalculate lazily
44        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
45        return randomForest;
46      }
47    }
48
49    public override IEnumerable<string> VariablesUsedForPrediction {
50      get { return originalTrainingData.AllowedInputVariables; }
51    }
52
53    public int NumberOfTrees {
54      get { return nTrees; }
55    }
56
57    // instead of storing the data of the model itself
58    // we instead only store data necessary to recalculate the same model lazily on demand
59    [Storable]
60    private int seed;
61    [Storable]
62    private IDataAnalysisProblemData originalTrainingData;
63    [Storable]
64    private double[] classValues;
65    [Storable]
66    private int nTrees;
67    [Storable]
68    private double r;
69    [Storable]
70    private double m;
71
72    [StorableConstructor]
73    private RandomForestModel(bool deserializing)
74      : base(deserializing) {
75      // for backwards compatibility (loading old solutions)
76      randomForest = new alglib.decisionforest();
77    }
78    private RandomForestModel(RandomForestModel original, Cloner cloner)
79      : base(original, cloner) {
80      randomForest = new alglib.decisionforest();
81      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
82      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
83      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
84      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
85      // we assume that the trees array (double[]) is immutable in alglib
86      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
87
88      // allowedInputVariables is immutable so we don't need to clone
89      allowedInputVariables = original.allowedInputVariables;
90
91      // clone data which is necessary to rebuild the model
92      this.seed = original.seed;
93      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
94      // classvalues is immutable so we don't need to clone
95      this.classValues = original.classValues;
96      this.nTrees = original.nTrees;
97      this.r = original.r;
98      this.m = original.m;
99    }
100
101    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
102    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
103      int seed, IDataAnalysisProblemData originalTrainingData,
104      int nTrees, double r, double m, double[] classValues = null)
105      : base(targetVariable) {
106      this.name = ItemName;
107      this.description = ItemDescription;
108      // the model itself
109      this.randomForest = randomForest;
110      // data which is necessary for recalculation of the model
111      this.seed = seed;
112      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
113      this.classValues = classValues;
114      this.nTrees = nTrees;
115      this.r = r;
116      this.m = m;
117    }
118
119    public override IDeepCloneable Clone(Cloner cloner) {
120      return new RandomForestModel(this, cloner);
121    }
122
123    private void RecalculateModel() {
124      double rmsError, oobRmsError, relClassError, oobRelClassError;
125      var regressionProblemData = originalTrainingData as IRegressionProblemData;
126      var classificationProblemData = originalTrainingData as IClassificationProblemData;
127      if (regressionProblemData != null) {
128        var model = CreateRegressionModel(regressionProblemData,
129                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
130                                              out relClassError, out oobRelClassError);
131        randomForest = model.randomForest;
132      } else if (classificationProblemData != null) {
133        var model = CreateClassificationModel(classificationProblemData,
134                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
135                                              out relClassError, out oobRelClassError);
136        randomForest = model.randomForest;
137      }
138    }
139
140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
141      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
142      AssertInputMatrix(inputData);
143
144      int n = inputData.GetLength(0);
145      int columns = inputData.GetLength(1);
146      double[] x = new double[columns];
147      double[] y = new double[1];
148
149      for (int row = 0; row < n; row++) {
150        for (int column = 0; column < columns; column++) {
151          x[column] = inputData[row, column];
152        }
153        alglib.dfprocess(RandomForest, x, ref y);
154        yield return y[0];
155      }
156    }
157
158    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
159      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
160      AssertInputMatrix(inputData);
161
162      int n = inputData.GetLength(0);
163      int columns = inputData.GetLength(1);
164      double[] x = new double[columns];
165      double[] ys = new double[this.RandomForest.innerobj.ntrees];
166
167      for (int row = 0; row < n; row++) {
168        for (int column = 0; column < columns; column++) {
169          x[column] = inputData[row, column];
170        }
171        alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
172        yield return ys.VariancePop();
173      }
174    }
175
176    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
177      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
178      AssertInputMatrix(inputData);
179
180      int n = inputData.GetLength(0);
181      int columns = inputData.GetLength(1);
182      double[] x = new double[columns];
183      double[] y = new double[RandomForest.innerobj.nclasses];
184
185      for (int row = 0; row < n; row++) {
186        for (int column = 0; column < columns; column++) {
187          x[column] = inputData[row, column];
188        }
189        alglib.dfprocess(randomForest, x, ref y);
190        // find class for with the largest probability value
191        int maxProbClassIndex = 0;
192        double maxProb = y[0];
193        for (int i = 1; i < y.Length; i++) {
194          if (maxProb < y[i]) {
195            maxProb = y[i];
196            maxProbClassIndex = i;
197          }
198        }
199        yield return classValues[maxProbClassIndex];
200      }
201    }
202
203    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
204      var rf = RandomForest;
205      // hoping that the internal representation of alglib is stable
206
207      // TREE FORMAT
208      // W[Offs]      -   size of sub-array (for the tree)
209      //     node info:
210      // W[K+0]       -   variable number        (-1 for leaf mode)
211      // W[K+1]       -   threshold              (class/value for leaf node)
212      // W[K+2]       -   ">=" branch index      (absent for leaf node)
213
214      // skip irrelevant trees
215      int offset = 0;
216      for (int i = 0; i < treeIdx - 1; i++) {
217        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
218      }
219
220      var constSy = new Constant();
221      var varCondSy = new VariableCondition() { IgnoreSlope = true };
222
223      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy);
224
225      var startNode = new StartSymbol().CreateTreeNode();
226      startNode.AddSubtree(node);
227      var root = new ProgramRootSymbol().CreateTreeNode();
228      root.AddSubtree(startNode);
229      return new SymbolicExpressionTree(root);
230    }
231
232    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
233
234      // alglib source for evaluation of one tree (dfprocessinternal)
235      // offs = 0
236      //
237      // Set pointer to the root
238      //
239      // k = offs + 1;
240      //
241      // //
242      // // Navigate through the tree
243      // //
244      // while (true) {
245      //   if ((double)(df.trees[k]) == (double)(-1)) {
246      //     if (df.nclasses == 1) {
247      //       y[0] = y[0] + df.trees[k + 1];
248      //     } else {
249      //       idx = (int)Math.Round(df.trees[k + 1]);
250      //       y[idx] = y[idx] + 1;
251      //     }
252      //     break;
253      //   }
254      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
255      //     k = k + innernodewidth;
256      //   } else {
257      //     k = offs + (int)Math.Round(df.trees[k + 2]);
258      //   }
259      // }
260
261      if ((double)(trees[k]) == (double)(-1)) {
262        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
263        constNode.Value = trees[k + 1];
264        return constNode;
265      } else {
266        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
267        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
268        condNode.Threshold = trees[k + 1];
269        condNode.Slope = double.PositiveInfinity;
270
271        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
272        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
273
274        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
275        condNode.AddSubtree(right);
276        return condNode;
277      }
278    }
279
280
281    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
282      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
283    }
284    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
285      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
286    }
287
288    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
289      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
290      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
291       rmsError: out rmsError, outOfBagRmsError: out outOfBagRmsError, avgRelError: out avgRelError, outOfBagAvgRelError: out outOfBagAvgRelError);
292    }
293
294    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
295      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
296      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
297      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
298
299      alglib.dfreport rep;
300      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
301
302      rmsError = rep.rmserror;
303      outOfBagRmsError = rep.oobrmserror;
304      avgRelError = rep.avgrelerror;
305      outOfBagAvgRelError = rep.oobavgrelerror;
306
307      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
308    }
309
310    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
311      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
312      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
313        out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
314    }
315
316    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
317      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
318
319      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
320      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
321
322      var classValues = problemData.ClassValues.ToArray();
323      int nClasses = classValues.Length;
324
325      // map original class values to values [0..nClasses-1]
326      var classIndices = new Dictionary<double, double>();
327      for (int i = 0; i < nClasses; i++) {
328        classIndices[classValues[i]] = i;
329      }
330
331      int nRows = inputMatrix.GetLength(0);
332      int nColumns = inputMatrix.GetLength(1);
333      for (int row = 0; row < nRows; row++) {
334        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
335      }
336
337      alglib.dfreport rep;
338      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
339
340      rmsError = rep.rmserror;
341      outOfBagRmsError = rep.oobrmserror;
342      relClassificationError = rep.relclserror;
343      outOfBagRelClassificationError = rep.oobrelclserror;
344
345      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
346    }
347
348    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
349      AssertParameters(r, m);
350      AssertInputMatrix(inputMatrix);
351
352      int info = 0;
353      alglib.math.rndobject = new System.Random(seed);
354      var dForest = new alglib.decisionforest();
355      rep = new alglib.dfreport();
356      int nRows = inputMatrix.GetLength(0);
357      int nColumns = inputMatrix.GetLength(1);
358      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
359      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
360
361      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
362      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
363      return dForest;
364    }
365
366    private static void AssertParameters(double r, double m) {
367      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
368      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
369    }
370
371    private static void AssertInputMatrix(double[,] inputMatrix) {
372      if (inputMatrix.ContainsNanOrInfinity())
373        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
374    }
375
376    #region persistence for backwards compatibility
377    // when the originalTrainingData is null this means the model was loaded from an old file
378    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
379    // in such cases we still store the compete model
380    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
381
382    private string[] allowedInputVariables;
383    [Storable(Name = "allowedInputVariables")]
384    private string[] AllowedInputVariables {
385      get {
386        if (IsCompatibilityLoaded) return allowedInputVariables;
387        else return originalTrainingData.AllowedInputVariables.ToArray();
388      }
389      set { allowedInputVariables = value; }
390    }
391    [Storable]
392    private int RandomForestBufSize {
393      get {
394        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
395        else return 0;
396      }
397      set {
398        randomForest.innerobj.bufsize = value;
399      }
400    }
401    [Storable]
402    private int RandomForestNClasses {
403      get {
404        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
405        else return 0;
406      }
407      set {
408        randomForest.innerobj.nclasses = value;
409      }
410    }
411    [Storable]
412    private int RandomForestNTrees {
413      get {
414        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
415        else return 0;
416      }
417      set {
418        randomForest.innerobj.ntrees = value;
419      }
420    }
421    [Storable]
422    private int RandomForestNVars {
423      get {
424        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
425        else return 0;
426      }
427      set {
428        randomForest.innerobj.nvars = value;
429      }
430    }
431    [Storable]
432    private double[] RandomForestTrees {
433      get {
434        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
435        else return new double[] { };
436      }
437      set {
438        randomForest.innerobj.trees = value;
439      }
440    }
441    #endregion
442  }
443}
Note: See TracBrowser for help on using the repository browser.