#region License Information /* HeuristicLab * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Threading.Tasks; using HeuristicLab.Common; using HeuristicLab.Data; using HeuristicLab.Problems.DataAnalysis; namespace HeuristicLab.Algorithms.DataAnalysis { public class RFParameter : ICloneable { public double n; // number of trees public double m; public double r; public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; } } public static class RandomForestUtil { private static Action GenerateSetter(string field) { var targetExp = Expression.Parameter(typeof(RFParameter)); var valueExp = Expression.Parameter(typeof(double)); var fieldExp = Expression.Field(targetExp, field); var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); var setter = Expression.Lambda>(assignExp, targetExp, valueExp).Compile(); return setter; } /// /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation) /// /// This method is aimed to be lightweight and as such does not clone the dataset. /// The problem data /// The number of folds to generate /// A sequence of folds representing each a sequence of row numbers public static IEnumerable> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) { int size = problemData.TrainingPartition.Size; int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder int start = 0, end = f; for (int i = 0; i < numberOfFolds; ++i) { if (r > 0) { ++end; --r; } yield return problemData.TrainingIndices.Skip(start).Take(end - start); start = end; end += f; } } private static Tuple, IEnumerable>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) { var folds = GenerateFolds(problemData, numberOfFolds).ToList(); var partitions = new Tuple, IEnumerable>[numberOfFolds]; for (int i = 0; i < numberOfFolds; ++i) { int p = i; // avoid "access to modified closure" warning var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty()); var testRows = folds[i]; partitions[i] = new Tuple, IEnumerable>(trainingRows, testRows); } return partitions; } public static void CrossValidate(IDataAnalysisProblemData problemData, int numberOfFolds, RFParameter parameters, int seed, out double error) { var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); CrossValidate(problemData, partitions, parameters, seed, out error); } // user should call the more specific CrossValidate methods public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple, IEnumerable>[] partitions, RFParameter parameters, int seed, out double error) { CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out error); } public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple, IEnumerable>[] partitions, int nTrees, double r, double m, int seed, out double error) { var regressionProblemData = problemData as IRegressionProblemData; var classificationProblemData = problemData as IClassificationProblemData; if (regressionProblemData != null) CrossValidate(regressionProblemData, partitions, nTrees, m, r, seed, out error); else if (classificationProblemData != null) CrossValidate(classificationProblemData, partitions, nTrees, m, r, seed, out error); else throw new ArgumentException("Problem data is neither regression or classification problem data."); } private static void CrossValidate(IRegressionProblemData problemData, Tuple, IEnumerable>[] partitions, RFParameter parameters, int seed, out double avgTestMse) { CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse); } private static void CrossValidate(IClassificationProblemData problemData, Tuple, IEnumerable>[] partitions, RFParameter parameters, int seed, out double avgTestMse) { CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse); } private static void CrossValidate(IRegressionProblemData problemData, Tuple, IEnumerable>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { avgTestMse = 0; var ds = problemData.Dataset; var targetVariable = GetTargetVariableName(problemData); foreach (var tuple in partitions) { double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; var trainingRandomForestPartition = tuple.Item1; var testRandomForestPartition = tuple.Item2; var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition); var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition); var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); OnlineCalculatorError calculatorError; double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError); if (calculatorError != OnlineCalculatorError.None) mse = double.NaN; avgTestMse += mse; } avgTestMse /= partitions.Length; } private static void CrossValidate(IClassificationProblemData problemData, Tuple, IEnumerable>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) { avgTestAccuracy = 0; var ds = problemData.Dataset; var targetVariable = GetTargetVariableName(problemData); foreach (var tuple in partitions) { double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; var trainingRandomForestPartition = tuple.Item1; var testRandomForestPartition = tuple.Item2; var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition); var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition); var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); OnlineCalculatorError calculatorError; double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError); if (calculatorError != OnlineCalculatorError.None) accuracy = double.NaN; avgTestAccuracy += accuracy; } avgTestAccuracy /= partitions.Length; } public static RFParameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { var regressionProblemData = problemData as IRegressionProblemData; var classificationProblemData = problemData as IClassificationProblemData; if (regressionProblemData != null) return GridSearch(regressionProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism); if (classificationProblemData != null) return GridSearch(classificationProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism); throw new ArgumentException("Problem data is neither regression or classification problem data."); } private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { DoubleValue mse = new DoubleValue(Double.MaxValue); RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults var pNames = parameterRanges.Keys.ToList(); var pRanges = pNames.Select(x => parameterRanges[x]); var setters = pNames.Select(GenerateSetter).ToList(); var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); var crossProduct = pRanges.CartesianProduct(); Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => { var list = nuple.ToList(); double testMSE; var parameters = new RFParameter(); for (int i = 0; i < pNames.Count; ++i) { var s = setters[i]; s(parameters, list[i]); } CrossValidate(problemData, partitions, parameters, seed, out testMSE); if (testMSE < mse.Value) { lock (mse) { mse.Value = testMSE; } lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); } } }); return bestParameter; } private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { DoubleValue accuracy = new DoubleValue(0); RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults var pNames = parameterRanges.Keys.ToList(); var pRanges = pNames.Select(x => parameterRanges[x]); var setters = pNames.Select(GenerateSetter).ToList(); var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); var crossProduct = pRanges.CartesianProduct(); Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => { var list = nuple.ToList(); double testAccuracy; var parameters = new RFParameter(); for (int i = 0; i < pNames.Count; ++i) { var s = setters[i]; s(parameters, list[i]); } CrossValidate(problemData, partitions, parameters, seed, out testAccuracy); if (testAccuracy > accuracy.Value) { lock (accuracy) { accuracy.Value = testAccuracy; } lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); } } }); return bestParameter; } private static string GetTargetVariableName(IDataAnalysisProblemData problemData) { var regressionProblemData = problemData as IRegressionProblemData; var classificationProblemData = problemData as IClassificationProblemData; if (regressionProblemData != null) return regressionProblemData.TargetVariable; if (classificationProblemData != null) return classificationProblemData.TargetVariable; throw new ArgumentException("Problem data is neither regression or classification problem data."); } } }