#region License Information
/* HeuristicLab
* Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
extern alias alglib_3_7;
using System;
using System.Collections.Generic;
using System.Linq;
using HEAL.Attic;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Problems.DataAnalysis;
using HeuristicLab.Problems.DataAnalysis.Symbolic;
namespace HeuristicLab.Algorithms.DataAnalysis {
///
/// Represents a random forest model for regression and classification
///
[Obsolete("This class only exists for backwards compatibility reasons for stored models with the XML Persistence. Use RFModelSurrogate or RFModelFull instead.")]
[StorableType("9AA4CCC2-CD75-4471-8DF6-949E5B783642")]
[Item("RandomForestModel", "Represents a random forest for regression and classification.")]
public sealed class RandomForestModel : ClassificationModel, IRandomForestModel {
// not persisted
private alglib_3_7.alglib.decisionforest randomForest;
private alglib_3_7.alglib.decisionforest RandomForest {
get {
// recalculate lazily
if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
return randomForest;
}
}
public override IEnumerable VariablesUsedForPrediction {
get { return originalTrainingData.AllowedInputVariables; }
}
public int NumberOfTrees {
get { return nTrees; }
}
// instead of storing the data of the model itself
// we instead only store data necessary to recalculate the same model lazily on demand
[Storable]
private int seed;
[Storable]
private IDataAnalysisProblemData originalTrainingData;
[Storable]
private double[] classValues;
[Storable]
private int nTrees;
[Storable]
private double r;
[Storable]
private double m;
[StorableConstructor]
private RandomForestModel(StorableConstructorFlag _) : base(_) {
// for backwards compatibility (loading old solutions)
randomForest = new alglib_3_7.alglib.decisionforest();
}
private RandomForestModel(RandomForestModel original, Cloner cloner)
: base(original, cloner) {
randomForest = new alglib_3_7.alglib.decisionforest();
randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
// we assume that the trees array (double[]) is immutable in alglib
randomForest.innerobj.trees = original.randomForest.innerobj.trees;
// allowedInputVariables is immutable so we don't need to clone
allowedInputVariables = original.allowedInputVariables;
// clone data which is necessary to rebuild the model
this.seed = original.seed;
this.originalTrainingData = cloner.Clone(original.originalTrainingData);
// classvalues is immutable so we don't need to clone
this.classValues = original.classValues;
this.nTrees = original.nTrees;
this.r = original.r;
this.m = original.m;
}
// random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
private RandomForestModel(string targetVariable, alglib_3_7.alglib.decisionforest randomForest,
int seed, IDataAnalysisProblemData originalTrainingData,
int nTrees, double r, double m, double[] classValues = null)
: base(targetVariable) {
this.name = ItemName;
this.description = ItemDescription;
// the model itself
this.randomForest = randomForest;
// data which is necessary for recalculation of the model
this.seed = seed;
this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
this.classValues = classValues;
this.nTrees = nTrees;
this.r = r;
this.m = m;
}
public override IDeepCloneable Clone(Cloner cloner) {
return new RandomForestModel(this, cloner);
}
private void RecalculateModel() {
double rmsError, oobRmsError, relClassError, oobRelClassError;
var regressionProblemData = originalTrainingData as IRegressionProblemData;
var classificationProblemData = originalTrainingData as IClassificationProblemData;
if (regressionProblemData != null) {
var model = CreateRegressionModel(regressionProblemData,
nTrees, r, m, seed, out rmsError, out oobRmsError,
out relClassError, out oobRelClassError);
randomForest = model.randomForest;
} else if (classificationProblemData != null) {
var model = CreateClassificationModel(classificationProblemData,
nTrees, r, m, seed, out rmsError, out oobRmsError,
out relClassError, out oobRelClassError);
randomForest = model.randomForest;
}
}
public IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) {
double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
RandomForestUtil.AssertInputMatrix(inputData);
int n = inputData.GetLength(0);
int columns = inputData.GetLength(1);
double[] x = new double[columns];
double[] y = new double[1];
for (int row = 0; row < n; row++) {
for (int column = 0; column < columns; column++) {
x[column] = inputData[row, column];
}
alglib_3_7.alglib.dfprocess(RandomForest, x, ref y);
yield return y[0];
}
}
public IEnumerable GetEstimatedVariances(IDataset dataset, IEnumerable rows) {
double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
RandomForestUtil.AssertInputMatrix(inputData);
int n = inputData.GetLength(0);
int columns = inputData.GetLength(1);
double[] x = new double[columns];
double[] ys = new double[this.RandomForest.innerobj.ntrees];
for (int row = 0; row < n; row++) {
for (int column = 0; column < columns; column++) {
x[column] = inputData[row, column];
}
alglib_3_7.alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys);
yield return ys.VariancePop();
}
}
public override IEnumerable GetEstimatedClassValues(IDataset dataset, IEnumerable rows) {
double[,] inputData = dataset.ToArray(AllowedInputVariables, rows);
RandomForestUtil.AssertInputMatrix(inputData);
int n = inputData.GetLength(0);
int columns = inputData.GetLength(1);
double[] x = new double[columns];
double[] y = new double[RandomForest.innerobj.nclasses];
for (int row = 0; row < n; row++) {
for (int column = 0; column < columns; column++) {
x[column] = inputData[row, column];
}
alglib_3_7.alglib.dfprocess(randomForest, x, ref y);
// find class for with the largest probability value
int maxProbClassIndex = 0;
double maxProb = y[0];
for (int i = 1; i < y.Length; i++) {
if (maxProb < y[i]) {
maxProb = y[i];
maxProbClassIndex = i;
}
}
yield return classValues[maxProbClassIndex];
}
}
public ISymbolicExpressionTree ExtractTree(int treeIdx) {
var rf = RandomForest;
// hoping that the internal representation of alglib is stable
// TREE FORMAT
// W[Offs] - size of sub-array (for the tree)
// node info:
// W[K+0] - variable number (-1 for leaf mode)
// W[K+1] - threshold (class/value for leaf node)
// W[K+2] - ">=" branch index (absent for leaf node)
// skip irrelevant trees
int offset = 0;
for (int i = 0; i < treeIdx - 1; i++) {
offset = offset + (int)Math.Round(rf.innerobj.trees[offset]);
}
var numSy = new Number();
var varCondSy = new VariableCondition() { IgnoreSlope = true };
var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, numSy, varCondSy);
var startNode = new StartSymbol().CreateTreeNode();
startNode.AddSubtree(node);
var root = new ProgramRootSymbol().CreateTreeNode();
root.AddSubtree(startNode);
return new SymbolicExpressionTree(root);
}
private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Number numSy, VariableCondition varCondSy) {
// alglib source for evaluation of one tree (dfprocessinternal)
// offs = 0
//
// Set pointer to the root
//
// k = offs + 1;
//
// //
// // Navigate through the tree
// //
// while (true) {
// if ((double)(df.trees[k]) == (double)(-1)) {
// if (df.nclasses == 1) {
// y[0] = y[0] + df.trees[k + 1];
// } else {
// idx = (int)Math.Round(df.trees[k + 1]);
// y[idx] = y[idx] + 1;
// }
// break;
// }
// if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) {
// k = k + innernodewidth;
// } else {
// k = offs + (int)Math.Round(df.trees[k + 2]);
// }
// }
if ((double)(trees[k]) == (double)(-1)) {
var numNode = (NumberTreeNode)numSy.CreateTreeNode();
numNode.Value = trees[k + 1];
return numNode;
} else {
var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])];
condNode.Threshold = trees[k + 1];
condNode.Slope = double.PositiveInfinity;
var left = CreateRegressionTreeRec(trees, offset, k + 3, numSy, varCondSy);
var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), numSy, varCondSy);
condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above)
condNode.AddSubtree(right);
return condNode;
}
}
public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData));
}
public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData));
}
public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
return RegressionModel.IsProblemDataCompatible(this, problemData, out errorMessage);
}
public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
if (problemData == null) throw new ArgumentNullException("problemData", "The provided problemData is null.");
var regressionProblemData = problemData as IRegressionProblemData;
if (regressionProblemData != null)
return IsProblemDataCompatible(regressionProblemData, out errorMessage);
var classificationProblemData = problemData as IClassificationProblemData;
if (classificationProblemData != null)
return IsProblemDataCompatible(classificationProblemData, out errorMessage);
throw new ArgumentException("The problem data is not compatible with this random forest. Instead a " + problemData.GetType().GetPrettyName() + " was provided.", "problemData");
}
public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
rmsError: out rmsError, outOfBagRmsError: out outOfBagRmsError, avgRelError: out avgRelError, outOfBagAvgRelError: out outOfBagAvgRelError);
}
public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable trainingIndices, int nTrees, double r, double m, int seed,
out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) {
var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
var dForest = RandomForestUtil.CreateRandomForestModelAlglib_3_7(seed, inputMatrix, nTrees, r, m, 1, out var rep);
rmsError = rep.rmserror;
outOfBagRmsError = rep.oobrmserror;
avgRelError = rep.avgrelerror;
outOfBagAvgRelError = rep.oobavgrelerror;
return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m);
}
public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed,
out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError);
}
public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable trainingIndices, int nTrees, double r, double m, int seed,
out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices);
var classValues = problemData.ClassValues.ToArray();
int nClasses = classValues.Length;
// map original class values to values [0..nClasses-1]
var classIndices = new Dictionary();
for (int i = 0; i < nClasses; i++) {
classIndices[classValues[i]] = i;
}
int nRows = inputMatrix.GetLength(0);
int nColumns = inputMatrix.GetLength(1);
for (int row = 0; row < nRows; row++) {
inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
}
var dForest = RandomForestUtil.CreateRandomForestModelAlglib_3_7(seed, inputMatrix, nTrees, r, m, nClasses, out var rep);
rmsError = rep.rmserror;
outOfBagRmsError = rep.oobrmserror;
relClassificationError = rep.relclserror;
outOfBagRelClassificationError = rep.oobrelclserror;
return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues);
}
#region persistence for backwards compatibility
// when the originalTrainingData is null this means the model was loaded from an old file
// therefore, we cannot use the new persistence mechanism because the original data is not available anymore
// in such cases we still store the compete model
private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
private string[] allowedInputVariables;
[Storable(Name = "allowedInputVariables")]
private string[] AllowedInputVariables {
get {
if (IsCompatibilityLoaded) return allowedInputVariables;
else return originalTrainingData.AllowedInputVariables.ToArray();
}
set { allowedInputVariables = value; }
}
[Storable]
private int RandomForestBufSize {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
else return 0;
}
set {
randomForest.innerobj.bufsize = value;
}
}
[Storable]
private int RandomForestNClasses {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
else return 0;
}
set {
randomForest.innerobj.nclasses = value;
}
}
[Storable]
private int RandomForestNTrees {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
else return 0;
}
set {
randomForest.innerobj.ntrees = value;
}
}
[Storable]
private int RandomForestNVars {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
else return 0;
}
set {
randomForest.innerobj.nvars = value;
}
}
[Storable]
private double[] RandomForestTrees {
get {
if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
else return new double[] { };
}
set {
randomForest.innerobj.trees = value;
}
}
#endregion
}
}