using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Algorithms.DataAnalysis; using HeuristicLab.Common; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Scripting; public class RFClassificationCrossValidationScript : HeuristicLab.Scripting.CSharpScriptBase { /* Maximum degree of parallelism (specifies whether or not the grid search should be parallelized) */ const int maximumDegreeOfParallelism = 4; /* Number of crossvalidation folds: */ const int numberOfFolds = 3; /* Specify whether the folds should be shuffled before doing crossvalidation */ const bool shuffleFolds = true; /* The tunable Random Forest parameters: - "N" (number of trees). In the random forests literature, this is referred to as the ntree parameter. Larger number of trees produce more stable models and covariate importance estimates, but require more memory and a longer run time. For small datasets, 50 trees may be sufficient. For larger datasets, 500 or more may be required. Please consult the random forests literature for extensive discussion of this parameter (e.g. Cutler et al., 2007; Strobl et al., 2007; Strobl et al., 2008). - "R" The ratio of the training set that will be used in the construction of individual trees (0> randomForestParameterRanges = new Dictionary> { { "N", SequenceGenerator.GenerateSteps(5m, 10, 1).Select(x => Math.Pow(2,(double)x)) }, { "R", SequenceGenerator.GenerateSteps(0.05m, 0.66m, 0.05m).Select(x => (double)x) }, { "M", SequenceGenerator.GenerateSteps(0.1m, 1, 0.1m).Select(x => (double)x) } }; private static RandomForestClassificationSolution GridSearchWithCrossvalidation(IClassificationProblemData problemData, int numberOfCrossvalidationFolds, out RFParameter bestParameters, int seed = 3141519) { double rmsError, outOfBagRmsError, relClassificationError, outOfBagRelClassificationError; bestParameters = RandomForestUtil.GridSearch(problemData, numberOfFolds, shuffleFolds, randomForestParameterRanges, seed, maximumDegreeOfParallelism); var model = RandomForestClassification.CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, bestParameters.N, bestParameters.R, bestParameters.M, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); return (RandomForestClassificationSolution)model.CreateClassificationSolution(problemData); } private static RandomForestClassificationSolution GridSearch(IClassificationProblemData problemData, out RFParameter bestParameters, int seed = 3141519) { double rmsError, outOfBagRmsError, relClassificationError, outOfBagRelClassificationError; bestParameters = RandomForestUtil.GridSearch(problemData, randomForestParameterRanges, seed, maximumDegreeOfParallelism); var model = RandomForestClassification.CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, bestParameters.N, bestParameters.R, bestParameters.M, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); return (RandomForestClassificationSolution)model.CreateClassificationSolution(problemData); } public override void Main() { var variables = (Variables)vars; var item = variables.SingleOrDefault(x => x.Value is IClassificationProblem || x.Value is IClassificationProblemData); if (item.Equals(default(KeyValuePair))) throw new ArgumentException("Could not find a suitable problem or problem data."); string name = item.Key; IClassificationProblemData problemData; if (item.Value is IClassificationProblem) problemData = ((IClassificationProblem)item.Value).ProblemData; else problemData = (IClassificationProblemData)item.Value; var bestParameters = new RFParameter(); var bestSolution = GridSearch(problemData, out bestParameters); vars["bestSolution"] = bestSolution; vars["bestParameters"] = bestParameters; Console.WriteLine("Accuracy (training): " + bestSolution.TrainingAccuracy + ", Accuracy (test): " + bestSolution.TestAccuracy); Console.WriteLine("Model parameters: n = {0}, r = {1}, m = {2}", bestParameters.N, bestParameters.R, bestParameters.M); } }