Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 17154

Last change on this file since 17154 was 17154, checked in by gkronber, 5 years ago

#2952: merged relevant revisions from branch to trunk

Merged revision(s) 17045-17153 from branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis:
#2952: Intermediate commit of refactoring RF models that is not yet finished.

........
#2952: Corrected evaluation in RF models.

........
#2952: Finished implementation of different RF models.

........
#2952 Fixed triggering model recalculation when cloning.
........
#2952: merged r17137 from trunk to branch
........
#2952: re-added backwards compatibility code for very old versions of GBT and RF
........
#2952: hide parameter in backwards compatibility hook
........

17045-17153

File size: 20.1 KB
RevLine 
[11315]1#region License Information
2
3/* HeuristicLab
[16565]4 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[11315]5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
[17154]29using HEAL.Attic;
[11315]30using HeuristicLab.Common;
[11362]31using HeuristicLab.Core;
[11315]32using HeuristicLab.Data;
[11443]33using HeuristicLab.Parameters;
[11315]34using HeuristicLab.Problems.DataAnalysis;
[11362]35using HeuristicLab.Random;
[11315]36
37namespace HeuristicLab.Algorithms.DataAnalysis {
[11443]38  [Item("RFParameter", "A random forest parameter collection")]
[16565]39  [StorableType("40E482DA-63C5-4D39-97C7-63701CF1D021")]
[11443]40  public class RFParameter : ParameterCollection {
41    public RFParameter() {
42      base.Add(new FixedValueParameter<IntValue>("N", "The number of random forest trees", new IntValue(50)));
43      base.Add(new FixedValueParameter<DoubleValue>("M", "The ratio of features that will be used in the construction of individual trees (0<m<=1)", new DoubleValue(0.1)));
44      base.Add(new FixedValueParameter<DoubleValue>("R", "The ratio of the training set that will be used in the construction of individual trees (0<r<=1)", new DoubleValue(0.1)));
45    }
[11315]46
[11443]47    [StorableConstructor]
[16565]48    protected RFParameter(StorableConstructorFlag _) : base(_) {
[11443]49    }
50
[11445]51    protected RFParameter(RFParameter original, Cloner cloner)
[11443]52      : base(original, cloner) {
53      this.N = original.N;
54      this.R = original.R;
55      this.M = original.M;
56    }
57
58    public override IDeepCloneable Clone(Cloner cloner) {
59      return new RFParameter(this, cloner);
60    }
61
62    private IFixedValueParameter<IntValue> NParameter {
63      get { return (IFixedValueParameter<IntValue>)base["N"]; }
64    }
65
66    private IFixedValueParameter<DoubleValue> RParameter {
67      get { return (IFixedValueParameter<DoubleValue>)base["R"]; }
68    }
69
70    private IFixedValueParameter<DoubleValue> MParameter {
71      get { return (IFixedValueParameter<DoubleValue>)base["M"]; }
72    }
73
74    public int N {
75      get { return NParameter.Value.Value; }
76      set { NParameter.Value.Value = value; }
77    }
78
79    public double R {
80      get { return RParameter.Value.Value; }
81      set { RParameter.Value.Value = value; }
82    }
83
84    public double M {
85      get { return MParameter.Value.Value; }
86      set { MParameter.Value.Value = value; }
87    }
[11315]88  }
89
90  public static class RandomForestUtil {
[17154]91    public static void AssertParameters(double r, double m) {
92      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
93      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
94    }
95
96    public static void AssertInputMatrix(double[,] inputMatrix) {
97      if (inputMatrix.ContainsNanOrInfinity())
98        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
99    }
100
101    internal static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
102      RandomForestUtil.AssertParameters(r, m);
103      RandomForestUtil.AssertInputMatrix(inputMatrix);
104
105      int info = 0;
106      alglib.math.rndobject = new System.Random(seed);
107      var dForest = new alglib.decisionforest();
108      rep = new alglib.dfreport();
109      int nRows = inputMatrix.GetLength(0);
110      int nColumns = inputMatrix.GetLength(1);
111      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
112      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
113
114      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
115      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
116      return dForest;
117    }
118
119
[11338]120    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
121      avgTestMse = 0;
122      var ds = problemData.Dataset;
123      var targetVariable = GetTargetVariableName(problemData);
124      foreach (var tuple in partitions) {
[11315]125        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
[11338]126        var trainingRandomForestPartition = tuple.Item1;
127        var testRandomForestPartition = tuple.Item2;
[11343]128        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
[11338]129        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
130        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
[11315]131        OnlineCalculatorError calculatorError;
[11338]132        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
[11315]133        if (calculatorError != OnlineCalculatorError.None)
134          mse = double.NaN;
[11338]135        avgTestMse += mse;
[11315]136      }
[11338]137      avgTestMse /= partitions.Length;
138    }
[11443]139
[11338]140    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
141      avgTestAccuracy = 0;
142      var ds = problemData.Dataset;
143      var targetVariable = GetTargetVariableName(problemData);
144      foreach (var tuple in partitions) {
145        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
146        var trainingRandomForestPartition = tuple.Item1;
147        var testRandomForestPartition = tuple.Item2;
[11343]148        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
[11338]149        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
150        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
151        OnlineCalculatorError calculatorError;
152        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
153        if (calculatorError != OnlineCalculatorError.None)
154          accuracy = double.NaN;
155        avgTestAccuracy += accuracy;
156      }
157      avgTestAccuracy /= partitions.Length;
[11315]158    }
159
[12509]160    /// <summary>
161    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
162    /// </summary>
163    /// <param name="problemData">The regression problem data</param>
164    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
165    /// <param name="seed">The random seed (required by the random forest model)</param>
166    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
[11362]167    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
168      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
169      var crossProduct = parameterRanges.Values.CartesianProduct();
170      double bestOutOfBagRmsError = double.MaxValue;
171      RFParameter bestParameters = new RFParameter();
172
[12509]173      var locker = new object();
[11362]174      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
175        var parameterValues = parameterCombination.ToList();
176        var parameters = new RFParameter();
177        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
178        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
[11443]179        RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
[11426]180
181        lock (locker) {
182          if (bestOutOfBagRmsError > outOfBagRmsError) {
[11362]183            bestOutOfBagRmsError = outOfBagRmsError;
184            bestParameters = (RFParameter)parameters.Clone();
185          }
186        }
187      });
188      return bestParameters;
189    }
190
[12509]191    /// <summary>
192    /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased)
193    /// </summary>
194    /// <param name="problemData">The classification problem data</param>
195    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
196    /// <param name="seed">The random seed (required by the random forest model)</param>
197    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
[11362]198    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
199      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
200      var crossProduct = parameterRanges.Values.CartesianProduct();
201
202      double bestOutOfBagRmsError = double.MaxValue;
203      RFParameter bestParameters = new RFParameter();
204
[12509]205      var locker = new object();
[11362]206      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
207        var parameterValues = parameterCombination.ToList();
208        var parameters = new RFParameter();
209        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
210        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
[11443]211        RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, parameters.N, parameters.R, parameters.M, seed,
[11362]212                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
[11426]213
214        lock (locker) {
215          if (bestOutOfBagRmsError > outOfBagRmsError) {
[11362]216            bestOutOfBagRmsError = outOfBagRmsError;
217            bestParameters = (RFParameter)parameters.Clone();
218          }
219        }
220      });
221      return bestParameters;
222    }
223
[12509]224    /// <summary>
225    /// Grid search with crossvalidation
226    /// </summary>
227    /// <param name="problemData">The regression problem data</param>
228    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
229    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
230    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
231    /// <param name="seed">The random seed (required by the random forest model)</param>
232    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
233    /// <returns>The best parameter values found by the grid search</returns>
[11362]234    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
[11315]235      DoubleValue mse = new DoubleValue(Double.MaxValue);
[11443]236      RFParameter bestParameter = new RFParameter();
[11315]237
[11343]238      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
[11338]239      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
[11343]240      var crossProduct = parameterRanges.Values.CartesianProduct();
[11315]241
[12509]242      var locker = new object();
[11343]243      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
244        var parameterValues = parameterCombination.ToList();
[11315]245        double testMSE;
246        var parameters = new RFParameter();
[11343]247        for (int i = 0; i < setters.Count; ++i) {
248          setters[i](parameters, parameterValues[i]);
[11315]249        }
[11443]250        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testMSE);
[11426]251
252        lock (locker) {
253          if (testMSE < mse.Value) {
[11343]254            mse.Value = testMSE;
255            bestParameter = (RFParameter)parameters.Clone();
256          }
[11315]257        }
258      });
259      return bestParameter;
260    }
[11338]261
[12509]262    /// <summary>
263    /// Grid search with crossvalidation
264    /// </summary>
265    /// <param name="problemData">The classification problem data</param>
266    /// <param name="numberOfFolds">The number of folds for crossvalidation</param>
267    /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param>
268    /// <param name="parameterRanges">The ranges for each parameter in the grid search</param>
269    /// <param name="seed">The random seed (for shuffling)</param>
270    /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param>
[11362]271    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
[11338]272      DoubleValue accuracy = new DoubleValue(0);
[11443]273      RFParameter bestParameter = new RFParameter();
[11338]274
[11343]275      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
276      var crossProduct = parameterRanges.Values.CartesianProduct();
[11362]277      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
[11338]278
[12509]279      var locker = new object();
[11343]280      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
281        var parameterValues = parameterCombination.ToList();
[11338]282        double testAccuracy;
283        var parameters = new RFParameter();
[11343]284        for (int i = 0; i < setters.Count; ++i) {
285          setters[i](parameters, parameterValues[i]);
[11338]286        }
[11443]287        CrossValidate(problemData, partitions, parameters.N, parameters.R, parameters.M, seed, out testAccuracy);
[11426]288
289        lock (locker) {
290          if (testAccuracy > accuracy.Value) {
[11343]291            accuracy.Value = testAccuracy;
292            bestParameter = (RFParameter)parameters.Clone();
293          }
[11338]294        }
295      });
296      return bestParameter;
297    }
298
[11362]299    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
300      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
[11343]301      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
302
303      for (int i = 0; i < numberOfFolds; ++i) {
304        int p = i; // avoid "access to modified closure" warning
305        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
306        var testRows = folds[i];
307        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
308      }
309      return partitions;
310    }
311
[11362]312    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
313      var random = new MersenneTwister((uint)Environment.TickCount);
314      if (problemData is IRegressionProblemData) {
315        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
316        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
317      }
318      if (problemData is IClassificationProblemData) {
319        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
320        // otherwise, generate folds normally
321        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
322      }
323      throw new ArgumentException("Problem data is neither regression or classification problem data.");
324    }
[11343]325
[11362]326    /// <summary>
327    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
328    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
329    /// the corresponding parts from each class label.
330    /// </summary>
331    /// <param name="problemData">The classification problem data.</param>
332    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
333    /// <param name="random">The random generator used to shuffle the folds.</param>
334    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
335    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
336      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
337      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
338      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
339      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
340      while (enumerators.All(e => e.MoveNext())) {
341        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
342      }
343    }
344
345    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
346      // if number of folds is greater than the number of values, some empty folds will be returned
347      if (valuesCount < numberOfFolds) {
348        for (int i = 0; i < numberOfFolds; ++i)
349          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
350      } else {
351        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
352        int start = 0, end = f;
353        for (int i = 0; i < numberOfFolds; ++i) {
354          if (r > 0) {
355            ++end;
356            --r;
357          }
358          yield return values.Skip(start).Take(end - start);
359          start = end;
360          end += f;
361        }
362      }
363    }
364
[11343]365    private static Action<RFParameter, double> GenerateSetter(string field) {
366      var targetExp = Expression.Parameter(typeof(RFParameter));
367      var valueExp = Expression.Parameter(typeof(double));
[11443]368      var fieldExp = Expression.Property(targetExp, field);
[11343]369      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
370      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
371      return setter;
372    }
373
[11338]374    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
375      var regressionProblemData = problemData as IRegressionProblemData;
376      var classificationProblemData = problemData as IClassificationProblemData;
377
378      if (regressionProblemData != null)
379        return regressionProblemData.TargetVariable;
380      if (classificationProblemData != null)
381        return classificationProblemData.TargetVariable;
382
383      throw new ArgumentException("Problem data is neither regression or classification problem data.");
384    }
[11315]385  }
386}
Note: See TracBrowser for help on using the repository browser.