Changeset 17226 for branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
- Timestamp:
- 08/29/19 15:57:35 (5 years ago)
- Location:
- branches/2521_ProblemRefactoring
- Files:
-
- 9 edited
- 3 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/2521_ProblemRefactoring
- Property svn:mergeinfo changed
-
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4
- Property svn:mergeinfo changed
-
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
r16723 r17226 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-2019Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 20 20 #endregion 21 21 22 using System.Collections.Generic; 23 using System.Linq; 22 24 using System.Threading; 25 using HEAL.Attic; 26 using HeuristicLab.Algorithms.DataAnalysis.RandomForest; 23 27 using HeuristicLab.Common; 24 28 using HeuristicLab.Core; … … 26 30 using HeuristicLab.Optimization; 27 31 using HeuristicLab.Parameters; 28 using HEAL.Attic;29 32 using HeuristicLab.Problems.DataAnalysis; 30 33 … … 43 46 private const string SeedParameterName = "Seed"; 44 47 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 45 private const string CreateSolutionParameterName = "CreateSolution";48 private const string ModelCreationParameterName = "ModelCreation"; 46 49 47 50 #region parameter properties … … 61 64 get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 62 65 } 63 p ublic IFixedValueParameter<BoolValue> CreateSolutionParameter {64 get { return (IFixedValueParameter< BoolValue>)Parameters[CreateSolutionParameterName]; }66 private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter { 67 get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; } 65 68 } 66 69 #endregion … … 86 89 set { SetSeedRandomlyParameter.Value.Value = value; } 87 90 } 88 public bool CreateSolution {89 get { return CreateSolutionParameter.Value.Value; }90 set { CreateSolutionParameter.Value.Value = value; }91 public ModelCreation ModelCreation { 92 get { return ModelCreationParameter.Value.Value; } 93 set { ModelCreationParameter.Value.Value = value; } 91 94 } 92 95 #endregion … … 105 108 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); 106 109 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 107 Parameters.Add(new FixedValueParameter< BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));108 Parameters[ CreateSolutionParameterName].Hidden = true;110 Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model))); 111 Parameters[ModelCreationParameterName].Hidden = true; 109 112 110 113 Problem = new ClassificationProblem(); … … 121 124 if (!Parameters.ContainsKey((SetSeedRandomlyParameterName))) 122 125 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 123 if (!Parameters.ContainsKey(CreateSolutionParameterName)) { 124 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 125 Parameters[CreateSolutionParameterName].Hidden = true; 126 127 // parameter type has been changed 128 if (Parameters.ContainsKey("CreateSolution")) { 129 var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>; 130 Parameters.Remove(createSolutionParam); 131 132 ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly; 133 Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value))); 134 Parameters[ModelCreationParameterName].Hidden = true; 135 } else if (!Parameters.ContainsKey(ModelCreationParameterName)) { 136 // very old version contains neither ModelCreationParameter nor CreateSolutionParameter 137 Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model))); 138 Parameters[ModelCreationParameterName].Hidden = true; 126 139 } 127 140 #endregion … … 138 151 139 152 var model = CreateRandomForestClassificationModel(Problem.ProblemData, NumberOfTrees, R, M, Seed, out rmsError, out relClassificationError, out outOfBagRmsError, out outOfBagRelClassificationError); 153 140 154 Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the random forest regression solution on the training set.", new DoubleValue(rmsError))); 141 155 Results.Add(new Result("Relative classification error", "Relative classification error of the random forest regression solution on the training set.", new PercentValue(relClassificationError))); … … 143 157 Results.Add(new Result("Relative classification error (out-of-bag)", "The out-of-bag relative classification error of the random forest regression solution.", new PercentValue(outOfBagRelClassificationError))); 144 158 145 if (CreateSolution) { 146 var solution = new RandomForestClassificationSolution(model, (IClassificationProblemData)Problem.ProblemData.Clone()); 159 160 IClassificationSolution solution = null; 161 if (ModelCreation == ModelCreation.Model) { 162 solution = model.CreateClassificationSolution(Problem.ProblemData); 163 } else if (ModelCreation == ModelCreation.SurrogateModel) { 164 var problemData = Problem.ProblemData; 165 var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M, problemData.ClassValues.ToArray()); 166 167 solution = surrogateModel.CreateClassificationSolution(problemData); 168 } 169 170 if (solution != null) { 147 171 Results.Add(new Result(RandomForestClassificationModelResultName, "The random forest classification solution.", solution)); 148 172 } … … 157 181 } 158 182 159 public static RandomForestModel CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 183 public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 184 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 185 var model = CreateRandomForestClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 186 return model; 187 } 188 189 public static RandomForestModelFull CreateRandomForestClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed, 160 190 out double rmsError, out double relClassificationError, out double outOfBagRmsError, out double outOfBagRelClassificationError) { 161 return RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, 162 rmsError: out rmsError, relClassificationError: out relClassificationError, outOfBagRmsError: out outOfBagRmsError, outOfBagRelClassificationError: out outOfBagRelClassificationError); 191 192 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 193 double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices); 194 195 var classValues = problemData.ClassValues.ToArray(); 196 int nClasses = classValues.Length; 197 198 // map original class values to values [0..nClasses-1] 199 var classIndices = new Dictionary<double, double>(); 200 for (int i = 0; i < nClasses; i++) { 201 classIndices[classValues[i]] = i; 202 } 203 204 int nRows = inputMatrix.GetLength(0); 205 int nColumns = inputMatrix.GetLength(1); 206 for (int row = 0; row < nRows; row++) { 207 inputMatrix[row, nColumns - 1] = classIndices[inputMatrix[row, nColumns - 1]]; 208 } 209 210 alglib.dfreport rep; 211 var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep); 212 213 rmsError = rep.rmserror; 214 outOfBagRmsError = rep.oobrmserror; 215 relClassificationError = rep.relclserror; 216 outOfBagRelClassificationError = rep.oobrelclserror; 217 218 return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables, classValues); 163 219 } 164 220 #endregion -
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassificationSolution.cs
r16723 r17226 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-2019Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. -
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r16801 r17226 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-2019Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 23 23 using System.Collections.Generic; 24 24 using System.Linq; 25 using HEAL.Attic; 25 26 using HeuristicLab.Common; 26 27 using HeuristicLab.Core; 27 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 28 using HEAL.Attic;29 29 using HeuristicLab.Problems.DataAnalysis; 30 30 using HeuristicLab.Problems.DataAnalysis.Symbolic; … … 34 34 /// Represents a random forest model for regression and classification 35 35 /// </summary> 36 [StorableType("A4F688CD-1F42-4103-8449-7DE52AEF6C69")] 36 [Obsolete("This class only exists for backwards compatibility reasons for stored models with the XML Persistence. Use RFModelSurrogate or RFModelFull instead.")] 37 [StorableType("9AA4CCC2-CD75-4471-8DF6-949E5B783642")] 37 38 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 38 39 public sealed class RandomForestModel : ClassificationModel, IRandomForestModel { … … 139 140 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 140 141 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 141 AssertInputMatrix(inputData);142 RandomForestUtil.AssertInputMatrix(inputData); 142 143 143 144 int n = inputData.GetLength(0); … … 157 158 public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) { 158 159 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 159 AssertInputMatrix(inputData);160 RandomForestUtil.AssertInputMatrix(inputData); 160 161 161 162 int n = inputData.GetLength(0); … … 175 176 public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 176 177 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 177 AssertInputMatrix(inputData);178 RandomForestUtil.AssertInputMatrix(inputData); 178 179 179 180 int n = inputData.GetLength(0); … … 315 316 316 317 alglib.dfreport rep; 317 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep);318 var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep); 318 319 319 320 rmsError = rep.rmserror; … … 353 354 354 355 alglib.dfreport rep; 355 var dForest = CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep);356 var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, nClasses, out rep); 356 357 357 358 rmsError = rep.rmserror; … … 361 362 362 363 return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues); 363 }364 365 private static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) {366 AssertParameters(r, m);367 AssertInputMatrix(inputMatrix);368 369 int info = 0;370 alglib.math.rndobject = new System.Random(seed);371 var dForest = new alglib.decisionforest();372 rep = new alglib.dfreport();373 int nRows = inputMatrix.GetLength(0);374 int nColumns = inputMatrix.GetLength(1);375 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1);376 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1);377 378 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj);379 if (info != 1) throw new ArgumentException("Error in calculation of random forest model");380 return dForest;381 }382 383 private static void AssertParameters(double r, double m) {384 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1.");385 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1.");386 }387 388 private static void AssertInputMatrix(double[,] inputMatrix) {389 if (inputMatrix.ContainsNanOrInfinity())390 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset.");391 364 } 392 365 -
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r16723 r17226 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-2019Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 20 20 #endregion 21 21 22 using System.Collections.Generic; 23 using System.Linq; 22 24 using System.Threading; 25 using HEAL.Attic; 26 using HeuristicLab.Algorithms.DataAnalysis.RandomForest; 23 27 using HeuristicLab.Common; 24 28 using HeuristicLab.Core; … … 26 30 using HeuristicLab.Optimization; 27 31 using HeuristicLab.Parameters; 28 using HEAL.Attic;29 32 using HeuristicLab.Problems.DataAnalysis; 30 33 … … 43 46 private const string SeedParameterName = "Seed"; 44 47 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 45 private const string CreateSolutionParameterName = "CreateSolution";48 private const string ModelCreationParameterName = "ModelCreation"; 46 49 47 50 #region parameter properties … … 61 64 get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 62 65 } 63 p ublic IFixedValueParameter<BoolValue> CreateSolutionParameter {64 get { return (IFixedValueParameter< BoolValue>)Parameters[CreateSolutionParameterName]; }66 private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter { 67 get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; } 65 68 } 66 69 #endregion … … 86 89 set { SetSeedRandomlyParameter.Value.Value = value; } 87 90 } 88 public bool CreateSolution {89 get { return CreateSolutionParameter.Value.Value; }90 set { CreateSolutionParameter.Value.Value = value; }91 public ModelCreation ModelCreation { 92 get { return ModelCreationParameter.Value.Value; } 93 set { ModelCreationParameter.Value.Value = value; } 91 94 } 92 95 #endregion … … 104 107 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); 105 108 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 106 Parameters.Add(new FixedValueParameter< BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));107 Parameters[ CreateSolutionParameterName].Hidden = true;109 Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model))); 110 Parameters[ModelCreationParameterName].Hidden = true; 108 111 109 112 Problem = new RegressionProblem(); … … 120 123 if (!Parameters.ContainsKey((SetSeedRandomlyParameterName))) 121 124 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 122 if (!Parameters.ContainsKey(CreateSolutionParameterName)) { 123 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 124 Parameters[CreateSolutionParameterName].Hidden = true; 125 126 // parameter type has been changed 127 if (Parameters.ContainsKey("CreateSolution")) { 128 var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>; 129 Parameters.Remove(createSolutionParam); 130 131 ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly; 132 Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value))); 133 Parameters[ModelCreationParameterName].Hidden = true; 134 } else if (!Parameters.ContainsKey(ModelCreationParameterName)) { 135 // very old version contains neither ModelCreationParameter nor CreateSolutionParameter 136 Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model))); 137 Parameters[ModelCreationParameterName].Hidden = true; 125 138 } 126 139 #endregion … … 143 156 Results.Add(new Result("Average relative error (out-of-bag)", "The out-of-bag average of relative errors of the random forest regression solution.", new PercentValue(outOfBagAvgRelError))); 144 157 145 if (CreateSolution) { 146 var solution = new RandomForestRegressionSolution(model, (IRegressionProblemData)Problem.ProblemData.Clone()); 158 IRegressionSolution solution = null; 159 if (ModelCreation == ModelCreation.Model) { 160 solution = model.CreateRegressionSolution(Problem.ProblemData); 161 } else if (ModelCreation == ModelCreation.SurrogateModel) { 162 var problemData = Problem.ProblemData; 163 var surrogateModel = new RandomForestModelSurrogate(model, problemData.TargetVariable, problemData, Seed, NumberOfTrees, R, M); 164 solution = surrogateModel.CreateRegressionSolution(problemData); 165 } 166 167 if (solution != null) { 147 168 Results.Add(new Result(RandomForestRegressionModelResultName, "The random forest regression solution.", solution)); 148 169 } 149 170 } 171 150 172 151 173 // keep for compatibility with old API … … 157 179 } 158 180 159 public static RandomForestModel CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees, 160 double r, double m, int seed, 161 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 162 return RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, 163 rmsError: out rmsError, avgRelError: out avgRelError, outOfBagRmsError: out outOfBagRmsError, outOfBagAvgRelError: out outOfBagAvgRelError); 181 public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, int nTrees, 182 double r, double m, int seed, 183 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 184 var model = CreateRandomForestRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 185 return model; 186 } 187 188 public static RandomForestModelFull CreateRandomForestRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed, 189 out double rmsError, out double avgRelError, out double outOfBagRmsError, out double outOfBagAvgRelError) { 190 191 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 192 double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices); 193 194 alglib.dfreport rep; 195 var dForest = RandomForestUtil.CreateRandomForestModel(seed, inputMatrix, nTrees, r, m, 1, out rep); 196 197 rmsError = rep.rmserror; 198 outOfBagRmsError = rep.oobrmserror; 199 avgRelError = rep.avgrelerror; 200 outOfBagAvgRelError = rep.oobavgrelerror; 201 202 return new RandomForestModelFull(dForest, problemData.TargetVariable, problemData.AllowedInputVariables); 164 203 } 165 204 -
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegressionSolution.cs
r16723 r17226 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-2019Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. -
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r16723 r17226 2 2 3 3 /* HeuristicLab 4 * Copyright (C) 2002-2019Heuristic and Evolutionary Algorithms Laboratory (HEAL)4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) 5 5 * 6 6 * This file is part of HeuristicLab. … … 27 27 using System.Linq.Expressions; 28 28 using System.Threading.Tasks; 29 using HEAL.Attic; 29 30 using HeuristicLab.Common; 30 31 using HeuristicLab.Core; 31 32 using HeuristicLab.Data; 32 33 using HeuristicLab.Parameters; 33 using HEAL.Attic;34 34 using HeuristicLab.Problems.DataAnalysis; 35 35 using HeuristicLab.Random; … … 89 89 90 90 public static class RandomForestUtil { 91 public static void AssertParameters(double r, double m) { 92 if (r <= 0 || r > 1) throw new ArgumentException("The R parameter for random forest modeling must be between 0 and 1."); 93 if (m <= 0 || m > 1) throw new ArgumentException("The M parameter for random forest modeling must be between 0 and 1."); 94 } 95 96 public static void AssertInputMatrix(double[,] inputMatrix) { 97 if (inputMatrix.ContainsNanOrInfinity()) 98 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); 99 } 100 101 internal static alglib.decisionforest CreateRandomForestModel(int seed, double[,] inputMatrix, int nTrees, double r, double m, int nClasses, out alglib.dfreport rep) { 102 RandomForestUtil.AssertParameters(r, m); 103 RandomForestUtil.AssertInputMatrix(inputMatrix); 104 105 int info = 0; 106 alglib.math.rndobject = new System.Random(seed); 107 var dForest = new alglib.decisionforest(); 108 rep = new alglib.dfreport(); 109 int nRows = inputMatrix.GetLength(0); 110 int nColumns = inputMatrix.GetLength(1); 111 int sampleSize = Math.Max((int)Math.Round(r * nRows), 1); 112 int nFeatures = Math.Max((int)Math.Round(m * (nColumns - 1)), 1); 113 114 alglib.dforest.dfbuildinternal(inputMatrix, nRows, nColumns - 1, nClasses, nTrees, sampleSize, nFeatures, alglib.dforest.dfusestrongsplits + alglib.dforest.dfuseevs, ref info, dForest.innerobj, rep.innerobj); 115 if (info != 1) throw new ArgumentException("Error in calculation of random forest model"); 116 return dForest; 117 } 118 119 91 120 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { 92 121 avgTestMse = 0;
Note: See TracChangeset
for help on using the changeset viewer.