Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 12329

Last change on this file since 12329 was 12012, checked in by ascheibe, 10 years ago

#2212 merged r12008, r12009, r12010 back into trunk

File size: 16.3 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Data;
32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
34using HeuristicLab.Problems.DataAnalysis;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.Algorithms.DataAnalysis {
38  [Item("RFParameter", "A random forest parameter collection")]
39  [StorableClass]
40  public class RFParameter : ParameterCollection {
41    public RFParameter() {
42      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
43      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
44      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
45    }
46
47    [StorableConstructor]
48    protected RFParameter(bool deserializing)
49      : base(deserializing) {
50    }
51
52    protected RFParameter(RFParameter original, Cloner cloner)
53      : base(original, cloner) {
54      this.N = original.N;
55      this.R = original.R;
56      this.M = original.M;
57    }
58
59    public override IDeepCloneable Clone(Cloner cloner) {
60      return new RFParameter(this, cloner);
61    }
62
63    private IFixedValueParameter<IntValue> NParameter {
64      get { return (IFixedValueParameter<IntValue>)base["N"]; }
65    }
66
67    private IFixedValueParameter<DoubleValue> RParameter {
68      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
69    }
70
71    private IFixedValueParameter<DoubleValue> MParameter {
72      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
73    }
74
75    public int N {
76      get { return NParameter.Value.Value; }
77      set { NParameter.Value.Value = value; }
78    }
79
80    public double R {
81      get { return RParameter.Value.Value; }
82      set { RParameter.Value.Value = value; }
83    }
84
85    public double M {
86      get { return MParameter.Value.Value; }
87      set { MParameter.Value.Value = value; }
88    }
89  }
90
91  public static class RandomForestUtil {
92    private static readonly object locker = new object();
93
94    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
95      avgTestMse = 0;
96      var ds = problemData.Dataset;
97      var targetVariable = GetTargetVariableName(problemData);
98      foreach (var tuple in partitions) {
99        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
100        var trainingRandomForestPartition = tuple.Item1;
101        var testRandomForestPartition = tuple.Item2;
102        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
103        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
104        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
105        OnlineCalculatorError calculatorError;
106        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
107        if (calculatorError != OnlineCalculatorError.None)
108          mse = double.NaN;
109        avgTestMse += mse;
110      }
111      avgTestMse /= partitions.Length;
112    }
113
114    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
115      avgTestAccuracy = 0;
116      var ds = problemData.Dataset;
117      var targetVariable = GetTargetVariableName(problemData);
118      foreach (var tuple in partitions) {
119        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
120        var trainingRandomForestPartition = tuple.Item1;
121        var testRandomForestPartition = tuple.Item2;
122        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
123        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
124        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
125        OnlineCalculatorError calculatorError;
126        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
127        if (calculatorError != OnlineCalculatorError.None)
128          accuracy = double.NaN;
129        avgTestAccuracy += accuracy;
130      }
131      avgTestAccuracy /= partitions.Length;
132    }
133
134    // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased
135    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
136      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
137      var crossProduct = parameterRanges.Values.CartesianProduct();
138      double bestOutOfBagRmsError = double.MaxValue;
139      RFParameter bestParameters = new RFParameter();
140
141      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
142        var parameterValues = parameterCombination.ToList();
143        var parameters = new RFParameter();
144        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
145        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
146        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
147
148        lock (locker) {
149          if (bestOutOfBagRmsError > outOfBagRmsError) {
150            bestOutOfBagRmsError = outOfBagRmsError;
151            bestParameters = (RFParameter)parameters.Clone();
152          }
153        }
154      });
155      return bestParameters;
156    }
157
158    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
159      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
160      var crossProduct = parameterRanges.Values.CartesianProduct();
161
162      double bestOutOfBagRmsError = double.MaxValue;
163      RFParameter bestParameters = new RFParameter();
164
165      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
166        var parameterValues = parameterCombination.ToList();
167        var parameters = new RFParameter();
168        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
169        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
170        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
171                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
172
173        lock (locker) {
174          if (bestOutOfBagRmsError > outOfBagRmsError) {
175            bestOutOfBagRmsError = outOfBagRmsError;
176            bestParameters = (RFParameter)parameters.Clone();
177          }
178        }
179      });
180      return bestParameters;
181    }
182
183    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
184      DoubleValue mse = new DoubleValue(Double.MaxValue);
185      RFParameter bestParameter = new RFParameter();
186
187      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
188      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
189      var crossProduct = parameterRanges.Values.CartesianProduct();
190
191      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
192        var parameterValues = parameterCombination.ToList();
193        double testMSE;
194        var parameters = new RFParameter();
195        for (int i = 0; i < setters.Count; ++i) {
196          setters[i](parameters, parameterValues[i]);
197        }
198        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
199
200        lock (locker) {
201          if (testMSE < mse.Value) {
202            mse.Value = testMSE;
203            bestParameter = (RFParameter)parameters.Clone();
204          }
205        }
206      });
207      return bestParameter;
208    }
209
210    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
211      DoubleValue accuracy = new DoubleValue(0);
212      RFParameter bestParameter = new RFParameter();
213
214      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
215      var crossProduct = parameterRanges.Values.CartesianProduct();
216      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
217
218      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
219        var parameterValues = parameterCombination.ToList();
220        double testAccuracy;
221        var parameters = new RFParameter();
222        for (int i = 0; i < setters.Count; ++i) {
223          setters[i](parameters, parameterValues[i]);
224        }
225        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
226
227        lock (locker) {
228          if (testAccuracy > accuracy.Value) {
229            accuracy.Value = testAccuracy;
230            bestParameter = (RFParameter)parameters.Clone();
231          }
232        }
233      });
234      return bestParameter;
235    }
236
237    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
238      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
239      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
240
241      for (int i = 0; i < numberOfFolds; ++i) {
242        int p = i; // avoid "access to modified closure" warning
243        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
244        var testRows = folds[i];
245        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
246      }
247      return partitions;
248    }
249
250    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
251      var random = new MersenneTwister((uint)Environment.TickCount);
252      if (problemData is IRegressionProblemData) {
253        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
254        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
255      }
256      if (problemData is IClassificationProblemData) {
257        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
258        // otherwise, generate folds normally
259        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
260      }
261      throw new ArgumentException("Problem data is neither regression or classification problem data.");
262    }
263
264    /// <summary>
265    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
266    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
267    /// the corresponding parts from each class label.
268    /// </summary>
269    /// <param name="problemData">The classification problem data.</param>
270    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
271    /// <param name="random">The random generator used to shuffle the folds.</param>
272    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
273    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
274      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
275      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
276      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
277      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
278      while (enumerators.All(e => e.MoveNext())) {
279        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
280      }
281    }
282
283    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
284      // if number of folds is greater than the number of values, some empty folds will be returned
285      if (valuesCount < numberOfFolds) {
286        for (int i = 0; i < numberOfFolds; ++i)
287          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
288      } else {
289        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
290        int start = 0, end = f;
291        for (int i = 0; i < numberOfFolds; ++i) {
292          if (r > 0) {
293            ++end;
294            --r;
295          }
296          yield return values.Skip(start).Take(end - start);
297          start = end;
298          end += f;
299        }
300      }
301    }
302
303    private static Action<RFParameter, double> GenerateSetter(string field) {
304      var targetExp = Expression.Parameter(typeof(RFParameter));
305      var valueExp = Expression.Parameter(typeof(double));
306      var fieldExp = Expression.Property(targetExp, field);
307      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
308      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
309      return setter;
310    }
311
312    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
313      var regressionProblemData = problemData as IRegressionProblemData;
314      var classificationProblemData = problemData as IClassificationProblemData;
315
316      if (regressionProblemData != null)
317        return regressionProblemData.TargetVariable;
318      if (classificationProblemData != null)
319        return classificationProblemData.TargetVariable;
320
321      throw new ArgumentException("Problem data is neither regression or classification problem data.");
322    }
323  }
324}
Note: See TracBrowser for help on using the repository browser.