source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 16243

Last change on this file since 16243 was 16243, checked in by mkommend, 2 years ago

#2955: Added IsProblemDataCompatible and IsDatasetCompatible to all DataAnalysisModels.

File size: 19.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  /// <summary>
34  /// Represents a random forest model for regression and classification
35  /// </summary>
36  [StorableClass]
37  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
38  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
39    // not persisted
40    private alglib.decisionforest randomForest;
41    private alglib.decisionforest RandomForest {
42      get {
43        // recalculate lazily
44        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
45        return randomForest;
46      }
47    }
48
49    public override IEnumerable<string> VariablesUsedForPrediction {
50      get { return originalTrainingData.AllowedInputVariables; }
51    }
52
53    public int NumberOfTrees {
54      get { return nTrees; }
55    }
56
57    // instead of storing the data of the model itself
58    // we instead only store data necessary to recalculate the same model lazily on demand
59    [Storable]
60    private int seed;
61    [Storable]
62    private IDataAnalysisProblemData originalTrainingData;
63    [Storable]
64    private double[] classValues;
65    [Storable]
66    private int nTrees;
67    [Storable]
68    private double r;
69    [Storable]
70    private double m;
71
72    [StorableConstructor]
73    private RandomForestModel(bool deserializing)
74      : base(deserializing) {
75      // for backwards compatibility (loading old solutions)
76      randomForest = new alglib.decisionforest();
77    }
78    private RandomForestModel(RandomForestModel original, Cloner cloner)
79      : base(original, cloner) {
80      randomForest = new alglib.decisionforest();
81      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
82      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
83      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
84      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
85      // we assume that the trees array (double[]) is immutable in alglib
86      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
87
88      // allowedInputVariables is immutable so we don't need to clone
89      allowedInputVariables = original.allowedInputVariables;
90
91      // clone data which is necessary to rebuild the model
92      this.seed = original.seed;
93      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
94      // classvalues is immutable so we don't need to clone
95      this.classValues = original.classValues;
96      this.nTrees = original.nTrees;
97      this.r = original.r;
98      this.m = original.m;
99    }
100
101    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
102    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
103      int seed, IDataAnalysisProblemData originalTrainingData,
104      int nTrees, double r, double m, double[] classValues = null)
105      : base(targetVariable) {
106      this.name = ItemName;
107      this.description = ItemDescription;
108      // the model itself
109      this.randomForest = randomForest;
110      // data which is necessary for recalculation of the model
111      this.seed = seed;
112      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
113      this.classValues = classValues;
114      this.nTrees = nTrees;
115      this.r = r;
116      this.m = m;
117    }
118
119    public override IDeepCloneable Clone(Cloner cloner) {
120      return new RandomForestModel(this, cloner);
121    }
122
123    private void RecalculateModel() {
124      double rmsError, oobRmsError, relClassError, oobRelClassError;
125      var regressionProblemData = originalTrainingData as IRegressionProblemData;
126      var classificationProblemData = originalTrainingData as IClassificationProblemData;
127      if (regressionProblemData != null) {
128        var model = CreateRegressionModel(regressionProblemData,
129                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
130                                              out relClassError, out oobRelClassError);
131        randomForest = model.randomForest;
132      } else if (classificationProblemData != null) {
133        var model = CreateClassificationModel(classificationProblemData,
134                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
135                                              out relClassError, out oobRelClassError);
136        randomForest = model.randomForest;
137      }
138    }
139
140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
141      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
142      AssertInputMatrix(inputData);
143
144      int n = inputData.GetLength(0);
145      int columns = inputData.GetLength(1);
146      double[] x = new double[columns];
147      double[] y = new double[1];
148
149      for (int row = 0; row < n; row++) {
150        for (int column = 0; column < columns; column++) {
151          x[column] = inputData[row, column];
152        }
153        alglib.dfprocess(RandomForest, x, ref y);
154        yield return y[0];
155      }
156    }
157
158    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
159      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
160      AssertInputMatrix(inputData);
161
162      int n = inputData.GetLength(0);
163      int columns = inputData.GetLength(1);
164      double[] x = new double[columns];
165      double[] ys = new double[this.RandomForest.innerobj.ntrees];
166
167      for (int row = 0; row < n; row++) {
168        for (int column = 0; column < columns; column++) {
169          x[column] = inputData[row, column];
170        }
171        alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
172        yield return ys.VariancePop();
173      }
174    }
175
176    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
177      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
178      AssertInputMatrix(inputData);
179
180      int n = inputData.GetLength(0);
181      int columns = inputData.GetLength(1);
182      double[] x = new double[columns];
183      double[] y = new double[RandomForest.innerobj.nclasses];
184
185      for (int row = 0; row < n; row++) {
186        for (int column = 0; column < columns; column++) {
187          x[column] = inputData[row, column];
188        }
189        alglib.dfprocess(randomForest, x, ref y);
190        // find class for with the largest probability value
191        int maxProbClassIndex = 0;
192        double maxProb = y[0];
193        for (int i = 1; i < y.Length; i++) {
194          if (maxProb < y[i]) {
195            maxProb = y[i];
196            maxProbClassIndex = i;
197          }
198        }
199        yield return classValues[maxProbClassIndex];
200      }
201    }
202
203    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
204      var rf = RandomForest;
205      // hoping that the internal representation of alglib is stable
206
207      // TREE FORMAT
208      // W[Offs]      -   size of sub-array (for the tree)
209      //     node info:
210      // W[K+0]       -   variable number        (-1 for leaf mode)
211      // W[K+1]       -   threshold              (class/value for leaf node)
212      // W[K+2]       -   ">=" branch index      (absent for leaf node)
213
214      // skip irrelevant trees
215      int offset = 0;
216      for (int i = 0; i < treeIdx - 1; i++) {
217        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
218      }
219
220      var constSy = new Constant();
221      var varCondSy = new VariableCondition() { IgnoreSlope = true };
222
223      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy);
224
225      var startNode = new StartSymbol().CreateTreeNode();
226      startNode.AddSubtree(node);
227      var root = new ProgramRootSymbol().CreateTreeNode();
228      root.AddSubtree(startNode);
229      return new SymbolicExpressionTree(root);
230    }
231
232    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
233
234      // alglib source for evaluation of one tree (dfprocessinternal)
235      // offs = 0
236      //
237      // Set pointer to the root
238      //
239      // k = offs + 1;
240      //
241      // //
242      // // Navigate through the tree
243      // //
244      // while (true) {
245      //   if ((double)(df.trees[k]) == (double)(-1)) {
246      //     if (df.nclasses == 1) {
247      //       y[0] = y[0] + df.trees[k + 1];
248      //     } else {
249      //       idx = (int)Math.Round(df.trees[k + 1]);
250      //       y[idx] = y[idx] + 1;
251      //     }
252      //     break;
253      //   }
254      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
255      //     k = k + innernodewidth;
256      //   } else {
257      //     k = offs + (int)Math.Round(df.trees[k + 2]);
258      //   }
259      // }
260
261      if ((double)(trees[k]) == (double)(-1)) {
262        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
263        constNode.Value = trees[k + 1];
264        return constNode;
265      } else {
266        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
267        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
268        condNode.Threshold = trees[k + 1];
269        condNode.Slope = double.PositiveInfinity;
270
271        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
272        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
273
274        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
275        condNode.AddSubtree(right);
276        return condNode;
277      }
278    }
279
280
281    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
282      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
283    }
284    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
285      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
286    }
287
288    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
289      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
290    }
291
292    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
293      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
294
295      var regressionProblemData = problemData as IRegressionProblemData;
296      if (regressionProblemData != null)
297        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
298
299      var classificationProblemData = problemData as IClassificationProblemData;
300      if (classificationProblemData != null)
301        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
302
303      throw new ArgumentException("The problem data is not a regression nor a classification problem data. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
304    }
305
306    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
307      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
308      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
309       rmsError: out rmsError, outOfBagRmsError: out outOfBagRmsError, avgRelError: out avgRelError, outOfBagAvgRelError: out outOfBagAvgRelError);
310    }
311
312    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
313      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
314      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
315      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
316
317      alglib.dfreport rep;
318      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
319
320      rmsError = rep.rmserror;
321      outOfBagRmsError = rep.oobrmserror;
322      avgRelError = rep.avgrelerror;
323      outOfBagAvgRelError = rep.oobavgrelerror;
324
325      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
326    }
327
328    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
329      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
330      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
331        out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
332    }
333
334    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
335      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
336
337      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
338      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
339
340      var classValues = problemData.ClassValues.ToArray();
341      int nClasses = classValues.Length;
342
343      // map original class values to values [0..nClasses-1]
344      var classIndices = new Dictionary<double, double>();
345      for (int i = 0; i < nClasses; i++) {
346        classIndices[classValues[i]] = i;
347      }
348
349      int nRows = inputMatrix.GetLength(0);
350      int nColumns = inputMatrix.GetLength(1);
351      for (int row = 0; row < nRows; row++) {
352        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
353      }
354
355      alglib.dfreport rep;
356      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
357
358      rmsError = rep.rmserror;
359      outOfBagRmsError = rep.oobrmserror;
360      relClassificationError = rep.relclserror;
361      outOfBagRelClassificationError = rep.oobrelclserror;
362
363      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
364    }
365
366    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
367      AssertParameters(r, m);
368      AssertInputMatrix(inputMatrix);
369
370      int info = 0;
371      alglib.math.rndobject = new System.Random(seed);
372      var dForest = new alglib.decisionforest();
373      rep = new alglib.dfreport();
374      int nRows = inputMatrix.GetLength(0);
375      int nColumns = inputMatrix.GetLength(1);
376      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
377      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
378
379      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
380      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
381      return dForest;
382    }
383
384    private static void AssertParameters(double r, double m) {
385      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
386      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
387    }
388
389    private static void AssertInputMatrix(double[,] inputMatrix) {
390      if (inputMatrix.ContainsNanOrInfinity())
391        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
392    }
393
394    #region persistence for backwards compatibility
395    // when the originalTrainingData is null this means the model was loaded from an old file
396    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
397    // in such cases we still store the compete model
398    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
399
400    private string[] allowedInputVariables;
401    [Storable(Name = "allowedInputVariables")]
402    private string[] AllowedInputVariables {
403      get {
404        if (IsCompatibilityLoaded) return allowedInputVariables;
405        else return originalTrainingData.AllowedInputVariables.ToArray();
406      }
407      set { allowedInputVariables = value; }
408    }
409    [Storable]
410    private int RandomForestBufSize {
411      get {
412        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
413        else return 0;
414      }
415      set {
416        randomForest.innerobj.bufsize = value;
417      }
418    }
419    [Storable]
420    private int RandomForestNClasses {
421      get {
422        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
423        else return 0;
424      }
425      set {
426        randomForest.innerobj.nclasses = value;
427      }
428    }
429    [Storable]
430    private int RandomForestNTrees {
431      get {
432        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
433        else return 0;
434      }
435      set {
436        randomForest.innerobj.ntrees = value;
437      }
438    }
439    [Storable]
440    private int RandomForestNVars {
441      get {
442        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
443        else return 0;
444      }
445      set {
446        randomForest.innerobj.nvars = value;
447      }
448    }
449    [Storable]
450    private double[] RandomForestTrees {
451      get {
452        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
453        else return new double[] { };
454      }
455      set {
456        randomForest.innerobj.trees = value;
457      }
458    }
459    #endregion
460  }
461}
Note: See TracBrowser for help on using the repository browser.