Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2434_crossvalidation/HeuristicLab.Tests/HeuristicLab.Scripting-3.3/Script Sources/GridSearchRFRegressionScriptSource.cs @ 16003

Last change on this file since 16003 was 12292, checked in by pfleck, 10 years ago

#2301 Removed the GenerateSteps from the ValueGenerator and put it into the new SequenceGenerator.
Adapted DataAnalysis-Instances and scripts (samples and unit tests).

File size: 4.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4
5using HeuristicLab.Algorithms.DataAnalysis;
6using HeuristicLab.Common;
7using HeuristicLab.Problems.DataAnalysis;
8using HeuristicLab.Random;
9using HeuristicLab.Scripting;
10
11public class RFRegressionCrossValidationScript : HeuristicLab.Scripting.CSharpScriptBase {
12  /* Maximum degree of parallelism (specifies whether or not the grid search should be parallelized) */
13  const int maximumDegreeOfParallelism = 4;
14  /* Number of crossvalidation folds: */
15  const int numberOfFolds = 3;
16  /* Specify whether the crossvalidation folds should be shuffled */
17  const bool shuffleFolds = true;
18
19  /* The tunable Random Forest parameters:
20     - "n" (number of trees). In the random forests literature, this is referred to as the ntree parameter.
21       Larger number of trees produce more stable models and covariate importance estimates, but require more memory and a longer run time.
22       For small datasets, 50 trees may be sufficient. For larger datasets, 500 or more may be required. Please consult the random forests
23       literature for extensive discussion of this parameter (e.g. Cutler et al., 2007; Strobl et al., 2007; Strobl et al., 2008).
24
25     - "r" The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on
26       the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve
27       good generalization error.
28
29     - "m" The ratio of features that will be used in the construction of individual trees (0<m<=1)
30  */
31  static Dictionary<string, IEnumerable<double>> randomForestParameterRanges = new Dictionary<string, IEnumerable<double>> {
32    { "N", SequenceGenerator.GenerateSteps(5m, 10, 1).Select(x => Math.Pow(2,(double)x)) },
33    { "R", SequenceGenerator.GenerateSteps(0.05m, 0.66m, 0.05m).Select(x => (double)x) },
34    { "M", SequenceGenerator.GenerateSteps(0.1m, 1, 0.1m).Select(x => (double)x) }
35  };
36
37  private static RandomForestRegressionSolution GridSearchWithCrossvalidation(IRegressionProblemData problemData, out RFParameter bestParameters, int seed = 3141519) {
38    double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
39    bestParameters = RandomForestUtil.GridSearch(problemData, numberOfFolds, shuffleFolds, randomForestParameterRanges, seed, maximumDegreeOfParallelism);
40    var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, bestParameters.N, bestParameters.R, bestParameters.M, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
41    return (RandomForestRegressionSolution)model.CreateRegressionSolution(problemData);
42  }
43
44  private static RandomForestRegressionSolution GridSearch(IRegressionProblemData problemData, out RFParameter bestParameters, int seed = 3141519) {
45    double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
46    var random = new MersenneTwister();
47    bestParameters = RandomForestUtil.GridSearch(problemData, randomForestParameterRanges, seed, maximumDegreeOfParallelism);
48    var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, bestParameters.N, bestParameters.R, bestParameters.M, seed,
49                                                        out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
50    return (RandomForestRegressionSolution)model.CreateRegressionSolution(problemData);
51  }
52
53  public override void Main() {
54    var variables = (Variables)vars;
55    var item = variables.SingleOrDefault(x => x.Value is IRegressionProblem || x.Value is IRegressionProblemData);
56    if (item.Equals(default(KeyValuePair<string, object>)))
57      throw new ArgumentException("Could not find a suitable problem or problem data.");
58
59    string name = item.Key;
60    IRegressionProblemData problemData;
61    if (item.Value is IRegressionProblem)
62      problemData = ((IRegressionProblem)item.Value).ProblemData;
63    else
64      problemData = (IRegressionProblemData)item.Value;
65
66    var bestParameters = new RFParameter();
67    var bestSolution = GridSearch(problemData, out bestParameters);
68    vars["bestSolution"] = bestSolution;
69    vars["bestParameters"] = bestParameters;
70
71    Console.WriteLine("R2 (training): " + bestSolution.TrainingRSquared + ", R2 (test): " + bestSolution.TestRSquared);
72    Console.WriteLine("Model parameters: n = {0}, r = {1}, m = {2}", bestParameters.N, bestParameters.R, bestParameters.M);
73  }
74}
Note: See TracBrowser for help on using the repository browser.