Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 15703

Last change on this file since 15703 was 15584, checked in by swagner, 7 years ago

#2640: Updated year of copyrights in license headers on stable

File size: 18.6 KB
RevLine 
[11315]1#region License Information
2
3/* HeuristicLab
[15584]4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[11315]5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
29using HeuristicLab.Common;
[11901]30using HeuristicLab.Core;
[11315]31using HeuristicLab.Data;
[11901]32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[11315]34using HeuristicLab.Problems.DataAnalysis;
[11901]35using HeuristicLab.Random;
[11315]36
37namespace HeuristicLab.Algorithms.DataAnalysis {
[11901]38  [Item("RFParameter", "A random forest parameter collection")]
39  [StorableClass]
40  public class RFParameter : ParameterCollection {
41    public RFParameter() {
42      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
43      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
44      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
45    }
[11315]46
[11901]47    [StorableConstructor]
48    protected RFParameter(bool deserializing)
49      : base(deserializing) {
50    }
[11315]51
[11901]52    protected RFParameter(RFParameter original, Cloner cloner)
53      : base(original, cloner) {
54      this.N = original.N;
55      this.R = original.R;
56      this.M = original.M;
57    }
[11315]58
[11901]59    public override IDeepCloneable Clone(Cloner cloner) {
60      return new RFParameter(this, cloner);
[11315]61    }
62
[11901]63    private IFixedValueParameter<IntValue> NParameter {
64      get { return (IFixedValueParameter<IntValue>)base["N"]; }
65    }
[11315]66
[11901]67    private IFixedValueParameter<DoubleValue> RParameter {
68      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
69    }
[11315]70
[11901]71    private IFixedValueParameter<DoubleValue> MParameter {
72      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
[11315]73    }
74
[11901]75    public int N {
76      get { return NParameter.Value.Value; }
77      set { NParameter.Value.Value = value; }
[11315]78    }
79
[11901]80    public double R {
81      get { return RParameter.Value.Value; }
82      set { RParameter.Value.Value = value; }
83    }
[11315]84
[11901]85    public double M {
86      get { return MParameter.Value.Value; }
87      set { MParameter.Value.Value = value; }
88    }
89  }
[11315]90
[11901]91  public static class RandomForestUtil {
92    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
93      avgTestMse = 0;
94      var ds = problemData.Dataset;
95      var targetVariable = GetTargetVariableName(problemData);
96      foreach (var tuple in partitions) {
[11315]97        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
[11901]98        var trainingRandomForestPartition = tuple.Item1;
99        var testRandomForestPartition = tuple.Item2;
100        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
101        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
102        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
[11315]103        OnlineCalculatorError calculatorError;
[11901]104        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
[11315]105        if (calculatorError != OnlineCalculatorError.None)
106          mse = double.NaN;
[11901]107        avgTestMse += mse;
[11315]108      }
[11901]109      avgTestMse /= partitions.Length;
110    }
[11315]111
[11901]112    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
113      avgTestAccuracy = 0;
114      var ds = problemData.Dataset;
115      var targetVariable = GetTargetVariableName(problemData);
116      foreach (var tuple in partitions) {
117        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
118        var trainingRandomForestPartition = tuple.Item1;
119        var testRandomForestPartition = tuple.Item2;
120        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
121        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
122        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
123        OnlineCalculatorError calculatorError;
124        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
125        if (calculatorError != OnlineCalculatorError.None)
126          accuracy = double.NaN;
127        avgTestAccuracy += accuracy;
128      }
129      avgTestAccuracy /= partitions.Length;
[11315]130    }
131
[12702]132    /// <summary>
133    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
134    /// </summary>
135    /// <param name="problemData">The regression problem data</param>
136    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
137    /// <param name="seed">The random seed (required by the random forest model)</param>
138    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
[11901]139    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
140      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
141      var crossProduct = parameterRanges.Values.CartesianProduct();
142      double bestOutOfBagRmsError = double.MaxValue;
143      RFParameter bestParameters = new RFParameter();
144
[12702]145      var locker = new object();
[11901]146      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
147        var parameterValues = parameterCombination.ToList();
148        var parameters = new RFParameter();
149        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
150        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
151        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
152
153        lock (locker) {
154          if (bestOutOfBagRmsError > outOfBagRmsError) {
155            bestOutOfBagRmsError = outOfBagRmsError;
156            bestParameters = (RFParameter)parameters.Clone();
157          }
158        }
159      });
160      return bestParameters;
161    }
162
[12702]163    /// <summary>
164    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
165    /// </summary>
166    /// <param name="problemData">The classification problem data</param>
167    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
168    /// <param name="seed">The random seed (required by the random forest model)</param>
169    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
[11901]170    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
171      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
172      var crossProduct = parameterRanges.Values.CartesianProduct();
173
174      double bestOutOfBagRmsError = double.MaxValue;
175      RFParameter bestParameters = new RFParameter();
176
[12702]177      var locker = new object();
[11901]178      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
179        var parameterValues = parameterCombination.ToList();
180        var parameters = new RFParameter();
181        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
182        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
183        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
184                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
185
186        lock (locker) {
187          if (bestOutOfBagRmsError > outOfBagRmsError) {
188            bestOutOfBagRmsError = outOfBagRmsError;
189            bestParameters = (RFParameter)parameters.Clone();
190          }
191        }
192      });
193      return bestParameters;
194    }
195
[12702]196    /// <summary>
197    /// Grid search with crossvalidation
198    /// </summary>
199    /// <param name="problemData">The regression problem data</param>
200    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
201    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
202    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
203    /// <param name="seed">The random seed (required by the random forest model)</param>
204    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
205    /// <returns>The best parameter values found by the grid search</returns>
[11901]206    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
[11315]207      DoubleValue mse = new DoubleValue(Double.MaxValue);
[11901]208      RFParameter bestParameter = new RFParameter();
[11315]209
[11901]210      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
211      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
212      var crossProduct = parameterRanges.Values.CartesianProduct();
[11315]213
[12702]214      var locker = new object();
[11901]215      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
216        var parameterValues = parameterCombination.ToList();
[11315]217        double testMSE;
218        var parameters = new RFParameter();
[11901]219        for (int i = 0; i < setters.Count; ++i) {
220          setters[i](parameters, parameterValues[i]);
[11315]221        }
[11901]222        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
223
224        lock (locker) {
225          if (testMSE < mse.Value) {
[11315]226            mse.Value = testMSE;
[11901]227            bestParameter = (RFParameter)parameters.Clone();
[11315]228          }
[11901]229        }
230      });
231      return bestParameter;
232    }
233
[12702]234    /// <summary>
235    /// Grid search with crossvalidation
236    /// </summary>
237    /// <param name="problemData">The classification problem data</param>
238    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
239    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
240    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
241    /// <param name="seed">The random seed (for shuffling)</param>
242    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
[11901]243    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
244      DoubleValue accuracy = new DoubleValue(0);
245      RFParameter bestParameter = new RFParameter();
246
247      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
248      var crossProduct = parameterRanges.Values.CartesianProduct();
249      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
250
[12702]251      var locker = new object();
[11901]252      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
253        var parameterValues = parameterCombination.ToList();
254        double testAccuracy;
255        var parameters = new RFParameter();
256        for (int i = 0; i < setters.Count; ++i) {
257          setters[i](parameters, parameterValues[i]);
258        }
259        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
260
261        lock (locker) {
262          if (testAccuracy > accuracy.Value) {
263            accuracy.Value = testAccuracy;
[11315]264            bestParameter = (RFParameter)parameters.Clone();
265          }
266        }
267      });
268      return bestParameter;
269    }
[11901]270
271    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
272      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
273      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
274
275      for (int i = 0; i < numberOfFolds; ++i) {
276        int p = i; // avoid "access to modified closure" warning
277        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
278        var testRows = folds[i];
279        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
280      }
281      return partitions;
282    }
283
284    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
285      var random = new MersenneTwister((uint)Environment.TickCount);
286      if (problemData is IRegressionProblemData) {
287        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
288        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
289      }
290      if (problemData is IClassificationProblemData) {
291        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
292        // otherwise, generate folds normally
293        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
294      }
295      throw new ArgumentException("Problem data is neither regression or classification problem data.");
296    }
297
298    /// <summary>
299    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
300    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
301    /// the corresponding parts from each class label.
302    /// </summary>
303    /// <param name="problemData">The classification problem data.</param>
304    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
305    /// <param name="random">The random generator used to shuffle the folds.</param>
306    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
307    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
308      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
309      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
310      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
311      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
312      while (enumerators.All(e => e.MoveNext())) {
313        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
314      }
315    }
316
317    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
318      // if number of folds is greater than the number of values, some empty folds will be returned
319      if (valuesCount < numberOfFolds) {
320        for (int i = 0; i < numberOfFolds; ++i)
321          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
322      } else {
323        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
324        int start = 0, end = f;
325        for (int i = 0; i < numberOfFolds; ++i) {
326          if (r > 0) {
327            ++end;
328            --r;
329          }
330          yield return values.Skip(start).Take(end - start);
331          start = end;
332          end += f;
333        }
334      }
335    }
336
337    private static Action<RFParameter, double> GenerateSetter(string field) {
338      var targetExp = Expression.Parameter(typeof(RFParameter));
339      var valueExp = Expression.Parameter(typeof(double));
340      var fieldExp = Expression.Property(targetExp, field);
341      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
342      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
343      return setter;
344    }
345
346    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
347      var regressionProblemData = problemData as IRegressionProblemData;
348      var classificationProblemData = problemData as IClassificationProblemData;
349
350      if (regressionProblemData != null)
351        return regressionProblemData.TargetVariable;
352      if (classificationProblemData != null)
353        return classificationProblemData.TargetVariable;
354
355      throw new ArgumentException("Problem data is neither regression or classification problem data.");
356    }
[11315]357  }
358}
Note: See TracBrowser for help on using the repository browser.