source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 11426

Last change on this file since 11426 was 11426, checked in by bburlacu, 5 years ago

#2237: Fixed thread synchronisation bug. Removed unused variables in GridSearch methods.

File size: 14.7 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Data;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  public class RFParameter : ICloneable {
37    public double n; // number of trees
38    public double m;
39    public double r;
40
41    public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; }
42  }
43
44  public static class RandomForestUtil {
45    private static readonly object locker = new object();
46
47    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
48      avgTestMse = 0;
49      var ds = problemData.Dataset;
50      var targetVariable = GetTargetVariableName(problemData);
51      foreach (var tuple in partitions) {
52        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
53        var trainingRandomForestPartition = tuple.Item1;
54        var testRandomForestPartition = tuple.Item2;
55        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
56        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
57        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
58        OnlineCalculatorError calculatorError;
59        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
60        if (calculatorError != OnlineCalculatorError.None)
61          mse = double.NaN;
62        avgTestMse += mse;
63      }
64      avgTestMse /= partitions.Length;
65    }
66    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
67      avgTestAccuracy = 0;
68      var ds = problemData.Dataset;
69      var targetVariable = GetTargetVariableName(problemData);
70      foreach (var tuple in partitions) {
71        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
72        var trainingRandomForestPartition = tuple.Item1;
73        var testRandomForestPartition = tuple.Item2;
74        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
75        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
76        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
77        OnlineCalculatorError calculatorError;
78        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
79        if (calculatorError != OnlineCalculatorError.None)
80          accuracy = double.NaN;
81        avgTestAccuracy += accuracy;
82      }
83      avgTestAccuracy /= partitions.Length;
84    }
85
86    // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased
87    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
88      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
89      var crossProduct = parameterRanges.Values.CartesianProduct();
90      double bestOutOfBagRmsError = double.MaxValue;
91      RFParameter bestParameters = new RFParameter();
92
93      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
94        var parameterValues = parameterCombination.ToList();
95        var parameters = new RFParameter();
96        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
97        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
98        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
99
100        lock (locker) {
101          if (bestOutOfBagRmsError > outOfBagRmsError) {
102            bestOutOfBagRmsError = outOfBagRmsError;
103            bestParameters = (RFParameter)parameters.Clone();
104          }
105        }
106      });
107      return bestParameters;
108    }
109
110    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
111      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
112      var crossProduct = parameterRanges.Values.CartesianProduct();
113
114      double bestOutOfBagRmsError = double.MaxValue;
115      RFParameter bestParameters = new RFParameter();
116
117      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
118        var parameterValues = parameterCombination.ToList();
119        var parameters = new RFParameter();
120        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
121        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
122        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
123                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
124
125        lock (locker) {
126          if (bestOutOfBagRmsError > outOfBagRmsError) {
127            bestOutOfBagRmsError = outOfBagRmsError;
128            bestParameters = (RFParameter)parameters.Clone();
129          }
130        }
131      });
132      return bestParameters;
133    }
134
135    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
136      DoubleValue mse = new DoubleValue(Double.MaxValue);
137      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
138
139      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
140      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
141      var crossProduct = parameterRanges.Values.CartesianProduct();
142
143      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
144        var parameterValues = parameterCombination.ToList();
145        double testMSE;
146        var parameters = new RFParameter();
147        for (int i = 0; i < setters.Count; ++i) {
148          setters[i](parameters, parameterValues[i]);
149        }
150        CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE);
151
152        lock (locker) {
153          if (testMSE < mse.Value) {
154            mse.Value = testMSE;
155            bestParameter = (RFParameter)parameters.Clone();
156          }
157        }
158      });
159      return bestParameter;
160    }
161
162    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
163      DoubleValue accuracy = new DoubleValue(0);
164      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
165
166      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
167      var crossProduct = parameterRanges.Values.CartesianProduct();
168      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
169
170      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
171        var parameterValues = parameterCombination.ToList();
172        double testAccuracy;
173        var parameters = new RFParameter();
174        for (int i = 0; i < setters.Count; ++i) {
175          setters[i](parameters, parameterValues[i]);
176        }
177        CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy);
178
179        lock (locker) {
180          if (testAccuracy > accuracy.Value) {
181            accuracy.Value = testAccuracy;
182            bestParameter = (RFParameter)parameters.Clone();
183          }
184        }
185      });
186      return bestParameter;
187    }
188
189    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
190      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
191      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
192
193      for (int i = 0; i < numberOfFolds; ++i) {
194        int p = i; // avoid "access to modified closure" warning
195        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
196        var testRows = folds[i];
197        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
198      }
199      return partitions;
200    }
201
202    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
203      var random = new MersenneTwister((uint)Environment.TickCount);
204      if (problemData is IRegressionProblemData) {
205        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
206        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
207      }
208      if (problemData is IClassificationProblemData) {
209        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
210        // otherwise, generate folds normally
211        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
212      }
213      throw new ArgumentException("Problem data is neither regression or classification problem data.");
214    }
215
216    /// <summary>
217    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
218    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
219    /// the corresponding parts from each class label.
220    /// </summary>
221    /// <param name="problemData">The classification problem data.</param>
222    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
223    /// <param name="random">The random generator used to shuffle the folds.</param>
224    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
225    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
226      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
227      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
228      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
229      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
230      while (enumerators.All(e => e.MoveNext())) {
231        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
232      }
233    }
234
235    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
236      // if number of folds is greater than the number of values, some empty folds will be returned
237      if (valuesCount < numberOfFolds) {
238        for (int i = 0; i < numberOfFolds; ++i)
239          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
240      } else {
241        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
242        int start = 0, end = f;
243        for (int i = 0; i < numberOfFolds; ++i) {
244          if (r > 0) {
245            ++end;
246            --r;
247          }
248          yield return values.Skip(start).Take(end - start);
249          start = end;
250          end += f;
251        }
252      }
253    }
254
255    private static Action<RFParameter, double> GenerateSetter(string field) {
256      var targetExp = Expression.Parameter(typeof(RFParameter));
257      var valueExp = Expression.Parameter(typeof(double));
258      var fieldExp = Expression.Field(targetExp, field);
259      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
260      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
261      return setter;
262    }
263
264    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
265      var regressionProblemData = problemData as IRegressionProblemData;
266      var classificationProblemData = problemData as IClassificationProblemData;
267
268      if (regressionProblemData != null)
269        return regressionProblemData.TargetVariable;
270      if (classificationProblemData != null)
271        return classificationProblemData.TargetVariable;
272
273      throw new ArgumentException("Problem data is neither regression or classification problem data.");
274    }
275  }
276}
Note: See TracBrowser for help on using the repository browser.