#region License Information /* HeuristicLab * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Threading.Tasks; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.DataAnalysis { public class RFParameter : ICloneable { public double n; // number of trees public double m; public double r; public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; } } public static class RandomForestUtil { private static void CrossValidate(IRegressionProblemData problemData, Tuple, IEnumerable>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { avgTestMse = 0; var ds = problemData.Dataset; var targetVariable = GetTargetVariableName(problemData); foreach (var tuple in partitions) { double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; var trainingRandomForestPartition = tuple.Item1; var testRandomForestPartition = tuple.Item2; var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition); var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); OnlineCalculatorError calculatorError; double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError); if (calculatorError != OnlineCalculatorError.None) mse = double.NaN; avgTestMse += mse; } avgTestMse /= partitions.Length; } private static void CrossValidate(IClassificationProblemData problemData, Tuple, IEnumerable>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) { avgTestAccuracy = 0; var ds = problemData.Dataset; var targetVariable = GetTargetVariableName(problemData); foreach (var tuple in partitions) { double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; var trainingRandomForestPartition = tuple.Item1; var testRandomForestPartition = tuple.Item2; var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition); var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); OnlineCalculatorError calculatorError; double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError); if (calculatorError != OnlineCalculatorError.None) accuracy = double.NaN; avgTestAccuracy += accuracy; } avgTestAccuracy /= partitions.Length; } // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); var crossProduct = parameterRanges.Values.CartesianProduct(); double bestOutOfBagRmsError = double.MaxValue; RFParameter bestParameters = new RFParameter(); Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { var parameterValues = parameterCombination.ToList(); double testMSE; var parameters = new RFParameter(); for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); if (bestOutOfBagRmsError > outOfBagRmsError) { lock (bestParameters) { bestOutOfBagRmsError = outOfBagRmsError; bestParameters = (RFParameter)parameters.Clone(); } } }); return bestParameters; } public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); var crossProduct = parameterRanges.Values.CartesianProduct(); double bestOutOfBagRmsError = double.MaxValue; RFParameter bestParameters = new RFParameter(); Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { var parameterValues = parameterCombination.ToList(); var parameters = new RFParameter(); for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed, out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); if (bestOutOfBagRmsError > outOfBagRmsError) { lock (bestParameters) { bestOutOfBagRmsError = outOfBagRmsError; bestParameters = (RFParameter)parameters.Clone(); } } }); return bestParameters; } public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { DoubleValue mse = new DoubleValue(Double.MaxValue); RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); var crossProduct = parameterRanges.Values.CartesianProduct(); Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { var parameterValues = parameterCombination.ToList(); double testMSE; var parameters = new RFParameter(); for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE); if (testMSE < mse.Value) { lock (mse) { mse.Value = testMSE; bestParameter = (RFParameter)parameters.Clone(); } } }); return bestParameter; } public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { DoubleValue accuracy = new DoubleValue(0); RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); var crossProduct = parameterRanges.Values.CartesianProduct(); var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds); Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { var parameterValues = parameterCombination.ToList(); double testAccuracy; var parameters = new RFParameter(); for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy); if (testAccuracy > accuracy.Value) { lock (accuracy) { accuracy.Value = testAccuracy; bestParameter = (RFParameter)parameters.Clone(); } } }); return bestParameter; } private static Tuple, IEnumerable>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) { var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList(); var partitions = new Tuple, IEnumerable>[numberOfFolds]; for (int i = 0; i < numberOfFolds; ++i) { int p = i; // avoid "access to modified closure" warning var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty()); var testRows = folds[i]; partitions[i] = new Tuple, IEnumerable>(trainingRows, testRows); } return partitions; } public static IEnumerable> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) { var random = new MersenneTwister((uint)Environment.TickCount); if (problemData is IRegressionProblemData) { var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices; return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds); } if (problemData is IClassificationProblemData) { // when shuffle is enabled do stratified folds generation, some folds may have zero elements // otherwise, generate folds normally return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds); } throw new ArgumentException("Problem data is neither regression or classification problem data."); } /// /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold. /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of /// the corresponding parts from each class label. /// /// The classification problem data. /// The number of folds in which to split the data. /// The random generator used to shuffle the folds. /// An enumerable sequece of folds, where a fold is represented by a sequence of row indices. private static IEnumerable> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) { var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList(); IEnumerable>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds)); var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList(); while (enumerators.All(e => e.MoveNext())) { yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList(); } } private static IEnumerable> GenerateFolds(IEnumerable values, int valuesCount, int numberOfFolds) { // if number of folds is greater than the number of values, some empty folds will be returned if (valuesCount < numberOfFolds) { for (int i = 0; i < numberOfFolds; ++i) yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty(); } else { int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder int start = 0, end = f; for (int i = 0; i < numberOfFolds; ++i) { if (r > 0) { ++end; --r; } yield return values.Skip(start).Take(end - start); start = end; end += f; } } } private static Action GenerateSetter(string field) { var targetExp = Expression.Parameter(typeof(RFParameter)); var valueExp = Expression.Parameter(typeof(double)); var fieldExp = Expression.Field(targetExp, field); var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); var setter = Expression.Lambda>(assignExp, targetExp, valueExp).Compile(); return setter; } private static string GetTargetVariableName(IDataAnalysisProblemData problemData) { var regressionProblemData = problemData as IRegressionProblemData; var classificationProblemData = problemData as IClassificationProblemData; if (regressionProblemData != null) return regressionProblemData.TargetVariable; if (classificationProblemData != null) return classificationProblemData.TargetVariable; throw new ArgumentException("Problem data is neither regression or classification problem data."); } } }