Changeset 17045 for branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
- Timestamp:
- 06/27/19 15:46:20 (6 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r16763 r17045 23 23 using System.Collections.Generic; 24 24 using System.Linq; 25 using HEAL.Attic; 25 26 using HeuristicLab.Common; 26 27 using HeuristicLab.Core; 27 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 28 using HEAL.Attic;29 29 using HeuristicLab.Problems.DataAnalysis; 30 30 using HeuristicLab.Problems.DataAnalysis.Symbolic; … … 34 34 /// Represents a random forest model for regression and classification 35 35 /// </summary> 36 [StorableType("A4F688CD-1F42-4103-8449-7DE52AEF6C69")] 36 [Obsolete("This class only exists for backwards compatibility reasons. Use RFModelSurrogate or RFModelFull instead.")] 37 [StorableType("9AA4CCC2-CD75-4471-8DF6-949E5B783642")] 37 38 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 38 39 public sealed class RandomForestModel : ClassificationModel, IRandomForestModel { … … 139 140 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 140 141 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 141 AssertInputMatrix(inputData);142 RandomForestUtil.AssertInputMatrix(inputData); 142 143 143 144 int n = inputData.GetLength(0); … … 157 158 public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) { 158 159 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 159 AssertInputMatrix(inputData);160 RandomForestUtil.AssertInputMatrix(inputData); 160 161 161 162 int n = inputData.GetLength(0); … … 175 176 public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 176 177 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 177 AssertInputMatrix(inputData);178 RandomForestUtil.AssertInputMatrix(inputData); 178 179 179 180 int n = inputData.GetLength(0); … … 315 316 316 317 alglib.dfreport rep; 317 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);318 var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep); 318 319 319 320 rmsError = rep.rmserror; … … 353 354 354 355 alglib.dfreport rep; 355 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);356 var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep); 356 357 357 358 rmsError = rep.rmserror; … … 361 362 362 363 return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues); 363 }364 365 private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {366 AssertParameters(r, m);367 AssertInputMatrix(inputMatrix);368 369 int info = 0;370 alglib.math.rndobject = new System.Random(seed);371 var dForest = new alglib.decisionforest();372 rep = new alglib.dfreport();373 int nRows = inputMatrix.GetLength(0);374 int nColumns = inputMatrix.GetLength(1);375 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);376 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);377 378 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);379 if (info != 1) throw new ArgumentException("Error in calculation of random forest model");380 return dForest;381 }382 383 private static void AssertParameters(double r, double m) {384 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");385 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");386 }387 388 private static void AssertInputMatrix(double[,] inputMatrix) {389 if (inputMatrix.ContainsNanOrInfinity())390 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");391 364 } 392 365
Note: See TracChangeset
for help on using the changeset viewer.