1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 |
5 | using HeuristicLab.Algorithms.DataAnalysis;
6 | using HeuristicLab.Problems.DataAnalysis;
7 | using HeuristicLab.Problems.Instances.DataAnalysis;
8 | using HeuristicLab.Scripting;
9 |
10 | public class RFClassificationCrossValidationScript : HeuristicLab.Scripting.CSharpScriptBase {
11 | /* Maximum degree of parallelism (specifies whether or not the grid search should be parallelized) */
12 | const int maximumDegreeOfParallelism = 4;
13 | /* Number of crossvalidation folds: */
14 | const int numberOfFolds = 3;
15 | /* Specify whether the folds should be shuffled before doing crossvalidation */
16 | const bool shuffleFolds = true;
17 |
18 | /* The tunable Random Forest parameters:
19 | - "N" (number of trees). In the random forests literature, this is referred to as the ntree parameter.
20 | Larger number of trees produce more stable models and covariate importance estimates, but require more memory and a longer run time.
21 | For small datasets, 50 trees may be sufficient. For larger datasets, 500 or more may be required. Please consult the random forests
22 | literature for extensive discussion of this parameter (e.g. Cutler et al., 2007; Strobl et al., 2007; Strobl et al., 2008).
23 |
24 | - "R" The ratio of the training set that will be used in the construction of individual trees (0<r<=1). Should be adjusted depending on
25 | the noise level in the dataset in the range from 0.66 (low noise) to 0.05 (high noise). This parameter should be adjusted to achieve
26 | good generalization error.
27 |
28 | - "M" The ratio of features that will be used in the construction of individual trees (0<m<=1)
29 | */
30 | static Dictionary<string, IEnumerable<double>> randomForestParameterRanges = new Dictionary<string, IEnumerable<double>> {
31 | { "N", ValueGenerator.GenerateSteps(5m, 10, 1).Select(x => Math.Pow(2,(double)x)) },
32 | { "R", ValueGenerator.GenerateSteps(0.05m, 0.66m, 0.05m).Select(x => (double)x) },
33 | { "M", ValueGenerator.GenerateSteps(0.1m, 1, 0.1m).Select(x => (double)x) }
34 | };
35 |
36 | private static RandomForestClassificationSolution GridSearchWithCrossvalidation(IClassificationProblemData problemData, int numberOfCrossvalidationFolds, out RFParameter bestParameters,
37 | int seed = 3141519) {
38 | double rmsError, outOfBagRmsError, relClassificationError, outOfBagRelClassificationError;
39 | bestParameters = RandomForestUtil.GridSearch(problemData, numberOfFolds, shuffleFolds, randomForestParameterRanges, seed, maximumDegreeOfParallelism);
40 | var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, bestParameters.N, bestParameters.R, bestParameters.M, seed,
41 | out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
42 | return (RandomForestClassificationSolution)model.CreateClassificationSolution(problemData);
43 | }
44 |
45 | private static RandomForestClassificationSolution GridSearch(IClassificationProblemData problemData, out RFParameter bestParameters, int seed = 3141519) {
46 | double rmsError, outOfBagRmsError, relClassificationError, outOfBagRelClassificationError;
47 | bestParameters = RandomForestUtil.GridSearch(problemData, randomForestParameterRanges, seed, maximumDegreeOfParallelism);
48 | var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, bestParameters.N, bestParameters.R, bestParameters.M, seed,
49 | out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
50 | return (RandomForestClassificationSolution)model.CreateClassificationSolution(problemData);
51 | }
52 |
53 | public override void Main() {
54 | var variables = (Variables)vars;
55 | var item = variables.SingleOrDefault(x => x.Value is IClassificationProblem || x.Value is IClassificationProblemData);
56 | if (item.Equals(default(KeyValuePair<string, object>)))
57 | throw new ArgumentException("Could not find a suitable problem or problem data.");
58 |
59 | string name = item.Key;
60 | IClassificationProblemData problemData;
61 | if (item.Value is IClassificationProblem)
62 | problemData = ((IClassificationProblem)item.Value).ProblemData;
63 | else
64 | problemData = (IClassificationProblemData)item.Value;
65 |
66 | var bestParameters = new RFParameter();
67 | var bestSolution = GridSearch(problemData, out bestParameters);
68 | vars["bestSolution"] = bestSolution;
69 | vars["bestParameters"] = bestParameters;
70 |
71 | Console.WriteLine("Accuracy (training): " + bestSolution.TrainingAccuracy + ", Accuracy (test): " + bestSolution.TestAccuracy);
72 | Console.WriteLine("Model parameters: n = {0}, r = {1}, m = {2}", bestParameters.N, bestParameters.R, bestParameters.M);
73 | }
74 | }