#region License Information
/* HeuristicLab
* Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis;
namespace HeuristicLab.Algorithms.DataAnalysis {
///
/// Represents a random forest model for regression and classification
///
[StorableClass]
[Item("RandomForestModel", "Represents a random forest for regression and classification.")]
public sealed class RandomForestModel : NamedItem, IRandomForestModel {
// not persisted
private alglib.decisionforest randomForest;
private alglib.decisionforest RandomForest {
get {
// recalculate lazily
if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
return randomForest;
}
}
// instead of storing the data of the model itself
// we instead only store data necessary to recalculate the same model lazily on demand
[Storable]
private int seed;
[Storable]
private IDataAnalysisProblemData originalTrainingData;
[Storable]
private double[] classValues;
[Storable]
private int nTrees;
[Storable]
private double r;
[Storable]
private double m;
[StorableConstructor]
private RandomForestModel(bool deserializing)
: base(deserializing) {
// for backwards compatibility (loading old solutions)
randomForest = new alglib.decisionforest();
}
private RandomForestModel(RandomForestModel original, Cloner cloner)
: base(original, cloner) {
randomForest = new alglib.decisionforest();
randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
// we assume that the trees array (double[]) is immutable in alglib
randomForest.innerobj.trees = original.randomForest.innerobj.trees;
// allowedInputVariables is immutable so we don't need to clone
allowedInputVariables = original.allowedInputVariables;
// clone data which is necessary to rebuild the model
this.seed = original.seed;
this.originalTrainingData = cloner.Clone(original.originalTrainingData);
// classvalues is immutable so we don't need to clone
this.classValues = original.classValues;
this.nTrees = original.nTrees;
this.r = original.r;
this.m = original.m;
}
// random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
private RandomForestModel(alglib.decisionforest randomForest,
int seed, IDataAnalysisProblemData originalTrainingData,
int nTrees, double r, double m, double[] classValues = null)
: base() {
this.name = ItemName;
this.description = ItemDescription;
// the model itself
this.randomForest = randomForest;
// data which is necessary for recalculation of the model
this.seed = seed;
this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
this.classValues = classValues;
this.nTrees = nTrees;
this.r = r;
this.m = m;
}
public override IDeepCloneable Clone(Cloner cloner) {
return new RandomForestModel(this, cloner);
}
private void RecalculateModel() {
double rmsError, oobRmsError, relClassError, oobRelClassError;
var regressionProblemData = originalTrainingData as IRegressionProblemData;
var classificationProblemData = originalTrainingData as IClassificationProblemData;
if (regressionProblemData != null) {
var model = CreateRegressionModel(regressionProblemData,
nTrees, r, m, seed, out rmsError, out oobRmsError,
out relClassError, out oobRelClassError);
randomForest = model.randomForest;
} else if (classificationProblemData != null) {
var model = CreateClassificationModel(classificationProblemData,
nTrees, r, m, seed, out rmsError, out oobRmsError,
out relClassError, out oobRelClassError);
randomForest = model.randomForest;
}
}
public IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) {
double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
AssertInputMatrix(inputData);
int n = inputData.GetLength(0);
int columns = inputData.GetLength(1);
double[] x = new double[columns];
double[] y = new double[1];
for (int row = 0; row < n; row++) {
for (int column = 0; column < columns; column++) {
x[column] = inputData[row, column];
}
alglib.dfprocess(RandomForest, x, ref y);
yield return y[0];
}
}
public IEnumerable GetEstimatedClassValues(IDataset dataset, IEnumerable rows) {
double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
AssertInputMatrix(inputData);
int n = inputData.GetLength(0);
int columns = inputData.GetLength(1);
double[] x = new double[columns];
double[] y = new double[RandomForest.innerobj.nclasses];
for (int row = 0; row < n; row++) {
for (int column = 0; column < columns; column++) {
x[column] = inputData[row, column];
}
alglib.dfprocess(randomForest, x, ref y);
// find class for with the largest probability value
int maxProbClassIndex = 0;
double maxProb = y[0];
for (int i = 1; i < y.Length; i++) {
if (maxProb < y[i]) {
maxProb = y[i];
maxProbClassIndex = i;
}
}
yield return classValues[maxProbClassIndex];
}
}
public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);
}
IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
return CreateRegressionSolution(problemData);
}
public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);
}
IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
return CreateClassificationSolution(problemData);
}
public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError);
}
public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable trainingIndices, int nTrees, double r, double m, int seed,
out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
alglib.dfreport rep;
var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
rmsError = rep.rmserror;
avgRelError = rep.avgrelerror;
outOfBagAvgRelError = rep.oobavgrelerror;
outOfBagRmsError = rep.oobrmserror;
return new RandomForestModel(dForest, seed, problemData, nTrees, r, m);
}
public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
}
public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable trainingIndices, int nTrees, double r, double m, int seed,
out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices);
var classValues = problemData.ClassValues.ToArray();
int nClasses = classValues.Length;
// map original class values to values [0..nClasses-1]
var classIndices = new Dictionary();
for (int i = 0; i < nClasses; i++) {
classIndices[classValues[i]] = i;
}
int nRows = inputMatrix.GetLength(0);
int nColumns = inputMatrix.GetLength(1);
for (int row = 0; row < nRows; row++) {
inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
}
alglib.dfreport rep;
var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
rmsError = rep.rmserror;
outOfBagRmsError = rep.oobrmserror;
relClassificationError = rep.relclserror;
outOfBagRelClassificationError = rep.oobrelclserror;
return new RandomForestModel(dForest, seed, problemData, nTrees, r, m, classValues);
}
private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
AssertParameters(r, m);
AssertInputMatrix(inputMatrix);
int info = 0;
alglib.math.rndobject = new System.Random(seed);
var dForest = new alglib.decisionforest();
rep = new alglib.dfreport();
int nRows = inputMatrix.GetLength(0);
int nColumns = inputMatrix.GetLength(1);
int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
return dForest;
}
private static void AssertParameters(double r, double m) {
if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
}
private static void AssertInputMatrix(double[,] inputMatrix) {
if (inputMatrix.Cast().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))
throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
}
#region persistence for backwards compatibility
// when the originalTrainingData is null this means the model was loaded from an old file
// therefore, we cannot use the new persistence mechanism because the original data is not available anymore
// in such cases we still store the compete model
private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
private string[] allowedInputVariables;
[Storable(Name = "allowedInputVariables")]
private string[] AllowedInputVariables {
get {
if (IsCompatibilityLoaded) return allowedInputVariables;
else return originalTrainingData.AllowedInputVariables.ToArray();
}
set { allowedInputVariables = value; }
}
[Storable]
private int RandomForestBufSize {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
else return 0;
}
set {
randomForest.innerobj.bufsize = value;
}
}
[Storable]
private int RandomForestNClasses {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
else return 0;
}
set {
randomForest.innerobj.nclasses = value;
}
}
[Storable]
private int RandomForestNTrees {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
else return 0;
}
set {
randomForest.innerobj.ntrees = value;
}
}
[Storable]
private int RandomForestNVars {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
else return 0;
}
set {
randomForest.innerobj.nvars = value;
}
}
[Storable]
private double[] RandomForestTrees {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
else return new double[] { };
}
set {
randomForest.innerobj.trees = value;
}
}
#endregion
}
}