Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 14185

Last change on this file since 14185 was 14185, checked in by swagner, 8 years ago

#2526: Updated year of copyrights in license headers

File size: 15.2 KB
RevLine 
[6240]1#region License Information
2/* HeuristicLab
[14185]3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6240]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
[6241]32  /// Represents a random forest model for regression and classification
[6240]33  /// </summary>
34  [StorableClass]
[6241]35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
[13941]36  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
[10963]37    // not persisted
[6240]38    private alglib.decisionforest randomForest;
[10963]39    private alglib.decisionforest RandomForest {
40      get {
41        // recalculate lazily
42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
43        return randomForest;
[6240]44      }
45    }
46
[13941]47    public override IEnumerable<string> VariablesUsedForPrediction {
[13921]48      get { return originalTrainingData.AllowedInputVariables; }
49    }
50
51
[10963]52    // instead of storing the data of the model itself
53    // we instead only store data necessary to recalculate the same model lazily on demand
[6240]54    [Storable]
[10963]55    private int seed;
[6240]56    [Storable]
[10963]57    private IDataAnalysisProblemData originalTrainingData;
[6241]58    [Storable]
59    private double[] classValues;
[10963]60    [Storable]
61    private int nTrees;
62    [Storable]
63    private double r;
64    [Storable]
65    private double m;
66
67
[6240]68    [StorableConstructor]
[6241]69    private RandomForestModel(bool deserializing)
[6240]70      : base(deserializing) {
[10963]71      // for backwards compatibility (loading old solutions)
72      randomForest = new alglib.decisionforest();
[6240]73    }
[6241]74    private RandomForestModel(RandomForestModel original, Cloner cloner)
[6240]75      : base(original, cloner) {
76      randomForest = new alglib.decisionforest();
77      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
78      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
79      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
80      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
[10963]81      // we assume that the trees array (double[]) is immutable in alglib
82      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
[11315]83
[10963]84      // allowedInputVariables is immutable so we don't need to clone
85      allowedInputVariables = original.allowedInputVariables;
86
87      // clone data which is necessary to rebuild the model
88      this.seed = original.seed;
89      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
90      // classvalues is immutable so we don't need to clone
91      this.classValues = original.classValues;
92      this.nTrees = original.nTrees;
93      this.r = original.r;
94      this.m = original.m;
[6240]95    }
[10963]96
97    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
[13941]98    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
[10963]99      int seed, IDataAnalysisProblemData originalTrainingData,
100      int nTrees, double r, double m, double[] classValues = null)
[13941]101      : base(targetVariable) {
[6240]102      this.name = ItemName;
103      this.description = ItemDescription;
[10963]104      // the model itself
[6240]105      this.randomForest = randomForest;
[10963]106      // data which is necessary for recalculation of the model
107      this.seed = seed;
108      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
109      this.classValues = classValues;
110      this.nTrees = nTrees;
111      this.r = r;
112      this.m = m;
[6240]113    }
114
115    public override IDeepCloneable Clone(Cloner cloner) {
[6241]116      return new RandomForestModel(this, cloner);
[6240]117    }
118
[10963]119    private void RecalculateModel() {
120      double rmsError, oobRmsError, relClassError, oobRelClassError;
121      var regressionProblemData = originalTrainingData as IRegressionProblemData;
122      var classificationProblemData = originalTrainingData as IClassificationProblemData;
123      if (regressionProblemData != null) {
124        var model = CreateRegressionModel(regressionProblemData,
125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
126                                              out relClassError, out oobRelClassError);
127        randomForest = model.randomForest;
128      } else if (classificationProblemData != null) {
129        var model = CreateClassificationModel(classificationProblemData,
130                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
131                                              out relClassError, out oobRelClassError);
132        randomForest = model.randomForest;
133      }
134    }
135
[12509]136    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
[10963]137      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
138      AssertInputMatrix(inputData);
[6240]139
140      int n = inputData.GetLength(0);
141      int columns = inputData.GetLength(1);
142      double[] x = new double[columns];
143      double[] y = new double[1];
144
145      for (int row = 0; row < n; row++) {
146        for (int column = 0; column < columns; column++) {
147          x[column] = inputData[row, column];
148        }
[10963]149        alglib.dfprocess(RandomForest, x, ref y);
[6240]150        yield return y[0];
151      }
152    }
153
[14107]154    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
155      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
156      AssertInputMatrix(inputData);
157
158      int n = inputData.GetLength(0);
159      int columns = inputData.GetLength(1);
160      double[] x = new double[columns];
161      double[] ys = new double[columns];
162
163      for (int row = 0; row < n; row++) {
164        for (int column = 0; column < columns; column++) {
165          x[column] = inputData[row, column];
166        }
167        alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
168        yield return ys.VariancePop();
169      }
170    }
171
[13941]172    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
[10963]173      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
174      AssertInputMatrix(inputData);
[6241]175
176      int n = inputData.GetLength(0);
177      int columns = inputData.GetLength(1);
178      double[] x = new double[columns];
[10963]179      double[] y = new double[RandomForest.innerobj.nclasses];
[6241]180
181      for (int row = 0; row < n; row++) {
182        for (int column = 0; column < columns; column++) {
183          x[column] = inputData[row, column];
184        }
185        alglib.dfprocess(randomForest, x, ref y);
186        // find class for with the largest probability value
187        int maxProbClassIndex = 0;
188        double maxProb = y[0];
189        for (int i = 1; i < y.Length; i++) {
190          if (maxProb < y[i]) {
191            maxProb = y[i];
192            maxProbClassIndex = i;
193          }
194        }
195        yield return classValues[maxProbClassIndex];
196      }
197    }
198
[13941]199
200    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
201      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
[6603]202    }
[13941]203    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
204      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
[6603]205    }
206
[10963]207    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
[11338]208      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
[11343]209      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
[11315]210    }
[10963]211
[11343]212    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
213      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
[10963]214      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
[11343]215      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
[10963]216
217      alglib.dfreport rep;
218      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
219
220      rmsError = rep.rmserror;
221      avgRelError = rep.avgrelerror;
222      outOfBagAvgRelError = rep.oobavgrelerror;
223      outOfBagRmsError = rep.oobrmserror;
224
[13941]225      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
[6240]226    }
227
[10963]228    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
229      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
[11343]230      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
[11338]231    }
[10963]232
[11343]233    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
234      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
[11338]235
[10963]236      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
[11338]237      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
[10963]238
239      var classValues = problemData.ClassValues.ToArray();
240      int nClasses = classValues.Length;
241
242      // map original class values to values [0..nClasses-1]
243      var classIndices = new Dictionary<double, double>();
244      for (int i = 0; i < nClasses; i++) {
245        classIndices[classValues[i]] = i;
246      }
247
248      int nRows = inputMatrix.GetLength(0);
249      int nColumns = inputMatrix.GetLength(1);
250      for (int row = 0; row < nRows; row++) {
251        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
252      }
253
254      alglib.dfreport rep;
255      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
256
257      rmsError = rep.rmserror;
258      outOfBagRmsError = rep.oobrmserror;
259      relClassificationError = rep.relclserror;
260      outOfBagRelClassificationError = rep.oobrelclserror;
261
[13941]262      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
[10963]263    }
264
265    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
266      AssertParameters(r, m);
267      AssertInputMatrix(inputMatrix);
268
269      int info = 0;
270      alglib.math.rndobject = new System.Random(seed);
271      var dForest = new alglib.decisionforest();
272      rep = new alglib.dfreport();
273      int nRows = inputMatrix.GetLength(0);
274      int nColumns = inputMatrix.GetLength(1);
275      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
276      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
277
278      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
279      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
280      return dForest;
281    }
282
283    private static void AssertParameters(double r, double m) {
284      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
285      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
286    }
287
288    private static void AssertInputMatrix(double[,] inputMatrix) {
[11338]289      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
[10963]290        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
291    }
292
293    #region persistence for backwards compatibility
294    // when the originalTrainingData is null this means the model was loaded from an old file
295    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
296    // in such cases we still store the compete model
297    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
298
299    private string[] allowedInputVariables;
300    [Storable(Name = "allowedInputVariables")]
301    private string[] AllowedInputVariables {
302      get {
303        if (IsCompatibilityLoaded) return allowedInputVariables;
304        else return originalTrainingData.AllowedInputVariables.ToArray();
305      }
306      set { allowedInputVariables = value; }
307    }
[6240]308    [Storable]
309    private int RandomForestBufSize {
310      get {
[10963]311        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
312        else return 0;
[6240]313      }
314      set {
315        randomForest.innerobj.bufsize = value;
316      }
317    }
318    [Storable]
319    private int RandomForestNClasses {
320      get {
[10963]321        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
322        else return 0;
[6240]323      }
324      set {
325        randomForest.innerobj.nclasses = value;
326      }
327    }
328    [Storable]
329    private int RandomForestNTrees {
330      get {
[10963]331        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
332        else return 0;
[6240]333      }
334      set {
335        randomForest.innerobj.ntrees = value;
336      }
337    }
338    [Storable]
339    private int RandomForestNVars {
340      get {
[10963]341        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
342        else return 0;
[6240]343      }
344      set {
345        randomForest.innerobj.nvars = value;
346      }
347    }
348    [Storable]
349    private double[] RandomForestTrees {
350      get {
[10963]351        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
352        else return new double[] { };
[6240]353      }
354      set {
355        randomForest.innerobj.trees = value;
356      }
357    }
358    #endregion
359  }
360}
Note: See TracBrowser for help on using the repository browser.