Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 12417

Last change on this file since 12417 was 12009, checked in by ascheibe, 10 years ago

#2212 updated copyright year

File size: 14.6 KB
RevLine 
[6240]1#region License Information
2/* HeuristicLab
[12009]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6240]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
[6241]32  /// Represents a random forest model for regression and classification
[6240]33  /// </summary>
34  [StorableClass]
[6241]35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
36  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
[11006]37    // not persisted
[6240]38    private alglib.decisionforest randomForest;
[11006]39    private alglib.decisionforest RandomForest {
40      get {
41        // recalculate lazily
42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
43        return randomForest;
[6240]44      }
45    }
46
[11006]47    // instead of storing the data of the model itself
48    // we instead only store data necessary to recalculate the same model lazily on demand
[6240]49    [Storable]
[11006]50    private int seed;
[6240]51    [Storable]
[11006]52    private IDataAnalysisProblemData originalTrainingData;
[6241]53    [Storable]
54    private double[] classValues;
[11006]55    [Storable]
56    private int nTrees;
57    [Storable]
58    private double r;
59    [Storable]
60    private double m;
61
62
[6240]63    [StorableConstructor]
[6241]64    private RandomForestModel(bool deserializing)
[6240]65      : base(deserializing) {
[11006]66      // for backwards compatibility (loading old solutions)
67      randomForest = new alglib.decisionforest();
[6240]68    }
[6241]69    private RandomForestModel(RandomForestModel original, Cloner cloner)
[6240]70      : base(original, cloner) {
71      randomForest = new alglib.decisionforest();
72      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
73      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
74      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
75      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
[11006]76      // we assume that the trees array (double[]) is immutable in alglib
77      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
[11901]78
[11006]79      // allowedInputVariables is immutable so we don't need to clone
80      allowedInputVariables = original.allowedInputVariables;
81
82      // clone data which is necessary to rebuild the model
83      this.seed = original.seed;
84      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
85      // classvalues is immutable so we don't need to clone
86      this.classValues = original.classValues;
87      this.nTrees = original.nTrees;
88      this.r = original.r;
89      this.m = original.m;
[6240]90    }
[11006]91
92    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
93    private RandomForestModel(alglib.decisionforest randomForest,
94      int seed, IDataAnalysisProblemData originalTrainingData,
95      int nTrees, double r, double m, double[] classValues = null)
[6240]96      : base() {
97      this.name = ItemName;
98      this.description = ItemDescription;
[11006]99      // the model itself
[6240]100      this.randomForest = randomForest;
[11006]101      // data which is necessary for recalculation of the model
102      this.seed = seed;
103      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
104      this.classValues = classValues;
105      this.nTrees = nTrees;
106      this.r = r;
107      this.m = m;
[6240]108    }
109
110    public override IDeepCloneable Clone(Cloner cloner) {
[6241]111      return new RandomForestModel(this, cloner);
[6240]112    }
113
[11006]114    private void RecalculateModel() {
115      double rmsError, oobRmsError, relClassError, oobRelClassError;
116      var regressionProblemData = originalTrainingData as IRegressionProblemData;
117      var classificationProblemData = originalTrainingData as IClassificationProblemData;
118      if (regressionProblemData != null) {
119        var model = CreateRegressionModel(regressionProblemData,
120                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
121                                              out relClassError, out oobRelClassError);
122        randomForest = model.randomForest;
123      } else if (classificationProblemData != null) {
124        var model = CreateClassificationModel(classificationProblemData,
125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
126                                              out relClassError, out oobRelClassError);
127        randomForest = model.randomForest;
128      }
129    }
130
[6240]131    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
[11006]132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
133      AssertInputMatrix(inputData);
[6240]134
135      int n = inputData.GetLength(0);
136      int columns = inputData.GetLength(1);
137      double[] x = new double[columns];
138      double[] y = new double[1];
139
140      for (int row = 0; row < n; row++) {
141        for (int column = 0; column < columns; column++) {
142          x[column] = inputData[row, column];
143        }
[11006]144        alglib.dfprocess(RandomForest, x, ref y);
[6240]145        yield return y[0];
146      }
147    }
148
[6241]149    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
[11006]150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
151      AssertInputMatrix(inputData);
[6241]152
153      int n = inputData.GetLength(0);
154      int columns = inputData.GetLength(1);
155      double[] x = new double[columns];
[11006]156      double[] y = new double[RandomForest.innerobj.nclasses];
[6241]157
158      for (int row = 0; row < n; row++) {
159        for (int column = 0; column < columns; column++) {
160          x[column] = inputData[row, column];
161        }
162        alglib.dfprocess(randomForest, x, ref y);
163        // find class for with the largest probability value
164        int maxProbClassIndex = 0;
165        double maxProb = y[0];
166        for (int i = 1; i < y.Length; i++) {
167          if (maxProb < y[i]) {
168            maxProb = y[i];
169            maxProbClassIndex = i;
170          }
171        }
172        yield return classValues[maxProbClassIndex];
173      }
174    }
175
[6603]176    public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
[8528]177      return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);
[6603]178    }
179    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
180      return CreateRegressionSolution(problemData);
181    }
[6604]182    public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
[8528]183      return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);
[6604]184    }
185    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
186      return CreateClassificationSolution(problemData);
187    }
[6603]188
[11006]189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
[11901]190      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
191      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
192    }
[11006]193
[11901]194    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
195      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
[11006]196      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
[11901]197      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
[11006]198
199      alglib.dfreport rep;
200      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
201
202      rmsError = rep.rmserror;
203      avgRelError = rep.avgrelerror;
204      outOfBagAvgRelError = rep.oobavgrelerror;
205      outOfBagRmsError = rep.oobrmserror;
206
[11901]207      return new RandomForestModel(dForest,seed, problemData,nTrees, r, m);
[6240]208    }
209
[11006]210    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
211      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
[11901]212      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
213    }
[11006]214
[11901]215    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
216      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
217
[11006]218      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
[11901]219      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
[11006]220
221      var classValues = problemData.ClassValues.ToArray();
222      int nClasses = classValues.Length;
223
224      // map original class values to values [0..nClasses-1]
225      var classIndices = new Dictionary<double, double>();
226      for (int i = 0; i < nClasses; i++) {
227        classIndices[classValues[i]] = i;
228      }
229
230      int nRows = inputMatrix.GetLength(0);
231      int nColumns = inputMatrix.GetLength(1);
232      for (int row = 0; row < nRows; row++) {
233        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
234      }
235
236      alglib.dfreport rep;
237      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
238
239      rmsError = rep.rmserror;
240      outOfBagRmsError = rep.oobrmserror;
241      relClassificationError = rep.relclserror;
242      outOfBagRelClassificationError = rep.oobrelclserror;
243
[11901]244      return new RandomForestModel(dForest,seed, problemData,nTrees, r, m, classValues);
[11006]245    }
246
247    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
248      AssertParameters(r, m);
249      AssertInputMatrix(inputMatrix);
250
251      int info = 0;
252      alglib.math.rndobject = new System.Random(seed);
253      var dForest = new alglib.decisionforest();
254      rep = new alglib.dfreport();
255      int nRows = inputMatrix.GetLength(0);
256      int nColumns = inputMatrix.GetLength(1);
257      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
258      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
259
260      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
261      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
262      return dForest;
263    }
264
265    private static void AssertParameters(double r, double m) {
266      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
267      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
268    }
269
270    private static void AssertInputMatrix(double[,] inputMatrix) {
[11901]271      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
[11006]272        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
273    }
274
275    #region persistence for backwards compatibility
276    // when the originalTrainingData is null this means the model was loaded from an old file
277    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
278    // in such cases we still store the compete model
279    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
280
281    private string[] allowedInputVariables;
282    [Storable(Name = "allowedInputVariables")]
283    private string[] AllowedInputVariables {
284      get {
285        if (IsCompatibilityLoaded) return allowedInputVariables;
286        else return originalTrainingData.AllowedInputVariables.ToArray();
287      }
288      set { allowedInputVariables = value; }
289    }
[6240]290    [Storable]
291    private int RandomForestBufSize {
292      get {
[11006]293        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
294        else return 0;
[6240]295      }
296      set {
297        randomForest.innerobj.bufsize = value;
298      }
299    }
300    [Storable]
301    private int RandomForestNClasses {
302      get {
[11006]303        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
304        else return 0;
[6240]305      }
306      set {
307        randomForest.innerobj.nclasses = value;
308      }
309    }
310    [Storable]
311    private int RandomForestNTrees {
312      get {
[11006]313        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
314        else return 0;
[6240]315      }
316      set {
317        randomForest.innerobj.ntrees = value;
318      }
319    }
320    [Storable]
321    private int RandomForestNVars {
322      get {
[11006]323        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
324        else return 0;
[6240]325      }
326      set {
327        randomForest.innerobj.nvars = value;
328      }
329    }
330    [Storable]
331    private double[] RandomForestTrees {
332      get {
[11006]333        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
334        else return new double[] { };
[6240]335      }
336      set {
337        randomForest.innerobj.trees = value;
338      }
339    }
340    #endregion
341  }
342}
Note: See TracBrowser for help on using the repository browser.