source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs @ 11343

Last change on this file since 11343 was 11343, checked in by mkommend, 8 years ago

#2237: Corrected newly introduced bug in RandomForestModel and reorganized RandomForestUtil.

File size: 9.8 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Linq.Expressions;
28using System.Threading.Tasks;
29using HeuristicLab.Common;
30using HeuristicLab.Data;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  public class RFParameter : ICloneable {
35    public double n; // number of trees
36    public double m;
37    public double r;
38
39    public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; }
40  }
41
42  public static class RandomForestUtil {
43    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
44      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
45    }
46    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
47      avgTestMse = 0;
48      var ds = problemData.Dataset;
49      var targetVariable = GetTargetVariableName(problemData);
50      foreach (var tuple in partitions) {
51        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
52        var trainingRandomForestPartition = tuple.Item1;
53        var testRandomForestPartition = tuple.Item2;
54        var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
55        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
56        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
57        OnlineCalculatorError calculatorError;
58        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
59        if (calculatorError != OnlineCalculatorError.None)
60          mse = double.NaN;
61        avgTestMse += mse;
62      }
63      avgTestMse /= partitions.Length;
64    }
65
66    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
67      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
68    }
69    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
70      avgTestAccuracy = 0;
71      var ds = problemData.Dataset;
72      var targetVariable = GetTargetVariableName(problemData);
73      foreach (var tuple in partitions) {
74        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
75        var trainingRandomForestPartition = tuple.Item1;
76        var testRandomForestPartition = tuple.Item2;
77        var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
78        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
79        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
80        OnlineCalculatorError calculatorError;
81        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
82        if (calculatorError != OnlineCalculatorError.None)
83          accuracy = double.NaN;
84        avgTestAccuracy += accuracy;
85      }
86      avgTestAccuracy /= partitions.Length;
87    }
88
89    private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
90      DoubleValue mse = new DoubleValue(Double.MaxValue);
91      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
92
93      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
94      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
95      var crossProduct = parameterRanges.Values.CartesianProduct();
96
97      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
98        var parameterValues = parameterCombination.ToList();
99        double testMSE;
100        var parameters = new RFParameter();
101        for (int i = 0; i < setters.Count; ++i) {
102          setters[i](parameters, parameterValues[i]);
103        }
104        CrossValidate(problemData, partitions, parameters, seed, out testMSE);
105        if (testMSE < mse.Value) {
106          lock (mse) {
107            mse.Value = testMSE;
108            bestParameter = (RFParameter)parameters.Clone();
109          }
110        }
111      });
112      return bestParameter;
113    }
114
115    private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
116      DoubleValue accuracy = new DoubleValue(0);
117      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
118
119      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
120      var crossProduct = parameterRanges.Values.CartesianProduct();
121      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
122
123      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
124        var parameterValues = parameterCombination.ToList();
125        double testAccuracy;
126        var parameters = new RFParameter();
127        for (int i = 0; i < setters.Count; ++i) {
128          setters[i](parameters, parameterValues[i]);
129        }
130        CrossValidate(problemData, partitions, parameters, seed, out testAccuracy);
131        if (testAccuracy > accuracy.Value) {
132          lock (accuracy) {
133            accuracy.Value = testAccuracy;
134            bestParameter = (RFParameter)parameters.Clone();
135          }
136        }
137      });
138      return bestParameter;
139    }
140
141    /// <summary>
142    /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
143    /// </summary>
144    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
145    /// <param name="problemData">The problem data</param>
146    /// <param name="numberOfFolds">The number of folds to generate</param>
147    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
148    private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
149      int size = problemData.TrainingPartition.Size;
150      int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
151      int start = 0, end = f;
152      for (int i = 0; i < numberOfFolds; ++i) {
153        if (r > 0) { ++end; --r; }
154        yield return problemData.TrainingIndices.Skip(start).Take(end - start);
155        start = end;
156        end += f;
157      }
158    }
159
160    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
161      var folds = GenerateFolds(problemData, numberOfFolds).ToList();
162      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
163
164      for (int i = 0; i < numberOfFolds; ++i) {
165        int p = i; // avoid "access to modified closure" warning
166        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
167        var testRows = folds[i];
168        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
169      }
170      return partitions;
171    }
172
173
174    private static Action<RFParameter, double> GenerateSetter(string field) {
175      var targetExp = Expression.Parameter(typeof(RFParameter));
176      var valueExp = Expression.Parameter(typeof(double));
177      var fieldExp = Expression.Field(targetExp, field);
178      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
179      var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile();
180      return setter;
181    }
182
183    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
184      var regressionProblemData = problemData as IRegressionProblemData;
185      var classificationProblemData = problemData as IClassificationProblemData;
186
187      if (regressionProblemData != null)
188        return regressionProblemData.TargetVariable;
189      if (classificationProblemData != null)
190        return classificationProblemData.TargetVariable;
191
192      throw new ArgumentException("Problem data is neither regression or classification problem data.");
193    }
194  }
195}
Note: See TracBrowser for help on using the repository browser.