[11514] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 |
|
---|
| 5 | using HeuristicLab.Algorithms.DataAnalysis;
|
---|
[12292] | 6 | using HeuristicLab.Common;
|
---|
[11514] | 7 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 8 | using HeuristicLab.Random;
|
---|
| 9 | using HeuristicLab.Scripting;
|
---|
| 10 |
|
---|
| 11 | public class RFRegressionCrossValidationScript : HeuristicLab.Scripting.CSharpScriptBase {
|
---|
| 12 | /* Maximum degree of parallelism (specifies whether or not the grid search should be parallelized) */
|
---|
| 13 | const int maximumDegreeOfParallelism = 4;
|
---|
| 14 | /* Number of crossvalidation folds: */
|
---|
| 15 | const int numberOfFolds = 3;
|
---|
| 16 |
|
---|
| 17 | /* The tunable Random Forest parameters:
|
---|
| 18 | - "n" (number of trees). In the random forests literature, this is referred to as the ntree parameter.
|
---|
| 19 | Larger number of trees produce more stable models and covariate importance estimates, but require more memory and a longer run time.
|
---|
| 20 | For small datasets, 50 trees may be sufficient. For larger datasets, 500 or more may be required. Please consult the random forests
|
---|
| 21 | literature for extensive discussion of this parameter (e.g. Cutler et al., 2007; Strobl et al., 2007; Strobl et al., 2008).
|
---|
| 22 |
|
---|
| 23 | - "r" The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on
|
---|
| 24 | the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve
|
---|
| 25 | good generalization error.
|
---|
| 26 |
|
---|
| 27 | - "m" The ratio of features that will be used in the construction of individual trees (0<m<=1)
|
---|
| 28 | */
|
---|
| 29 | static Dictionary<string, IEnumerable<double>> randomForestParameterRanges = new Dictionary<string, IEnumerable<double>> {
|
---|
[12292] | 30 | { "N", SequenceGenerator.GenerateSteps(5m, 10, 1).Select(x => Math.Pow(2,(double)x)) },
|
---|
| 31 | { "R", SequenceGenerator.GenerateSteps(0.05m, 0.66m, 0.05m).Select(x => (double)x) },
|
---|
| 32 | { "M", SequenceGenerator.GenerateSteps(0.1m, 1, 0.1m).Select(x => (double)x) }
|
---|
[11514] | 33 | };
|
---|
| 34 |
|
---|
| 35 | private static RandomForestRegressionSolution GridSearchWithCrossvalidation(IRegressionProblemData problemData, out RFParameter bestParameters, int seed = 3141519) {
|
---|
| 36 | double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
|
---|
[17943] | 37 | bestParameters = RandomForestUtil.GridSearch(problemData, numberOfFolds, randomForestParameterRanges, seed, maximumDegreeOfParallelism);
|
---|
| 38 | var model = RandomForestRegression.CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, bestParameters.N, bestParameters.R, bestParameters.M, seed,
|
---|
| 39 | out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
|
---|
[11514] | 40 | return (RandomForestRegressionSolution)model.CreateRegressionSolution(problemData);
|
---|
| 41 | }
|
---|
| 42 |
|
---|
| 43 | private static RandomForestRegressionSolution GridSearch(IRegressionProblemData problemData, out RFParameter bestParameters, int seed = 3141519) {
|
---|
| 44 | double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
|
---|
| 45 | var random = new MersenneTwister();
|
---|
| 46 | bestParameters = RandomForestUtil.GridSearch(problemData, randomForestParameterRanges, seed, maximumDegreeOfParallelism);
|
---|
[17943] | 47 | var model = RandomForestRegression.CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, bestParameters.N, bestParameters.R, bestParameters.M, seed,
|
---|
| 48 | out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
|
---|
[11514] | 49 | return (RandomForestRegressionSolution)model.CreateRegressionSolution(problemData);
|
---|
| 50 | }
|
---|
| 51 |
|
---|
| 52 | public override void Main() {
|
---|
| 53 | var variables = (Variables)vars;
|
---|
| 54 | var item = variables.SingleOrDefault(x => x.Value is IRegressionProblem || x.Value is IRegressionProblemData);
|
---|
| 55 | if (item.Equals(default(KeyValuePair<string, object>)))
|
---|
| 56 | throw new ArgumentException("Could not find a suitable problem or problem data.");
|
---|
| 57 |
|
---|
| 58 | string name = item.Key;
|
---|
| 59 | IRegressionProblemData problemData;
|
---|
| 60 | if (item.Value is IRegressionProblem)
|
---|
| 61 | problemData = ((IRegressionProblem)item.Value).ProblemData;
|
---|
| 62 | else
|
---|
| 63 | problemData = (IRegressionProblemData)item.Value;
|
---|
| 64 |
|
---|
| 65 | var bestParameters = new RFParameter();
|
---|
| 66 | var bestSolution = GridSearch(problemData, out bestParameters);
|
---|
| 67 | vars["bestSolution"] = bestSolution;
|
---|
| 68 | vars["bestParameters"] = bestParameters;
|
---|
| 69 |
|
---|
| 70 | Console.WriteLine("R2 (training): " + bestSolution.TrainingRSquared + ", R2 (test): " + bestSolution.TestRSquared);
|
---|
| 71 | Console.WriteLine("Model parameters: n = {0}, r = {1}, m = {2}", bestParameters.N, bestParameters.R, bestParameters.M);
|
---|
| 72 | }
|
---|
| 73 | }
|
---|