Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/12/14 10:16:17 (10 years ago)
Author:
gkronber
Message:

#1721: merged improved random forest persistence from trunk to stable branch

Location:
stable
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r9456 r11006  
    3535  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    3636  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
    37 
     37    // not persisted
    3838    private alglib.decisionforest randomForest;
    39     public alglib.decisionforest RandomForest {
    40       get { return randomForest; }
    41       set {
    42         if (value != randomForest) {
    43           if (value == null) throw new ArgumentNullException();
    44           randomForest = value;
    45           OnChanged(EventArgs.Empty);
    46         }
    47       }
    48     }
    49 
    50     [Storable]
    51     private string targetVariable;
    52     [Storable]
    53     private string[] allowedInputVariables;
     39    private alglib.decisionforest RandomForest {
     40      get {
     41        // recalculate lazily
     42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
     43        return randomForest;
     44      }
     45    }
     46
     47    // instead of storing the data of the model itself
     48    // we instead only store data necessary to recalculate the same model lazily on demand
     49    [Storable]
     50    private int seed;
     51    [Storable]
     52    private IDataAnalysisProblemData originalTrainingData;
    5453    [Storable]
    5554    private double[] classValues;
     55    [Storable]
     56    private int nTrees;
     57    [Storable]
     58    private double r;
     59    [Storable]
     60    private double m;
     61
     62
    5663    [StorableConstructor]
    5764    private RandomForestModel(bool deserializing)
    5865      : base(deserializing) {
    59       if (deserializing)
    60         randomForest = new alglib.decisionforest();
     66      // for backwards compatibility (loading old solutions)
     67      randomForest = new alglib.decisionforest();
    6168    }
    6269    private RandomForestModel(RandomForestModel original, Cloner cloner)
     
    6774      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
    6875      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
    69       randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
    70       targetVariable = original.targetVariable;
    71       allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    72       if (original.classValues != null)
    73         this.classValues = (double[])original.classValues.Clone();
    74     }
    75     public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
     76      // we assume that the trees array (double[]) is immutable in alglib
     77      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
     78     
     79      // allowedInputVariables is immutable so we don't need to clone
     80      allowedInputVariables = original.allowedInputVariables;
     81
     82      // clone data which is necessary to rebuild the model
     83      this.seed = original.seed;
     84      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
     85      // classvalues is immutable so we don't need to clone
     86      this.classValues = original.classValues;
     87      this.nTrees = original.nTrees;
     88      this.r = original.r;
     89      this.m = original.m;
     90    }
     91
     92    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
     93    private RandomForestModel(alglib.decisionforest randomForest,
     94      int seed, IDataAnalysisProblemData originalTrainingData,
     95      int nTrees, double r, double m, double[] classValues = null)
    7696      : base() {
    7797      this.name = ItemName;
    7898      this.description = ItemDescription;
     99      // the model itself
    79100      this.randomForest = randomForest;
    80       this.targetVariable = targetVariable;
    81       this.allowedInputVariables = allowedInputVariables.ToArray();
    82       if (classValues != null)
    83         this.classValues = (double[])classValues.Clone();
     101      // data which is necessary for recalculation of the model
     102      this.seed = seed;
     103      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
     104      this.classValues = classValues;
     105      this.nTrees = nTrees;
     106      this.r = r;
     107      this.m = m;
    84108    }
    85109
     
    88112    }
    89113
     114    private void RecalculateModel() {
     115      double rmsError, oobRmsError, relClassError, oobRelClassError;
     116      var regressionProblemData = originalTrainingData as IRegressionProblemData;
     117      var classificationProblemData = originalTrainingData as IClassificationProblemData;
     118      if (regressionProblemData != null) {
     119        var model = CreateRegressionModel(regressionProblemData,
     120                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
     121                                              out relClassError, out oobRelClassError);
     122        randomForest = model.randomForest;
     123      } else if (classificationProblemData != null) {
     124        var model = CreateClassificationModel(classificationProblemData,
     125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
     126                                              out relClassError, out oobRelClassError);
     127        randomForest = model.randomForest;
     128      }
     129    }
     130
    90131    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    91       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
     133      AssertInputMatrix(inputData);
    92134
    93135      int n = inputData.GetLength(0);
     
    100142          x[column] = inputData[row, column];
    101143        }
    102         alglib.dfprocess(randomForest, x, ref y);
     144        alglib.dfprocess(RandomForest, x, ref y);
    103145        yield return y[0];
    104146      }
     
    106148
    107149    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    108       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
     151      AssertInputMatrix(inputData);
    109152
    110153      int n = inputData.GetLength(0);
    111154      int columns = inputData.GetLength(1);
    112155      double[] x = new double[columns];
    113       double[] y = new double[randomForest.innerobj.nclasses];
     156      double[] y = new double[RandomForest.innerobj.nclasses];
    114157
    115158      for (int row = 0; row < n; row++) {
     
    144187    }
    145188
    146     #region events
    147     public event EventHandler Changed;
    148     private void OnChanged(EventArgs e) {
    149       var handlers = Changed;
    150       if (handlers != null)
    151         handlers(this, e);
    152     }
    153     #endregion
    154 
    155     #region persistence
     189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
     190      out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
     191
     192      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     193      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
     194
     195      alglib.dfreport rep;
     196      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     197
     198      rmsError = rep.rmserror;
     199      avgRelError = rep.avgrelerror;
     200      outOfBagAvgRelError = rep.oobavgrelerror;
     201      outOfBagRmsError = rep.oobrmserror;
     202
     203      return new RandomForestModel(dForest,
     204        seed, problemData,
     205        nTrees, r, m);
     206    }
     207
     208    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     209      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
     210
     211      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     212      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
     213
     214      var classValues = problemData.ClassValues.ToArray();
     215      int nClasses = classValues.Length;
     216
     217      // map original class values to values [0..nClasses-1]
     218      var classIndices = new Dictionary<double, double>();
     219      for (int i = 0; i < nClasses; i++) {
     220        classIndices[classValues[i]] = i;
     221      }
     222
     223      int nRows = inputMatrix.GetLength(0);
     224      int nColumns = inputMatrix.GetLength(1);
     225      for (int row = 0; row < nRows; row++) {
     226        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
     227      }
     228
     229      alglib.dfreport rep;
     230      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     231
     232      rmsError = rep.rmserror;
     233      outOfBagRmsError = rep.oobrmserror;
     234      relClassificationError = rep.relclserror;
     235      outOfBagRelClassificationError = rep.oobrelclserror;
     236
     237      return new RandomForestModel(dForest,
     238        seed, problemData,
     239        nTrees, r, m, classValues);
     240    }
     241
     242    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
     243      AssertParameters(r, m);
     244      AssertInputMatrix(inputMatrix);
     245
     246      int info = 0;
     247      alglib.math.rndobject = new System.Random(seed);
     248      var dForest = new alglib.decisionforest();
     249      rep = new alglib.dfreport();
     250      int nRows = inputMatrix.GetLength(0);
     251      int nColumns = inputMatrix.GetLength(1);
     252      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     253      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
     254
     255      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
     256      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
     257      return dForest;
     258    }
     259
     260    private static void AssertParameters(double r, double m) {
     261      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
     262      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
     263    }
     264
     265    private static void AssertInputMatrix(double[,] inputMatrix) {
     266      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
     267        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
     268    }
     269
     270    #region persistence for backwards compatibility
     271    // when the originalTrainingData is null this means the model was loaded from an old file
     272    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
     273    // in such cases we still store the compete model
     274    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
     275
     276    private string[] allowedInputVariables;
     277    [Storable(Name = "allowedInputVariables")]
     278    private string[] AllowedInputVariables {
     279      get {
     280        if (IsCompatibilityLoaded) return allowedInputVariables;
     281        else return originalTrainingData.AllowedInputVariables.ToArray();
     282      }
     283      set { allowedInputVariables = value; }
     284    }
    156285    [Storable]
    157286    private int RandomForestBufSize {
    158287      get {
    159         return randomForest.innerobj.bufsize;
     288        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
     289        else return 0;
    160290      }
    161291      set {
     
    166296    private int RandomForestNClasses {
    167297      get {
    168         return randomForest.innerobj.nclasses;
     298        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
     299        else return 0;
    169300      }
    170301      set {
     
    175306    private int RandomForestNTrees {
    176307      get {
    177         return randomForest.innerobj.ntrees;
     308        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
     309        else return 0;
    178310      }
    179311      set {
     
    184316    private int RandomForestNVars {
    185317      get {
    186         return randomForest.innerobj.nvars;
     318        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
     319        else return 0;
    187320      }
    188321      set {
     
    193326    private double[] RandomForestTrees {
    194327      get {
    195         return randomForest.innerobj.trees;
     328        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
     329        else return new double[] { };
    196330      }
    197331      set {
Note: See TracChangeset for help on using the changeset viewer.