#region License Information /* HeuristicLab * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace HeuristicLab.Algorithms.DataAnalysis { /// /// Represents a random forest model for regression and classification /// [StorableClass] [Item("RandomForestModel", "Represents a random forest for regression and classification.")] public sealed class RandomForestModel : NamedItem, IRandomForestModel { // not persisted private alglib.decisionforest randomForest; private alglib.decisionforest RandomForest { get { // recalculate lazily if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel(); return randomForest; } } // instead of storing the data of the model itself // we instead only store data necessary to recalculate the same model lazily on demand [Storable] private int seed; [Storable] private IDataAnalysisProblemData originalTrainingData; [Storable] private double[] classValues; [Storable] private int nTrees; [Storable] private double r; [Storable] private double m; [StorableConstructor] private RandomForestModel(bool deserializing) : base(deserializing) { // for backwards compatibility (loading old solutions) randomForest = new alglib.decisionforest(); } private RandomForestModel(RandomForestModel original, Cloner cloner) : base(original, cloner) { randomForest = new alglib.decisionforest(); randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize; randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses; randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees; randomForest.innerobj.nvars = original.randomForest.innerobj.nvars; // we assume that the trees array (double[]) is immutable in alglib randomForest.innerobj.trees = original.randomForest.innerobj.trees; // allowedInputVariables is immutable so we don't need to clone allowedInputVariables = original.allowedInputVariables; // clone data which is necessary to rebuild the model this.seed = original.seed; this.originalTrainingData = cloner.Clone(original.originalTrainingData); // classvalues is immutable so we don't need to clone this.classValues = original.classValues; this.nTrees = original.nTrees; this.r = original.r; this.m = original.m; } // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel private RandomForestModel(alglib.decisionforest randomForest, int seed, IDataAnalysisProblemData originalTrainingData, int nTrees, double r, double m, double[] classValues = null) : base() { this.name = ItemName; this.description = ItemDescription; // the model itself this.randomForest = randomForest; // data which is necessary for recalculation of the model this.seed = seed; this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone(); this.classValues = classValues; this.nTrees = nTrees; this.r = r; this.m = m; } public override IDeepCloneable Clone(Cloner cloner) { return new RandomForestModel(this, cloner); } private void RecalculateModel() { double rmsError, oobRmsError, relClassError, oobRelClassError; var regressionProblemData = originalTrainingData as IRegressionProblemData; var classificationProblemData = originalTrainingData as IClassificationProblemData; if (regressionProblemData != null) { var model = CreateRegressionModel(regressionProblemData, nTrees, r, m, seed, out rmsError, out oobRmsError, out relClassError, out oobRelClassError); randomForest = model.randomForest; } else if (classificationProblemData != null) { var model = CreateClassificationModel(classificationProblemData, nTrees, r, m, seed, out rmsError, out oobRmsError, out relClassError, out oobRelClassError); randomForest = model.randomForest; } } public IEnumerable GetEstimatedValues(Dataset dataset, IEnumerable rows) { double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); AssertInputMatrix(inputData); int n = inputData.GetLength(0); int columns = inputData.GetLength(1); double[] x = new double[columns]; double[] y = new double[1]; for (int row = 0; row < n; row++) { for (int column = 0; column < columns; column++) { x[column] = inputData[row, column]; } alglib.dfprocess(RandomForest, x, ref y); yield return y[0]; } } public IEnumerable GetEstimatedClassValues(Dataset dataset, IEnumerable rows) { double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); AssertInputMatrix(inputData); int n = inputData.GetLength(0); int columns = inputData.GetLength(1); double[] x = new double[columns]; double[] y = new double[RandomForest.innerobj.nclasses]; for (int row = 0; row < n; row++) { for (int column = 0; column < columns; column++) { x[column] = inputData[row, column]; } alglib.dfprocess(randomForest, x, ref y); // find class for with the largest probability value int maxProbClassIndex = 0; double maxProb = y[0]; for (int i = 1; i < y.Length; i++) { if (maxProb < y[i]) { maxProb = y[i]; maxProbClassIndex = i; } } yield return classValues[maxProbClassIndex]; } } public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this); } IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { return CreateRegressionSolution(problemData); } public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this); } IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { return CreateClassificationSolution(problemData); } public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError); } public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable trainingIndices, int nTrees, double r, double m, int seed, out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices); alglib.dfreport rep; var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep); rmsError = rep.rmserror; avgRelError = rep.avgrelerror; outOfBagAvgRelError = rep.oobavgrelerror; outOfBagRmsError = rep.oobrmserror; return new RandomForestModel(dForest,seed, problemData,nTrees, r, m); } public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError); } public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable trainingIndices, int nTrees, double r, double m, int seed, out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices); var classValues = problemData.ClassValues.ToArray(); int nClasses = classValues.Length; // map original class values to values [0..nClasses-1] var classIndices = new Dictionary(); for (int i = 0; i < nClasses; i++) { classIndices[classValues[i]] = i; } int nRows = inputMatrix.GetLength(0); int nColumns = inputMatrix.GetLength(1); for (int row = 0; row < nRows; row++) { inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; } alglib.dfreport rep; var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep); rmsError = rep.rmserror; outOfBagRmsError = rep.oobrmserror; relClassificationError = rep.relclserror; outOfBagRelClassificationError = rep.oobrelclserror; return new RandomForestModel(dForest,seed, problemData,nTrees, r, m, classValues); } private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) { AssertParameters(r, m); AssertInputMatrix(inputMatrix); int info = 0; alglib.math.rndobject = new System.Random(seed); var dForest = new alglib.decisionforest(); rep = new alglib.dfreport(); int nRows = inputMatrix.GetLength(0); int nColumns = inputMatrix.GetLength(1); int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); if (info != 1) throw new ArgumentException("Error in calculation of random forest model"); return dForest; } private static void AssertParameters(double r, double m) { if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1."); if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1."); } private static void AssertInputMatrix(double[,] inputMatrix) { if (inputMatrix.Cast().Any(x => Double.IsNaN(x) || Double.IsInfinity(x))) throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); } #region persistence for backwards compatibility // when the originalTrainingData is null this means the model was loaded from an old file // therefore, we cannot use the new persistence mechanism because the original data is not available anymore // in such cases we still store the compete model private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } } private string[] allowedInputVariables; [Storable(Name = "allowedInputVariables")] private string[] AllowedInputVariables { get { if (IsCompatibilityLoaded) return allowedInputVariables; else return originalTrainingData.AllowedInputVariables.ToArray(); } set { allowedInputVariables = value; } } [Storable] private int RandomForestBufSize { get { if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize; else return 0; } set { randomForest.innerobj.bufsize = value; } } [Storable] private int RandomForestNClasses { get { if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses; else return 0; } set { randomForest.innerobj.nclasses = value; } } [Storable] private int RandomForestNTrees { get { if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees; else return 0; } set { randomForest.innerobj.ntrees = value; } } [Storable] private int RandomForestNVars { get { if (IsCompatibilityLoaded) return randomForest.innerobj.nvars; else return 0; } set { randomForest.innerobj.nvars = value; } } [Storable] private double[] RandomForestTrees { get { if (IsCompatibilityLoaded) return randomForest.innerobj.trees; else return new double[] { }; } set { randomForest.innerobj.trees = value; } } #endregion } }