Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/14/14 16:15:28 (10 years ago)
Author:
bburlacu
Message:

#2237: Addressed part of the comments above:

  • Methods are similar to the ones from SupportVectorMachineUtil
  • Cleaned up sample scripts
  • Elapsed time is shown in seconds
  • Included demo problem
  • Added stratified crossvalidation (shuffling is turned off by default)
  • Added different GridSearch methods with/without crossvalidation.
  • Fixed bug in fold generation when the number of folds is larger than the number of values
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r11343 r11362  
    2828using System.Threading.Tasks;
    2929using HeuristicLab.Common;
     30using HeuristicLab.Core;
    3031using HeuristicLab.Data;
    3132using HeuristicLab.Problems.DataAnalysis;
     33using HeuristicLab.Random;
    3234
    3335namespace HeuristicLab.Algorithms.DataAnalysis {
     
    4143
    4244  public static class RandomForestUtil {
    43     private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
    44       CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
    45     }
    4645    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
    4746      avgTestMse = 0;
     
    6362      avgTestMse /= partitions.Length;
    6463    }
    65 
    66     private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
    67       CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
    68     }
    6964    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
    7065      avgTestAccuracy = 0;
     
    8782    }
    8883
    89     private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     84    // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased
     85    public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     86      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     87      var crossProduct = parameterRanges.Values.CartesianProduct();
     88      double bestOutOfBagRmsError = double.MaxValue;
     89      RFParameter bestParameters = new RFParameter();
     90
     91      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     92        var parameterValues = parameterCombination.ToList();
     93        double testMSE;
     94        var parameters = new RFParameter();
     95        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
     96        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
     97        var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
     98                                                            out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
     99        if (bestOutOfBagRmsError > outOfBagRmsError) {
     100          lock (bestParameters) {
     101            bestOutOfBagRmsError = outOfBagRmsError;
     102            bestParameters = (RFParameter)parameters.Clone();
     103          }
     104        }
     105      });
     106      return bestParameters;
     107    }
     108
     109    public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     110      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     111      var crossProduct = parameterRanges.Values.CartesianProduct();
     112
     113      double bestOutOfBagRmsError = double.MaxValue;
     114      RFParameter bestParameters = new RFParameter();
     115
     116      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     117        var parameterValues = parameterCombination.ToList();
     118        var parameters = new RFParameter();
     119        for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
     120        double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
     121        var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
     122                                                                out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
     123        if (bestOutOfBagRmsError > outOfBagRmsError) {
     124          lock (bestParameters) {
     125            bestOutOfBagRmsError = outOfBagRmsError;
     126            bestParameters = (RFParameter)parameters.Clone();
     127          }
     128        }
     129      });
     130      return bestParameters;
     131    }
     132
     133    public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    90134      DoubleValue mse = new DoubleValue(Double.MaxValue);
    91       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
     135      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
    92136
    93137      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     
    102146          setters[i](parameters, parameterValues[i]);
    103147        }
    104         CrossValidate(problemData, partitions, parameters, seed, out testMSE);
     148        CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE);
    105149        if (testMSE < mse.Value) {
    106150          lock (mse) {
     
    113157    }
    114158
    115     private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     159    public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    116160      DoubleValue accuracy = new DoubleValue(0);
    117       RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
    118 
    119       var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
    120       var crossProduct = parameterRanges.Values.CartesianProduct();
    121       var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
     161      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
     162
     163      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     164      var crossProduct = parameterRanges.Values.CartesianProduct();
     165      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
    122166
    123167      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     
    128172          setters[i](parameters, parameterValues[i]);
    129173        }
    130         CrossValidate(problemData, partitions, parameters, seed, out testAccuracy);
     174        CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy);
    131175        if (testAccuracy > accuracy.Value) {
    132176          lock (accuracy) {
     
    139183    }
    140184
    141     /// <summary>
    142     /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
    143     /// </summary>
    144     /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
    145     /// <param name="problemData">The problem data</param>
    146     /// <param name="numberOfFolds">The number of folds to generate</param>
    147     /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
    148     private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
    149       int size = problemData.TrainingPartition.Size;
    150       int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
    151       int start = 0, end = f;
    152       for (int i = 0; i < numberOfFolds; ++i) {
    153         if (r > 0) { ++end; --r; }
    154         yield return problemData.TrainingIndices.Skip(start).Take(end - start);
    155         start = end;
    156         end += f;
    157       }
    158     }
    159 
    160     private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
    161       var folds = GenerateFolds(problemData, numberOfFolds).ToList();
     185    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
     186      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
    162187      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
    163188
     
    171196    }
    172197
     198    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
     199      var random = new MersenneTwister((uint)Environment.TickCount);
     200      if (problemData is IRegressionProblemData) {
     201        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
     202        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
     203      }
     204      if (problemData is IClassificationProblemData) {
     205        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
     206        // otherwise, generate folds normally
     207        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
     208      }
     209      throw new ArgumentException("Problem data is neither regression or classification problem data.");
     210    }
     211
     212    /// <summary>
     213    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
     214    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
     215    /// the corresponding parts from each class label.
     216    /// </summary>
     217    /// <param name="problemData">The classification problem data.</param>
     218    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
     219    /// <param name="random">The random generator used to shuffle the folds.</param>
     220    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
     221    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
     222      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     223      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
     224      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
     225      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
     226      while (enumerators.All(e => e.MoveNext())) {
     227        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
     228      }
     229    }
     230
     231    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
     232      // if number of folds is greater than the number of values, some empty folds will be returned
     233      if (valuesCount < numberOfFolds) {
     234        for (int i = 0; i < numberOfFolds; ++i)
     235          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
     236      } else {
     237        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
     238        int start = 0, end = f;
     239        for (int i = 0; i < numberOfFolds; ++i) {
     240          if (r > 0) {
     241            ++end;
     242            --r;
     243          }
     244          yield return values.Skip(start).Take(end - start);
     245          start = end;
     246          end += f;
     247        }
     248      }
     249    }
    173250
    174251    private static Action<RFParameter, double> GenerateSetter(string field) {
Note: See TracChangeset for help on using the changeset viewer.