Changeset 10321 for branches/1721-RandomForestPersistence
- Timestamp:
- 01/08/14 17:33:53 (11 years ago)
- Location:
- branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
r9456 r10321 131 131 public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 132 132 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 133 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1."); 134 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1."); 135 136 alglib.math.rndobject = new System.Random(seed); 137 138 Dataset dataset = problemData.Dataset; 139 string targetVariable = problemData.TargetVariable; 140 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 141 IEnumerable<int> rows = problemData.TrainingIndices; 142 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 143 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 144 throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset."); 145 146 int info = 0; 147 alglib.decisionforest dForest = new alglib.decisionforest(); 148 alglib.dfreport rep = new alglib.dfreport(); ; 149 int nRows = inputMatrix.GetLength(0); 150 int nColumns = inputMatrix.GetLength(1); 151 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 152 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 153 154 155 double[] classValues = problemData.ClassValues.ToArray(); 156 int nClasses = problemData.Classes; 157 // map original class values to values [0..nClasses-1] 158 Dictionary<double, double> classIndices = new Dictionary<double, double>(); 159 for (int i = 0; i < nClasses; i++) { 160 classIndices[classValues[i]] = i; 161 } 162 for (int row = 0; row < nRows; row++) { 163 inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; 164 } 165 // execute random forest algorithm 166 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 167 if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution"); 168 169 rmsError = rep.rmserror; 170 outOfBagRmsError = rep.oobrmserror; 171 relClassificationError = rep.relclserror; 172 outOfBagRelClassificationError = rep.oobrelclserror; 173 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues)); 133 var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 134 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model); 174 135 } 175 136 #endregion -
branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r9456 r10321 35 35 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 36 36 public sealed class RandomForestModel : NamedItem, IRandomForestModel { 37 37 // not persisted 38 38 private alglib.decisionforest randomForest; 39 public alglib.decisionforest RandomForest { 40 get { return randomForest; } 41 set { 42 if (value != randomForest) { 43 if (value == null) throw new ArgumentNullException(); 44 randomForest = value; 45 OnChanged(EventArgs.Empty); 46 } 47 } 48 } 49 39 private alglib.decisionforest RandomForest { 40 get { 41 // recalculate lazily 42 if (randomForest == null || randomForest.innerobj.trees == null) RecalculateModel(); 43 return randomForest; 44 } 45 } 46 47 // instead of storing the data of the model itself 48 // we instead only store data necessary to recalculate the same model lazily on demand 49 [Storable] 50 private int seed; 51 [Storable] 52 private Dataset originalTrainingData; 53 [Storable] 54 private int[] trainingRows; 55 [Storable] 56 private double[] classValues; 57 [Storable] 58 private string[] allowedInputVariables; 50 59 [Storable] 51 60 private string targetVariable; 52 61 [Storable] 53 private string[] allowedInputVariables; 54 [Storable] 55 private double[] classValues; 62 private int nTrees; 63 [Storable] 64 private double r; 65 [Storable] 66 private double m; 67 68 56 69 [StorableConstructor] 57 70 private RandomForestModel(bool deserializing) 58 71 : base(deserializing) { 59 if (deserializing)60 72 // for backwards compatibility (loading old solutions) 73 randomForest = new alglib.decisionforest(); 61 74 } 62 75 private RandomForestModel(RandomForestModel original, Cloner cloner) 63 76 : base(original, cloner) { 64 randomForest = new alglib.decisionforest(); 65 randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize; 66 randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses; 67 randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees; 68 randomForest.innerobj.nvars = original.randomForest.innerobj.nvars; 69 randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone(); 70 targetVariable = original.targetVariable; 71 allowedInputVariables = (string[])original.allowedInputVariables.Clone(); 72 if (original.classValues != null) 73 this.classValues = (double[])original.classValues.Clone(); 74 } 75 public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) 77 // clone the model if necessary 78 if (original.randomForest != null) { 79 randomForest = new alglib.decisionforest(); 80 randomForest.innerobj.bufsize = original.randomForest.innerobj.bufsize; 81 randomForest.innerobj.nclasses = original.randomForest.innerobj.nclasses; 82 randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees; 83 randomForest.innerobj.nvars = original.randomForest.innerobj.nvars; 84 // we assume that the trees array (double[]) is immutable in alglib 85 randomForest.innerobj.trees = original.randomForest.innerobj.trees; 86 } 87 // clone data which is necessary to rebuild the model 88 this.seed = original.seed; 89 // Dataset is immutable so we do not need to clone 90 this.originalTrainingData = original.originalTrainingData; 91 this.targetVariable = original.targetVariable; 92 // trainingRows, classvalues and allowedInputVariables are immutable 93 this.trainingRows = original.trainingRows; 94 this.allowedInputVariables = original.allowedInputVariables; 95 this.classValues = original.classValues; // null for regression problems 96 97 this.nTrees = original.nTrees; 98 this.r = original.r; 99 this.m = original.m; 100 } 101 102 // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel 103 private RandomForestModel(alglib.decisionforest randomForest, 104 int seed, Dataset originalTrainingData, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable, 105 int nTrees, double r, double m, double[] classValues = null) 76 106 : base() { 77 107 this.name = ItemName; 78 108 this.description = ItemDescription; 109 // the model itself 79 110 this.randomForest = randomForest; 111 // data which is necessary for recalculation of the model 112 this.seed = seed; 113 this.originalTrainingData = originalTrainingData; 114 this.trainingRows = (int[])trainingRows.ToArray().Clone(); 80 115 this.targetVariable = targetVariable; 81 this.allowedInputVariables = allowedInputVariables.ToArray(); 82 if (classValues != null) 83 this.classValues = (double[])classValues.Clone(); 116 this.allowedInputVariables = (string[])allowedInputVariables.ToArray().Clone(); 117 this.classValues = classValues; // null for regression problems 118 this.nTrees = nTrees; 119 this.r = r; 120 this.m = m; 84 121 } 85 122 … … 88 125 } 89 126 127 private void RecalculateModel() { 128 double rmsError, oobRmsError, relClassError, oobRelClassError; 129 if (classValues == null) { 130 var model = CreateRegressionModel(originalTrainingData, trainingRows, allowedInputVariables, targetVariable, 131 nTrees, r, m, seed, out rmsError, out oobRmsError, 132 out relClassError, out oobRelClassError); 133 randomForest = model.randomForest; 134 } else { 135 var model = CreateClassificationModel(originalTrainingData, trainingRows, allowedInputVariables, targetVariable, 136 classValues, nTrees, r, m, seed, out rmsError, out oobRmsError, 137 out relClassError, out oobRelClassError); 138 randomForest = model.randomForest; 139 } 140 } 141 90 142 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 91 143 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 144 AssertInputMatrix(inputData); 92 145 93 146 int n = inputData.GetLength(0); … … 100 153 x[column] = inputData[row, column]; 101 154 } 102 alglib.dfprocess( randomForest, x, ref y);155 alglib.dfprocess(RandomForest, x, ref y); 103 156 yield return y[0]; 104 157 } … … 107 160 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 108 161 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 162 AssertInputMatrix(inputData); 109 163 110 164 int n = inputData.GetLength(0); 111 165 int columns = inputData.GetLength(1); 112 166 double[] x = new double[columns]; 113 double[] y = new double[ randomForest.innerobj.nclasses];167 double[] y = new double[RandomForest.innerobj.nclasses]; 114 168 115 169 for (int row = 0; row < n; row++) { … … 144 198 } 145 199 146 #region events 147 public event EventHandler Changed; 148 private void OnChanged(EventArgs e) { 149 var handlers = Changed; 150 if (handlers != null) 151 handlers(this, e); 152 } 153 #endregion 154 155 #region persistence 156 [Storable] 200 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 201 out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) { 202 return CreateRegressionModel(problemData.Dataset, problemData.TrainingIndices, 203 problemData.AllowedInputVariables, problemData.TargetVariable, 204 nTrees, r, m, seed, 205 out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 206 } 207 208 private static RandomForestModel CreateRegressionModel( 209 // prob param 210 Dataset dataset, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable, 211 // alg param 212 int nTrees, double r, double m, int seed, 213 // results 214 out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) { 215 216 var variables = allowedInputVariables.Concat(new string[] { targetVariable }); 217 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, variables, trainingRows); 218 219 alglib.dfreport rep; 220 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep); 221 222 rmsError = rep.rmserror; 223 avgRelError = rep.avgrelerror; 224 outOfBagAvgRelError = rep.oobavgrelerror; 225 outOfBagRmsError = rep.oobrmserror; 226 227 return new RandomForestModel(dForest, 228 seed, dataset, trainingRows, allowedInputVariables, targetVariable, 229 nTrees, r, m); 230 } 231 232 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 233 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 234 return CreateClassificationModel(problemData.Dataset, problemData.TrainingIndices, 235 problemData.AllowedInputVariables, problemData.TargetVariable, 236 problemData.ClassValues.ToArray(), 237 nTrees, r, m, seed, 238 out rmsError, out outOfBagRmsError, out relClassificationError, 239 out outOfBagRelClassificationError); 240 } 241 242 private static RandomForestModel CreateClassificationModel( 243 // prob param 244 Dataset dataset, IEnumerable<int> trainingRows, IEnumerable<string> allowedInputVariables, string targetVariable, double[] classValues, 245 // alg param 246 int nTrees, double r, double m, int seed, 247 // results 248 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 249 250 var variables = allowedInputVariables.Concat(new string[] { targetVariable }); 251 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, variables, trainingRows); 252 253 int nClasses = classValues.Length; 254 255 // map original class values to values [0..nClasses-1] 256 var classIndices = new Dictionary<double, double>(); 257 for (int i = 0; i < nClasses; i++) { 258 classIndices[classValues[i]] = i; 259 } 260 261 int nRows = inputMatrix.GetLength(0); 262 int nColumns = inputMatrix.GetLength(1); 263 for (int row = 0; row < nRows; row++) { 264 inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; 265 } 266 267 alglib.dfreport rep; 268 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep); 269 270 rmsError = rep.rmserror; 271 outOfBagRmsError = rep.oobrmserror; 272 relClassificationError = rep.relclserror; 273 outOfBagRelClassificationError = rep.oobrelclserror; 274 275 return new RandomForestModel(dForest, 276 seed, dataset, trainingRows, allowedInputVariables, targetVariable, 277 nTrees, r, m, classValues); 278 } 279 280 private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) { 281 AssertParameters(r, m); 282 AssertInputMatrix(inputMatrix); 283 284 int info = 0; 285 alglib.math.rndobject = new System.Random(seed); 286 alglib.decisionforest dForest = new alglib.decisionforest(); 287 rep = new alglib.dfreport(); 288 int nRows = inputMatrix.GetLength(0); 289 int nColumns = inputMatrix.GetLength(1); 290 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 291 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 292 293 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 294 if (info != 1) throw new ArgumentException("Error in calculation of random forest model"); 295 return dForest; 296 } 297 298 private static void AssertParameters(double r, double m) { 299 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1."); 300 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1."); 301 } 302 303 private static void AssertInputMatrix(double[,] inputMatrix) { 304 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 305 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); 306 } 307 308 #region persistence for backwards compatibility 309 [Storable(AllowOneWay = true)] 157 310 private int RandomForestBufSize { 158 get {159 return randomForest.innerobj.bufsize;160 }161 311 set { 162 312 randomForest.innerobj.bufsize = value; 163 313 } 164 314 } 165 [Storable ]315 [Storable(AllowOneWay = true)] 166 316 private int RandomForestNClasses { 167 get {168 return randomForest.innerobj.nclasses;169 }170 317 set { 171 318 randomForest.innerobj.nclasses = value; 172 319 } 173 320 } 174 [Storable ]321 [Storable(AllowOneWay = true)] 175 322 private int RandomForestNTrees { 176 get {177 return randomForest.innerobj.ntrees;178 }179 323 set { 180 324 randomForest.innerobj.ntrees = value; 181 325 } 182 326 } 183 [Storable ]327 [Storable(AllowOneWay = true)] 184 328 private int RandomForestNVars { 185 get {186 return randomForest.innerobj.nvars;187 }188 329 set { 189 330 randomForest.innerobj.nvars = value; 190 331 } 191 332 } 192 [Storable ]333 [Storable(AllowOneWay = true)] 193 334 private double[] RandomForestTrees { 194 get {195 return randomForest.innerobj.trees;196 }197 335 set { 198 336 randomForest.innerobj.trees = value; -
branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r9456 r10321 130 130 public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 131 131 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 132 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1."); 133 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1."); 134 135 alglib.math.rndobject = new System.Random(seed); 136 137 Dataset dataset = problemData.Dataset; 138 string targetVariable = problemData.TargetVariable; 139 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 140 IEnumerable<int> rows = problemData.TrainingIndices; 141 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 142 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 143 throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset."); 144 145 int info = 0; 146 alglib.decisionforest dForest = new alglib.decisionforest(); 147 alglib.dfreport rep = new alglib.dfreport(); ; 148 int nRows = inputMatrix.GetLength(0); 149 int nColumns = inputMatrix.GetLength(1); 150 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 151 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 152 153 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 154 if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution"); 155 156 rmsError = rep.rmserror; 157 avgRelError = rep.avgrelerror; 158 outOfBagAvgRelError = rep.oobavgrelerror; 159 outOfBagRmsError = rep.oobrmserror; 160 161 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables)); 132 var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 133 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model); 162 134 } 163 135 #endregion
Note: See TracChangeset
for help on using the changeset viewer.