Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 14027

Last change on this file since 14027 was 14027, checked in by mkommend, 8 years ago

#2604: Merged r13826,r13921, r13922, r13941, r13992, r13993, r14000 intos table.

File size: 14.5 KB
RevLine 
[6240]1#region License Information
2/* HeuristicLab
[12009]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6240]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
[6241]32  /// Represents a random forest model for regression and classification
[6240]33  /// </summary>
34  [StorableClass]
[6241]35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
[14027]36  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
[11006]37    // not persisted
[6240]38    private alglib.decisionforest randomForest;
[11006]39    private alglib.decisionforest RandomForest {
40      get {
41        // recalculate lazily
42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
43        return randomForest;
[6240]44      }
45    }
46
[14027]47    public override IEnumerable<string> VariablesUsedForPrediction {
48      get { return originalTrainingData.AllowedInputVariables; }
49    }
50
51
[11006]52    // instead of storing the data of the model itself
53    // we instead only store data necessary to recalculate the same model lazily on demand
[6240]54    [Storable]
[11006]55    private int seed;
[6240]56    [Storable]
[11006]57    private IDataAnalysisProblemData originalTrainingData;
[6241]58    [Storable]
59    private double[] classValues;
[11006]60    [Storable]
61    private int nTrees;
62    [Storable]
63    private double r;
64    [Storable]
65    private double m;
66
67
[6240]68    [StorableConstructor]
[6241]69    private RandomForestModel(bool deserializing)
[6240]70      : base(deserializing) {
[11006]71      // for backwards compatibility (loading old solutions)
72      randomForest = new alglib.decisionforest();
[6240]73    }
[6241]74    private RandomForestModel(RandomForestModel original, Cloner cloner)
[6240]75      : base(original, cloner) {
76      randomForest = new alglib.decisionforest();
77      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
78      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
79      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
80      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
[11006]81      // we assume that the trees array (double[]) is immutable in alglib
82      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
[11901]83
[11006]84      // allowedInputVariables is immutable so we don't need to clone
85      allowedInputVariables = original.allowedInputVariables;
86
87      // clone data which is necessary to rebuild the model
88      this.seed = original.seed;
89      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
90      // classvalues is immutable so we don't need to clone
91      this.classValues = original.classValues;
92      this.nTrees = original.nTrees;
93      this.r = original.r;
94      this.m = original.m;
[6240]95    }
[11006]96
97    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
[14027]98    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
[11006]99      int seed, IDataAnalysisProblemData originalTrainingData,
100      int nTrees, double r, double m, double[] classValues = null)
[14027]101      : base(targetVariable) {
[6240]102      this.name = ItemName;
103      this.description = ItemDescription;
[11006]104      // the model itself
[6240]105      this.randomForest = randomForest;
[11006]106      // data which is necessary for recalculation of the model
107      this.seed = seed;
108      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
109      this.classValues = classValues;
110      this.nTrees = nTrees;
111      this.r = r;
112      this.m = m;
[6240]113    }
114
115    public override IDeepCloneable Clone(Cloner cloner) {
[6241]116      return new RandomForestModel(this, cloner);
[6240]117    }
118
[11006]119    private void RecalculateModel() {
120      double rmsError, oobRmsError, relClassError, oobRelClassError;
121      var regressionProblemData = originalTrainingData as IRegressionProblemData;
122      var classificationProblemData = originalTrainingData as IClassificationProblemData;
123      if (regressionProblemData != null) {
124        var model = CreateRegressionModel(regressionProblemData,
125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
126                                              out relClassError, out oobRelClassError);
127        randomForest = model.randomForest;
128      } else if (classificationProblemData != null) {
129        var model = CreateClassificationModel(classificationProblemData,
130                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
131                                              out relClassError, out oobRelClassError);
132        randomForest = model.randomForest;
133      }
134    }
135
[12702]136    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
[11006]137      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
138      AssertInputMatrix(inputData);
[6240]139
140      int n = inputData.GetLength(0);
141      int columns = inputData.GetLength(1);
142      double[] x = new double[columns];
143      double[] y = new double[1];
144
145      for (int row = 0; row < n; row++) {
146        for (int column = 0; column < columns; column++) {
147          x[column] = inputData[row, column];
148        }
[11006]149        alglib.dfprocess(RandomForest, x, ref y);
[6240]150        yield return y[0];
151      }
152    }
153
[14027]154    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
[11006]155      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
156      AssertInputMatrix(inputData);
[6241]157
158      int n = inputData.GetLength(0);
159      int columns = inputData.GetLength(1);
160      double[] x = new double[columns];
[11006]161      double[] y = new double[RandomForest.innerobj.nclasses];
[6241]162
163      for (int row = 0; row < n; row++) {
164        for (int column = 0; column < columns; column++) {
165          x[column] = inputData[row, column];
166        }
167        alglib.dfprocess(randomForest, x, ref y);
168        // find class for with the largest probability value
169        int maxProbClassIndex = 0;
170        double maxProb = y[0];
171        for (int i = 1; i < y.Length; i++) {
172          if (maxProb < y[i]) {
173            maxProb = y[i];
174            maxProbClassIndex = i;
175          }
176        }
177        yield return classValues[maxProbClassIndex];
178      }
179    }
180
[14027]181
182    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
183      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
[6603]184    }
[14027]185    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
186      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
[6603]187    }
188
[11006]189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
[11901]190      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
191      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
192    }
[11006]193
[11901]194    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
195      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
[11006]196      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
[11901]197      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
[11006]198
199      alglib.dfreport rep;
200      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
201
202      rmsError = rep.rmserror;
203      avgRelError = rep.avgrelerror;
204      outOfBagAvgRelError = rep.oobavgrelerror;
205      outOfBagRmsError = rep.oobrmserror;
206
[14027]207      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
[6240]208    }
209
[11006]210    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
211      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
[11901]212      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
213    }
[11006]214
[11901]215    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
216      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
217
[11006]218      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
[11901]219      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
[11006]220
221      var classValues = problemData.ClassValues.ToArray();
222      int nClasses = classValues.Length;
223
224      // map original class values to values [0..nClasses-1]
225      var classIndices = new Dictionary<double, double>();
226      for (int i = 0; i < nClasses; i++) {
227        classIndices[classValues[i]] = i;
228      }
229
230      int nRows = inputMatrix.GetLength(0);
231      int nColumns = inputMatrix.GetLength(1);
232      for (int row = 0; row < nRows; row++) {
233        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
234      }
235
236      alglib.dfreport rep;
237      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
238
239      rmsError = rep.rmserror;
240      outOfBagRmsError = rep.oobrmserror;
241      relClassificationError = rep.relclserror;
242      outOfBagRelClassificationError = rep.oobrelclserror;
243
[14027]244      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
[11006]245    }
246
247    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
248      AssertParameters(r, m);
249      AssertInputMatrix(inputMatrix);
250
251      int info = 0;
252      alglib.math.rndobject = new System.Random(seed);
253      var dForest = new alglib.decisionforest();
254      rep = new alglib.dfreport();
255      int nRows = inputMatrix.GetLength(0);
256      int nColumns = inputMatrix.GetLength(1);
257      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
258      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
259
260      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
261      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
262      return dForest;
263    }
264
265    private static void AssertParameters(double r, double m) {
266      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
267      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
268    }
269
270    private static void AssertInputMatrix(double[,] inputMatrix) {
[11901]271      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
[11006]272        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
273    }
274
275    #region persistence for backwards compatibility
276    // when the originalTrainingData is null this means the model was loaded from an old file
277    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
278    // in such cases we still store the compete model
279    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
280
281    private string[] allowedInputVariables;
282    [Storable(Name = "allowedInputVariables")]
283    private string[] AllowedInputVariables {
284      get {
285        if (IsCompatibilityLoaded) return allowedInputVariables;
286        else return originalTrainingData.AllowedInputVariables.ToArray();
287      }
288      set { allowedInputVariables = value; }
289    }
[6240]290    [Storable]
291    private int RandomForestBufSize {
292      get {
[11006]293        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
294        else return 0;
[6240]295      }
296      set {
297        randomForest.innerobj.bufsize = value;
298      }
299    }
300    [Storable]
301    private int RandomForestNClasses {
302      get {
[11006]303        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
304        else return 0;
[6240]305      }
306      set {
307        randomForest.innerobj.nclasses = value;
308      }
309    }
310    [Storable]
311    private int RandomForestNTrees {
312      get {
[11006]313        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
314        else return 0;
[6240]315      }
316      set {
317        randomForest.innerobj.ntrees = value;
318      }
319    }
320    [Storable]
321    private int RandomForestNVars {
322      get {
[11006]323        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
324        else return 0;
[6240]325      }
326      set {
327        randomForest.innerobj.nvars = value;
328      }
329    }
330    [Storable]
331    private double[] RandomForestTrees {
332      get {
[11006]333        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
334        else return new double[] { };
[6240]335      }
336      set {
337        randomForest.innerobj.trees = value;
338      }
339    }
340    #endregion
341  }
342}
Note: See TracBrowser for help on using the repository browser.