Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 17399

Last change on this file since 17399 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 20.1 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
29using HEAL.Attic;
30using HeuristicLab.Common;
31using HeuristicLab.Core;
32using HeuristicLab.Data;
33using HeuristicLab.Parameters;
34using HeuristicLab.Problems.DataAnalysis;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.Algorithms.DataAnalysis {
38  [Item("RFParameter", "A random forest parameter collection")]
39  [StorableType("40E482DA-63C5-4D39-97C7-63701CF1D021")]
40  public class RFParameter : ParameterCollection {
41    public RFParameter() {
42      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
43      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
44      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
45    }
46
47    [StorableConstructor]
48    protected RFParameter(StorableConstructorFlag _) : base(_) {
49    }
50
51    protected RFParameter(RFParameter original, Cloner cloner)
52      : base(original, cloner) {
53      this.N = original.N;
54      this.R = original.R;
55      this.M = original.M;
56    }
57
58    public override IDeepCloneable Clone(Cloner cloner) {
59      return new RFParameter(this, cloner);
60    }
61
62    private IFixedValueParameter<IntValue> NParameter {
63      get { return (IFixedValueParameter<IntValue>)base["N"]; }
64    }
65
66    private IFixedValueParameter<DoubleValue> RParameter {
67      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
68    }
69
70    private IFixedValueParameter<DoubleValue> MParameter {
71      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
72    }
73
74    public int N {
75      get { return NParameter.Value.Value; }
76      set { NParameter.Value.Value = value; }
77    }
78
79    public double R {
80      get { return RParameter.Value.Value; }
81      set { RParameter.Value.Value = value; }
82    }
83
84    public double M {
85      get { return MParameter.Value.Value; }
86      set { MParameter.Value.Value = value; }
87    }
88  }
89
90  public static class RandomForestUtil {
91    public static void AssertParameters(double r, double m) {
92      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
93      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
94    }
95
96    public static void AssertInputMatrix(double[,] inputMatrix) {
97      if (inputMatrix.ContainsNanOrInfinity())
98        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
99    }
100
101    internal static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
102      RandomForestUtil.AssertParameters(r, m);
103      RandomForestUtil.AssertInputMatrix(inputMatrix);
104
105      int info = 0;
106      alglib.math.rndobject = new System.Random(seed);
107      var dForest = new alglib.decisionforest();
108      rep = new alglib.dfreport();
109      int nRows = inputMatrix.GetLength(0);
110      int nColumns = inputMatrix.GetLength(1);
111      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
112      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
113
114      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
115      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
116      return dForest;
117    }
118
119
120    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
121      avgTestMse = 0;
122      var ds = problemData.Dataset;
123      var targetVariable = GetTargetVariableName(problemData);
124      foreach (var tuple in partitions) {
125        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
126        var trainingRandomForestPartition = tuple.Item1;
127        var testRandomForestPartition = tuple.Item2;
128        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
129        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
130        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
131        OnlineCalculatorError calculatorError;
132        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
133        if (calculatorError != OnlineCalculatorError.None)
134          mse = double.NaN;
135        avgTestMse += mse;
136      }
137      avgTestMse /= partitions.Length;
138    }
139
140    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
141      avgTestAccuracy = 0;
142      var ds = problemData.Dataset;
143      var targetVariable = GetTargetVariableName(problemData);
144      foreach (var tuple in partitions) {
145        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
146        var trainingRandomForestPartition = tuple.Item1;
147        var testRandomForestPartition = tuple.Item2;
148        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
149        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
150        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
151        OnlineCalculatorError calculatorError;
152        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
153        if (calculatorError != OnlineCalculatorError.None)
154          accuracy = double.NaN;
155        avgTestAccuracy += accuracy;
156      }
157      avgTestAccuracy /= partitions.Length;
158    }
159
160    /// <summary>
161    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
162    /// </summary>
163    /// <param name="problemData">The regression problem data</param>
164    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
165    /// <param name="seed">The random seed (required by the random forest model)</param>
166    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
167    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
168      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
169      var crossProduct = parameterRanges.Values.CartesianProduct();
170      double bestOutOfBagRmsError = double.MaxValue;
171      RFParameter bestParameters = new RFParameter();
172
173      var locker = new object();
174      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
175        var parameterValues = parameterCombination.ToList();
176        var parameters = new RFParameter();
177        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
178        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
179        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
180
181        lock (locker) {
182          if (bestOutOfBagRmsError > outOfBagRmsError) {
183            bestOutOfBagRmsError = outOfBagRmsError;
184            bestParameters = (RFParameter)parameters.Clone();
185          }
186        }
187      });
188      return bestParameters;
189    }
190
191    /// <summary>
192    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
193    /// </summary>
194    /// <param name="problemData">The classification problem data</param>
195    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
196    /// <param name="seed">The random seed (required by the random forest model)</param>
197    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
198    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
199      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
200      var crossProduct = parameterRanges.Values.CartesianProduct();
201
202      double bestOutOfBagRmsError = double.MaxValue;
203      RFParameter bestParameters = new RFParameter();
204
205      var locker = new object();
206      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
207        var parameterValues = parameterCombination.ToList();
208        var parameters = new RFParameter();
209        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
210        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
211        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
212                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
213
214        lock (locker) {
215          if (bestOutOfBagRmsError > outOfBagRmsError) {
216            bestOutOfBagRmsError = outOfBagRmsError;
217            bestParameters = (RFParameter)parameters.Clone();
218          }
219        }
220      });
221      return bestParameters;
222    }
223
224    /// <summary>
225    /// Grid search with crossvalidation
226    /// </summary>
227    /// <param name="problemData">The regression problem data</param>
228    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
229    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
230    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
231    /// <param name="seed">The random seed (required by the random forest model)</param>
232    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
233    /// <returns>The best parameter values found by the grid search</returns>
234    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
235      DoubleValue mse = new DoubleValue(Double.MaxValue);
236      RFParameter bestParameter = new RFParameter();
237
238      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
239      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
240      var crossProduct = parameterRanges.Values.CartesianProduct();
241
242      var locker = new object();
243      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
244        var parameterValues = parameterCombination.ToList();
245        double testMSE;
246        var parameters = new RFParameter();
247        for (int i = 0; i < setters.Count; ++i) {
248          setters[i](parameters, parameterValues[i]);
249        }
250        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
251
252        lock (locker) {
253          if (testMSE < mse.Value) {
254            mse.Value = testMSE;
255            bestParameter = (RFParameter)parameters.Clone();
256          }
257        }
258      });
259      return bestParameter;
260    }
261
262    /// <summary>
263    /// Grid search with crossvalidation
264    /// </summary>
265    /// <param name="problemData">The classification problem data</param>
266    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
267    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
268    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
269    /// <param name="seed">The random seed (for shuffling)</param>
270    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
271    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
272      DoubleValue accuracy = new DoubleValue(0);
273      RFParameter bestParameter = new RFParameter();
274
275      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
276      var crossProduct = parameterRanges.Values.CartesianProduct();
277      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
278
279      var locker = new object();
280      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
281        var parameterValues = parameterCombination.ToList();
282        double testAccuracy;
283        var parameters = new RFParameter();
284        for (int i = 0; i < setters.Count; ++i) {
285          setters[i](parameters, parameterValues[i]);
286        }
287        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
288
289        lock (locker) {
290          if (testAccuracy > accuracy.Value) {
291            accuracy.Value = testAccuracy;
292            bestParameter = (RFParameter)parameters.Clone();
293          }
294        }
295      });
296      return bestParameter;
297    }
298
299    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
300      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
301      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
302
303      for (int i = 0; i < numberOfFolds; ++i) {
304        int p = i; // avoid "access to modified closure" warning
305        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
306        var testRows = folds[i];
307        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
308      }
309      return partitions;
310    }
311
312    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
313      var random = new MersenneTwister((uint)Environment.TickCount);
314      if (problemData is IRegressionProblemData) {
315        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
316        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
317      }
318      if (problemData is IClassificationProblemData) {
319        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
320        // otherwise, generate folds normally
321        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
322      }
323      throw new ArgumentException("Problem data is neither regression or classification problem data.");
324    }
325
326    /// <summary>
327    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
328    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
329    /// the corresponding parts from each class label.
330    /// </summary>
331    /// <param name="problemData">The classification problem data.</param>
332    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
333    /// <param name="random">The random generator used to shuffle the folds.</param>
334    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
335    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
336      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
337      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
338      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
339      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
340      while (enumerators.All(e => e.MoveNext())) {
341        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
342      }
343    }
344
345    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
346      // if number of folds is greater than the number of values, some empty folds will be returned
347      if (valuesCount < numberOfFolds) {
348        for (int i = 0; i < numberOfFolds; ++i)
349          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
350      } else {
351        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
352        int start = 0, end = f;
353        for (int i = 0; i < numberOfFolds; ++i) {
354          if (r > 0) {
355            ++end;
356            --r;
357          }
358          yield return values.Skip(start).Take(end - start);
359          start = end;
360          end += f;
361        }
362      }
363    }
364
365    private static Action<RFParameter, double> GenerateSetter(string field) {
366      var targetExp = Expression.Parameter(typeof(RFParameter));
367      var valueExp = Expression.Parameter(typeof(double));
368      var fieldExp = Expression.Property(targetExp, field);
369      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
370      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
371      return setter;
372    }
373
374    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
375      var regressionProblemData = problemData as IRegressionProblemData;
376      var classificationProblemData = problemData as IClassificationProblemData;
377
378      if (regressionProblemData != null)
379        return regressionProblemData.TargetVariable;
380      if (classificationProblemData != null)
381        return classificationProblemData.TargetVariable;
382
383      throw new ArgumentException("Problem data is neither regression or classification problem data.");
384    }
385  }
386}
Note: See TracBrowser for help on using the repository browser.