Changeset 10963 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis
- Timestamp:
- 06/11/14 10:25:04 (10 years ago)
- Location:
- trunk/sources
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources
- Property svn:mergeinfo changed
/branches/1721-RandomForestPersistence (added) merged: 10321-10322
- Property svn:mergeinfo changed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis (added) merged: 10321-10322
- Property svn:mergeinfo changed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
r9456 r10963 131 131 public static IClassificationSolution CreateRandomForestClassificationSolution(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 132 132 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 133 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1."); 134 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1."); 135 136 alglib.math.rndobject = new System.Random(seed); 137 138 Dataset dataset = problemData.Dataset; 139 string targetVariable = problemData.TargetVariable; 140 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 141 IEnumerable<int> rows = problemData.TrainingIndices; 142 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 143 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 144 throw new NotSupportedException("Random forest classification does not support NaN or infinity values in the input dataset."); 145 146 int info = 0; 147 alglib.decisionforest dForest = new alglib.decisionforest(); 148 alglib.dfreport rep = new alglib.dfreport(); ; 149 int nRows = inputMatrix.GetLength(0); 150 int nColumns = inputMatrix.GetLength(1); 151 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 152 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 153 154 155 double[] classValues = problemData.ClassValues.ToArray(); 156 int nClasses = problemData.Classes; 157 // map original class values to values [0..nClasses-1] 158 Dictionary<double, double> classIndices = new Dictionary<double, double>(); 159 for (int i = 0; i < nClasses; i++) { 160 classIndices[classValues[i]] = i; 161 } 162 for (int row = 0; row < nRows; row++) { 163 inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; 164 } 165 // execute random forest algorithm 166 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 167 if (info != 1) throw new ArgumentException("Error in calculation of random forest classification solution"); 168 169 rmsError = rep.rmserror; 170 outOfBagRmsError = rep.oobrmserror; 171 relClassificationError = rep.relclserror; 172 outOfBagRelClassificationError = rep.oobrelclserror; 173 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables, classValues)); 133 var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 134 return new RandomForestClassificationSolution((IClassificationProblemData)problemData.Clone(), model); 174 135 } 175 136 #endregion -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r9456 r10963 35 35 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 36 36 public sealed class RandomForestModel : NamedItem, IRandomForestModel { 37 37 // not persisted 38 38 private alglib.decisionforest randomForest; 39 public alglib.decisionforest RandomForest { 40 get { return randomForest; } 41 set { 42 if (value != randomForest) { 43 if (value == null) throw new ArgumentNullException(); 44 randomForest = value; 45 OnChanged(EventArgs.Empty); 46 } 47 } 48 } 49 50 [Storable] 51 private string targetVariable; 52 [Storable] 53 private string[] allowedInputVariables; 39 private alglib.decisionforest RandomForest { 40 get { 41 // recalculate lazily 42 if (randomForest.innerobj.trees == null || randomForest.innerobj.trees.Length == 0) RecalculateModel(); 43 return randomForest; 44 } 45 } 46 47 // instead of storing the data of the model itself 48 // we instead only store data necessary to recalculate the same model lazily on demand 49 [Storable] 50 private int seed; 51 [Storable] 52 private IDataAnalysisProblemData originalTrainingData; 54 53 [Storable] 55 54 private double[] classValues; 55 [Storable] 56 private int nTrees; 57 [Storable] 58 private double r; 59 [Storable] 60 private double m; 61 62 56 63 [StorableConstructor] 57 64 private RandomForestModel(bool deserializing) 58 65 : base(deserializing) { 59 if (deserializing)60 66 // for backwards compatibility (loading old solutions) 67 randomForest = new alglib.decisionforest(); 61 68 } 62 69 private RandomForestModel(RandomForestModel original, Cloner cloner) … … 67 74 randomForest.innerobj.ntrees = original.randomForest.innerobj.ntrees; 68 75 randomForest.innerobj.nvars = original.randomForest.innerobj.nvars; 69 randomForest.innerobj.trees = (double[])original.randomForest.innerobj.trees.Clone(); 70 targetVariable = original.targetVariable; 71 allowedInputVariables = (string[])original.allowedInputVariables.Clone(); 72 if (original.classValues != null) 73 this.classValues = (double[])original.classValues.Clone(); 74 } 75 public RandomForestModel(alglib.decisionforest randomForest, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) 76 // we assume that the trees array (double[]) is immutable in alglib 77 randomForest.innerobj.trees = original.randomForest.innerobj.trees; 78 79 // allowedInputVariables is immutable so we don't need to clone 80 allowedInputVariables = original.allowedInputVariables; 81 82 // clone data which is necessary to rebuild the model 83 this.seed = original.seed; 84 this.originalTrainingData = cloner.Clone(original.originalTrainingData); 85 // classvalues is immutable so we don't need to clone 86 this.classValues = original.classValues; 87 this.nTrees = original.nTrees; 88 this.r = original.r; 89 this.m = original.m; 90 } 91 92 // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel 93 private RandomForestModel(alglib.decisionforest randomForest, 94 int seed, IDataAnalysisProblemData originalTrainingData, 95 int nTrees, double r, double m, double[] classValues = null) 76 96 : base() { 77 97 this.name = ItemName; 78 98 this.description = ItemDescription; 99 // the model itself 79 100 this.randomForest = randomForest; 80 this.targetVariable = targetVariable; 81 this.allowedInputVariables = allowedInputVariables.ToArray(); 82 if (classValues != null) 83 this.classValues = (double[])classValues.Clone(); 101 // data which is necessary for recalculation of the model 102 this.seed = seed; 103 this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone(); 104 this.classValues = classValues; 105 this.nTrees = nTrees; 106 this.r = r; 107 this.m = m; 84 108 } 85 109 … … 88 112 } 89 113 114 private void RecalculateModel() { 115 double rmsError, oobRmsError, relClassError, oobRelClassError; 116 var regressionProblemData = originalTrainingData as IRegressionProblemData; 117 var classificationProblemData = originalTrainingData as IClassificationProblemData; 118 if (regressionProblemData != null) { 119 var model = CreateRegressionModel(regressionProblemData, 120 nTrees, r, m, seed, out rmsError, out oobRmsError, 121 out relClassError, out oobRelClassError); 122 randomForest = model.randomForest; 123 } else if (classificationProblemData != null) { 124 var model = CreateClassificationModel(classificationProblemData, 125 nTrees, r, m, seed, out rmsError, out oobRmsError, 126 out relClassError, out oobRelClassError); 127 randomForest = model.randomForest; 128 } 129 } 130 90 131 public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) { 91 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 132 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 133 AssertInputMatrix(inputData); 92 134 93 135 int n = inputData.GetLength(0); … … 100 142 x[column] = inputData[row, column]; 101 143 } 102 alglib.dfprocess( randomForest, x, ref y);144 alglib.dfprocess(RandomForest, x, ref y); 103 145 yield return y[0]; 104 146 } … … 106 148 107 149 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 108 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 150 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 151 AssertInputMatrix(inputData); 109 152 110 153 int n = inputData.GetLength(0); 111 154 int columns = inputData.GetLength(1); 112 155 double[] x = new double[columns]; 113 double[] y = new double[ randomForest.innerobj.nclasses];156 double[] y = new double[RandomForest.innerobj.nclasses]; 114 157 115 158 for (int row = 0; row < n; row++) { … … 144 187 } 145 188 146 #region events 147 public event EventHandler Changed; 148 private void OnChanged(EventArgs e) { 149 var handlers = Changed; 150 if (handlers != null) 151 handlers(this, e); 152 } 153 #endregion 154 155 #region persistence 189 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 190 out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) { 191 192 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 193 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices); 194 195 alglib.dfreport rep; 196 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep); 197 198 rmsError = rep.rmserror; 199 avgRelError = rep.avgrelerror; 200 outOfBagAvgRelError = rep.oobavgrelerror; 201 outOfBagRmsError = rep.oobrmserror; 202 203 return new RandomForestModel(dForest, 204 seed, problemData, 205 nTrees, r, m); 206 } 207 208 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 209 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 210 211 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 212 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices); 213 214 var classValues = problemData.ClassValues.ToArray(); 215 int nClasses = classValues.Length; 216 217 // map original class values to values [0..nClasses-1] 218 var classIndices = new Dictionary<double, double>(); 219 for (int i = 0; i < nClasses; i++) { 220 classIndices[classValues[i]] = i; 221 } 222 223 int nRows = inputMatrix.GetLength(0); 224 int nColumns = inputMatrix.GetLength(1); 225 for (int row = 0; row < nRows; row++) { 226 inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; 227 } 228 229 alglib.dfreport rep; 230 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep); 231 232 rmsError = rep.rmserror; 233 outOfBagRmsError = rep.oobrmserror; 234 relClassificationError = rep.relclserror; 235 outOfBagRelClassificationError = rep.oobrelclserror; 236 237 return new RandomForestModel(dForest, 238 seed, problemData, 239 nTrees, r, m, classValues); 240 } 241 242 private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) { 243 AssertParameters(r, m); 244 AssertInputMatrix(inputMatrix); 245 246 int info = 0; 247 alglib.math.rndobject = new System.Random(seed); 248 var dForest = new alglib.decisionforest(); 249 rep = new alglib.dfreport(); 250 int nRows = inputMatrix.GetLength(0); 251 int nColumns = inputMatrix.GetLength(1); 252 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 253 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 254 255 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 256 if (info != 1) throw new ArgumentException("Error in calculation of random forest model"); 257 return dForest; 258 } 259 260 private static void AssertParameters(double r, double m) { 261 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1."); 262 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1."); 263 } 264 265 private static void AssertInputMatrix(double[,] inputMatrix) { 266 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 267 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); 268 } 269 270 #region persistence for backwards compatibility 271 // when the originalTrainingData is null this means the model was loaded from an old file 272 // therefore, we cannot use the new persistence mechanism because the original data is not available anymore 273 // in such cases we still store the compete model 274 private bool IsCompatibilityLoaded { get { return originalTrainingData == null; } } 275 276 private string[] allowedInputVariables; 277 [Storable(Name = "allowedInputVariables")] 278 private string[] AllowedInputVariables { 279 get { 280 if (IsCompatibilityLoaded) return allowedInputVariables; 281 else return originalTrainingData.AllowedInputVariables.ToArray(); 282 } 283 set { allowedInputVariables = value; } 284 } 156 285 [Storable] 157 286 private int RandomForestBufSize { 158 287 get { 159 return randomForest.innerobj.bufsize; 288 if (IsCompatibilityLoaded) return randomForest.innerobj.bufsize; 289 else return 0; 160 290 } 161 291 set { … … 166 296 private int RandomForestNClasses { 167 297 get { 168 return randomForest.innerobj.nclasses; 298 if (IsCompatibilityLoaded) return randomForest.innerobj.nclasses; 299 else return 0; 169 300 } 170 301 set { … … 175 306 private int RandomForestNTrees { 176 307 get { 177 return randomForest.innerobj.ntrees; 308 if (IsCompatibilityLoaded) return randomForest.innerobj.ntrees; 309 else return 0; 178 310 } 179 311 set { … … 184 316 private int RandomForestNVars { 185 317 get { 186 return randomForest.innerobj.nvars; 318 if (IsCompatibilityLoaded) return randomForest.innerobj.nvars; 319 else return 0; 187 320 } 188 321 set { … … 193 326 private double[] RandomForestTrees { 194 327 get { 195 return randomForest.innerobj.trees; 328 if (IsCompatibilityLoaded) return randomForest.innerobj.trees; 329 else return new double[] { }; 196 330 } 197 331 set { -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r9456 r10963 130 130 public static IRegressionSolution CreateRandomForestRegressionSolution(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 131 131 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 132 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter in the random forest regression must be between 0 and 1."); 133 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter in the random forest regression must be between 0 and 1."); 134 135 alglib.math.rndobject = new System.Random(seed); 136 137 Dataset dataset = problemData.Dataset; 138 string targetVariable = problemData.TargetVariable; 139 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 140 IEnumerable<int> rows = problemData.TrainingIndices; 141 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 142 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 143 throw new NotSupportedException("Random forest regression does not support NaN or infinity values in the input dataset."); 144 145 int info = 0; 146 alglib.decisionforest dForest = new alglib.decisionforest(); 147 alglib.dfreport rep = new alglib.dfreport(); ; 148 int nRows = inputMatrix.GetLength(0); 149 int nColumns = inputMatrix.GetLength(1); 150 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 151 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 152 153 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, 1, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 154 if (info != 1) throw new ArgumentException("Error in calculation of random forest regression solution"); 155 156 rmsError = rep.rmserror; 157 avgRelError = rep.avgrelerror; 158 outOfBagAvgRelError = rep.oobavgrelerror; 159 outOfBagRmsError = rep.oobrmserror; 160 161 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), new RandomForestModel(dForest, targetVariable, allowedInputVariables)); 132 var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 133 return new RandomForestRegressionSolution((IRegressionProblemData)problemData.Clone(), model); 162 134 } 163 135 #endregion
Note: See TracChangeset
for help on using the changeset viewer.