Changeset 10322 for branches/1721-RandomForestPersistence
- Timestamp:
- 01/08/14 18:10:20 (11 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r10321 r10322 40 40 get { 41 41 // recalculate lazily 42 if (randomForest == null || randomForest.innerobj.trees == null) RecalculateModel();42 if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel(); 43 43 return randomForest; 44 44 } … … 50 50 private int seed; 51 51 [Storable] 52 private Dataset originalTrainingData; 53 [Storable] 54 private int[] trainingRows; 52 private IDataAnalysisProblemData originalTrainingData; 55 53 [Storable] 56 54 private double[] classValues; 57 [Storable]58 private string[] allowedInputVariables;59 [Storable]60 private string targetVariable;61 55 [Storable] 62 56 private int nTrees; … … 75 69 private RandomForestModel(RandomForestModel original, Cloner cloner) 76 70 : base(original, cloner) { 77 // clone the model if necessary 78 if (original.randomForest != null) { 79 randomForest = new alglib.decisionforest(); 80 randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize; 81 randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses; 82 randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees; 83 randomForest.innerobj.nvars = original.randomForest.innerobj.nvars; 84 // we assume that the trees array (double[]) is immutable in alglib 85 randomForest.innerobj.trees = original.randomForest.innerobj.trees; 86 } 71 randomForest = new alglib.decisionforest(); 72 randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize; 73 randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses; 74 randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees; 75 randomForest.innerobj.nvars = original.randomForest.innerobj.nvars; 76 // we assume that the trees array (double[]) is immutable in alglib 77 randomForest.innerobj.trees = original.randomForest.innerobj.trees; 78 79 // allowedInputVariables is immutable so we don't need to clone 80 allowedInputVariables = original.allowedInputVariables; 81 87 82 // clone data which is necessary to rebuild the model 88 83 this.seed = original.seed; 89 // Dataset is immutable so we do not need to clone 90 this.originalTrainingData = original.originalTrainingData; 91 this.targetVariable = original.targetVariable; 92 // trainingRows, classvalues and allowedInputVariables are immutable 93 this.trainingRows = original.trainingRows; 94 this.allowedInputVariables = original.allowedInputVariables; 95 this.classValues = original.classValues; // null for regression problems 96 84 this.originalTrainingData = cloner.Clone(original.originalTrainingData); 85 // classvalues is immutable so we don't need to clone 86 this.classValues = original.classValues; 97 87 this.nTrees = original.nTrees; 98 88 this.r = original.r; … … 102 92 // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel 103 93 private RandomForestModel(alglib.decisionforest randomForest, 104 int seed, Dataset originalTrainingData, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable,94 int seed, IDataAnalysisProblemData originalTrainingData, 105 95 int nTrees, double r, double m, double[] classValues = null) 106 96 : base() { … … 111 101 // data which is necessary for recalculation of the model 112 102 this.seed = seed; 113 this.originalTrainingData = originalTrainingData; 114 this.trainingRows = (int[])trainingRows.ToArray().Clone(); 115 this.targetVariable = targetVariable; 116 this.allowedInputVariables = (string[])allowedInputVariables.ToArray().Clone(); 117 this.classValues = classValues; // null for regression problems 103 this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone(); 104 this.classValues = classValues; 118 105 this.nTrees = nTrees; 119 106 this.r = r; … … 127 114 private void RecalculateModel() { 128 115 double rmsError, oobRmsError, relClassError, oobRelClassError; 129 if (classValues == null) { 130 var model = CreateRegressionModel(originalTrainingData, trainingRows, allowedInputVariables, targetVariable, 116 var regressionProblemData = originalTrainingData as IRegressionProblemData; 117 var classificationProblemData = originalTrainingData as IClassificationProblemData; 118 if (regressionProblemData != null) { 119 var model = CreateRegressionModel(regressionProblemData, 131 120 nTrees, r, m, seed, out rmsError, out oobRmsError, 132 121 out relClassError, out oobRelClassError); 133 122 randomForest = model.randomForest; 134 } else {135 var model = CreateClassificationModel( originalTrainingData, trainingRows, allowedInputVariables, targetVariable,136 classValues,nTrees, r, m, seed, out rmsError, out oobRmsError,123 } else if (classificationProblemData != null) { 124 var model = CreateClassificationModel(classificationProblemData, 125 nTrees, r, m, seed, out rmsError, out oobRmsError, 137 126 out relClassError, out oobRelClassError); 138 127 randomForest = model.randomForest; … … 141 130 142 131 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 143 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);132 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 144 133 AssertInputMatrix(inputData); 145 134 … … 159 148 160 149 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 161 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);150 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 162 151 AssertInputMatrix(inputData); 163 152 … … 200 189 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 201 190 out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) { 202 return CreateRegressionModel(problemData.Dataset, problemData.TrainingIndices, 203 problemData.AllowedInputVariables, problemData.TargetVariable, 204 nTrees, r, m, seed, 205 out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 206 } 207 208 private static RandomForestModel CreateRegressionModel( 209 // prob param 210 Dataset dataset, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable, 211 // alg param 212 int nTrees, double r, double m, int seed, 213 // results 214 out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) { 215 216 var variables = allowedInputVariables.Concat(new string[] { targetVariable }); 217 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, variables, trainingRows); 191 192 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 193 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices); 218 194 219 195 alglib.dfreport rep; … … 226 202 227 203 return new RandomForestModel(dForest, 228 seed, dataset, trainingRows, allowedInputVariables, targetVariable,204 seed, problemData, 229 205 nTrees, r, m); 230 206 } … … 232 208 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 233 209 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 234 return CreateClassificationModel(problemData.Dataset, problemData.TrainingIndices, 235 problemData.AllowedInputVariables, problemData.TargetVariable, 236 problemData.ClassValues.ToArray(), 237 nTrees, r, m, seed, 238 out rmsError, out outOfBagRmsError, out relClassificationError, 239 out outOfBagRelClassificationError); 240 } 241 242 private static RandomForestModel CreateClassificationModel( 243 // prob param 244 Dataset dataset, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable, double[] classValues, 245 // alg param 246 int nTrees, double r, double m, int seed, 247 // results 248 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 249 250 var variables = allowedInputVariables.Concat(new string[] { targetVariable }); 251 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, variables, trainingRows); 252 210 211 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 212 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices); 213 214 var classValues = problemData.ClassValues.ToArray(); 253 215 int nClasses = classValues.Length; 254 216 … … 274 236 275 237 return new RandomForestModel(dForest, 276 seed, dataset, trainingRows, allowedInputVariables, targetVariable,238 seed, problemData, 277 239 nTrees, r, m, classValues); 278 240 } … … 284 246 int info = 0; 285 247 alglib.math.rndobject = new System.Random(seed); 286 alglib.decisionforestdForest = new alglib.decisionforest();248 var dForest = new alglib.decisionforest(); 287 249 rep = new alglib.dfreport(); 288 250 int nRows = inputMatrix.GetLength(0); … … 307 269 308 270 #region persistence for backwards compatibility 309 [Storable(AllowOneWay = true)] 271 // when the originalTrainingData is null this means the model was loaded from an old file 272 // therefore, we cannot use the new persistence mechanism because the original data is not available anymore 273 // in such cases we still store the compete model 274 private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } } 275 276 private string[] allowedInputVariables; 277 [Storable(Name = "allowedInputVariables")] 278 private string[] AllowedInputVariables { 279 get { 280 if (IsCompatibilityLoaded) return allowedInputVariables; 281 else return originalTrainingData.AllowedInputVariables.ToArray(); 282 } 283 set { allowedInputVariables = value; } 284 } 285 [Storable] 310 286 private int RandomForestBufSize { 287 get { 288 if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize; 289 else return 0; 290 } 311 291 set { 312 292 randomForest.innerobj.bufsize = value; 313 293 } 314 294 } 315 [Storable (AllowOneWay = true)]295 [Storable] 316 296 private int RandomForestNClasses { 297 get { 298 if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses; 299 else return 0; 300 } 317 301 set { 318 302 randomForest.innerobj.nclasses = value; 319 303 } 320 304 } 321 [Storable (AllowOneWay = true)]305 [Storable] 322 306 private int RandomForestNTrees { 307 get { 308 if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees; 309 else return 0; 310 } 323 311 set { 324 312 randomForest.innerobj.ntrees = value; 325 313 } 326 314 } 327 [Storable (AllowOneWay = true)]315 [Storable] 328 316 private int RandomForestNVars { 317 get { 318 if (IsCompatibilityLoaded) return randomForest.innerobj.nvars; 319 else return 0; 320 } 329 321 set { 330 322 randomForest.innerobj.nvars = value; 331 323 } 332 324 } 333 [Storable (AllowOneWay = true)]325 [Storable] 334 326 private double[] RandomForestTrees { 327 get { 328 if (IsCompatibilityLoaded) return randomForest.innerobj.trees; 329 else return new double[] { }; 330 } 335 331 set { 336 332 randomForest.innerobj.trees = value;
Note: See TracChangeset
for help on using the changeset viewer.