Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 11006

Last change on this file since 11006 was 11006, checked in by gkronber, 10 years ago

#1721: merged improved random forest persistence from trunk to stable branch

File size: 13.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a random forest model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
36  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
37    // not persisted
38    private alglib.decisionforest randomForest;
39    private alglib.decisionforest RandomForest {
40      get {
41        // recalculate lazily
42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
43        return randomForest;
44      }
45    }
46
47    // instead of storing the data of the model itself
48    // we instead only store data necessary to recalculate the same model lazily on demand
49    [Storable]
50    private int seed;
51    [Storable]
52    private IDataAnalysisProblemData originalTrainingData;
53    [Storable]
54    private double[] classValues;
55    [Storable]
56    private int nTrees;
57    [Storable]
58    private double r;
59    [Storable]
60    private double m;
61
62
63    [StorableConstructor]
64    private RandomForestModel(bool deserializing)
65      : base(deserializing) {
66      // for backwards compatibility (loading old solutions)
67      randomForest = new alglib.decisionforest();
68    }
69    private RandomForestModel(RandomForestModel original, Cloner cloner)
70      : base(original, cloner) {
71      randomForest = new alglib.decisionforest();
72      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
73      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
74      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
75      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
76      // we assume that the trees array (double[]) is immutable in alglib
77      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
78     
79      // allowedInputVariables is immutable so we don't need to clone
80      allowedInputVariables = original.allowedInputVariables;
81
82      // clone data which is necessary to rebuild the model
83      this.seed = original.seed;
84      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
85      // classvalues is immutable so we don't need to clone
86      this.classValues = original.classValues;
87      this.nTrees = original.nTrees;
88      this.r = original.r;
89      this.m = original.m;
90    }
91
92    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
93    private RandomForestModel(alglib.decisionforest randomForest,
94      int seed, IDataAnalysisProblemData originalTrainingData,
95      int nTrees, double r, double m, double[] classValues = null)
96      : base() {
97      this.name = ItemName;
98      this.description = ItemDescription;
99      // the model itself
100      this.randomForest = randomForest;
101      // data which is necessary for recalculation of the model
102      this.seed = seed;
103      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
104      this.classValues = classValues;
105      this.nTrees = nTrees;
106      this.r = r;
107      this.m = m;
108    }
109
110    public override IDeepCloneable Clone(Cloner cloner) {
111      return new RandomForestModel(this, cloner);
112    }
113
114    private void RecalculateModel() {
115      double rmsError, oobRmsError, relClassError, oobRelClassError;
116      var regressionProblemData = originalTrainingData as IRegressionProblemData;
117      var classificationProblemData = originalTrainingData as IClassificationProblemData;
118      if (regressionProblemData != null) {
119        var model = CreateRegressionModel(regressionProblemData,
120                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
121                                              out relClassError, out oobRelClassError);
122        randomForest = model.randomForest;
123      } else if (classificationProblemData != null) {
124        var model = CreateClassificationModel(classificationProblemData,
125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
126                                              out relClassError, out oobRelClassError);
127        randomForest = model.randomForest;
128      }
129    }
130
131    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
133      AssertInputMatrix(inputData);
134
135      int n = inputData.GetLength(0);
136      int columns = inputData.GetLength(1);
137      double[] x = new double[columns];
138      double[] y = new double[1];
139
140      for (int row = 0; row < n; row++) {
141        for (int column = 0; column < columns; column++) {
142          x[column] = inputData[row, column];
143        }
144        alglib.dfprocess(RandomForest, x, ref y);
145        yield return y[0];
146      }
147    }
148
149    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
151      AssertInputMatrix(inputData);
152
153      int n = inputData.GetLength(0);
154      int columns = inputData.GetLength(1);
155      double[] x = new double[columns];
156      double[] y = new double[RandomForest.innerobj.nclasses];
157
158      for (int row = 0; row < n; row++) {
159        for (int column = 0; column < columns; column++) {
160          x[column] = inputData[row, column];
161        }
162        alglib.dfprocess(randomForest, x, ref y);
163        // find class for with the largest probability value
164        int maxProbClassIndex = 0;
165        double maxProb = y[0];
166        for (int i = 1; i < y.Length; i++) {
167          if (maxProb < y[i]) {
168            maxProb = y[i];
169            maxProbClassIndex = i;
170          }
171        }
172        yield return classValues[maxProbClassIndex];
173      }
174    }
175
176    public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
177      return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);
178    }
179    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
180      return CreateRegressionSolution(problemData);
181    }
182    public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
183      return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);
184    }
185    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
186      return CreateClassificationSolution(problemData);
187    }
188
189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
190      out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
191
192      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
193      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
194
195      alglib.dfreport rep;
196      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
197
198      rmsError = rep.rmserror;
199      avgRelError = rep.avgrelerror;
200      outOfBagAvgRelError = rep.oobavgrelerror;
201      outOfBagRmsError = rep.oobrmserror;
202
203      return new RandomForestModel(dForest,
204        seed, problemData,
205        nTrees, r, m);
206    }
207
208    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
209      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
210
211      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
212      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
213
214      var classValues = problemData.ClassValues.ToArray();
215      int nClasses = classValues.Length;
216
217      // map original class values to values [0..nClasses-1]
218      var classIndices = new Dictionary<double, double>();
219      for (int i = 0; i < nClasses; i++) {
220        classIndices[classValues[i]] = i;
221      }
222
223      int nRows = inputMatrix.GetLength(0);
224      int nColumns = inputMatrix.GetLength(1);
225      for (int row = 0; row < nRows; row++) {
226        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
227      }
228
229      alglib.dfreport rep;
230      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
231
232      rmsError = rep.rmserror;
233      outOfBagRmsError = rep.oobrmserror;
234      relClassificationError = rep.relclserror;
235      outOfBagRelClassificationError = rep.oobrelclserror;
236
237      return new RandomForestModel(dForest,
238        seed, problemData,
239        nTrees, r, m, classValues);
240    }
241
242    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
243      AssertParameters(r, m);
244      AssertInputMatrix(inputMatrix);
245
246      int info = 0;
247      alglib.math.rndobject = new System.Random(seed);
248      var dForest = new alglib.decisionforest();
249      rep = new alglib.dfreport();
250      int nRows = inputMatrix.GetLength(0);
251      int nColumns = inputMatrix.GetLength(1);
252      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
253      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
254
255      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
256      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
257      return dForest;
258    }
259
260    private static void AssertParameters(double r, double m) {
261      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
262      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
263    }
264
265    private static void AssertInputMatrix(double[,] inputMatrix) {
266      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
267        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
268    }
269
270    #region persistence for backwards compatibility
271    // when the originalTrainingData is null this means the model was loaded from an old file
272    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
273    // in such cases we still store the compete model
274    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
275
276    private string[] allowedInputVariables;
277    [Storable(Name = "allowedInputVariables")]
278    private string[] AllowedInputVariables {
279      get {
280        if (IsCompatibilityLoaded) return allowedInputVariables;
281        else return originalTrainingData.AllowedInputVariables.ToArray();
282      }
283      set { allowedInputVariables = value; }
284    }
285    [Storable]
286    private int RandomForestBufSize {
287      get {
288        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
289        else return 0;
290      }
291      set {
292        randomForest.innerobj.bufsize = value;
293      }
294    }
295    [Storable]
296    private int RandomForestNClasses {
297      get {
298        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
299        else return 0;
300      }
301      set {
302        randomForest.innerobj.nclasses = value;
303      }
304    }
305    [Storable]
306    private int RandomForestNTrees {
307      get {
308        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
309        else return 0;
310      }
311      set {
312        randomForest.innerobj.ntrees = value;
313      }
314    }
315    [Storable]
316    private int RandomForestNVars {
317      get {
318        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
319        else return 0;
320      }
321      set {
322        randomForest.innerobj.nvars = value;
323      }
324    }
325    [Storable]
326    private double[] RandomForestTrees {
327      get {
328        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
329        else return new double[] { };
330      }
331      set {
332        randomForest.innerobj.trees = value;
333      }
334    }
335    #endregion
336  }
337}
Note: See TracBrowser for help on using the repository browser.