source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 17942

Last change on this file since 17942 was 17942, checked in by gkronber, 9 months ago

#3117: fixed order of parameters in grid search method for RF and removed unused shuffleFolds parameter

File size: 21.5 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24extern alias alglib_3_7;
25
26using System;
27using System.Collections.Generic;
28using System.Linq;
29using System.Linq.Expressions;
30using System.Threading.Tasks;
31using HEAL.Attic;
32using HeuristicLab.Common;
33using HeuristicLab.Core;
34using HeuristicLab.Data;
35using HeuristicLab.Parameters;
36using HeuristicLab.Problems.DataAnalysis;
37using HeuristicLab.Random;
38
39namespace HeuristicLab.Algorithms.DataAnalysis {
40  [Item("RFParameter", "A random forest parameter collection")]
41  [StorableType("40E482DA-63C5-4D39-97C7-63701CF1D021")]
42  public class RFParameter : ParameterCollection {
43    public RFParameter() {
44      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
45      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
46      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
47    }
48
49    [StorableConstructor]
50    protected RFParameter(StorableConstructorFlag _) : base(_) {
51    }
52
53    protected RFParameter(RFParameter original, Cloner cloner)
54      : base(original, cloner) {
55      this.N = original.N;
56      this.R = original.R;
57      this.M = original.M;
58    }
59
60    public override IDeepCloneable Clone(Cloner cloner) {
61      return new RFParameter(this, cloner);
62    }
63
64    private IFixedValueParameter<IntValue> NParameter {
65      get { return (IFixedValueParameter<IntValue>)base["N"]; }
66    }
67
68    private IFixedValueParameter<DoubleValue> RParameter {
69      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
70    }
71
72    private IFixedValueParameter<DoubleValue> MParameter {
73      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
74    }
75
76    public int N {
77      get { return NParameter.Value.Value; }
78      set { NParameter.Value.Value = value; }
79    }
80
81    public double R {
82      get { return RParameter.Value.Value; }
83      set { RParameter.Value.Value = value; }
84    }
85
86    public double M {
87      get { return MParameter.Value.Value; }
88      set { MParameter.Value.Value = value; }
89    }
90  }
91
92  public static class RandomForestUtil {
93    public static void AssertParameters(double r, double m) {
94      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
95      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
96    }
97
98    public static void AssertInputMatrix(double[,] inputMatrix) {
99      if (inputMatrix.ContainsNanOrInfinity())
100        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
101    }
102
103    internal static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
104      RandomForestUtil.AssertParameters(r, m);
105      RandomForestUtil.AssertInputMatrix(inputMatrix);
106
107      int nRows = inputMatrix.GetLength(0);
108      int nColumns = inputMatrix.GetLength(1);
109
110      alglib.dfbuildercreate(out var dfbuilder);
111      alglib.dfbuildersetdataset(dfbuilder, inputMatrix, nRows, nColumns - 1, nClasses);
112      alglib.dfbuildersetimportancenone(dfbuilder); // do not calculate importance (TODO add this feature)
113      alglib.dfbuildersetrdfalgo(dfbuilder, 0); // only one algorithm supported in version 3.17
114      alglib.dfbuildersetrdfsplitstrength(dfbuilder, 2); // 0 = split at the random position, fastest one
115                                                         // 1 = split at the middle of the range
116                                                         // 2 = strong split at the best point of the range (default)
117      alglib.dfbuildersetrndvarsratio(dfbuilder, m);
118      alglib.dfbuildersetsubsampleratio(dfbuilder, r);
119      alglib.dfbuildersetseed(dfbuilder, seed);
120      alglib.dfbuilderbuildrandomforest(dfbuilder, nTrees, out var dForest, out rep);
121      return dForest;
122    }
123    internal static alglib_3_7.alglib.decisionforest CreateRandomForestModelAlglib_3_7(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib_3_7.alglib.dfreport rep) {
124      RandomForestUtil.AssertParameters(r, m);
125      RandomForestUtil.AssertInputMatrix(inputMatrix);
126
127      int info = 0;
128      alglib_3_7.alglib.math.rndobject = new System.Random(seed);
129      var dForest = new alglib_3_7.alglib.decisionforest();
130      rep = new alglib_3_7.alglib.dfreport();
131      int nRows = inputMatrix.GetLength(0);
132      int nColumns = inputMatrix.GetLength(1);
133      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
134      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
135
136      alglib_3_7.alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib_3_7.alglib.dforest.dfusestrongsplits + alglib_3_7.alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
137      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
138      return dForest;
139    }
140
141
142    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
143      avgTestMse = 0;
144      var ds = problemData.Dataset;
145      var targetVariable = GetTargetVariableName(problemData);
146      foreach (var tuple in partitions) {
147        var trainingRandomForestPartition = tuple.Item1;
148        var testRandomForestPartition = tuple.Item2;
149        var model = RandomForestRegression.CreateRandomForestRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed,
150                                                                             out var rmsError, out var avgRelError, out var outOfBagRmsError, out var outOfBagAvgRelError);
151        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
152        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
153        OnlineCalculatorError calculatorError;
154        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
155        if (calculatorError != OnlineCalculatorError.None)
156          mse = double.NaN;
157        avgTestMse += mse;
158      }
159      avgTestMse /= partitions.Length;
160    }
161
162    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
163      avgTestAccuracy = 0;
164      var ds = problemData.Dataset;
165      var targetVariable = GetTargetVariableName(problemData);
166      foreach (var tuple in partitions) {
167        var trainingRandomForestPartition = tuple.Item1;
168        var testRandomForestPartition = tuple.Item2;
169        var model = RandomForestClassification.CreateRandomForestClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed,
170                                                                                     out var rmsError, out var avgRelError, out var outOfBagRmsError, out var outOfBagAvgRelError);
171        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
172        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
173        OnlineCalculatorError calculatorError;
174        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
175        if (calculatorError != OnlineCalculatorError.None)
176          accuracy = double.NaN;
177        avgTestAccuracy += accuracy;
178      }
179      avgTestAccuracy /= partitions.Length;
180    }
181
182    /// <summary>
183    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
184    /// </summary>
185    /// <param name="problemData">The regression problem data</param>
186    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
187    /// <param name="seed">The random seed (required by the random forest model)</param>
188    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
189    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
190      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
191      var crossProduct = parameterRanges.Values.CartesianProduct();
192      double bestOutOfBagRmsError = double.MaxValue;
193      RFParameter bestParameters = new RFParameter();
194
195      var locker = new object();
196      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
197        var parameterValues = parameterCombination.ToList();
198        var parameters = new RFParameter();
199        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
200        RandomForestRegression.CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
201                                                                 out var rmsError, out var avgRelError, out var outOfBagRmsError, out var outOfBagAvgRelError);
202
203        lock (locker) {
204          if (bestOutOfBagRmsError > outOfBagRmsError) {
205            bestOutOfBagRmsError = outOfBagRmsError;
206            bestParameters = (RFParameter)parameters.Clone();
207          }
208        }
209      });
210      return bestParameters;
211    }
212
213    /// <summary>
214    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
215    /// </summary>
216    /// <param name="problemData">The classification problem data</param>
217    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
218    /// <param name="seed">The random seed (required by the random forest model)</param>
219    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
220    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
221      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
222      var crossProduct = parameterRanges.Values.CartesianProduct();
223
224      double bestOutOfBagRmsError = double.MaxValue;
225      RFParameter bestParameters = new RFParameter();
226
227      var locker = new object();
228      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
229        var parameterValues = parameterCombination.ToList();
230        var parameters = new RFParameter();
231        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
232        RandomForestClassification.CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
233                                                                         out var rmsError, out var avgRelError, out var outOfBagRmsError, out var outOfBagAvgRelError);
234
235        lock (locker) {
236          if (bestOutOfBagRmsError > outOfBagRmsError) {
237            bestOutOfBagRmsError = outOfBagRmsError;
238            bestParameters = (RFParameter)parameters.Clone();
239          }
240        }
241      });
242      return bestParameters;
243    }
244
245    /// <summary>
246    /// Grid search with crossvalidation
247    /// </summary>
248    /// <param name="problemData">The regression problem data</param>
249    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
250    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
251    /// <param name="seed">The random seed (required by the random forest model)</param>
252    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
253    /// <returns>The best parameter values found by the grid search</returns>
254    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
255      DoubleValue mse = new DoubleValue(Double.MaxValue);
256      RFParameter bestParameter = new RFParameter();
257
258      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
259      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
260      var crossProduct = parameterRanges.Values.CartesianProduct();
261
262      var locker = new object();
263      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
264        var parameterValues = parameterCombination.ToList();
265        double testMSE;
266        var parameters = new RFParameter();
267        for (int i = 0; i < setters.Count; ++i) {
268          setters[i](parameters, parameterValues[i]);
269        }
270        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
271
272        lock (locker) {
273          if (testMSE < mse.Value) {
274            mse.Value = testMSE;
275            bestParameter = (RFParameter)parameters.Clone();
276          }
277        }
278      });
279      return bestParameter;
280    }
281
282    /// <summary>
283    /// Grid search with crossvalidation
284    /// </summary>
285    /// <param name="problemData">The classification problem data</param>
286    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
287    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
288    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
289    /// <param name="seed">The random seed (for shuffling)</param>
290    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
291    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
292      DoubleValue accuracy = new DoubleValue(0);
293      RFParameter bestParameter = new RFParameter();
294
295      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
296      var crossProduct = parameterRanges.Values.CartesianProduct();
297      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
298
299      var locker = new object();
300      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
301        var parameterValues = parameterCombination.ToList();
302        double testAccuracy;
303        var parameters = new RFParameter();
304        for (int i = 0; i < setters.Count; ++i) {
305          setters[i](parameters, parameterValues[i]);
306        }
307        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
308
309        lock (locker) {
310          if (testAccuracy > accuracy.Value) {
311            accuracy.Value = testAccuracy;
312            bestParameter = (RFParameter)parameters.Clone();
313          }
314        }
315      });
316      return bestParameter;
317    }
318
319    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
320      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
321      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
322
323      for (int i = 0; i < numberOfFolds; ++i) {
324        int p = i; // avoid "access to modified closure" warning
325        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
326        var testRows = folds[i];
327        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
328      }
329      return partitions;
330    }
331
332    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
333      var random = new MersenneTwister((uint)Environment.TickCount);
334      if (problemData is IRegressionProblemData) {
335        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
336        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
337      }
338      if (problemData is IClassificationProblemData) {
339        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
340        // otherwise, generate folds normally
341        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
342      }
343      throw new ArgumentException("Problem data is neither regression or classification problem data.");
344    }
345
346    /// <summary>
347    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
348    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
349    /// the corresponding parts from each class label.
350    /// </summary>
351    /// <param name="problemData">The classification problem data.</param>
352    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
353    /// <param name="random">The random generator used to shuffle the folds.</param>
354    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
355    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
356      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
357      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
358      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
359      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
360      while (enumerators.All(e => e.MoveNext())) {
361        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
362      }
363    }
364
365    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
366      // if number of folds is greater than the number of values, some empty folds will be returned
367      if (valuesCount < numberOfFolds) {
368        for (int i = 0; i < numberOfFolds; ++i)
369          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
370      } else {
371        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
372        int start = 0, end = f;
373        for (int i = 0; i < numberOfFolds; ++i) {
374          if (r > 0) {
375            ++end;
376            --r;
377          }
378          yield return values.Skip(start).Take(end - start);
379          start = end;
380          end += f;
381        }
382      }
383    }
384
385    private static Action<RFParameter, double> GenerateSetter(string field) {
386      var targetExp = Expression.Parameter(typeof(RFParameter));
387      var valueExp = Expression.Parameter(typeof(double));
388      var fieldExp = Expression.Property(targetExp, field);
389      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
390      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
391      return setter;
392    }
393
394    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
395      var regressionProblemData = problemData as IRegressionProblemData;
396      var classificationProblemData = problemData as IClassificationProblemData;
397
398      if (regressionProblemData != null)
399        return regressionProblemData.TargetVariable;
400      if (classificationProblemData != null)
401        return classificationProblemData.TargetVariable;
402
403      throw new ArgumentException("Problem data is neither regression or classification problem data.");
404    }
405  }
406}
Note: See TracBrowser for help on using the repository browser.