source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs @ 14230

Last change on this file since 14230 was 14230, checked in by gkronber, 3 years ago

#2631: minor change

File size: 15.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a random forest model for regression and classification
33  /// </summary>
34  [StorableClass]
35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
36  public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
37    // not persisted
38    private alglib.decisionforest randomForest;
39    private alglib.decisionforest RandomForest {
40      get {
41        // recalculate lazily
42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
43        return randomForest;
44      }
45    }
46
47    public override IEnumerable<string> VariablesUsedForPrediction {
48      get { return originalTrainingData.AllowedInputVariables; }
49    }
50
51
52    // instead of storing the data of the model itself
53    // we instead only store data necessary to recalculate the same model lazily on demand
54    [Storable]
55    private int seed;
56    [Storable]
57    private IDataAnalysisProblemData originalTrainingData;
58    [Storable]
59    private double[] classValues;
60    [Storable]
61    private int nTrees;
62    [Storable]
63    private double r;
64    [Storable]
65    private double m;
66
67
68    [StorableConstructor]
69    private RandomForestModel(bool deserializing)
70      : base(deserializing) {
71      // for backwards compatibility (loading old solutions)
72      randomForest = new alglib.decisionforest();
73    }
74    private RandomForestModel(RandomForestModel original, Cloner cloner)
75      : base(original, cloner) {
76      randomForest = new alglib.decisionforest();
77      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
78      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
79      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
80      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
81      // we assume that the trees array (double[]) is immutable in alglib
82      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
83
84      // allowedInputVariables is immutable so we don't need to clone
85      allowedInputVariables = original.allowedInputVariables;
86
87      // clone data which is necessary to rebuild the model
88      this.seed = original.seed;
89      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
90      // classvalues is immutable so we don't need to clone
91      this.classValues = original.classValues;
92      this.nTrees = original.nTrees;
93      this.r = original.r;
94      this.m = original.m;
95    }
96
97    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
98    private RandomForestModel(string targetVariable, alglib.decisionforest randomForest,
99      int seed, IDataAnalysisProblemData originalTrainingData,
100      int nTrees, double r, double m, double[] classValues = null)
101      : base(targetVariable) {
102      this.name = ItemName;
103      this.description = ItemDescription;
104      // the model itself
105      this.randomForest = randomForest;
106      // data which is necessary for recalculation of the model
107      this.seed = seed;
108      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
109      this.classValues = classValues;
110      this.nTrees = nTrees;
111      this.r = r;
112      this.m = m;
113    }
114
115    public override IDeepCloneable Clone(Cloner cloner) {
116      return new RandomForestModel(this, cloner);
117    }
118
119    private void RecalculateModel() {
120      double rmsError, oobRmsError, relClassError, oobRelClassError;
121      var regressionProblemData = originalTrainingData as IRegressionProblemData;
122      var classificationProblemData = originalTrainingData as IClassificationProblemData;
123      if (regressionProblemData != null) {
124        var model = CreateRegressionModel(regressionProblemData,
125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
126                                              out relClassError, out oobRelClassError);
127        randomForest = model.randomForest;
128      } else if (classificationProblemData != null) {
129        var model = CreateClassificationModel(classificationProblemData,
130                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
131                                              out relClassError, out oobRelClassError);
132        randomForest = model.randomForest;
133      }
134    }
135
136    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
137      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
138      AssertInputMatrix(inputData);
139
140      int n = inputData.GetLength(0);
141      int columns = inputData.GetLength(1);
142      double[] x = new double[columns];
143      double[] y = new double[1];
144
145      for (int row = 0; row < n; row++) {
146        for (int column = 0; column < columns; column++) {
147          x[column] = inputData[row, column];
148        }
149        alglib.dfprocess(RandomForest, x, ref y);
150        yield return y[0];
151      }
152    }
153
154    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
155      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
156      AssertInputMatrix(inputData);
157
158      int n = inputData.GetLength(0);
159      int columns = inputData.GetLength(1);
160      double[] x = new double[columns];
161      double[] ys = new double[this.RandomForest.innerobj.ntrees];
162
163      for (int row = 0; row < n; row++) {
164        for (int column = 0; column < columns; column++) {
165          x[column] = inputData[row, column];
166        }
167        alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
168        yield return ys.VariancePop();
169      }
170    }
171
172    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
173      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
174      AssertInputMatrix(inputData);
175
176      int n = inputData.GetLength(0);
177      int columns = inputData.GetLength(1);
178      double[] x = new double[columns];
179      double[] y = new double[RandomForest.innerobj.nclasses];
180
181      for (int row = 0; row < n; row++) {
182        for (int column = 0; column < columns; column++) {
183          x[column] = inputData[row, column];
184        }
185        alglib.dfprocess(randomForest, x, ref y);
186        // find class for with the largest probability value
187        int maxProbClassIndex = 0;
188        double maxProb = y[0];
189        for (int i = 1; i < y.Length; i++) {
190          if (maxProb < y[i]) {
191            maxProb = y[i];
192            maxProbClassIndex = i;
193          }
194        }
195        yield return classValues[maxProbClassIndex];
196      }
197    }
198
199
200    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
201      return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
202    }
203    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
204      return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
205    }
206
207    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
208      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
209      return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
210    }
211
212    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
213      out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
214      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
215      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
216
217      alglib.dfreport rep;
218      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
219
220      rmsError = rep.rmserror;
221      avgRelError = rep.avgrelerror;
222      outOfBagAvgRelError = rep.oobavgrelerror;
223      outOfBagRmsError = rep.oobrmserror;
224
225      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
226    }
227
228    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
229      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
230      return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
231    }
232
233    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed,
234      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
235
236      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
237      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
238
239      var classValues = problemData.ClassValues.ToArray();
240      int nClasses = classValues.Length;
241
242      // map original class values to values [0..nClasses-1]
243      var classIndices = new Dictionary<double, double>();
244      for (int i = 0; i < nClasses; i++) {
245        classIndices[classValues[i]] = i;
246      }
247
248      int nRows = inputMatrix.GetLength(0);
249      int nColumns = inputMatrix.GetLength(1);
250      for (int row = 0; row < nRows; row++) {
251        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
252      }
253
254      alglib.dfreport rep;
255      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
256
257      rmsError = rep.rmserror;
258      outOfBagRmsError = rep.oobrmserror;
259      relClassificationError = rep.relclserror;
260      outOfBagRelClassificationError = rep.oobrelclserror;
261
262      return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
263    }
264
265    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
266      AssertParameters(r, m);
267      AssertInputMatrix(inputMatrix);
268
269      int info = 0;
270      alglib.math.rndobject = new System.Random(seed);
271      var dForest = new alglib.decisionforest();
272      rep = new alglib.dfreport();
273      int nRows = inputMatrix.GetLength(0);
274      int nColumns = inputMatrix.GetLength(1);
275      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
276      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
277
278      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
279      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
280      return dForest;
281    }
282
283    private static void AssertParameters(double r, double m) {
284      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
285      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
286    }
287
288    private static void AssertInputMatrix(double[,] inputMatrix) {
289      if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
290        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
291    }
292
293    #region persistence for backwards compatibility
294    // when the originalTrainingData is null this means the model was loaded from an old file
295    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
296    // in such cases we still store the compete model
297    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
298
299    private string[] allowedInputVariables;
300    [Storable(Name = "allowedInputVariables")]
301    private string[] AllowedInputVariables {
302      get {
303        if (IsCompatibilityLoaded) return allowedInputVariables;
304        else return originalTrainingData.AllowedInputVariables.ToArray();
305      }
306      set { allowedInputVariables = value; }
307    }
308    [Storable]
309    private int RandomForestBufSize {
310      get {
311        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
312        else return 0;
313      }
314      set {
315        randomForest.innerobj.bufsize = value;
316      }
317    }
318    [Storable]
319    private int RandomForestNClasses {
320      get {
321        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
322        else return 0;
323      }
324      set {
325        randomForest.innerobj.nclasses = value;
326      }
327    }
328    [Storable]
329    private int RandomForestNTrees {
330      get {
331        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
332        else return 0;
333      }
334      set {
335        randomForest.innerobj.ntrees = value;
336      }
337    }
338    [Storable]
339    private int RandomForestNVars {
340      get {
341        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
342        else return 0;
343      }
344      set {
345        randomForest.innerobj.nvars = value;
346      }
347    }
348    [Storable]
349    private double[] RandomForestTrees {
350      get {
351        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
352        else return new double[] { };
353      }
354      set {
355        randomForest.innerobj.trees = value;
356      }
357    }
358    #endregion
359  }
360}
Note: See TracBrowser for help on using the repository browser.