Changeset 10321


Ignore:
Timestamp:
01/08/14 17:33:53 (7 years ago)
Author:
gkronber
Message:

#1721 refactored RandormForestModel and changed persistence (store data and parameters instead of model)

Location:
branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs

    r9456 r10321  
    131131    public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    132132      out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) {
    133       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
    134       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
    135 
    136       alglib.math.rndobject = new System.Random(seed);
    137 
    138       Dataset dataset = problemData.Dataset;
    139       string targetVariable = problemData.TargetVariable;
    140       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    141       IEnumerable<int> rows = problemData.TrainingIndices;
    142       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    143       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    144         throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset.");
    145 
    146       int info = 0;
    147       alglib.decisionforest dForest = new alglib.decisionforest();
    148       alglib.dfreport rep = new alglib.dfreport(); ;
    149       int nRows = inputMatrix.GetLength(0);
    150       int nColumns = inputMatrix.GetLength(1);
    151       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    152       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    153 
    154 
    155       double[] classValues = problemData.ClassValues.ToArray();
    156       int nClasses = problemData.Classes;
    157       // map original class values to values [0..nClasses-1]
    158       Dictionary<double, double> classIndices = new Dictionary<double, double>();
    159       for (int i = 0; i < nClasses; i++) {
    160         classIndices[classValues[i]] = i;
    161       }
    162       for (int row = 0; row < nRows; row++) {
    163         inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
    164       }
    165       // execute random forest algorithm     
    166       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    167       if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution");
    168 
    169       rmsError = rep.rmserror;
    170       outOfBagRmsError = rep.oobrmserror;
    171       relClassificationError = rep.relclserror;
    172       outOfBagRelClassificationError = rep.oobrelclserror;
    173       return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues));
     133      var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError);
     134      return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model);
    174135    }
    175136    #endregion
  • branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r9456 r10321  
    3535  [Item("RandomForestModel", "Represents a random forest for regression and classification.")]
    3636  public sealed class RandomForestModel : NamedItem, IRandomForestModel {
    37 
     37    // not persisted
    3838    private alglib.decisionforest randomForest;
    39     public alglib.decisionforest RandomForest {
    40       get { return randomForest; }
    41       set {
    42         if (value != randomForest) {
    43           if (value == null) throw new ArgumentNullException();
    44           randomForest = value;
    45           OnChanged(EventArgs.Empty);
    46         }
    47       }
    48     }
    49 
     39    private alglib.decisionforest RandomForest {
     40      get {
     41        // recalculate lazily
     42        if (randomForest == null || randomForest.innerobj.trees == null) RecalculateModel();
     43        return randomForest;
     44      }
     45    }
     46
     47    // instead of storing the data of the model itself
     48    // we instead only store data necessary to recalculate the same model lazily on demand
     49    [Storable]
     50    private int seed;
     51    [Storable]
     52    private Dataset originalTrainingData;
     53    [Storable]
     54    private int[] trainingRows;
     55    [Storable]
     56    private double[] classValues;
     57    [Storable]
     58    private string[] allowedInputVariables;
    5059    [Storable]
    5160    private string targetVariable;
    5261    [Storable]
    53     private string[] allowedInputVariables;
    54     [Storable]
    55     private double[] classValues;
     62    private int nTrees;
     63    [Storable]
     64    private double r;
     65    [Storable]
     66    private double m;
     67
     68
    5669    [StorableConstructor]
    5770    private RandomForestModel(bool deserializing)
    5871      : base(deserializing) {
    59       if (deserializing)
    60         randomForest = new alglib.decisionforest();
     72      // for backwards compatibility (loading old solutions)
     73      randomForest = new alglib.decisionforest();
    6174    }
    6275    private RandomForestModel(RandomForestModel original, Cloner cloner)
    6376      : base(original, cloner) {
    64       randomForest = new alglib.decisionforest();
    65       randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
    66       randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
    67       randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
    68       randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
    69       randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone();
    70       targetVariable = original.targetVariable;
    71       allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    72       if (original.classValues != null)
    73         this.classValues = (double[])original.classValues.Clone();
    74     }
    75     public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
     77      // clone the model if necessary
     78      if (original.randomForest != null) {
     79        randomForest = new alglib.decisionforest();
     80        randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
     81        randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
     82        randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
     83        randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
     84        // we assume that the trees array (double[]) is immutable in alglib
     85        randomForest.innerobj.trees = original.randomForest.innerobj.trees;
     86      }
     87      // clone data which is necessary to rebuild the model
     88      this.seed = original.seed;
     89      // Dataset is immutable so we do not need to clone
     90      this.originalTrainingData = original.originalTrainingData;
     91      this.targetVariable = original.targetVariable;
     92      // trainingRows, classvalues and allowedInputVariables are immutable
     93      this.trainingRows = original.trainingRows;
     94      this.allowedInputVariables = original.allowedInputVariables;
     95      this.classValues = original.classValues;  // null for regression problems
     96
     97      this.nTrees = original.nTrees;
     98      this.r = original.r;
     99      this.m = original.m;
     100    }
     101
     102    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
     103    private RandomForestModel(alglib.decisionforest randomForest,
     104      int seed, Dataset originalTrainingData, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable,
     105      int nTrees, double r, double m, double[] classValues = null)
    76106      : base() {
    77107      this.name = ItemName;
    78108      this.description = ItemDescription;
     109      // the model itself
    79110      this.randomForest = randomForest;
     111      // data which is necessary for recalculation of the model
     112      this.seed = seed;
     113      this.originalTrainingData = originalTrainingData;
     114      this.trainingRows = (int[])trainingRows.ToArray().Clone();
    80115      this.targetVariable = targetVariable;
    81       this.allowedInputVariables = allowedInputVariables.ToArray();
    82       if (classValues != null)
    83         this.classValues = (double[])classValues.Clone();
     116      this.allowedInputVariables = (string[])allowedInputVariables.ToArray().Clone();
     117      this.classValues = classValues; // null for regression problems
     118      this.nTrees = nTrees;
     119      this.r = r;
     120      this.m = m;
    84121    }
    85122
     
    88125    }
    89126
     127    private void RecalculateModel() {
     128      double rmsError, oobRmsError, relClassError, oobRelClassError;
     129      if (classValues == null) {
     130        var model = CreateRegressionModel(originalTrainingData, trainingRows, allowedInputVariables, targetVariable,
     131                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
     132                                              out relClassError, out oobRelClassError);
     133        randomForest = model.randomForest;
     134      } else {
     135        var model = CreateClassificationModel(originalTrainingData, trainingRows, allowedInputVariables, targetVariable,
     136                                              classValues, nTrees, r, m, seed, out rmsError, out oobRmsError,
     137                                              out relClassError, out oobRelClassError);
     138        randomForest = model.randomForest;
     139      }
     140    }
     141
    90142    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    91143      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     144      AssertInputMatrix(inputData);
    92145
    93146      int n = inputData.GetLength(0);
     
    100153          x[column] = inputData[row, column];
    101154        }
    102         alglib.dfprocess(randomForest, x, ref y);
     155        alglib.dfprocess(RandomForest, x, ref y);
    103156        yield return y[0];
    104157      }
     
    107160    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    108161      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     162      AssertInputMatrix(inputData);
    109163
    110164      int n = inputData.GetLength(0);
    111165      int columns = inputData.GetLength(1);
    112166      double[] x = new double[columns];
    113       double[] y = new double[randomForest.innerobj.nclasses];
     167      double[] y = new double[RandomForest.innerobj.nclasses];
    114168
    115169      for (int row = 0; row < n; row++) {
     
    144198    }
    145199
    146     #region events
    147     public event EventHandler Changed;
    148     private void OnChanged(EventArgs e) {
    149       var handlers = Changed;
    150       if (handlers != null)
    151         handlers(this, e);
    152     }
    153     #endregion
    154 
    155     #region persistence
    156     [Storable]
     200    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
     201      out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
     202      return CreateRegressionModel(problemData.Dataset, problemData.TrainingIndices,
     203                                     problemData.AllowedInputVariables, problemData.TargetVariable,
     204                                     nTrees, r, m, seed,
     205                                     out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
     206    }
     207
     208    private static RandomForestModel CreateRegressionModel(
     209      // prob param
     210        Dataset dataset, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable,
     211      // alg param
     212        int nTrees, double r, double m, int seed,
     213      // results
     214        out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
     215
     216      var variables = allowedInputVariables.Concat(new string[] { targetVariable });
     217      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, variables, trainingRows);
     218
     219      alglib.dfreport rep;
     220      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);
     221
     222      rmsError = rep.rmserror;
     223      avgRelError = rep.avgrelerror;
     224      outOfBagAvgRelError = rep.oobavgrelerror;
     225      outOfBagRmsError = rep.oobrmserror;
     226
     227      return new RandomForestModel(dForest,
     228        seed, dataset, trainingRows, allowedInputVariables, targetVariable,
     229        nTrees, r, m);
     230    }
     231
     232    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
     233      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
     234      return CreateClassificationModel(problemData.Dataset, problemData.TrainingIndices,
     235                                       problemData.AllowedInputVariables, problemData.TargetVariable,
     236                                       problemData.ClassValues.ToArray(),
     237                                       nTrees, r, m, seed,
     238                                       out rmsError, out outOfBagRmsError, out relClassificationError,
     239                                       out outOfBagRelClassificationError);
     240    }
     241
     242    private static RandomForestModel CreateClassificationModel(
     243      // prob param
     244        Dataset dataset, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable, double[] classValues,
     245      // alg param
     246        int nTrees, double r, double m, int seed,
     247      // results
     248        out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
     249
     250      var variables = allowedInputVariables.Concat(new string[] { targetVariable });
     251      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, variables, trainingRows);
     252
     253      int nClasses = classValues.Length;
     254
     255      // map original class values to values [0..nClasses-1]
     256      var classIndices = new Dictionary<double, double>();
     257      for (int i = 0; i < nClasses; i++) {
     258        classIndices[classValues[i]] = i;
     259      }
     260
     261      int nRows = inputMatrix.GetLength(0);
     262      int nColumns = inputMatrix.GetLength(1);
     263      for (int row = 0; row < nRows; row++) {
     264        inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]];
     265      }
     266
     267      alglib.dfreport rep;
     268      var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);
     269
     270      rmsError = rep.rmserror;
     271      outOfBagRmsError = rep.oobrmserror;
     272      relClassificationError = rep.relclserror;
     273      outOfBagRelClassificationError = rep.oobrelclserror;
     274
     275      return new RandomForestModel(dForest,
     276        seed, dataset, trainingRows, allowedInputVariables, targetVariable,
     277        nTrees, r, m, classValues);
     278    }
     279
     280    private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {
     281      AssertParameters(r, m);
     282      AssertInputMatrix(inputMatrix);
     283
     284      int info = 0;
     285      alglib.math.rndobject = new System.Random(seed);
     286      alglib.decisionforest dForest = new alglib.decisionforest();
     287      rep = new alglib.dfreport();
     288      int nRows = inputMatrix.GetLength(0);
     289      int nColumns = inputMatrix.GetLength(1);
     290      int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
     291      int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
     292
     293      alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
     294      if (info != 1) throw new ArgumentException("Error in calculation of random forest model");
     295      return dForest;
     296    }
     297
     298    private static void AssertParameters(double r, double m) {
     299      if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");
     300      if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");
     301    }
     302
     303    private static void AssertInputMatrix(double[,] inputMatrix) {
     304      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
     305        throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");
     306    }
     307
     308    #region persistence for backwards compatibility
     309    [Storable(AllowOneWay = true)]
    157310    private int RandomForestBufSize {
    158       get {
    159         return randomForest.innerobj.bufsize;
    160       }
    161311      set {
    162312        randomForest.innerobj.bufsize = value;
    163313      }
    164314    }
    165     [Storable]
     315    [Storable(AllowOneWay = true)]
    166316    private int RandomForestNClasses {
    167       get {
    168         return randomForest.innerobj.nclasses;
    169       }
    170317      set {
    171318        randomForest.innerobj.nclasses = value;
    172319      }
    173320    }
    174     [Storable]
     321    [Storable(AllowOneWay = true)]
    175322    private int RandomForestNTrees {
    176       get {
    177         return randomForest.innerobj.ntrees;
    178       }
    179323      set {
    180324        randomForest.innerobj.ntrees = value;
    181325      }
    182326    }
    183     [Storable]
     327    [Storable(AllowOneWay = true)]
    184328    private int RandomForestNVars {
    185       get {
    186         return randomForest.innerobj.nvars;
    187       }
    188329      set {
    189330        randomForest.innerobj.nvars = value;
    190331      }
    191332    }
    192     [Storable]
     333    [Storable(AllowOneWay = true)]
    193334    private double[] RandomForestTrees {
    194       get {
    195         return randomForest.innerobj.trees;
    196       }
    197335      set {
    198336        randomForest.innerobj.trees = value;
  • branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs

    r9456 r10321  
    130130    public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    131131      out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) {
    132       if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1.");
    133       if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1.");
    134 
    135       alglib.math.rndobject = new System.Random(seed);
    136 
    137       Dataset dataset = problemData.Dataset;
    138       string targetVariable = problemData.TargetVariable;
    139       IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
    140       IEnumerable<int> rows = problemData.TrainingIndices;
    141       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
    142       if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
    143         throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset.");
    144 
    145       int info = 0;
    146       alglib.decisionforest dForest = new alglib.decisionforest();
    147       alglib.dfreport rep = new alglib.dfreport(); ;
    148       int nRows = inputMatrix.GetLength(0);
    149       int nColumns = inputMatrix.GetLength(1);
    150       int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);
    151       int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);
    152 
    153       alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);
    154       if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution");
    155 
    156       rmsError = rep.rmserror;
    157       avgRelError = rep.avgrelerror;
    158       outOfBagAvgRelError = rep.oobavgrelerror;
    159       outOfBagRmsError = rep.oobrmserror;
    160 
    161       return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables));
     132      var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError);
     133      return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model);
    162134    }
    163135    #endregion
Note: See TracChangeset for help on using the changeset viewer.