Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 11362

Last change on this file since 11362 was 11362, checked in by bburlacu, 10 years ago

#2237: Addressed part of the comments above:

  • Methods are similar to the ones from SupportVectorMachineUtil
  • Cleaned up sample scripts
  • Elapsed time is shown in seconds
  • Included demo problem
  • Added stratified crossvalidation (shuffling is turned off by default)
  • Added different GridSearch methods with/without crossvalidation.
  • Fixed bug in fold generation when the number of folds is larger than the number of values
File size: 14.7 KB
RevLine 
[11315]1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
29using HeuristicLab.Common;
[11362]30using HeuristicLab.Core;
[11315]31using HeuristicLab.Data;
32using HeuristicLab.Problems.DataAnalysis;
[11362]33using HeuristicLab.Random;
[11315]34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  public class RFParameter : ICloneable {
37    public double n; // number of trees
38    public double m;
39    public double r;
40
41    public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; }
42  }
43
44  public static class RandomForestUtil {
[11338]45    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
46      avgTestMse = 0;
47      var ds = problemData.Dataset;
48      var targetVariable = GetTargetVariableName(problemData);
49      foreach (var tuple in partitions) {
[11315]50        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
[11338]51        var trainingRandomForestPartition = tuple.Item1;
52        var testRandomForestPartition = tuple.Item2;
[11343]53        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
[11338]54        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
55        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
[11315]56        OnlineCalculatorError calculatorError;
[11338]57        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
[11315]58        if (calculatorError != OnlineCalculatorError.None)
59          mse = double.NaN;
[11338]60        avgTestMse += mse;
[11315]61      }
[11338]62      avgTestMse /= partitions.Length;
63    }
64    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
65      avgTestAccuracy = 0;
66      var ds = problemData.Dataset;
67      var targetVariable = GetTargetVariableName(problemData);
68      foreach (var tuple in partitions) {
69        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
70        var trainingRandomForestPartition = tuple.Item1;
71        var testRandomForestPartition = tuple.Item2;
[11343]72        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
[11338]73        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
74        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
75        OnlineCalculatorError calculatorError;
76        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
77        if (calculatorError != OnlineCalculatorError.None)
78          accuracy = double.NaN;
79        avgTestAccuracy += accuracy;
80      }
81      avgTestAccuracy /= partitions.Length;
[11315]82    }
83
[11362]84    // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased
85    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
86      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
87      var crossProduct = parameterRanges.Values.CartesianProduct();
88      double bestOutOfBagRmsError = double.MaxValue;
89      RFParameter bestParameters = new RFParameter();
90
91      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
92        var parameterValues = parameterCombination.ToList();
93        double testMSE;
94        var parameters = new RFParameter();
95        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
96        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
97        var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
98                                                            out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
99        if (bestOutOfBagRmsError > outOfBagRmsError) {
100          lock (bestParameters) {
101            bestOutOfBagRmsError = outOfBagRmsError;
102            bestParameters = (RFParameter)parameters.Clone();
103          }
104        }
105      });
106      return bestParameters;
107    }
108
109    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
110      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
111      var crossProduct = parameterRanges.Values.CartesianProduct();
112
113      double bestOutOfBagRmsError = double.MaxValue;
114      RFParameter bestParameters = new RFParameter();
115
116      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
117        var parameterValues = parameterCombination.ToList();
118        var parameters = new RFParameter();
119        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
120        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
121        var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
122                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
123        if (bestOutOfBagRmsError > outOfBagRmsError) {
124          lock (bestParameters) {
125            bestOutOfBagRmsError = outOfBagRmsError;
126            bestParameters = (RFParameter)parameters.Clone();
127          }
128        }
129      });
130      return bestParameters;
131    }
132
133    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
[11315]134      DoubleValue mse = new DoubleValue(Double.MaxValue);
[11362]135      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
[11315]136
[11343]137      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
[11338]138      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
[11343]139      var crossProduct = parameterRanges.Values.CartesianProduct();
[11315]140
[11343]141      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
142        var parameterValues = parameterCombination.ToList();
[11315]143        double testMSE;
144        var parameters = new RFParameter();
[11343]145        for (int i = 0; i < setters.Count; ++i) {
146          setters[i](parameters, parameterValues[i]);
[11315]147        }
[11362]148        CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE);
[11315]149        if (testMSE < mse.Value) {
[11343]150          lock (mse) {
151            mse.Value = testMSE;
152            bestParameter = (RFParameter)parameters.Clone();
153          }
[11315]154        }
155      });
156      return bestParameter;
157    }
[11338]158
[11362]159    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
[11338]160      DoubleValue accuracy = new DoubleValue(0);
[11362]161      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
[11338]162
[11343]163      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
164      var crossProduct = parameterRanges.Values.CartesianProduct();
[11362]165      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
[11338]166
[11343]167      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
168        var parameterValues = parameterCombination.ToList();
[11338]169        double testAccuracy;
170        var parameters = new RFParameter();
[11343]171        for (int i = 0; i < setters.Count; ++i) {
172          setters[i](parameters, parameterValues[i]);
[11338]173        }
[11362]174        CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy);
[11338]175        if (testAccuracy > accuracy.Value) {
[11343]176          lock (accuracy) {
177            accuracy.Value = testAccuracy;
178            bestParameter = (RFParameter)parameters.Clone();
179          }
[11338]180        }
181      });
182      return bestParameter;
183    }
184
[11362]185    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
186      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
[11343]187      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
188
189      for (int i = 0; i < numberOfFolds; ++i) {
190        int p = i; // avoid "access to modified closure" warning
191        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
192        var testRows = folds[i];
193        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
194      }
195      return partitions;
196    }
197
[11362]198    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
199      var random = new MersenneTwister((uint)Environment.TickCount);
200      if (problemData is IRegressionProblemData) {
201        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
202        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
203      }
204      if (problemData is IClassificationProblemData) {
205        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
206        // otherwise, generate folds normally
207        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
208      }
209      throw new ArgumentException("Problem data is neither regression or classification problem data.");
210    }
[11343]211
[11362]212    /// <summary>
213    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
214    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
215    /// the corresponding parts from each class label.
216    /// </summary>
217    /// <param name="problemData">The classification problem data.</param>
218    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
219    /// <param name="random">The random generator used to shuffle the folds.</param>
220    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
221    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
222      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
223      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
224      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
225      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
226      while (enumerators.All(e => e.MoveNext())) {
227        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
228      }
229    }
230
231    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
232      // if number of folds is greater than the number of values, some empty folds will be returned
233      if (valuesCount < numberOfFolds) {
234        for (int i = 0; i < numberOfFolds; ++i)
235          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
236      } else {
237        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
238        int start = 0, end = f;
239        for (int i = 0; i < numberOfFolds; ++i) {
240          if (r > 0) {
241            ++end;
242            --r;
243          }
244          yield return values.Skip(start).Take(end - start);
245          start = end;
246          end += f;
247        }
248      }
249    }
250
[11343]251    private static Action<RFParameter, double> GenerateSetter(string field) {
252      var targetExp = Expression.Parameter(typeof(RFParameter));
253      var valueExp = Expression.Parameter(typeof(double));
254      var fieldExp = Expression.Field(targetExp, field);
255      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
256      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
257      return setter;
258    }
259
[11338]260    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
261      var regressionProblemData = problemData as IRegressionProblemData;
262      var classificationProblemData = problemData as IClassificationProblemData;
263
264      if (regressionProblemData != null)
265        return regressionProblemData.TargetVariable;
266      if (classificationProblemData != null)
267        return classificationProblemData.TargetVariable;
268
269      throw new ArgumentException("Problem data is neither regression or classification problem data.");
270    }
[11315]271  }
272}
Note: See TracBrowser for help on using the repository browser.