Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 11338

Last change on this file since 11338 was 11338, checked in by bburlacu, 10 years ago

#2237: Refactored random forest grid search and added support for symbolic classification.

File size: 14.6 KB
RevLine 
[6240]1#region License Information
2/* HeuristicLab
[11171]3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6240]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
[6241]32  /// Represents a random forest model for regression and classification
[6240]33  /// </summary>
34  [StorableClass]
[6241]35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
36  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
[10963]37    // not persisted
[6240]38    private alglib.decisionforest randomForest;
[10963]39    private alglib.decisionforest RandomForest {
40      get {
41        // recalculate lazily
42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
43        return randomForest;
[6240]44      }
45    }
46
[10963]47    // instead of storing the data of the model itself
48    // we instead only store data necessary to recalculate the same model lazily on demand
[6240]49    [Storable]
[10963]50    private int seed;
[6240]51    [Storable]
[10963]52    private IDataAnalysisProblemData originalTrainingData;
[6241]53    [Storable]
54    private double[] classValues;
[10963]55    [Storable]
56    private int nTrees;
57    [Storable]
58    private double r;
59    [Storable]
60    private double m;
61
62
[6240]63    [StorableConstructor]
[6241]64    private RandomForestModel(bool deserializing)
[6240]65      : base(deserializing) {
[10963]66      // for backwards compatibility (loading old solutions)
67      randomForest = new alglib.decisionforest();
[6240]68    }
[6241]69    private RandomForestModel(RandomForestModel original, Cloner cloner)
[6240]70      : base(original, cloner) {
71      randomForest = new alglib.decisionforest();
72      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
73      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
74      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
75      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
[10963]76      // we assume that the trees array (double[]) is immutable in alglib
77      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
[11315]78
[10963]79      // allowedInputVariables is immutable so we don't need to clone
80      allowedInputVariables = original.allowedInputVariables;
81
82      // clone data which is necessary to rebuild the model
83      this.seed = original.seed;
84      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
85      // classvalues is immutable so we don't need to clone
86      this.classValues = original.classValues;
87      this.nTrees = original.nTrees;
88      this.r = original.r;
89      this.m = original.m;
[6240]90    }
[10963]91
92    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
93    private RandomForestModel(alglib.decisionforest randomForest,
94      int seed, IDataAnalysisProblemData originalTrainingData,
95      int nTrees, double r, double m, double[] classValues = null)
[6240]96      : base() {
97      this.name = ItemName;
98      this.description = ItemDescription;
[10963]99      // the model itself
[6240]100      this.randomForest = randomForest;
[10963]101      // data which is necessary for recalculation of the model
102      this.seed = seed;
103      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
104      this.classValues = classValues;
105      this.nTrees = nTrees;
106      this.r = r;
107      this.m = m;
[6240]108    }
109
110    public override IDeepCloneable Clone(Cloner cloner) {
[6241]111      return new RandomForestModel(this, cloner);
[6240]112    }
113
[10963]114    private void RecalculateModel() {
115      double rmsError, oobRmsError, relClassError, oobRelClassError;
116      var regressionProblemData = originalTrainingData as IRegressionProblemData;
117      var classificationProblemData = originalTrainingData as IClassificationProblemData;
118      if (regressionProblemData != null) {
119        var model = CreateRegressionModel(regressionProblemData,
120                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
121                                              out relClassError, out oobRelClassError);
122        randomForest = model.randomForest;
123      } else if (classificationProblemData != null) {
124        var model = CreateClassificationModel(classificationProblemData,
125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
126                                              out relClassError, out oobRelClassError);
127        randomForest = model.randomForest;
128      }
129    }
130
[6240]131    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
[10963]132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
133      AssertInputMatrix(inputData);
[6240]134
135      int n = inputData.GetLength(0);
136      int columns = inputData.GetLength(1);
137      double[] x = new double[columns];
138      double[] y = new double[1];
139
140      for (int row = 0; row < n; row++) {
141        for (int column = 0; column < columns; column++) {
142          x[column] = inputData[row, column];
143        }
[10963]144        alglib.dfprocess(RandomForest, x, ref y);
[6240]145        yield return y[0];
146      }
147    }
148
[6241]149    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
[10963]150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
151      AssertInputMatrix(inputData);
[6241]152
153      int n = inputData.GetLength(0);
154      int columns = inputData.GetLength(1);
155      double[] x = new double[columns];
[10963]156      double[] y = new double[RandomForest.innerobj.nclasses];
[6241]157
158      for (int row = 0; row < n; row++) {
159        for (int column = 0; column < columns; column++) {
160          x[column] = inputData[row, column];
161        }
162        alglib.dfprocess(randomForest, x, ref y);
163        // find class for with the largest probability value
164        int maxProbClassIndex = 0;
165        double maxProb = y[0];
166        for (int i = 1; i < y.Length; i++) {
167          if (maxProb < y[i]) {
168            maxProb = y[i];
169            maxProbClassIndex = i;
170          }
171        }
172        yield return classValues[maxProbClassIndex];
173      }
174    }
175
[6603]176    public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
[8528]177      return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);
[6603]178    }
179    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
180      return CreateRegressionSolution(problemData);
181    }
[6604]182    public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
[8528]183      return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);
[6604]184    }
185    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
186      return CreateClassificationSolution(problemData);
187    }
[6603]188
[10963]189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
[11338]190      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
[11315]191      return CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError, problemData.TrainingIndices);
192    }
[10963]193
[11315]194    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
[11338]195      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError, IEnumerable<int> trainingIndices) {
[10963]196      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
197      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
198
199      alglib.dfreport rep;
200      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
201
202      rmsError = rep.rmserror;
203      avgRelError = rep.avgrelerror;
204      outOfBagAvgRelError = rep.oobavgrelerror;
205      outOfBagRmsError = rep.oobrmserror;
206
207      return new RandomForestModel(dForest,
208        seed, problemData,
209        nTrees, r, m);
[6240]210    }
211
[10963]212    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
213      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
[11338]214      return CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError, problemData.TrainingIndices);
215    }
[10963]216
[11338]217    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
218      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError, IEnumerable<int> trainingIndices) {
219
[10963]220      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
[11338]221      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
[10963]222
223      var classValues = problemData.ClassValues.ToArray();
224      int nClasses = classValues.Length;
225
226      // map original class values to values [0..nClasses-1]
227      var classIndices = new Dictionary<double, double>();
228      for (int i = 0; i < nClasses; i++) {
229        classIndices[classValues[i]] = i;
230      }
231
232      int nRows = inputMatrix.GetLength(0);
233      int nColumns = inputMatrix.GetLength(1);
234      for (int row = 0; row < nRows; row++) {
235        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
236      }
237
238      alglib.dfreport rep;
239      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
240
241      rmsError = rep.rmserror;
242      outOfBagRmsError = rep.oobrmserror;
243      relClassificationError = rep.relclserror;
244      outOfBagRelClassificationError = rep.oobrelclserror;
245
246      return new RandomForestModel(dForest,
247        seed, problemData,
248        nTrees, r, m, classValues);
249    }
250
251    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
252      AssertParameters(r, m);
253      AssertInputMatrix(inputMatrix);
254
255      int info = 0;
256      alglib.math.rndobject = new System.Random(seed);
257      var dForest = new alglib.decisionforest();
258      rep = new alglib.dfreport();
259      int nRows = inputMatrix.GetLength(0);
260      int nColumns = inputMatrix.GetLength(1);
261      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
262      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
263
264      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
265      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
266      return dForest;
267    }
268
269    private static void AssertParameters(double r, double m) {
270      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
271      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
272    }
273
274    private static void AssertInputMatrix(double[,] inputMatrix) {
[11338]275      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
[10963]276        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
277    }
278
279    #region persistence for backwards compatibility
280    // when the originalTrainingData is null this means the model was loaded from an old file
281    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
282    // in such cases we still store the compete model
283    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
284
285    private string[] allowedInputVariables;
286    [Storable(Name = "allowedInputVariables")]
287    private string[] AllowedInputVariables {
288      get {
289        if (IsCompatibilityLoaded) return allowedInputVariables;
290        else return originalTrainingData.AllowedInputVariables.ToArray();
291      }
292      set { allowedInputVariables = value; }
293    }
[6240]294    [Storable]
295    private int RandomForestBufSize {
296      get {
[10963]297        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
298        else return 0;
[6240]299      }
300      set {
301        randomForest.innerobj.bufsize = value;
302      }
303    }
304    [Storable]
305    private int RandomForestNClasses {
306      get {
[10963]307        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
308        else return 0;
[6240]309      }
310      set {
311        randomForest.innerobj.nclasses = value;
312      }
313    }
314    [Storable]
315    private int RandomForestNTrees {
316      get {
[10963]317        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
318        else return 0;
[6240]319      }
320      set {
321        randomForest.innerobj.ntrees = value;
322      }
323    }
324    [Storable]
325    private int RandomForestNVars {
326      get {
[10963]327        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
328        else return 0;
[6240]329      }
330      set {
331        randomForest.innerobj.nvars = value;
332      }
333    }
334    [Storable]
335    private double[] RandomForestTrees {
336      get {
[10963]337        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
338        else return new double[] { };
[6240]339      }
340      set {
341        randomForest.innerobj.trees = value;
342      }
343    }
344    #endregion
345  }
346}
Note: See TracBrowser for help on using the repository browser.