Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/08/14 18:10:20 (11 years ago)
Author:
gkronber
Message:

#1721: changed the RF model to store the original problem data directly and fixed bugs in backwards compatibility loading and saving.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs

    r10321 r10322  
    4040      get {
    4141        // recalculate lazily
    42         if (randomForest == null || randomForest.innerobj.trees == null) RecalculateModel();
     42        if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel();
    4343        return randomForest;
    4444      }
     
    5050    private int seed;
    5151    [Storable]
    52     private Dataset originalTrainingData;
    53     [Storable]
    54     private int[] trainingRows;
     52    private IDataAnalysisProblemData originalTrainingData;
    5553    [Storable]
    5654    private double[] classValues;
    57     [Storable]
    58     private string[] allowedInputVariables;
    59     [Storable]
    60     private string targetVariable;
    6155    [Storable]
    6256    private int nTrees;
     
    7569    private RandomForestModel(RandomForestModel original, Cloner cloner)
    7670      : base(original, cloner) {
    77       // clone the model if necessary
    78       if (original.randomForest != null) {
    79         randomForest = new alglib.decisionforest();
    80         randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
    81         randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
    82         randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
    83         randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
    84         // we assume that the trees array (double[]) is immutable in alglib
    85         randomForest.innerobj.trees = original.randomForest.innerobj.trees;
    86       }
     71      randomForest = new alglib.decisionforest();
     72      randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize;
     73      randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses;
     74      randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees;
     75      randomForest.innerobj.nvars = original.randomForest.innerobj.nvars;
     76      // we assume that the trees array (double[]) is immutable in alglib
     77      randomForest.innerobj.trees = original.randomForest.innerobj.trees;
     78     
     79      // allowedInputVariables is immutable so we don't need to clone
     80      allowedInputVariables = original.allowedInputVariables;
     81
    8782      // clone data which is necessary to rebuild the model
    8883      this.seed = original.seed;
    89       // Dataset is immutable so we do not need to clone
    90       this.originalTrainingData = original.originalTrainingData;
    91       this.targetVariable = original.targetVariable;
    92       // trainingRows, classvalues and allowedInputVariables are immutable
    93       this.trainingRows = original.trainingRows;
    94       this.allowedInputVariables = original.allowedInputVariables;
    95       this.classValues = original.classValues;  // null for regression problems
    96 
     84      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
     85      // classvalues is immutable so we don't need to clone
     86      this.classValues = original.classValues;
    9787      this.nTrees = original.nTrees;
    9888      this.r = original.r;
     
    10292    // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel
    10393    private RandomForestModel(alglib.decisionforest randomForest,
    104       int seed, Dataset originalTrainingData, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable,
     94      int seed, IDataAnalysisProblemData originalTrainingData,
    10595      int nTrees, double r, double m, double[] classValues = null)
    10696      : base() {
     
    111101      // data which is necessary for recalculation of the model
    112102      this.seed = seed;
    113       this.originalTrainingData = originalTrainingData;
    114       this.trainingRows = (int[])trainingRows.ToArray().Clone();
    115       this.targetVariable = targetVariable;
    116       this.allowedInputVariables = (string[])allowedInputVariables.ToArray().Clone();
    117       this.classValues = classValues; // null for regression problems
     103      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
     104      this.classValues = classValues;
    118105      this.nTrees = nTrees;
    119106      this.r = r;
     
    127114    private void RecalculateModel() {
    128115      double rmsError, oobRmsError, relClassError, oobRelClassError;
    129       if (classValues == null) {
    130         var model = CreateRegressionModel(originalTrainingData, trainingRows, allowedInputVariables, targetVariable,
     116      var regressionProblemData = originalTrainingData as IRegressionProblemData;
     117      var classificationProblemData = originalTrainingData as IClassificationProblemData;
     118      if (regressionProblemData != null) {
     119        var model = CreateRegressionModel(regressionProblemData,
    131120                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
    132121                                              out relClassError, out oobRelClassError);
    133122        randomForest = model.randomForest;
    134       } else {
    135         var model = CreateClassificationModel(originalTrainingData, trainingRows, allowedInputVariables, targetVariable,
    136                                               classValues, nTrees, r, m, seed, out rmsError, out oobRmsError,
     123      } else if (classificationProblemData != null) {
     124        var model = CreateClassificationModel(classificationProblemData,
     125                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
    137126                                              out relClassError, out oobRelClassError);
    138127        randomForest = model.randomForest;
     
    141130
    142131    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    143       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     132      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
    144133      AssertInputMatrix(inputData);
    145134
     
    159148
    160149    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    161       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     150      double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows);
    162151      AssertInputMatrix(inputData);
    163152
     
    200189    public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,
    201190      out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
    202       return CreateRegressionModel(problemData.Dataset, problemData.TrainingIndices,
    203                                      problemData.AllowedInputVariables, problemData.TargetVariable,
    204                                      nTrees, r, m, seed,
    205                                      out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError);
    206     }
    207 
    208     private static RandomForestModel CreateRegressionModel(
    209       // prob param
    210         Dataset dataset, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable,
    211       // alg param
    212         int nTrees, double r, double m, int seed,
    213       // results
    214         out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {
    215 
    216       var variables = allowedInputVariables.Concat(new string[] { targetVariable });
    217       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, variables, trainingRows);
     191
     192      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     193      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
    218194
    219195      alglib.dfreport rep;
     
    226202
    227203      return new RandomForestModel(dForest,
    228         seed, dataset, trainingRows, allowedInputVariables, targetVariable,
     204        seed, problemData,
    229205        nTrees, r, m);
    230206    }
     
    232208    public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,
    233209      out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
    234       return CreateClassificationModel(problemData.Dataset, problemData.TrainingIndices,
    235                                        problemData.AllowedInputVariables, problemData.TargetVariable,
    236                                        problemData.ClassValues.ToArray(),
    237                                        nTrees, r, m, seed,
    238                                        out rmsError, out outOfBagRmsError, out relClassificationError,
    239                                        out outOfBagRelClassificationError);
    240     }
    241 
    242     private static RandomForestModel CreateClassificationModel(
    243       // prob param
    244         Dataset dataset, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable, double[] classValues,
    245       // alg param
    246         int nTrees, double r, double m, int seed,
    247       // results
    248         out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) {
    249 
    250       var variables = allowedInputVariables.Concat(new string[] { targetVariable });
    251       double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, variables, trainingRows);
    252 
     210
     211      var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable });
     212      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);
     213
     214      var classValues = problemData.ClassValues.ToArray();
    253215      int nClasses = classValues.Length;
    254216
     
    274236
    275237      return new RandomForestModel(dForest,
    276         seed, dataset, trainingRows, allowedInputVariables, targetVariable,
     238        seed, problemData,
    277239        nTrees, r, m, classValues);
    278240    }
     
    284246      int info = 0;
    285247      alglib.math.rndobject = new System.Random(seed);
    286       alglib.decisionforest dForest = new alglib.decisionforest();
     248      var dForest = new alglib.decisionforest();
    287249      rep = new alglib.dfreport();
    288250      int nRows = inputMatrix.GetLength(0);
     
    307269
    308270    #region persistence for backwards compatibility
    309     [Storable(AllowOneWay = true)]
     271    // when the originalTrainingData is null this means the model was loaded from an old file
     272    // therefore, we cannot use the new persistence mechanism because the original data is not available anymore
     273    // in such cases we still store the compete model
     274    private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } }
     275
     276    private string[] allowedInputVariables;
     277    [Storable(Name = "allowedInputVariables")]
     278    private string[] AllowedInputVariables {
     279      get {
     280        if (IsCompatibilityLoaded) return allowedInputVariables;
     281        else return originalTrainingData.AllowedInputVariables.ToArray();
     282      }
     283      set { allowedInputVariables = value; }
     284    }
     285    [Storable]
    310286    private int RandomForestBufSize {
     287      get {
     288        if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize;
     289        else return 0;
     290      }
    311291      set {
    312292        randomForest.innerobj.bufsize = value;
    313293      }
    314294    }
    315     [Storable(AllowOneWay = true)]
     295    [Storable]
    316296    private int RandomForestNClasses {
     297      get {
     298        if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses;
     299        else return 0;
     300      }
    317301      set {
    318302        randomForest.innerobj.nclasses = value;
    319303      }
    320304    }
    321     [Storable(AllowOneWay = true)]
     305    [Storable]
    322306    private int RandomForestNTrees {
     307      get {
     308        if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees;
     309        else return 0;
     310      }
    323311      set {
    324312        randomForest.innerobj.ntrees = value;
    325313      }
    326314    }
    327     [Storable(AllowOneWay = true)]
     315    [Storable]
    328316    private int RandomForestNVars {
     317      get {
     318        if (IsCompatibilityLoaded) return randomForest.innerobj.nvars;
     319        else return 0;
     320      }
    329321      set {
    330322        randomForest.innerobj.nvars = value;
    331323      }
    332324    }
    333     [Storable(AllowOneWay = true)]
     325    [Storable]
    334326    private double[] RandomForestTrees {
     327      get {
     328        if (IsCompatibilityLoaded) return randomForest.innerobj.trees;
     329        else return new double[] { };
     330      }
    335331      set {
    336332        randomForest.innerobj.trees = value;
Note: See TracChangeset for help on using the changeset viewer.