Changeset 11009 for branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
- Timestamp:
- 06/12/14 13:26:18 (10 years ago)
- Location:
- branches/DataPreprocessing
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/DataPreprocessing
- Property svn:ignore
-
old new 8 8 FxCopResults.txt 9 9 Google.ProtocolBuffers-0.9.1.dll 10 Google.ProtocolBuffers-2.4.1.473.dll 10 11 HeuristicLab 3.3.5.1.ReSharper.user 11 12 HeuristicLab 3.3.6.0.ReSharper.user 12 13 HeuristicLab.4.5.resharper.user 13 14 HeuristicLab.ExtLibs.6.0.ReSharper.user 15 HeuristicLab.Scripting.Development 14 16 HeuristicLab.resharper.user 15 17 ProtoGen.exe … … 17 19 _ReSharper.HeuristicLab 18 20 _ReSharper.HeuristicLab 3.3 21 _ReSharper.HeuristicLab 3.3 Tests 19 22 _ReSharper.HeuristicLab.ExtLibs 20 23 bin 21 24 protoc.exe 22 _ReSharper.HeuristicLab 3.3 Tests23 Google.ProtocolBuffers-2.4.1.473.dll
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis (added) merged: 10321-10322 /trunk/sources/HeuristicLab.Algorithms.DataAnalysis merged: 10963
- Property svn:mergeinfo changed
-
branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r9456 r11009 35 35 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 36 36 public sealed class RandomForestModel : NamedItem, IRandomForestModel { 37 37 // not persisted 38 38 private alglib.decisionforest randomForest; 39 public alglib.decisionforest RandomForest { 40 get { return randomForest; } 41 set { 42 if (value != randomForest) { 43 if (value == null) throw new ArgumentNullException(); 44 randomForest = value; 45 OnChanged(EventArgs.Empty); 46 } 47 } 48 } 49 50 [Storable] 51 private string targetVariable; 52 [Storable] 53 private string[] allowedInputVariables; 39 private alglib.decisionforest RandomForest { 40 get { 41 // recalculate lazily 42 if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel(); 43 return randomForest; 44 } 45 } 46 47 // instead of storing the data of the model itself 48 // we instead only store data necessary to recalculate the same model lazily on demand 49 [Storable] 50 private int seed; 51 [Storable] 52 private IDataAnalysisProblemData originalTrainingData; 54 53 [Storable] 55 54 private double[] classValues; 55 [Storable] 56 private int nTrees; 57 [Storable] 58 private double r; 59 [Storable] 60 private double m; 61 62 56 63 [StorableConstructor] 57 64 private RandomForestModel(bool deserializing) 58 65 : base(deserializing) { 59 if (deserializing)60 66 // for backwards compatibility (loading old solutions) 67 randomForest = new alglib.decisionforest(); 61 68 } 62 69 private RandomForestModel(RandomForestModel original, Cloner cloner) … … 67 74 randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees; 68 75 randomForest.innerobj.nvars = original.randomForest.innerobj.nvars; 69 randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone(); 70 targetVariable = original.targetVariable; 71 allowedInputVariables = (string[])original.allowedInputVariables.Clone(); 72 if (original.classValues != null) 73 this.classValues = (double[])original.classValues.Clone(); 74 } 75 public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) 76 // we assume that the trees array (double[]) is immutable in alglib 77 randomForest.innerobj.trees = original.randomForest.innerobj.trees; 78 79 // allowedInputVariables is immutable so we don't need to clone 80 allowedInputVariables = original.allowedInputVariables; 81 82 // clone data which is necessary to rebuild the model 83 this.seed = original.seed; 84 this.originalTrainingData = cloner.Clone(original.originalTrainingData); 85 // classvalues is immutable so we don't need to clone 86 this.classValues = original.classValues; 87 this.nTrees = original.nTrees; 88 this.r = original.r; 89 this.m = original.m; 90 } 91 92 // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel 93 private RandomForestModel(alglib.decisionforest randomForest, 94 int seed, IDataAnalysisProblemData originalTrainingData, 95 int nTrees, double r, double m, double[] classValues = null) 76 96 : base() { 77 97 this.name = ItemName; 78 98 this.description = ItemDescription; 99 // the model itself 79 100 this.randomForest = randomForest; 80 this.targetVariable = targetVariable; 81 this.allowedInputVariables = allowedInputVariables.ToArray(); 82 if (classValues != null) 83 this.classValues = (double[])classValues.Clone(); 101 // data which is necessary for recalculation of the model 102 this.seed = seed; 103 this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone(); 104 this.classValues = classValues; 105 this.nTrees = nTrees; 106 this.r = r; 107 this.m = m; 84 108 } 85 109 … … 88 112 } 89 113 114 private void RecalculateModel() { 115 double rmsError, oobRmsError, relClassError, oobRelClassError; 116 var regressionProblemData = originalTrainingData as IRegressionProblemData; 117 var classificationProblemData = originalTrainingData as IClassificationProblemData; 118 if (regressionProblemData != null) { 119 var model = CreateRegressionModel(regressionProblemData, 120 nTrees, r, m, seed, out rmsError, out oobRmsError, 121 out relClassError, out oobRelClassError); 122 randomForest = model.randomForest; 123 } else if (classificationProblemData != null) { 124 var model = CreateClassificationModel(classificationProblemData, 125 nTrees, r, m, seed, out rmsError, out oobRmsError, 126 out relClassError, out oobRelClassError); 127 randomForest = model.randomForest; 128 } 129 } 130 90 131 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 91 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 132 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 133 AssertInputMatrix(inputData); 92 134 93 135 int n = inputData.GetLength(0); … … 100 142 x[column] = inputData[row, column]; 101 143 } 102 alglib.dfprocess( randomForest, x, ref y);144 alglib.dfprocess(RandomForest, x, ref y); 103 145 yield return y[0]; 104 146 } … … 106 148 107 149 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 108 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 150 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 151 AssertInputMatrix(inputData); 109 152 110 153 int n = inputData.GetLength(0); 111 154 int columns = inputData.GetLength(1); 112 155 double[] x = new double[columns]; 113 double[] y = new double[ randomForest.innerobj.nclasses];156 double[] y = new double[RandomForest.innerobj.nclasses]; 114 157 115 158 for (int row = 0; row < n; row++) { … … 144 187 } 145 188 146 #region events 147 public event EventHandler Changed; 148 private void OnChanged(EventArgs e) { 149 var handlers = Changed; 150 if (handlers != null) 151 handlers(this, e); 152 } 153 #endregion 154 155 #region persistence 189 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 190 out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) { 191 192 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 193 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices); 194 195 alglib.dfreport rep; 196 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep); 197 198 rmsError = rep.rmserror; 199 avgRelError = rep.avgrelerror; 200 outOfBagAvgRelError = rep.oobavgrelerror; 201 outOfBagRmsError = rep.oobrmserror; 202 203 return new RandomForestModel(dForest, 204 seed, problemData, 205 nTrees, r, m); 206 } 207 208 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 209 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 210 211 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 212 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices); 213 214 var classValues = problemData.ClassValues.ToArray(); 215 int nClasses = classValues.Length; 216 217 // map original class values to values [0..nClasses-1] 218 var classIndices = new Dictionary<double, double>(); 219 for (int i = 0; i < nClasses; i++) { 220 classIndices[classValues[i]] = i; 221 } 222 223 int nRows = inputMatrix.GetLength(0); 224 int nColumns = inputMatrix.GetLength(1); 225 for (int row = 0; row < nRows; row++) { 226 inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; 227 } 228 229 alglib.dfreport rep; 230 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep); 231 232 rmsError = rep.rmserror; 233 outOfBagRmsError = rep.oobrmserror; 234 relClassificationError = rep.relclserror; 235 outOfBagRelClassificationError = rep.oobrelclserror; 236 237 return new RandomForestModel(dForest, 238 seed, problemData, 239 nTrees, r, m, classValues); 240 } 241 242 private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) { 243 AssertParameters(r, m); 244 AssertInputMatrix(inputMatrix); 245 246 int info = 0; 247 alglib.math.rndobject = new System.Random(seed); 248 var dForest = new alglib.decisionforest(); 249 rep = new alglib.dfreport(); 250 int nRows = inputMatrix.GetLength(0); 251 int nColumns = inputMatrix.GetLength(1); 252 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 253 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 254 255 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 256 if (info != 1) throw new ArgumentException("Error in calculation of random forest model"); 257 return dForest; 258 } 259 260 private static void AssertParameters(double r, double m) { 261 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1."); 262 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1."); 263 } 264 265 private static void AssertInputMatrix(double[,] inputMatrix) { 266 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 267 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); 268 } 269 270 #region persistence for backwards compatibility 271 // when the originalTrainingData is null this means the model was loaded from an old file 272 // therefore, we cannot use the new persistence mechanism because the original data is not available anymore 273 // in such cases we still store the compete model 274 private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } } 275 276 private string[] allowedInputVariables; 277 [Storable(Name = "allowedInputVariables")] 278 private string[] AllowedInputVariables { 279 get { 280 if (IsCompatibilityLoaded) return allowedInputVariables; 281 else return originalTrainingData.AllowedInputVariables.ToArray(); 282 } 283 set { allowedInputVariables = value; } 284 } 156 285 [Storable] 157 286 private int RandomForestBufSize { 158 287 get { 159 return randomForest.innerobj.bufsize; 288 if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize; 289 else return 0; 160 290 } 161 291 set { … … 166 296 private int RandomForestNClasses { 167 297 get { 168 return randomForest.innerobj.nclasses; 298 if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses; 299 else return 0; 169 300 } 170 301 set { … … 175 306 private int RandomForestNTrees { 176 307 get { 177 return randomForest.innerobj.ntrees; 308 if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees; 309 else return 0; 178 310 } 179 311 set { … … 184 316 private int RandomForestNVars { 185 317 get { 186 return randomForest.innerobj.nvars; 318 if (IsCompatibilityLoaded) return randomForest.innerobj.nvars; 319 else return 0; 187 320 } 188 321 set { … … 193 326 private double[] RandomForestTrees { 194 327 get { 195 return randomForest.innerobj.trees; 328 if (IsCompatibilityLoaded) return randomForest.innerobj.trees; 329 else return new double[] { }; 196 330 } 197 331 set {
Note: See TracChangeset
for help on using the changeset viewer.