Free cookie consent management tool by TermsFeed Policy Generator

Changeset 11006


Ignore:
Timestamp:
06/12/14 10:16:17 (10 years ago)
Author:
gkronber
Message:

#1721: merged improved random forest persistence from trunk to stable branch

Location:
stable
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r9456 r11006  
    131131    public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    132132      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    133       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
    134       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
    135 
    136       alglib.math.rndobject = new System.Random(seed);
    137 
    138       Dataset dataset = problemData.Dataset;
    139       string targetVariable = problemData.TargetVariable;
    140       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    141       IEnumerable<int> rows = problemData.TrainingIndices;
    142       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    143       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    144         throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset.");
    145 
    146       int info = 0;
    147       alglib.decisionforest dForest = new alglib.decisionforest();
    148       alglib.dfreport rep = new alglib.dfreport(); ;
    149       int nRows = inputMatrix.GetLength(0);
    150       int nColumns = inputMatrix.GetLength(1);
    151       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    152       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    153 
    154 
    155       double[] classValues = problemData.ClassValues.ToArray();
    156       int nClasses = problemData.Classes;
    157       // map original class values to values [0..nClasses-1]
    158       Dictionary<double, double> classIndices = new Dictionary<double, double>();
    159       for (int i = 0; i < nClasses; i++) {
    160         classIndices[classValues[i]] = i;
    161       }
    162       for (int row = 0; row < nRows; row++) {
    163         inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
    164       }
    165       // execute random forest algorithm     
    166       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    167       if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution");
    168 
    169       rmsError = rep.rmserror;
    170       outOfBagRmsError = rep.oobrmserror;
    171       relClassificationError = rep.relclserror;
    172       outOfBagRelClassificationError = rep.oobrelclserror;
    173       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues));
     133      var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     134      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
    174135    }
    175136    #endregion
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r9456 r11006  
    3535  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    3636  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
    37 
     37    // not persisted
    3838    private alglib.decisionforest randomForest;
    39     public alglib.decisionforest RandomForest {
    40       get { return randomForest; }
    41       set {
    42         if (value != randomForest) {
    43           if (value == null) throw new ArgumentNullException();
    44           randomForest = value;
    45           OnChanged(EventArgs.Empty);
    46         }
    47       }
    48     }
    49 
    50     [Storable]
    51     private string targetVariable;
    52     [Storable]
    53     private string[] allowedInputVariables;
     39    private alglib.decisionforest RandomForest {
     40      get {
     41        // recalculate lazily
     42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
     43        return randomForest;
     44      }
     45    }
     46
     47    // instead of storing the data of the model itself
     48    // we instead only store data necessary to recalculate the same model lazily on demand
     49    [Storable]
     50    private int seed;
     51    [Storable]
     52    private IDataAnalysisProblemData originalTrainingData;
    5453    [Storable]
    5554    private double[] classValues;
     55    [Storable]
     56    private int nTrees;
     57    [Storable]
     58    private double r;
     59    [Storable]
     60    private double m;
     61
     62
    5663    [StorableConstructor]
    5764    private RandomForestModel(bool deserializing)
    5865      : base(deserializing) {
    59       if (deserializing)
    60         randomForest = new alglib.decisionforest();
     66      // for backwards compatibility (loading old solutions)
     67      randomForest = new alglib.decisionforest();
    6168    }
    6269    private RandomForestModel(RandomForestModel original, Cloner cloner)
     
    6774      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
    6875      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
    69       randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
    70       targetVariable = original.targetVariable;
    71       allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    72       if (original.classValues != null)
    73         this.classValues = (double[])original.classValues.Clone();
    74     }
    75     public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
     76      // we assume that the trees array (double[]) is immutable in alglib
     77      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
     78     
     79      // allowedInputVariables is immutable so we don't need to clone
     80      allowedInputVariables = original.allowedInputVariables;
     81
     82      // clone data which is necessary to rebuild the model
     83      this.seed = original.seed;
     84      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
     85      // classvalues is immutable so we don't need to clone
     86      this.classValues = original.classValues;
     87      this.nTrees = original.nTrees;
     88      this.r = original.r;
     89      this.m = original.m;
     90    }
     91
     92    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
     93    private RandomForestModel(alglib.decisionforest randomForest,
     94      int seed, IDataAnalysisProblemData originalTrainingData,
     95      int nTrees, double r, double m, double[] classValues = null)
    7696      : base() {
    7797      this.name = ItemName;
    7898      this.description = ItemDescription;
     99      // the model itself
    79100      this.randomForest = randomForest;
    80       this.targetVariable = targetVariable;
    81       this.allowedInputVariables = allowedInputVariables.ToArray();
    82       if (classValues != null)
    83         this.classValues = (double[])classValues.Clone();
     101      // data which is necessary for recalculation of the model
     102      this.seed = seed;
     103      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
     104      this.classValues = classValues;
     105      this.nTrees = nTrees;
     106      this.r = r;
     107      this.m = m;
    84108    }
    85109
     
    88112    }
    89113
     114    private void RecalculateModel() {
     115      double rmsError, oobRmsError, relClassError, oobRelClassError;
     116      var regressionProblemData = originalTrainingData as IRegressionProblemData;
     117      var classificationProblemData = originalTrainingData as IClassificationProblemData;
     118      if (regressionProblemData != null) {
     119        var model = CreateRegressionModel(regressionProblemData,
     120                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
     121                                              out relClassError, out oobRelClassError);
     122        randomForest = model.randomForest;
     123      } else if (classificationProblemData != null) {
     124        var model = CreateClassificationModel(classificationProblemData,
     125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
     126                                              out relClassError, out oobRelClassError);
     127        randomForest = model.randomForest;
     128      }
     129    }
     130
    90131    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    91       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
     133      AssertInputMatrix(inputData);
    92134
    93135      int n = inputData.GetLength(0);
     
    100142          x[column] = inputData[row, column];
    101143        }
    102         alglib.dfprocess(randomForest, x, ref y);
     144        alglib.dfprocess(RandomForest, x, ref y);
    103145        yield return y[0];
    104146      }
     
    106148
    107149    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    108       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
     151      AssertInputMatrix(inputData);
    109152
    110153      int n = inputData.GetLength(0);
    111154      int columns = inputData.GetLength(1);
    112155      double[] x = new double[columns];
    113       double[] y = new double[randomForest.innerobj.nclasses];
     156      double[] y = new double[RandomForest.innerobj.nclasses];
    114157
    115158      for (int row = 0; row < n; row++) {
     
    144187    }
    145188
    146     #region events
    147     public event EventHandler Changed;
    148     private void OnChanged(EventArgs e) {
    149       var handlers = Changed;
    150       if (handlers != null)
    151         handlers(this, e);
    152     }
    153     #endregion
    154 
    155     #region persistence
     189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
     190      out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
     191
     192      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     193      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
     194
     195      alglib.dfreport rep;
     196      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     197
     198      rmsError = rep.rmserror;
     199      avgRelError = rep.avgrelerror;
     200      outOfBagAvgRelError = rep.oobavgrelerror;
     201      outOfBagRmsError = rep.oobrmserror;
     202
     203      return new RandomForestModel(dForest,
     204        seed, problemData,
     205        nTrees, r, m);
     206    }
     207
     208    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     209      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
     210
     211      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     212      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
     213
     214      var classValues = problemData.ClassValues.ToArray();
     215      int nClasses = classValues.Length;
     216
     217      // map original class values to values [0..nClasses-1]
     218      var classIndices = new Dictionary<double, double>();
     219      for (int i = 0; i < nClasses; i++) {
     220        classIndices[classValues[i]] = i;
     221      }
     222
     223      int nRows = inputMatrix.GetLength(0);
     224      int nColumns = inputMatrix.GetLength(1);
     225      for (int row = 0; row < nRows; row++) {
     226        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
     227      }
     228
     229      alglib.dfreport rep;
     230      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     231
     232      rmsError = rep.rmserror;
     233      outOfBagRmsError = rep.oobrmserror;
     234      relClassificationError = rep.relclserror;
     235      outOfBagRelClassificationError = rep.oobrelclserror;
     236
     237      return new RandomForestModel(dForest,
     238        seed, problemData,
     239        nTrees, r, m, classValues);
     240    }
     241
     242    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
     243      AssertParameters(r, m);
     244      AssertInputMatrix(inputMatrix);
     245
     246      int info = 0;
     247      alglib.math.rndobject = new System.Random(seed);
     248      var dForest = new alglib.decisionforest();
     249      rep = new alglib.dfreport();
     250      int nRows = inputMatrix.GetLength(0);
     251      int nColumns = inputMatrix.GetLength(1);
     252      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     253      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
     254
     255      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
     256      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
     257      return dForest;
     258    }
     259
     260    private static void AssertParameters(double r, double m) {
     261      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
     262      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
     263    }
     264
     265    private static void AssertInputMatrix(double[,] inputMatrix) {
     266      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
     267        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
     268    }
     269
     270    #region persistence for backwards compatibility
     271    // when the originalTrainingData is null this means the model was loaded from an old file
     272    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
     273    // in such cases we still store the compete model
     274    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
     275
     276    private string[] allowedInputVariables;
     277    [Storable(Name = "allowedInputVariables")]
     278    private string[] AllowedInputVariables {
     279      get {
     280        if (IsCompatibilityLoaded) return allowedInputVariables;
     281        else return originalTrainingData.AllowedInputVariables.ToArray();
     282      }
     283      set { allowedInputVariables = value; }
     284    }
    156285    [Storable]
    157286    private int RandomForestBufSize {
    158287      get {
    159         return randomForest.innerobj.bufsize;
     288        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
     289        else return 0;
    160290      }
    161291      set {
     
    166296    private int RandomForestNClasses {
    167297      get {
    168         return randomForest.innerobj.nclasses;
     298        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
     299        else return 0;
    169300      }
    170301      set {
     
    175306    private int RandomForestNTrees {
    176307      get {
    177         return randomForest.innerobj.ntrees;
     308        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
     309        else return 0;
    178310      }
    179311      set {
     
    184316    private int RandomForestNVars {
    185317      get {
    186         return randomForest.innerobj.nvars;
     318        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
     319        else return 0;
    187320      }
    188321      set {
     
    193326    private double[] RandomForestTrees {
    194327      get {
    195         return randomForest.innerobj.trees;
     328        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
     329        else return new double[] { };
    196330      }
    197331      set {
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r9456 r11006  
    130130    public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    131131      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    132       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
    133       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
    134 
    135       alglib.math.rndobject = new System.Random(seed);
    136 
    137       Dataset dataset = problemData.Dataset;
    138       string targetVariable = problemData.TargetVariable;
    139       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    140       IEnumerable<int> rows = problemData.TrainingIndices;
    141       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    142       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    143         throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset.");
    144 
    145       int info = 0;
    146       alglib.decisionforest dForest = new alglib.decisionforest();
    147       alglib.dfreport rep = new alglib.dfreport(); ;
    148       int nRows = inputMatrix.GetLength(0);
    149       int nColumns = inputMatrix.GetLength(1);
    150       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    151       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    152 
    153       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    154       if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution");
    155 
    156       rmsError = rep.rmserror;
    157       avgRelError = rep.avgrelerror;
    158       outOfBagAvgRelError = rep.oobavgrelerror;
    159       outOfBagRmsError = rep.oobrmserror;
    160 
    161       return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables));
     132      var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     133      return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model);
    162134    }
    163135    #endregion
Note: See TracChangeset for help on using the changeset viewer.