source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 17157

Last change on this file since 17157 was 17157, checked in by gkronber, 8 weeks ago

#2952: merged r17154 from trunk to stable

File size: 18.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Problems.DataAnalysis;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  /// <summary>
34  /// Represents a random forest model for regression and classification
35  /// </summary>
36  [Obsolete("This class only exists for backwards compatibility reasons for stored models with the XML Persistence. Use RFModelSurrogate or RFModelFull instead.")]
37  [StorableType("9AA4CCC2-CD75-4471-8DF6-949E5B783642")]
38  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
39  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
40    // not persisted
41    private alglib.decisionforest randomForest;
42    private alglib.decisionforest RandomForest {
43      get {
44        // recalculate lazily
45        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
46        return randomForest;
47      }
48    }
49
50    public override IEnumerable<string> VariablesUsedForPrediction {
51      get { return originalTrainingData.AllowedInputVariables; }
52    }
53
54    public int NumberOfTrees {
55      get { return nTrees; }
56    }
57
58    // instead of storing the data of the model itself
59    // we instead only store data necessary to recalculate the same model lazily on demand
60    [Storable]
61    private int seed;
62    [Storable]
63    private IDataAnalysisProblemData originalTrainingData;
64    [Storable]
65    private double[] classValues;
66    [Storable]
67    private int nTrees;
68    [Storable]
69    private double r;
70    [Storable]
71    private double m;
72
73    [StorableConstructor]
74    private RandomForestModel(StorableConstructorFlag _) : base(_) {
75      // for backwards compatibility (loading old solutions)
76      randomForest = new alglib.decisionforest();
77    }
78    private RandomForestModel(RandomForestModel original, Cloner cloner)
79      : base(original, cloner) {
80      randomForest = new alglib.decisionforest();
81      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
82      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
83      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
84      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
85      // we assume that the trees array (double[]) is immutable in alglib
86      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
87
88      // allowedInputVariables is immutable so we don't need to clone
89      allowedInputVariables = original.allowedInputVariables;
90
91      // clone data which is necessary to rebuild the model
92      this.seed = original.seed;
93      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
94      // classvalues is immutable so we don't need to clone
95      this.classValues = original.classValues;
96      this.nTrees = original.nTrees;
97      this.r = original.r;
98      this.m = original.m;
99    }
100
101    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
102    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
103      int seed, IDataAnalysisProblemData originalTrainingData,
104      int nTrees, double r, double m, double[] classValues = null)
105      : base(targetVariable) {
106      this.name = ItemName;
107      this.description = ItemDescription;
108      // the model itself
109      this.randomForest = randomForest;
110      // data which is necessary for recalculation of the model
111      this.seed = seed;
112      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
113      this.classValues = classValues;
114      this.nTrees = nTrees;
115      this.r = r;
116      this.m = m;
117    }
118
119    public override IDeepCloneable Clone(Cloner cloner) {
120      return new RandomForestModel(this, cloner);
121    }
122
123    private void RecalculateModel() {
124      double rmsError, oobRmsError, relClassError, oobRelClassError;
125      var regressionProblemData = originalTrainingData as IRegressionProblemData;
126      var classificationProblemData = originalTrainingData as IClassificationProblemData;
127      if (regressionProblemData != null) {
128        var model = CreateRegressionModel(regressionProblemData,
129                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
130                                              out relClassError, out oobRelClassError);
131        randomForest = model.randomForest;
132      } else if (classificationProblemData != null) {
133        var model = CreateClassificationModel(classificationProblemData,
134                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
135                                              out relClassError, out oobRelClassError);
136        randomForest = model.randomForest;
137      }
138    }
139
140    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
141      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
142      RandomForestUtil.AssertInputMatrix(inputData);
143
144      int n = inputData.GetLength(0);
145      int columns = inputData.GetLength(1);
146      double[] x = new double[columns];
147      double[] y = new double[1];
148
149      for (int row = 0; row < n; row++) {
150        for (int column = 0; column < columns; column++) {
151          x[column] = inputData[row, column];
152        }
153        alglib.dfprocess(RandomForest, x, ref y);
154        yield return y[0];
155      }
156    }
157
158    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
159      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
160      RandomForestUtil.AssertInputMatrix(inputData);
161
162      int n = inputData.GetLength(0);
163      int columns = inputData.GetLength(1);
164      double[] x = new double[columns];
165      double[] ys = new double[this.RandomForest.innerobj.ntrees];
166
167      for (int row = 0; row < n; row++) {
168        for (int column = 0; column < columns; column++) {
169          x[column] = inputData[row, column];
170        }
171        alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
172        yield return ys.VariancePop();
173      }
174    }
175
176    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
177      double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
178      RandomForestUtil.AssertInputMatrix(inputData);
179
180      int n = inputData.GetLength(0);
181      int columns = inputData.GetLength(1);
182      double[] x = new double[columns];
183      double[] y = new double[RandomForest.innerobj.nclasses];
184
185      for (int row = 0; row < n; row++) {
186        for (int column = 0; column < columns; column++) {
187          x[column] = inputData[row, column];
188        }
189        alglib.dfprocess(randomForest, x, ref y);
190        // find class for with the largest probability value
191        int maxProbClassIndex = 0;
192        double maxProb = y[0];
193        for (int i = 1; i < y.Length; i++) {
194          if (maxProb < y[i]) {
195            maxProb = y[i];
196            maxProbClassIndex = i;
197          }
198        }
199        yield return classValues[maxProbClassIndex];
200      }
201    }
202
203    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
204      var rf = RandomForest;
205      // hoping that the internal representation of alglib is stable
206
207      // TREE FORMAT
208      // W[Offs]      -   size of sub-array (for the tree)
209      //     node info:
210      // W[K+0]       -   variable number        (-1 for leaf mode)
211      // W[K+1]       -   threshold              (class/value for leaf node)
212      // W[K+2]       -   ">=" branch index      (absent for leaf node)
213
214      // skip irrelevant trees
215      int offset = 0;
216      for (int i = 0; i < treeIdx - 1; i++) {
217        offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
218      }
219
220      var constSy = new Constant();
221      var varCondSy = new VariableCondition() { IgnoreSlope = true };
222
223      var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy);
224
225      var startNode = new StartSymbol().CreateTreeNode();
226      startNode.AddSubtree(node);
227      var root = new ProgramRootSymbol().CreateTreeNode();
228      root.AddSubtree(startNode);
229      return new SymbolicExpressionTree(root);
230    }
231
232    private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) {
233
234      // alglib source for evaluation of one tree (dfprocessinternal)
235      // offs = 0
236      //
237      // Set pointer to the root
238      //
239      // k = offs + 1;
240      //
241      // //
242      // // Navigate through the tree
243      // //
244      // while (true) {
245      //   if ((double)(df.trees[k]) == (double)(-1)) {
246      //     if (df.nclasses == 1) {
247      //       y[0] = y[0] + df.trees[k + 1];
248      //     } else {
249      //       idx = (int)Math.Round(df.trees[k + 1]);
250      //       y[idx] = y[idx] + 1;
251      //     }
252      //     break;
253      //   }
254      //   if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
255      //     k = k + innernodewidth;
256      //   } else {
257      //     k = offs + (int)Math.Round(df.trees[k + 2]);
258      //   }
259      // }
260
261      if ((double)(trees[k]) == (double)(-1)) {
262        var constNode = (ConstantTreeNode)constSy.CreateTreeNode();
263        constNode.Value = trees[k + 1];
264        return constNode;
265      } else {
266        var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
267        condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
268        condNode.Threshold = trees[k + 1];
269        condNode.Slope = double.PositiveInfinity;
270
271        var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy);
272        var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy);
273
274        condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
275        condNode.AddSubtree(right);
276        return condNode;
277      }
278    }
279
280
281    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
282      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
283    }
284    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
285      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
286    }
287
288    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
289      return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
290    }
291
292    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
293      if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
294
295      var regressionProblemData = problemData as IRegressionProblemData;
296      if (regressionProblemData != null)
297        return IsProblemDataCompatible(regressionProblemData, out errorMessage);
298
299      var classificationProblemData = problemData as IClassificationProblemData;
300      if (classificationProblemData != null)
301        return IsProblemDataCompatible(classificationProblemData, out errorMessage);
302
303      throw new ArgumentException("The problem data is not compatible with this random forest. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
304    }
305
306    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
307      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
308      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
309       rmsError: out rmsError, outOfBagRmsError: out outOfBagRmsError, avgRelError: out avgRelError, outOfBagAvgRelError: out outOfBagAvgRelError);
310    }
311
312    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
313      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
314      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
315      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
316
317      alglib.dfreport rep;
318      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
319
320      rmsError = rep.rmserror;
321      outOfBagRmsError = rep.oobrmserror;
322      avgRelError = rep.avgrelerror;
323      outOfBagAvgRelError = rep.oobavgrelerror;
324
325      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
326    }
327
328    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
329      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
330      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
331        out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
332    }
333
334    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
335      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
336
337      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
338      double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
339
340      var classValues = problemData.ClassValues.ToArray();
341      int nClasses = classValues.Length;
342
343      // map original class values to values [0..nClasses-1]
344      var classIndices = new Dictionary<double, double>();
345      for (int i = 0; i < nClasses; i++) {
346        classIndices[classValues[i]] = i;
347      }
348
349      int nRows = inputMatrix.GetLength(0);
350      int nColumns = inputMatrix.GetLength(1);
351      for (int row = 0; row < nRows; row++) {
352        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
353      }
354
355      alglib.dfreport rep;
356      var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
357
358      rmsError = rep.rmserror;
359      outOfBagRmsError = rep.oobrmserror;
360      relClassificationError = rep.relclserror;
361      outOfBagRelClassificationError = rep.oobrelclserror;
362
363      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
364    }
365
366    #region persistence for backwards compatibility
367    // when the originalTrainingData is null this means the model was loaded from an old file
368    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
369    // in such cases we still store the compete model
370    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
371
372    private string[] allowedInputVariables;
373    [Storable(Name = "allowedInputVariables")]
374    private string[] AllowedInputVariables {
375      get {
376        if (IsCompatibilityLoaded) return allowedInputVariables;
377        else return originalTrainingData.AllowedInputVariables.ToArray();
378      }
379      set { allowedInputVariables = value; }
380    }
381    [Storable]
382    private int RandomForestBufSize {
383      get {
384        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
385        else return 0;
386      }
387      set {
388        randomForest.innerobj.bufsize = value;
389      }
390    }
391    [Storable]
392    private int RandomForestNClasses {
393      get {
394        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
395        else return 0;
396      }
397      set {
398        randomForest.innerobj.nclasses = value;
399      }
400    }
401    [Storable]
402    private int RandomForestNTrees {
403      get {
404        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
405        else return 0;
406      }
407      set {
408        randomForest.innerobj.ntrees = value;
409      }
410    }
411    [Storable]
412    private int RandomForestNVars {
413      get {
414        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
415        else return 0;
416      }
417      set {
418        randomForest.innerobj.nvars = value;
419      }
420    }
421    [Storable]
422    private double[] RandomForestTrees {
423      get {
424        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
425        else return new double[] { };
426      }
427      set {
428        randomForest.innerobj.trees = value;
429      }
430    }
431    #endregion
432  }
433}
Note: See TracBrowser for help on using the repository browser.