Ignore:
Timestamp:
09/03/14 15:15:41 (8 years ago)
Author:
bburlacu
Message:

#2237: Refactored random forest grid search and added support for symbolic classification.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs

    r11315 r11338  
    4444      var targetExp = Expression.Parameter(typeof(RFParameter));
    4545      var valueExp = Expression.Parameter(typeof(double));
    46 
    47       // Expression.Property can be used here as well
    4846      var fieldExp = Expression.Field(targetExp, field);
    4947      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
     
    5351
    5452    /// <summary>
    55     /// Generate a collection of training indices corresponding to folds in the data (used for crossvalidation)
     53    /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
    5654    /// </summary>
    5755    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
    5856    /// <param name="problemData">The problem data</param>
    59     /// <param name="nFolds">The number of folds to generate</param>
     57    /// <param name="numberOfFolds">The number of folds to generate</param>
    6058    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
    61     public static IEnumerable<IEnumerable<int>> GenerateFolds(IRegressionProblemData problemData, int nFolds) {
     59    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
    6260      int size = problemData.TrainingPartition.Size;
    63 
    64       int foldSize = size / nFolds; // rounding to integer
    65       var trainingIndices = problemData.TrainingIndices;
    66 
    67       for (int i = 0; i < nFolds; ++i) {
    68         int n = i * foldSize;
    69         int s = n + 2 * foldSize > size ? foldSize + size % foldSize : foldSize;
    70         yield return trainingIndices.Skip(n).Take(s);
    71       }
    72     }
    73 
    74     public static void CrossValidate(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, RFParameter parameter, int seed, out double avgTestMSE) {
    75       CrossValidate(problemData, folds, (int)Math.Round(parameter.n), parameter.m, parameter.r, seed, out avgTestMSE);
    76     }
    77 
    78     public static void CrossValidate(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, int nTrees, double m, double r, int seed, out double avgTestMSE) {
    79       avgTestMSE = 0;
    80 
     61      int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
     62      int start = 0, end = f;
     63      for (int i = 0; i < numberOfFolds; ++i) {
     64        if (r > 0) { ++end; --r; }
     65        yield return problemData.TrainingIndices.Skip(start).Take(end - start);
     66        start = end;
     67        end += f;
     68      }
     69    }
     70
     71    private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
     72      var folds = GenerateFolds(problemData, numberOfFolds).ToList();
     73      var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds];
     74
     75      for (int i = 0; i < numberOfFolds; ++i) {
     76        int p = i; // avoid "access to modified closure" warning
     77        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
     78        var testRows = folds[i];
     79        partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows);
     80      }
     81      return partitions;
     82    }
     83
     84    public static void CrossValidate(IDataAnalysisProblemData problemData, int numberOfFolds, RFParameter parameters, int seed, out double error) {
     85      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
     86      CrossValidate(problemData, partitions, parameters, seed, out error);
     87    }
     88
     89    // user should call the more specific CrossValidate methods
     90    public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double error) {
     91      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out error);
     92    }
     93
     94    public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double error) {
     95      var regressionProblemData = problemData as IRegressionProblemData;
     96      var classificationProblemData = problemData as IClassificationProblemData;
     97      if (regressionProblemData != null)
     98        CrossValidate(regressionProblemData, partitions, nTrees, m, r, seed, out error);
     99      else if (classificationProblemData != null)
     100        CrossValidate(classificationProblemData, partitions, nTrees, m, r, seed, out error);
     101      else throw new ArgumentException("Problem data is neither regression or classification problem data.");
     102    }
     103
     104    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
     105      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
     106    }
     107
     108    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {
     109      CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);
     110    }
     111
     112    private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
     113      avgTestMse = 0;
    81114      var ds = problemData.Dataset;
    82       var targetVariable = problemData.TargetVariable;
    83 
    84       var partitions = folds.ToList();
    85 
    86       for (int i = 0; i < partitions.Count; ++i) {
     115      var targetVariable = GetTargetVariableName(problemData);
     116      foreach (var tuple in partitions) {
    87117        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
    88         var test = partitions[i];
    89         var training = new List<int>();
    90         for (int j = 0; j < i; ++j)
    91           training.AddRange(partitions[j]);
    92 
    93         for (int j = i + 1; j < partitions.Count; ++j)
    94           training.AddRange(partitions[j]);
    95 
    96         var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, m, r, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError, training);
    97         var estimatedValues = model.GetEstimatedValues(ds, test);
    98         var outputValues = ds.GetDoubleValues(targetVariable, test);
    99 
     118        var trainingRandomForestPartition = tuple.Item1;
     119        var testRandomForestPartition = tuple.Item2;
     120        var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition);
     121        var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
     122        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
    100123        OnlineCalculatorError calculatorError;
    101         double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, outputValues, out calculatorError);
     124        double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
    102125        if (calculatorError != OnlineCalculatorError.None)
    103126          mse = double.NaN;
    104         avgTestMSE += mse;
    105       }
    106 
    107       avgTestMSE /= partitions.Count;
    108     }
    109 
    110     public static RFParameter GridSearch(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     127        avgTestMse += mse;
     128      }
     129      avgTestMse /= partitions.Length;
     130    }
     131
     132    private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
     133      avgTestAccuracy = 0;
     134      var ds = problemData.Dataset;
     135      var targetVariable = GetTargetVariableName(problemData);
     136      foreach (var tuple in partitions) {
     137        double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
     138        var trainingRandomForestPartition = tuple.Item1;
     139        var testRandomForestPartition = tuple.Item2;
     140        var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition);
     141        var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
     142        var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
     143        OnlineCalculatorError calculatorError;
     144        double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
     145        if (calculatorError != OnlineCalculatorError.None)
     146          accuracy = double.NaN;
     147        avgTestAccuracy += accuracy;
     148      }
     149      avgTestAccuracy /= partitions.Length;
     150    }
     151
     152    public static RFParameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     153      var regressionProblemData = problemData as IRegressionProblemData;
     154      var classificationProblemData = problemData as IClassificationProblemData;
     155
     156      if (regressionProblemData != null)
     157        return GridSearch(regressionProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism);
     158      if (classificationProblemData != null)
     159        return GridSearch(classificationProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism);
     160
     161      throw new ArgumentException("Problem data is neither regression or classification problem data.");
     162    }
     163
     164    private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
    111165      DoubleValue mse = new DoubleValue(Double.MaxValue);
    112166      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
     
    115169      var pRanges = pNames.Select(x => parameterRanges[x]);
    116170      var setters = pNames.Select(GenerateSetter).ToList();
    117 
     171      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
    118172      var crossProduct = pRanges.CartesianProduct();
    119173
     
    126180          s(parameters, list[i]);
    127181        }
    128         CrossValidate(problemData, folds, parameters, seed, out testMSE);
     182        CrossValidate(problemData, partitions, parameters, seed, out testMSE);
    129183        if (testMSE < mse.Value) {
    130           lock (mse) {
    131             mse.Value = testMSE;
    132           }
    133           lock (bestParameter) {
    134             bestParameter = (RFParameter)parameters.Clone();
    135           }
     184          lock (mse) { mse.Value = testMSE; }
     185          lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); }
    136186        }
    137187      });
    138188      return bestParameter;
    139189    }
     190
     191    private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
     192      DoubleValue accuracy = new DoubleValue(0);
     193      RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults
     194
     195      var pNames = parameterRanges.Keys.ToList();
     196      var pRanges = pNames.Select(x => parameterRanges[x]);
     197      var setters = pNames.Select(GenerateSetter).ToList();
     198      var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
     199      var crossProduct = pRanges.CartesianProduct();
     200
     201      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
     202        var list = nuple.ToList();
     203        double testAccuracy;
     204        var parameters = new RFParameter();
     205        for (int i = 0; i < pNames.Count; ++i) {
     206          var s = setters[i];
     207          s(parameters, list[i]);
     208        }
     209        CrossValidate(problemData, partitions, parameters, seed, out testAccuracy);
     210        if (testAccuracy > accuracy.Value) {
     211          lock (accuracy) { accuracy.Value = testAccuracy; }
     212          lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); }
     213        }
     214      });
     215      return bestParameter;
     216    }
     217
     218    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
     219      var regressionProblemData = problemData as IRegressionProblemData;
     220      var classificationProblemData = problemData as IClassificationProblemData;
     221
     222      if (regressionProblemData != null)
     223        return regressionProblemData.TargetVariable;
     224      if (classificationProblemData != null)
     225        return classificationProblemData.TargetVariable;
     226
     227      throw new ArgumentException("Problem data is neither regression or classification problem data.");
     228    }
    140229  }
    141230}
Note: See TracChangeset for help on using the changeset viewer.