Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 11315

Last change on this file since 11315 was 11315, checked in by bburlacu, 10 years ago

#2237: Added RandomForestUtil class implementing fold generation, cross-validation and grid search. Overloaded CreateRegressionModel method to accept a user-specified data partition.

File size: 14.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a random forest model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
36  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
37    // not persisted
38    private alglib.decisionforest randomForest;
39    private alglib.decisionforest RandomForest {
40      get {
41        // recalculate lazily
42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
43        return randomForest;
44      }
45    }
46
47    // instead of storing the data of the model itself
48    // we instead only store data necessary to recalculate the same model lazily on demand
49    [Storable]
50    private int seed;
51    [Storable]
52    private IDataAnalysisProblemData originalTrainingData;
53    [Storable]
54    private double[] classValues;
55    [Storable]
56    private int nTrees;
57    [Storable]
58    private double r;
59    [Storable]
60    private double m;
61
62
63    [StorableConstructor]
64    private RandomForestModel(bool deserializing)
65      : base(deserializing) {
66      // for backwards compatibility (loading old solutions)
67      randomForest = new alglib.decisionforest();
68    }
69    private RandomForestModel(RandomForestModel original, Cloner cloner)
70      : base(original, cloner) {
71      randomForest = new alglib.decisionforest();
72      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
73      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
74      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
75      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
76      // we assume that the trees array (double[]) is immutable in alglib
77      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
78
79      // allowedInputVariables is immutable so we don't need to clone
80      allowedInputVariables = original.allowedInputVariables;
81
82      // clone data which is necessary to rebuild the model
83      this.seed = original.seed;
84      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
85      // classvalues is immutable so we don't need to clone
86      this.classValues = original.classValues;
87      this.nTrees = original.nTrees;
88      this.r = original.r;
89      this.m = original.m;
90    }
91
92    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
93    private RandomForestModel(alglib.decisionforest randomForest,
94      int seed, IDataAnalysisProblemData originalTrainingData,
95      int nTrees, double r, double m, double[] classValues = null)
96      : base() {
97      this.name = ItemName;
98      this.description = ItemDescription;
99      // the model itself
100      this.randomForest = randomForest;
101      // data which is necessary for recalculation of the model
102      this.seed = seed;
103      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
104      this.classValues = classValues;
105      this.nTrees = nTrees;
106      this.r = r;
107      this.m = m;
108    }
109
110    public override IDeepCloneable Clone(Cloner cloner) {
111      return new RandomForestModel(this, cloner);
112    }
113
114    private void RecalculateModel() {
115      double rmsError, oobRmsError, relClassError, oobRelClassError;
116      var regressionProblemData = originalTrainingData as IRegressionProblemData;
117      var classificationProblemData = originalTrainingData as IClassificationProblemData;
118      if (regressionProblemData != null) {
119        var model = CreateRegressionModel(regressionProblemData,
120                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
121                                              out relClassError, out oobRelClassError);
122        randomForest = model.randomForest;
123      } else if (classificationProblemData != null) {
124        var model = CreateClassificationModel(classificationProblemData,
125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
126                                              out relClassError, out oobRelClassError);
127        randomForest = model.randomForest;
128      }
129    }
130
131    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
133      AssertInputMatrix(inputData);
134
135      int n = inputData.GetLength(0);
136      int columns = inputData.GetLength(1);
137      double[] x = new double[columns];
138      double[] y = new double[1];
139
140      for (int row = 0; row < n; row++) {
141        for (int column = 0; column < columns; column++) {
142          x[column] = inputData[row, column];
143        }
144        alglib.dfprocess(RandomForest, x, ref y);
145        yield return y[0];
146      }
147    }
148
149    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
151      AssertInputMatrix(inputData);
152
153      int n = inputData.GetLength(0);
154      int columns = inputData.GetLength(1);
155      double[] x = new double[columns];
156      double[] y = new double[RandomForest.innerobj.nclasses];
157
158      for (int row = 0; row < n; row++) {
159        for (int column = 0; column < columns; column++) {
160          x[column] = inputData[row, column];
161        }
162        alglib.dfprocess(randomForest, x, ref y);
163        // find class for with the largest probability value
164        int maxProbClassIndex = 0;
165        double maxProb = y[0];
166        for (int i = 1; i < y.Length; i++) {
167          if (maxProb < y[i]) {
168            maxProb = y[i];
169            maxProbClassIndex = i;
170          }
171        }
172        yield return classValues[maxProbClassIndex];
173      }
174    }
175
176    public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
177      return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);
178    }
179    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
180      return CreateRegressionSolution(problemData);
181    }
182    public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
183      return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);
184    }
185    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
186      return CreateClassificationSolution(problemData);
187    }
188
189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
190      out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
191      return CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError, problemData.TrainingIndices);
192    }
193
194    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
195      out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError, IEnumerable<int> trainingIndices) {
196      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
197      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
198
199      alglib.dfreport rep;
200      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
201
202      rmsError = rep.rmserror;
203      avgRelError = rep.avgrelerror;
204      outOfBagAvgRelError = rep.oobavgrelerror;
205      outOfBagRmsError = rep.oobrmserror;
206
207      return new RandomForestModel(dForest,
208        seed, problemData,
209        nTrees, r, m);
210    }
211
212    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
213      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
214
215      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
216      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
217
218      var classValues = problemData.ClassValues.ToArray();
219      int nClasses = classValues.Length;
220
221      // map original class values to values [0..nClasses-1]
222      var classIndices = new Dictionary<double, double>();
223      for (int i = 0; i < nClasses; i++) {
224        classIndices[classValues[i]] = i;
225      }
226
227      int nRows = inputMatrix.GetLength(0);
228      int nColumns = inputMatrix.GetLength(1);
229      for (int row = 0; row < nRows; row++) {
230        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
231      }
232
233      alglib.dfreport rep;
234      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
235
236      rmsError = rep.rmserror;
237      outOfBagRmsError = rep.oobrmserror;
238      relClassificationError = rep.relclserror;
239      outOfBagRelClassificationError = rep.oobrelclserror;
240
241      return new RandomForestModel(dForest,
242        seed, problemData,
243        nTrees, r, m, classValues);
244    }
245
246    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
247      AssertParameters(r, m);
248      AssertInputMatrix(inputMatrix);
249
250      int info = 0;
251      alglib.math.rndobject = new System.Random(seed);
252      var dForest = new alglib.decisionforest();
253      rep = new alglib.dfreport();
254      int nRows = inputMatrix.GetLength(0);
255      int nColumns = inputMatrix.GetLength(1);
256      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
257      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
258
259      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
260      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
261      return dForest;
262    }
263
264    private static void AssertParameters(double r, double m) {
265      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
266      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
267    }
268
269    private static void AssertInputMatrix(double[,] inputMatrix) {
270      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
271        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
272    }
273
274    #region persistence for backwards compatibility
275    // when the originalTrainingData is null this means the model was loaded from an old file
276    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
277    // in such cases we still store the compete model
278    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
279
280    private string[] allowedInputVariables;
281    [Storable(Name = "allowedInputVariables")]
282    private string[] AllowedInputVariables {
283      get {
284        if (IsCompatibilityLoaded) return allowedInputVariables;
285        else return originalTrainingData.AllowedInputVariables.ToArray();
286      }
287      set { allowedInputVariables = value; }
288    }
289    [Storable]
290    private int RandomForestBufSize {
291      get {
292        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
293        else return 0;
294      }
295      set {
296        randomForest.innerobj.bufsize = value;
297      }
298    }
299    [Storable]
300    private int RandomForestNClasses {
301      get {
302        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
303        else return 0;
304      }
305      set {
306        randomForest.innerobj.nclasses = value;
307      }
308    }
309    [Storable]
310    private int RandomForestNTrees {
311      get {
312        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
313        else return 0;
314      }
315      set {
316        randomForest.innerobj.ntrees = value;
317      }
318    }
319    [Storable]
320    private int RandomForestNVars {
321      get {
322        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
323        else return 0;
324      }
325      set {
326        randomForest.innerobj.nvars = value;
327      }
328    }
329    [Storable]
330    private double[] RandomForestTrees {
331      get {
332        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
333        else return new double[] { };
334      }
335      set {
336        randomForest.innerobj.trees = value;
337      }
338    }
339    #endregion
340  }
341}
Note: See TracBrowser for help on using the repository browser.