Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2434_crossvalidation/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 16654

Last change on this file since 16654 was 14029, checked in by gkronber, 8 years ago

#2434: merged trunk changes r12934:14026 from trunk to branch

File size: 14.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a random forest model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
36  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
37    // not persisted
38    private alglib.decisionforest randomForest;
39    private alglib.decisionforest RandomForest {
40      get {
41        // recalculate lazily
42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
43        return randomForest;
44      }
45    }
46
47    public override IEnumerable<string> VariablesUsedForPrediction {
48      get { return originalTrainingData.AllowedInputVariables; }
49    }
50
51
52    // instead of storing the data of the model itself
53    // we instead only store data necessary to recalculate the same model lazily on demand
54    [Storable]
55    private int seed;
56    [Storable]
57    private IDataAnalysisProblemData originalTrainingData;
58    [Storable]
59    private double[] classValues;
60    [Storable]
61    private int nTrees;
62    [Storable]
63    private double r;
64    [Storable]
65    private double m;
66
67
68    [StorableConstructor]
69    private RandomForestModel(bool deserializing)
70      : base(deserializing) {
71      // for backwards compatibility (loading old solutions)
72      randomForest = new alglib.decisionforest();
73    }
74    private RandomForestModel(RandomForestModel original, Cloner cloner)
75      : base(original, cloner) {
76      randomForest = new alglib.decisionforest();
77      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
78      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
79      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
80      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
81      // we assume that the trees array (double[]) is immutable in alglib
82      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
83
84      // allowedInputVariables is immutable so we don't need to clone
85      allowedInputVariables = original.allowedInputVariables;
86
87      // clone data which is necessary to rebuild the model
88      this.seed = original.seed;
89      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
90      // classvalues is immutable so we don't need to clone
91      this.classValues = original.classValues;
92      this.nTrees = original.nTrees;
93      this.r = original.r;
94      this.m = original.m;
95    }
96
97    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
98    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
99      int seed, IDataAnalysisProblemData originalTrainingData,
100      int nTrees, double r, double m, double[] classValues = null)
101      : base(targetVariable) {
102      this.name = ItemName;
103      this.description = ItemDescription;
104      // the model itself
105      this.randomForest = randomForest;
106      // data which is necessary for recalculation of the model
107      this.seed = seed;
108      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
109      this.classValues = classValues;
110      this.nTrees = nTrees;
111      this.r = r;
112      this.m = m;
113    }
114
115    public override IDeepCloneable Clone(Cloner cloner) {
116      return new RandomForestModel(this, cloner);
117    }
118
119    private void RecalculateModel() {
120      double rmsError, oobRmsError, relClassError, oobRelClassError;
121      var regressionProblemData = originalTrainingData as IRegressionProblemData;
122      var classificationProblemData = originalTrainingData as IClassificationProblemData;
123      if (regressionProblemData != null) {
124        var model = CreateRegressionModel(regressionProblemData,
125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
126                                              out relClassError, out oobRelClassError);
127        randomForest = model.randomForest;
128      } else if (classificationProblemData != null) {
129        var model = CreateClassificationModel(classificationProblemData,
130                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
131                                              out relClassError, out oobRelClassError);
132        randomForest = model.randomForest;
133      }
134    }
135
136    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
137      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
138      AssertInputMatrix(inputData);
139
140      int n = inputData.GetLength(0);
141      int columns = inputData.GetLength(1);
142      double[] x = new double[columns];
143      double[] y = new double[1];
144
145      for (int row = 0; row < n; row++) {
146        for (int column = 0; column < columns; column++) {
147          x[column] = inputData[row, column];
148        }
149        alglib.dfprocess(RandomForest, x, ref y);
150        yield return y[0];
151      }
152    }
153
154    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
155      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
156      AssertInputMatrix(inputData);
157
158      int n = inputData.GetLength(0);
159      int columns = inputData.GetLength(1);
160      double[] x = new double[columns];
161      double[] y = new double[RandomForest.innerobj.nclasses];
162
163      for (int row = 0; row < n; row++) {
164        for (int column = 0; column < columns; column++) {
165          x[column] = inputData[row, column];
166        }
167        alglib.dfprocess(randomForest, x, ref y);
168        // find class for with the largest probability value
169        int maxProbClassIndex = 0;
170        double maxProb = y[0];
171        for (int i = 1; i < y.Length; i++) {
172          if (maxProb < y[i]) {
173            maxProb = y[i];
174            maxProbClassIndex = i;
175          }
176        }
177        yield return classValues[maxProbClassIndex];
178      }
179    }
180
181
182    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
183      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
184    }
185    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
186      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
187    }
188
189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
190      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
191      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
192    }
193
194    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
195      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
196      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
197      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
198
199      alglib.dfreport rep;
200      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
201
202      rmsError = rep.rmserror;
203      avgRelError = rep.avgrelerror;
204      outOfBagAvgRelError = rep.oobavgrelerror;
205      outOfBagRmsError = rep.oobrmserror;
206
207      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
208    }
209
210    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
211      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
212      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
213    }
214
215    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
216      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
217
218      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
219      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
220
221      var classValues = problemData.ClassValues.ToArray();
222      int nClasses = classValues.Length;
223
224      // map original class values to values [0..nClasses-1]
225      var classIndices = new Dictionary<double, double>();
226      for (int i = 0; i < nClasses; i++) {
227        classIndices[classValues[i]] = i;
228      }
229
230      int nRows = inputMatrix.GetLength(0);
231      int nColumns = inputMatrix.GetLength(1);
232      for (int row = 0; row < nRows; row++) {
233        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
234      }
235
236      alglib.dfreport rep;
237      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
238
239      rmsError = rep.rmserror;
240      outOfBagRmsError = rep.oobrmserror;
241      relClassificationError = rep.relclserror;
242      outOfBagRelClassificationError = rep.oobrelclserror;
243
244      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
245    }
246
247    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
248      AssertParameters(r, m);
249      AssertInputMatrix(inputMatrix);
250
251      int info = 0;
252      alglib.math.rndobject = new System.Random(seed);
253      var dForest = new alglib.decisionforest();
254      rep = new alglib.dfreport();
255      int nRows = inputMatrix.GetLength(0);
256      int nColumns = inputMatrix.GetLength(1);
257      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
258      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
259
260      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
261      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
262      return dForest;
263    }
264
265    private static void AssertParameters(double r, double m) {
266      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
267      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
268    }
269
270    private static void AssertInputMatrix(double[,] inputMatrix) {
271      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
272        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
273    }
274
275    #region persistence for backwards compatibility
276    // when the originalTrainingData is null this means the model was loaded from an old file
277    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
278    // in such cases we still store the compete model
279    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
280
281    private string[] allowedInputVariables;
282    [Storable(Name = "allowedInputVariables")]
283    private string[] AllowedInputVariables {
284      get {
285        if (IsCompatibilityLoaded) return allowedInputVariables;
286        else return originalTrainingData.AllowedInputVariables.ToArray();
287      }
288      set { allowedInputVariables = value; }
289    }
290    [Storable]
291    private int RandomForestBufSize {
292      get {
293        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
294        else return 0;
295      }
296      set {
297        randomForest.innerobj.bufsize = value;
298      }
299    }
300    [Storable]
301    private int RandomForestNClasses {
302      get {
303        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
304        else return 0;
305      }
306      set {
307        randomForest.innerobj.nclasses = value;
308      }
309    }
310    [Storable]
311    private int RandomForestNTrees {
312      get {
313        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
314        else return 0;
315      }
316      set {
317        randomForest.innerobj.ntrees = value;
318      }
319    }
320    [Storable]
321    private int RandomForestNVars {
322      get {
323        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
324        else return 0;
325      }
326      set {
327        randomForest.innerobj.nvars = value;
328      }
329    }
330    [Storable]
331    private double[] RandomForestTrees {
332      get {
333        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
334        else return new double[] { };
335      }
336      set {
337        randomForest.innerobj.trees = value;
338      }
339    }
340    #endregion
341  }
342}
Note: See TracBrowser for help on using the repository browser.