source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 17097

Last change on this file since 17097 was 17097, checked in by mkommend, 3 months ago

#2520: Merged 16565 - 16579 into stable.

File size: 18.6 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Data;
32using HeuristicLab.Parameters;
33using HEAL.Attic;
34using HeuristicLab.Problems.DataAnalysis;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.Algorithms.DataAnalysis {
38  [Item("RFParameter", "A random forest parameter collection")]
39  [StorableType("40E482DA-63C5-4D39-97C7-63701CF1D021")]
40  public class RFParameter : ParameterCollection {
41    public RFParameter() {
42      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
43      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
44      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
45    }
46
47    [StorableConstructor]
48    protected RFParameter(StorableConstructorFlag _) : base(_) {
49    }
50
51    protected RFParameter(RFParameter original, Cloner cloner)
52      : base(original, cloner) {
53      this.N = original.N;
54      this.R = original.R;
55      this.M = original.M;
56    }
57
58    public override IDeepCloneable Clone(Cloner cloner) {
59      return new RFParameter(this, cloner);
60    }
61
62    private IFixedValueParameter<IntValue> NParameter {
63      get { return (IFixedValueParameter<IntValue>)base["N"]; }
64    }
65
66    private IFixedValueParameter<DoubleValue> RParameter {
67      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
68    }
69
70    private IFixedValueParameter<DoubleValue> MParameter {
71      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
72    }
73
74    public int N {
75      get { return NParameter.Value.Value; }
76      set { NParameter.Value.Value = value; }
77    }
78
79    public double R {
80      get { return RParameter.Value.Value; }
81      set { RParameter.Value.Value = value; }
82    }
83
84    public double M {
85      get { return MParameter.Value.Value; }
86      set { MParameter.Value.Value = value; }
87    }
88  }
89
90  public static class RandomForestUtil {
91    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
92      avgTestMse = 0;
93      var ds = problemData.Dataset;
94      var targetVariable = GetTargetVariableName(problemData);
95      foreach (var tuple in partitions) {
96        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
97        var trainingRandomForestPartition = tuple.Item1;
98        var testRandomForestPartition = tuple.Item2;
99        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
100        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
101        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
102        OnlineCalculatorError calculatorError;
103        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
104        if (calculatorError != OnlineCalculatorError.None)
105          mse = double.NaN;
106        avgTestMse += mse;
107      }
108      avgTestMse /= partitions.Length;
109    }
110
111    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
112      avgTestAccuracy = 0;
113      var ds = problemData.Dataset;
114      var targetVariable = GetTargetVariableName(problemData);
115      foreach (var tuple in partitions) {
116        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
117        var trainingRandomForestPartition = tuple.Item1;
118        var testRandomForestPartition = tuple.Item2;
119        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
120        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
121        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
122        OnlineCalculatorError calculatorError;
123        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
124        if (calculatorError != OnlineCalculatorError.None)
125          accuracy = double.NaN;
126        avgTestAccuracy += accuracy;
127      }
128      avgTestAccuracy /= partitions.Length;
129    }
130
131    /// <summary>
132    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
133    /// </summary>
134    /// <param name="problemData">The regression problem data</param>
135    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
136    /// <param name="seed">The random seed (required by the random forest model)</param>
137    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
138    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
139      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
140      var crossProduct = parameterRanges.Values.CartesianProduct();
141      double bestOutOfBagRmsError = double.MaxValue;
142      RFParameter bestParameters = new RFParameter();
143
144      var locker = new object();
145      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
146        var parameterValues = parameterCombination.ToList();
147        var parameters = new RFParameter();
148        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
149        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
150        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
151
152        lock (locker) {
153          if (bestOutOfBagRmsError > outOfBagRmsError) {
154            bestOutOfBagRmsError = outOfBagRmsError;
155            bestParameters = (RFParameter)parameters.Clone();
156          }
157        }
158      });
159      return bestParameters;
160    }
161
162    /// <summary>
163    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
164    /// </summary>
165    /// <param name="problemData">The classification problem data</param>
166    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
167    /// <param name="seed">The random seed (required by the random forest model)</param>
168    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
169    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
170      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
171      var crossProduct = parameterRanges.Values.CartesianProduct();
172
173      double bestOutOfBagRmsError = double.MaxValue;
174      RFParameter bestParameters = new RFParameter();
175
176      var locker = new object();
177      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
178        var parameterValues = parameterCombination.ToList();
179        var parameters = new RFParameter();
180        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
181        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
182        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
183                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
184
185        lock (locker) {
186          if (bestOutOfBagRmsError > outOfBagRmsError) {
187            bestOutOfBagRmsError = outOfBagRmsError;
188            bestParameters = (RFParameter)parameters.Clone();
189          }
190        }
191      });
192      return bestParameters;
193    }
194
195    /// <summary>
196    /// Grid search with crossvalidation
197    /// </summary>
198    /// <param name="problemData">The regression problem data</param>
199    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
200    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
201    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
202    /// <param name="seed">The random seed (required by the random forest model)</param>
203    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
204    /// <returns>The best parameter values found by the grid search</returns>
205    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
206      DoubleValue mse = new DoubleValue(Double.MaxValue);
207      RFParameter bestParameter = new RFParameter();
208
209      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
210      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
211      var crossProduct = parameterRanges.Values.CartesianProduct();
212
213      var locker = new object();
214      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
215        var parameterValues = parameterCombination.ToList();
216        double testMSE;
217        var parameters = new RFParameter();
218        for (int i = 0; i < setters.Count; ++i) {
219          setters[i](parameters, parameterValues[i]);
220        }
221        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
222
223        lock (locker) {
224          if (testMSE < mse.Value) {
225            mse.Value = testMSE;
226            bestParameter = (RFParameter)parameters.Clone();
227          }
228        }
229      });
230      return bestParameter;
231    }
232
233    /// <summary>
234    /// Grid search with crossvalidation
235    /// </summary>
236    /// <param name="problemData">The classification problem data</param>
237    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
238    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
239    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
240    /// <param name="seed">The random seed (for shuffling)</param>
241    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
242    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
243      DoubleValue accuracy = new DoubleValue(0);
244      RFParameter bestParameter = new RFParameter();
245
246      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
247      var crossProduct = parameterRanges.Values.CartesianProduct();
248      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
249
250      var locker = new object();
251      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
252        var parameterValues = parameterCombination.ToList();
253        double testAccuracy;
254        var parameters = new RFParameter();
255        for (int i = 0; i < setters.Count; ++i) {
256          setters[i](parameters, parameterValues[i]);
257        }
258        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
259
260        lock (locker) {
261          if (testAccuracy > accuracy.Value) {
262            accuracy.Value = testAccuracy;
263            bestParameter = (RFParameter)parameters.Clone();
264          }
265        }
266      });
267      return bestParameter;
268    }
269
270    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
271      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
272      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
273
274      for (int i = 0; i < numberOfFolds; ++i) {
275        int p = i; // avoid "access to modified closure" warning
276        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
277        var testRows = folds[i];
278        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
279      }
280      return partitions;
281    }
282
283    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
284      var random = new MersenneTwister((uint)Environment.TickCount);
285      if (problemData is IRegressionProblemData) {
286        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
287        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
288      }
289      if (problemData is IClassificationProblemData) {
290        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
291        // otherwise, generate folds normally
292        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
293      }
294      throw new ArgumentException("Problem data is neither regression or classification problem data.");
295    }
296
297    /// <summary>
298    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
299    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
300    /// the corresponding parts from each class label.
301    /// </summary>
302    /// <param name="problemData">The classification problem data.</param>
303    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
304    /// <param name="random">The random generator used to shuffle the folds.</param>
305    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
306    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
307      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
308      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
309      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
310      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
311      while (enumerators.All(e => e.MoveNext())) {
312        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
313      }
314    }
315
316    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
317      // if number of folds is greater than the number of values, some empty folds will be returned
318      if (valuesCount < numberOfFolds) {
319        for (int i = 0; i < numberOfFolds; ++i)
320          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
321      } else {
322        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
323        int start = 0, end = f;
324        for (int i = 0; i < numberOfFolds; ++i) {
325          if (r > 0) {
326            ++end;
327            --r;
328          }
329          yield return values.Skip(start).Take(end - start);
330          start = end;
331          end += f;
332        }
333      }
334    }
335
336    private static Action<RFParameter, double> GenerateSetter(string field) {
337      var targetExp = Expression.Parameter(typeof(RFParameter));
338      var valueExp = Expression.Parameter(typeof(double));
339      var fieldExp = Expression.Property(targetExp, field);
340      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
341      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
342      return setter;
343    }
344
345    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
346      var regressionProblemData = problemData as IRegressionProblemData;
347      var classificationProblemData = problemData as IClassificationProblemData;
348
349      if (regressionProblemData != null)
350        return regressionProblemData.TargetVariable;
351      if (classificationProblemData != null)
352        return classificationProblemData.TargetVariable;
353
354      throw new ArgumentException("Problem data is neither regression or classification problem data.");
355    }
356  }
357}
Note: See TracBrowser for help on using the repository browser.