source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 16763

Last change on this file since 16763 was 16763, checked in by gkronber, 6 months ago

#2955: changed error strings when trying to load an incompatible dataset for a model.

File size: 19.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HEAL.Attic;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  /// <summary>
34  /// Represents a random forest model for regression and classification
35  /// </summary>
36  [StorableType("A4F688CD-1F42-4103-8449-7DE52AEF6C69")]
37  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
38  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
39    // not persisted
40    private alglib.decisionforest randomForest;
41    private alglib.decisionforest RandomForest {
42      get {
43        // recalculate lazily
44        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
45        return randomForest;
46      }
47    }
48
49    public override IEnumerable<string> VariablesUsedForPrediction {
50      get { return originalTrainingData.AllowedInputVariables; }
51    }
52
53    public int NumberOfTrees {
54      get { return nTrees; }
55    }
56
57    // instead of storing the data of the model itself
58    // we instead only store data necessary to recalculate the same model lazily on demand
59    [Storable]
60    private int seed;
61    [Storable]
62    private IDataAnalysisProblemData originalTrainingData;
63    [Storable]
64    private double[] classValues;
65    [Storable]
66    private int nTrees;
67    [Storable]
68    private double r;
69    [Storable]
70    private double m;
71
72    [StorableConstructor]
73    private RandomForestModel(StorableConstructorFlag _) : base(_) {
74      // for backwards compatibility (loading old solutions)
75      randomForest = new alglib.decisionforest();
76    }
77    private RandomForestModel(RandomForestModel original, Cloner cloner)
78      : base(original, cloner) {
79      randomForest = new alglib.decisionforest();
80      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
81      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
82      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
83      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
84      // we assume that the trees array (double[]) is immutable in alglib
85      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
86
87      // allowedInputVariables is immutable so we don't need to clone
88      allowedInputVariables = original.allowedInputVariables;
89
90      // clone data which is necessary to rebuild the model
91      this.seed = original.seed;
92      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
93      // classvalues is immutable so we don't need to clone
94      this.classValues = original.classValues;
95      this.nTrees = original.nTrees;
96      this.r = original.r;
97      this.m = original.m;
98    }
99
100    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
101    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
102      int seed, IDataAnalysisProblemData originalTrainingData,
103      int nTrees, double r, double m, double[] classValues = null)
104      : base(targetVariable) {
105      this.name = ItemName;
106      this.description = ItemDescription;
107      // the model itself
108      this.randomForest = randomForest;
109      // data which is necessary for recalculation of the model
110      this.seed = seed;
111      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
112      this.classValues = classValues;
113      this.nTrees = nTrees;
114      this.r = r;
115      this.m = m;
116    }
117
118    public override IDeepCloneable Clone(Cloner cloner) {
119      return new RandomForestModel(this, cloner);
120    }
121
122    private void RecalculateModel() {
123      double rmsError, oobRmsError, relClassError, oobRelClassError;
124      var regressionProblemData = originalTrainingData as IRegressionProblemData;
125      var classificationProblemData = originalTrainingData as IClassificationProblemData;
126      if (regressionProblemData != null) {
127        var model = CreateRegressionModel(regressionProblemData,
128                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
129                                              out relClassError, out oobRelClassError);
130        randomForest = model.randomForest;
131      } else if (classificationProblemData != null) {
132        var model = CreateClassificationModel(classificationProblemData,
133                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
134                                              out relClassError, out oobRelClassError);
135        randomForest = model.randomForest;
136      }
137    }
138
139    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
140      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
141      AssertInputMatrix(inputData);
142
143      int n = inputData.GetLength(0);
144      int columns = inputData.GetLength(1);
145      double[] x = new double[columns];
146      double[] y = new double[1];
147
148      for (int row = 0; row < n; row++) {
149        for (int column = 0; column < columns; column++) {
150          x[column] = inputData[row, column];
151        }
152        alglib.dfprocess(RandomForest, x, ref y);
153        yield return y[0];
154      }
155    }
156
157    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
158      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
159      AssertInputMatrix(inputData);
160
161      int n = inputData.GetLength(0);
162      int columns = inputData.GetLength(1);
163      double[] x = new double[columns];
164      double[] ys = new double[this.RandomForest.innerobj.ntrees];
165
166      for (int row = 0; row < n; row++) {
167        for (int column = 0; column < columns; column++) {
168          x[column] = inputData[row, column];
169        }
170        alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
171        yield return ys.VariancePop();
172      }
173    }
174
175    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
176      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
177      AssertInputMatrix(inputData);
178
179      int n = inputData.GetLength(0);
180      int columns = inputData.GetLength(1);
181      double[] x = new double[columns];
182      double[] y = new double[RandomForest.innerobj.nclasses];
183
184      for (int row = 0; row < n; row++) {
185        for (int column = 0; column < columns; column++) {
186          x[column] = inputData[row, column];
187        }
188        alglib.dfprocess(randomForest, x, ref y);
189        // find class for with the largest probability value
190        int maxProbClassIndex = 0;
191        double maxProb = y[0];
192        for (int i = 1; i < y.Length; i++) {
193          if (maxProb < y[i]) {
194            maxProb = y[i];
195            maxProbClassIndex = i;
196          }
197        }
198        yield return classValues[maxProbClassIndex];
199      }
200    }
201
202    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
203      var rf = RandomForest;
204      // hoping that the internal representation of alglib is stable
205
206      // TREE FORMAT
207      // W[Offs]      -   size of sub-array (for the tree)
208      //     node info:
209      // W[K+0]       -   variable number        (-1 for leaf mode)
210      // W[K+1]       -   threshold              (class/value for leaf node)
211      // W[K+2]       -   ">=" branch index      (absent for leaf node)
212
213      // skip irrelevant trees
214      int offset = 0;
215      for (int i = 0; i < treeIdx - 1; i++) {
216        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
217      }
218
219      var constSy = new Constant();
220      var varCondSy = new VariableCondition() { IgnoreSlope = true };
221
222      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy);
223
224      var startNode = new StartSymbol().CreateTreeNode();
225      startNode.AddSubtree(node);
226      var root = new ProgramRootSymbol().CreateTreeNode();
227      root.AddSubtree(startNode);
228      return new SymbolicExpressionTree(root);
229    }
230
231    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
232
233      // alglib source for evaluation of one tree (dfprocessinternal)
234      // offs = 0
235      //
236      // Set pointer to the root
237      //
238      // k = offs + 1;
239      //
240      // //
241      // // Navigate through the tree
242      // //
243      // while (true) {
244      //   if ((double)(df.trees[k]) == (double)(-1)) {
245      //     if (df.nclasses == 1) {
246      //       y[0] = y[0] + df.trees[k + 1];
247      //     } else {
248      //       idx = (int)Math.Round(df.trees[k + 1]);
249      //       y[idx] = y[idx] + 1;
250      //     }
251      //     break;
252      //   }
253      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
254      //     k = k + innernodewidth;
255      //   } else {
256      //     k = offs + (int)Math.Round(df.trees[k + 2]);
257      //   }
258      // }
259
260      if ((double)(trees[k]) == (double)(-1)) {
261        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
262        constNode.Value = trees[k + 1];
263        return constNode;
264      } else {
265        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
266        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
267        condNode.Threshold = trees[k + 1];
268        condNode.Slope = double.PositiveInfinity;
269
270        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
271        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
272
273        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
274        condNode.AddSubtree(right);
275        return condNode;
276      }
277    }
278
279
280    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
281      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
282    }
283    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
284      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
285    }
286
287    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
288      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
289    }
290
291    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
292      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
293
294      var regressionProblemData = problemData as IRegressionProblemData;
295      if (regressionProblemData != null)
296        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
297
298      var classificationProblemData = problemData as IClassificationProblemData;
299      if (classificationProblemData != null)
300        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
301
302      throw new ArgumentException("The problem data is not compatible with this random forest. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
303    }
304
305    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
306      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
307      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
308       rmsError: out rmsError, outOfBagRmsError: out outOfBagRmsError, avgRelError: out avgRelError, outOfBagAvgRelError: out outOfBagAvgRelError);
309    }
310
311    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
312      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
313      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
314      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
315
316      alglib.dfreport rep;
317      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
318
319      rmsError = rep.rmserror;
320      outOfBagRmsError = rep.oobrmserror;
321      avgRelError = rep.avgrelerror;
322      outOfBagAvgRelError = rep.oobavgrelerror;
323
324      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
325    }
326
327    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
328      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
329      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
330        out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
331    }
332
333    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
334      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
335
336      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
337      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
338
339      var classValues = problemData.ClassValues.ToArray();
340      int nClasses = classValues.Length;
341
342      // map original class values to values [0..nClasses-1]
343      var classIndices = new Dictionary<double, double>();
344      for (int i = 0; i < nClasses; i++) {
345        classIndices[classValues[i]] = i;
346      }
347
348      int nRows = inputMatrix.GetLength(0);
349      int nColumns = inputMatrix.GetLength(1);
350      for (int row = 0; row < nRows; row++) {
351        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
352      }
353
354      alglib.dfreport rep;
355      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
356
357      rmsError = rep.rmserror;
358      outOfBagRmsError = rep.oobrmserror;
359      relClassificationError = rep.relclserror;
360      outOfBagRelClassificationError = rep.oobrelclserror;
361
362      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
363    }
364
365    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
366      AssertParameters(r, m);
367      AssertInputMatrix(inputMatrix);
368
369      int info = 0;
370      alglib.math.rndobject = new System.Random(seed);
371      var dForest = new alglib.decisionforest();
372      rep = new alglib.dfreport();
373      int nRows = inputMatrix.GetLength(0);
374      int nColumns = inputMatrix.GetLength(1);
375      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
376      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
377
378      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
379      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
380      return dForest;
381    }
382
383    private static void AssertParameters(double r, double m) {
384      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
385      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
386    }
387
388    private static void AssertInputMatrix(double[,] inputMatrix) {
389      if (inputMatrix.ContainsNanOrInfinity())
390        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
391    }
392
393    #region persistence for backwards compatibility
394    // when the originalTrainingData is null this means the model was loaded from an old file
395    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
396    // in such cases we still store the compete model
397    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
398
399    private string[] allowedInputVariables;
400    [Storable(Name = "allowedInputVariables")]
401    private string[] AllowedInputVariables {
402      get {
403        if (IsCompatibilityLoaded) return allowedInputVariables;
404        else return originalTrainingData.AllowedInputVariables.ToArray();
405      }
406      set { allowedInputVariables = value; }
407    }
408    [Storable]
409    private int RandomForestBufSize {
410      get {
411        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
412        else return 0;
413      }
414      set {
415        randomForest.innerobj.bufsize = value;
416      }
417    }
418    [Storable]
419    private int RandomForestNClasses {
420      get {
421        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
422        else return 0;
423      }
424      set {
425        randomForest.innerobj.nclasses = value;
426      }
427    }
428    [Storable]
429    private int RandomForestNTrees {
430      get {
431        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
432        else return 0;
433      }
434      set {
435        randomForest.innerobj.ntrees = value;
436      }
437    }
438    [Storable]
439    private int RandomForestNVars {
440      get {
441        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
442        else return 0;
443      }
444      set {
445        randomForest.innerobj.nvars = value;
446      }
447    }
448    [Storable]
449    private double[] RandomForestTrees {
450      get {
451        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
452        else return new double[] { };
453      }
454      set {
455        randomForest.innerobj.trees = value;
456      }
457    }
458    #endregion
459  }
460}
Note: See TracBrowser for help on using the repository browser.