21 


22  using System;


23  using System.Collections.Generic;


24  using System.Linq;


25  using HeuristicLab.Common;


26  using HeuristicLab.Core;


27  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


28  using HeuristicLab.Problems.DataAnalysis;


29 


30  namespace HeuristicLab.Algorithms.DataAnalysis {


31  /// <summary>


32  /// Represents a random forest model for regression and classification


33  /// </summary>


34  [StorableClass]


35  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]


36  public sealed class RandomForestModel : NamedItem, IRandomForestModel {


37  // not persisted


38  private alglib.decisionforest randomForest;


39  private alglib.decisionforest RandomForest {


40  get {


41  // recalculate lazily


42  if (randomForest.innerobj.trees == null  randomForest.innerobj.trees.Length == 0) RecalculateModel();


43  return randomForest;


44  }


45  }


46 


47  // instead of storing the data of the model itself


48  // we instead only store data necessary to recalculate the same model lazily on demand


49  [Storable]


50  private int seed;


51  [Storable]


52  private IDataAnalysisProblemData originalTrainingData;


53  [Storable]


54  private double[] classValues;


55  [Storable]


56  private int nTrees;


57  [Storable]


58  private double r;


59  [Storable]


60  private double m;


61 


62 


63  [StorableConstructor]


64  private RandomForestModel(bool deserializing)


65  : base(deserializing) {


66  // for backwards compatibility (loading old solutions)


67  randomForest = new alglib.decisionforest();


68  }


69  private RandomForestModel(RandomForestModel original, Cloner cloner)


70  : base(original, cloner) {


71  randomForest = new alglib.decisionforest();


72  randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;


73  randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;


74  randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;


75  randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;


76  // we assume that the trees array (double[]) is immutable in alglib


77  randomForest.innerobj.trees = original.randomForest.innerobj.trees;


78 


79  // allowedInputVariables is immutable so we don't need to clone


80  allowedInputVariables = original.allowedInputVariables;


81 


82  // clone data which is necessary to rebuild the model


83  this.seed = original.seed;


84  this.originalTrainingData = cloner.Clone(original.originalTrainingData);


85  // classvalues is immutable so we don't need to clone


86  this.classValues = original.classValues;


87  this.nTrees = original.nTrees;


88  this.r = original.r;


89  this.m = original.m;


90  }


91 


92  // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel


93  private RandomForestModel(alglib.decisionforest randomForest,


94  int seed, IDataAnalysisProblemData originalTrainingData,


95  int nTrees, double r, double m, double[] classValues = null)


96  : base() {


97  this.name = ItemName;


98  this.description = ItemDescription;


99  // the model itself


100  this.randomForest = randomForest;


101  // data which is necessary for recalculation of the model


102  this.seed = seed;


103  this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();


104  this.classValues = classValues;


105  this.nTrees = nTrees;


106  this.r = r;


107  this.m = m;


108  }


109 


110  public override IDeepCloneable Clone(Cloner cloner) {


111  return new RandomForestModel(this, cloner);


112  }


113 


114  private void RecalculateModel() {


115  double rmsError, oobRmsError, relClassError, oobRelClassError;


116  var regressionProblemData = originalTrainingData as IRegressionProblemData;


117  var classificationProblemData = originalTrainingData as IClassificationProblemData;


118  if (regressionProblemData != null) {


119  var model = CreateRegressionModel(regressionProblemData,


120  nTrees, r, m, seed, out rmsError, out oobRmsError,


121  out relClassError, out oobRelClassError);


122  randomForest = model.randomForest;


123  } else if (classificationProblemData != null) {


124  var model = CreateClassificationModel(classificationProblemData,


125  nTrees, r, m, seed, out rmsError, out oobRmsError,


126  out relClassError, out oobRelClassError);


127  randomForest = model.randomForest;


128  }


129  }


130 


131  public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {


132  double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);


133  AssertInputMatrix(inputData);


134 


135  int n = inputData.GetLength(0);


136  int columns = inputData.GetLength(1);


137  double[] x = new double[columns];


138  double[] y = new double[1];


139 


140  for (int row = 0; row < n; row++) {


141  for (int column = 0; column < columns; column++) {


142  x[column] = inputData[row, column];


143  }


144  alglib.dfprocess(RandomForest, x, ref y);


145  yield return y[0];


146  }


147  }


148 


149  public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {


150  double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);


151  AssertInputMatrix(inputData);


152 


153  int n = inputData.GetLength(0);


154  int columns = inputData.GetLength(1);


155  double[] x = new double[columns];


156  double[] y = new double[RandomForest.innerobj.nclasses];


157 


