source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 11338

Last change on this file since 11338 was 11338, checked in by bburlacu, 8 years ago

#2237: Refactored random forest grid search and added support for symbolic classification.

File size: 12.1 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
29using HeuristicLab.Common;
30using HeuristicLab.Data;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  public class RFParameter : ICloneable {
35    public double n; // number of trees
36    public double m;
37    public double r;
38
39    public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; }
40  }
41
42  public static class RandomForestUtil {
43    private static Action<RFParameter, double> GenerateSetter(string field) {
44      var targetExp = Expression.Parameter(typeof(RFParameter));
45      var valueExp = Expression.Parameter(typeof(double));
46      var fieldExp = Expression.Field(targetExp, field);
47      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
48      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
49      return setter;
50    }
51
52    /// <summary>
53    /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
54    /// </summary>
55    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
56    /// <param name="problemData">The problem data</param>
57    /// <param name="numberOfFolds">The number of folds to generate</param>
58    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
59    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
60      int size = problemData.TrainingPartition.Size;
61      int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
62      int start = 0, end = f;
63      for (int i = 0; i < numberOfFolds; ++i) {
64        if (r > 0) { ++end; --r; }
65        yield return problemData.TrainingIndices.Skip(start).Take(end - start);
66        start = end;
67        end += f;
68      }
69    }
70
71    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
72      var folds = GenerateFolds(problemData, numberOfFolds).ToList();
73      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
74
75      for (int i = 0; i < numberOfFolds; ++i) {
76        int p = i; // avoid "access to modified closure" warning
77        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
78        var testRows = folds[i];
79        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
80      }
81      return partitions;
82    }
83
84    public static void CrossValidate(IDataAnalysisProblemData problemData, int numberOfFolds, RFParameter parameters, int seed, out double error) {
85      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
86      CrossValidate(problemData, partitions, parameters, seed, out error);
87    }
88
89    // user should call the more specific CrossValidate methods
90    public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double error) {
91      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out error);
92    }
93
94    public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double error) {
95      var regressionProblemData = problemData as IRegressionProblemData;
96      var classificationProblemData = problemData as IClassificationProblemData;
97      if (regressionProblemData != null)
98        CrossValidate(regressionProblemData, partitions, nTrees, m, r, seed, out error);
99      else if (classificationProblemData != null)
100        CrossValidate(classificationProblemData, partitions, nTrees, m, r, seed, out error);
101      else throw new ArgumentException("Problem data is neither regression or classification problem data.");
102    }
103
104    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
105      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
106    }
107
108    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
109      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
110    }
111
112    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
113      avgTestMse = 0;
114      var ds = problemData.Dataset;
115      var targetVariable = GetTargetVariableName(problemData);
116      foreach (var tuple in partitions) {
117        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
118        var trainingRandomForestPartition = tuple.Item1;
119        var testRandomForestPartition = tuple.Item2;
120        var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition);
121        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
122        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
123        OnlineCalculatorError calculatorError;
124        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
125        if (calculatorError != OnlineCalculatorError.None)
126          mse = double.NaN;
127        avgTestMse += mse;
128      }
129      avgTestMse /= partitions.Length;
130    }
131
132    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
133      avgTestAccuracy = 0;
134      var ds = problemData.Dataset;
135      var targetVariable = GetTargetVariableName(problemData);
136      foreach (var tuple in partitions) {
137        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
138        var trainingRandomForestPartition = tuple.Item1;
139        var testRandomForestPartition = tuple.Item2;
140        var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition);
141        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
142        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
143        OnlineCalculatorError calculatorError;
144        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
145        if (calculatorError != OnlineCalculatorError.None)
146          accuracy = double.NaN;
147        avgTestAccuracy += accuracy;
148      }
149      avgTestAccuracy /= partitions.Length;
150    }
151
152    public static RFParameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
153      var regressionProblemData = problemData as IRegressionProblemData;
154      var classificationProblemData = problemData as IClassificationProblemData;
155
156      if (regressionProblemData != null)
157        return GridSearch(regressionProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism);
158      if (classificationProblemData != null)
159        return GridSearch(classificationProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism);
160
161      throw new ArgumentException("Problem data is neither regression or classification problem data.");
162    }
163
164    private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
165      DoubleValue mse = new DoubleValue(Double.MaxValue);
166      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
167
168      var pNames = parameterRanges.Keys.ToList();
169      var pRanges = pNames.Select(x => parameterRanges[x]);
170      var setters = pNames.Select(GenerateSetter).ToList();
171      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
172      var crossProduct = pRanges.CartesianProduct();
173
174      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
175        var list = nuple.ToList();
176        double testMSE;
177        var parameters = new RFParameter();
178        for (int i = 0; i < pNames.Count; ++i) {
179          var s = setters[i];
180          s(parameters, list[i]);
181        }
182        CrossValidate(problemData, partitions, parameters, seed, out testMSE);
183        if (testMSE < mse.Value) {
184          lock (mse) { mse.Value = testMSE; }
185          lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); }
186        }
187      });
188      return bestParameter;
189    }
190
191    private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
192      DoubleValue accuracy = new DoubleValue(0);
193      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
194
195      var pNames = parameterRanges.Keys.ToList();
196      var pRanges = pNames.Select(x => parameterRanges[x]);
197      var setters = pNames.Select(GenerateSetter).ToList();
198      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
199      var crossProduct = pRanges.CartesianProduct();
200
201      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
202        var list = nuple.ToList();
203        double testAccuracy;
204        var parameters = new RFParameter();
205        for (int i = 0; i < pNames.Count; ++i) {
206          var s = setters[i];
207          s(parameters, list[i]);
208        }
209        CrossValidate(problemData, partitions, parameters, seed, out testAccuracy);
210        if (testAccuracy > accuracy.Value) {
211          lock (accuracy) { accuracy.Value = testAccuracy; }
212          lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); }
213        }
214      });
215      return bestParameter;
216    }
217
218    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
219      var regressionProblemData = problemData as IRegressionProblemData;
220      var classificationProblemData = problemData as IClassificationProblemData;
221
222      if (regressionProblemData != null)
223        return regressionProblemData.TargetVariable;
224      if (classificationProblemData != null)
225        return classificationProblemData.TargetVariable;
226
227      throw new ArgumentException("Problem data is neither regression or classification problem data.");
228    }
229  }
230}
Note: See TracBrowser for help on using the repository browser.