#region License Information
/* HeuristicLab
* Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Problems.DataAnalysis;
using HeuristicLab.Random;
namespace HeuristicLab.Algorithms.DataAnalysis {
public class RFParameter : ICloneable {
public double n; // number of trees
public double m;
public double r;
public object Clone() { return new RFParameter { n = this.n, m = this.m, r = this.r }; }
}
public static class RandomForestUtil {
private static void CrossValidate(IRegressionProblemData problemData, Tuple, IEnumerable>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) {
avgTestMse = 0;
var ds = problemData.Dataset;
var targetVariable = GetTargetVariableName(problemData);
foreach (var tuple in partitions) {
double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
var trainingRandomForestPartition = tuple.Item1;
var testRandomForestPartition = tuple.Item2;
var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition);
var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
OnlineCalculatorError calculatorError;
double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
if (calculatorError != OnlineCalculatorError.None)
mse = double.NaN;
avgTestMse += mse;
}
avgTestMse /= partitions.Length;
}
private static void CrossValidate(IClassificationProblemData problemData, Tuple, IEnumerable>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) {
avgTestAccuracy = 0;
var ds = problemData.Dataset;
var targetVariable = GetTargetVariableName(problemData);
foreach (var tuple in partitions) {
double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError;
var trainingRandomForestPartition = tuple.Item1;
var testRandomForestPartition = tuple.Item2;
var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition);
var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition);
OnlineCalculatorError calculatorError;
double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError);
if (calculatorError != OnlineCalculatorError.None)
accuracy = double.NaN;
avgTestAccuracy += accuracy;
}
avgTestAccuracy /= partitions.Length;
}
// grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased
public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
var crossProduct = parameterRanges.Values.CartesianProduct();
double bestOutOfBagRmsError = double.MaxValue;
RFParameter bestParameters = new RFParameter();
Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
var parameterValues = parameterCombination.ToList();
double testMSE;
var parameters = new RFParameter();
for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
if (bestOutOfBagRmsError > outOfBagRmsError) {
lock (bestParameters) {
bestOutOfBagRmsError = outOfBagRmsError;
bestParameters = (RFParameter)parameters.Clone();
}
}
});
return bestParameters;
}
public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
var crossProduct = parameterRanges.Values.CartesianProduct();
double bestOutOfBagRmsError = double.MaxValue;
RFParameter bestParameters = new RFParameter();
Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
var parameterValues = parameterCombination.ToList();
var parameters = new RFParameter();
for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); }
double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError;
var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed,
out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
if (bestOutOfBagRmsError > outOfBagRmsError) {
lock (bestParameters) {
bestOutOfBagRmsError = outOfBagRmsError;
bestParameters = (RFParameter)parameters.Clone();
}
}
});
return bestParameters;
}
public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
DoubleValue mse = new DoubleValue(Double.MaxValue);
RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);
var crossProduct = parameterRanges.Values.CartesianProduct();
Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
var parameterValues = parameterCombination.ToList();
double testMSE;
var parameters = new RFParameter();
for (int i = 0; i < setters.Count; ++i) {
setters[i](parameters, parameterValues[i]);
}
CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE);
if (testMSE < mse.Value) {
lock (mse) {
mse.Value = testMSE;
bestParameter = (RFParameter)parameters.Clone();
}
}
});
return bestParameter;
}
public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {
DoubleValue accuracy = new DoubleValue(0);
RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 };
var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
var crossProduct = parameterRanges.Values.CartesianProduct();
var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds);
Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
var parameterValues = parameterCombination.ToList();
double testAccuracy;
var parameters = new RFParameter();
for (int i = 0; i < setters.Count; ++i) {
setters[i](parameters, parameterValues[i]);
}
CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy);
if (testAccuracy > accuracy.Value) {
lock (accuracy) {
accuracy.Value = testAccuracy;
bestParameter = (RFParameter)parameters.Clone();
}
}
});
return bestParameter;
}
private static Tuple, IEnumerable>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
var partitions = new Tuple, IEnumerable>[numberOfFolds];
for (int i = 0; i < numberOfFolds; ++i) {
int p = i; // avoid "access to modified closure" warning
var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty());
var testRows = folds[i];
partitions[i] = new Tuple, IEnumerable>(trainingRows, testRows);
}
return partitions;
}
public static IEnumerable> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) {
var random = new MersenneTwister((uint)Environment.TickCount);
if (problemData is IRegressionProblemData) {
var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
}
if (problemData is IClassificationProblemData) {
// when shuffle is enabled do stratified folds generation, some folds may have zero elements
// otherwise, generate folds normally
return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
}
throw new ArgumentException("Problem data is neither regression or classification problem data.");
}
///
/// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
/// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
/// the corresponding parts from each class label.
///
/// The classification problem data.
/// The number of folds in which to split the data.
/// The random generator used to shuffle the folds.
/// An enumerable sequece of folds, where a fold is represented by a sequence of row indices.
private static IEnumerable> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
IEnumerable>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
while (enumerators.All(e => e.MoveNext())) {
yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
}
}
private static IEnumerable> GenerateFolds(IEnumerable values, int valuesCount, int numberOfFolds) {
// if number of folds is greater than the number of values, some empty folds will be returned
if (valuesCount < numberOfFolds) {
for (int i = 0; i < numberOfFolds; ++i)
yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty();
} else {
int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
int start = 0, end = f;
for (int i = 0; i < numberOfFolds; ++i) {
if (r > 0) {
++end;
--r;
}
yield return values.Skip(start).Take(end - start);
start = end;
end += f;
}
}
}
private static Action GenerateSetter(string field) {
var targetExp = Expression.Parameter(typeof(RFParameter));
var valueExp = Expression.Parameter(typeof(double));
var fieldExp = Expression.Field(targetExp, field);
var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
var setter = Expression.Lambda>(assignExp, targetExp, valueExp).Compile();
return setter;
}
private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
var regressionProblemData = problemData as IRegressionProblemData;
var classificationProblemData = problemData as IClassificationProblemData;
if (regressionProblemData != null)
return regressionProblemData.TargetVariable;
if (classificationProblemData != null)
return classificationProblemData.TargetVariable;
throw new ArgumentException("Problem data is neither regression or classification problem data.");
}
}
}