158  for (int row = 0; row < n; row++) {


159  for (int column = 0; column < columns; column++) {


160  x[column] = inputData[row, column];


161  }


162  alglib.dfprocess(randomForest, x, ref y);


163  // find class for with the largest probability value


164  int maxProbClassIndex = 0;


165  double maxProb = y[0];


166  for (int i = 1; i < y.Length; i++) {


167  if (maxProb < y[i]) {


168  maxProb = y[i];


169  maxProbClassIndex = i;


170  }


171  }


172  yield return classValues[maxProbClassIndex];


173  }


174  }


175 


176  public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {


177  return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this);


178  }


179  IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {


180  return CreateRegressionSolution(problemData);


181  }


182  public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {


183  return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this);


184  }


185  IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {


186  return CreateClassificationSolution(problemData);


187  }


188 


189  public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,


190  out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {


191 


192  var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });


193  double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);


194 


195  alglib.dfreport rep;


196  var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);


197 


198  rmsError = rep.rmserror;


199  avgRelError = rep.avgrelerror;


200  outOfBagAvgRelError = rep.oobavgrelerror;


201  outOfBagRmsError = rep.oobrmserror;


202 


203  return new RandomForestModel(dForest,


204  seed, problemData,


205  nTrees, r, m);


206  }


207 


208  public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,


209  out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {


210 


211  var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });


212  double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);


213 


214  var classValues = problemData.ClassValues.ToArray();


215  int nClasses = classValues.Length;


216 


217  // map original class values to values [0..nClasses1]


218  var classIndices = new Dictionary<double, double>();


219  for (int i = 0; i < nClasses; i++) {


220  classIndices[classValues[i]] = i;


221  }


222 


223  int nRows = inputMatrix.GetLength(0);


224  int nColumns = inputMatrix.GetLength(1);


225  for (int row = 0; row < nRows; row++) {


226  inputMatrix[row, nColumns  1] = classIndices[inputMatrix[row, nColumns  1]];


227  }


228 


229  alglib.dfreport rep;


230  var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);


231 


232  rmsError = rep.rmserror;


233  outOfBagRmsError = rep.oobrmserror;


234  relClassificationError = rep.relclserror;


235  outOfBagRelClassificationError = rep.oobrelclserror;


236 


237  return new RandomForestModel(dForest,


238  seed, problemData,


239  nTrees, r, m, classValues);


240  }


241 


242  private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {


243  AssertParameters(r, m);


244  AssertInputMatrix(inputMatrix);


245 


246  int info = 0;


247  alglib.math.rndobject = new System.Random(seed);


248  var dForest = new alglib.decisionforest();


249  rep = new alglib.dfreport();


250  int nRows = inputMatrix.GetLength(0);


251  int nColumns = inputMatrix.GetLength(1);


252  int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);


253  int nFeatures = Math.Max((int)Math.Round(m * (nColumns  1)), 1);


254 


255  alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns  1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);


256  if (info != 1) throw new ArgumentException("Error in calculation of random forest model");


257  return dForest;


258  }


259 


260  private static void AssertParameters(double r, double m) {


261  if (r <= 0  r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");


262  if (m <= 0  m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");


263  }


264 


265  private static void AssertInputMatrix(double[,] inputMatrix) {


266  if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x)  double.IsInfinity(x)))


267  throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");


268  }


269 


270  #region persistence for backwards compatibility


271  // when the originalTrainingData is null this means the model was loaded from an old file


272  // therefore, we cannot use the new persistence mechanism because the original data is not available anymore


273  // in such cases we still store the compete model


274  private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }


275 


276  private string[] allowedInputVariables;


277  [Storable(Name = "allowedInputVariables")]


278  private string[] AllowedInputVariables {


279  get {


280  if (IsCompatibilityLoaded) return allowedInputVariables;


281  else return originalTrainingData.AllowedInputVariables.ToArray();


282  }


283  set { allowedInputVariables = value; }


284  }


285  [Storable]


286  private int RandomForestBufSize {


287  get {


288  if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;


289  else return 0;


290  }


291  set {


292  randomForest.innerobj.bufsize = value;


293  }


294  }


295  [Storable]


296  private int RandomForestNClasses {


297  get {


298  if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;


299  else return 0;


300  }


301  set {


302  randomForest.innerobj.nclasses = value;


303  }


304  }


305  [Storable]


306  private int RandomForestNTrees {


307  get {


308  if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;


309  else return 0;


310  }


311  set {


312  randomForest.innerobj.ntrees = value;


313  }


314  }


315  [Storable]


316  private int RandomForestNVars {


317  get {


318  if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;


319  else return 0;


320  }


321  set {


322  randomForest.innerobj.nvars = value;


323  }


324  }


325  [Storable]


326  private double[] RandomForestTrees {


327  get {


328  if (IsCompatibilityLoaded) return randomForest.innerobj.trees;


329  else return new double[] { };


330  }


331  set {


332  randomForest.innerobj.trees = value;


333  }


334  }


335  #endregion


336  }


337  }

