Changeset 15973 for branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
- Timestamp:
- 06/28/18 11:13:37 (6 years ago)
- Location:
- branches/2522_RefactorPluginInfrastructure
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2522_RefactorPluginInfrastructure
- Property svn:ignore
-
old new 24 24 protoc.exe 25 25 obj 26 .vs
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
/stable/HeuristicLab.Algorithms.DataAnalysis/3.4 merged eligible /trunk/HeuristicLab.Algorithms.DataAnalysis/3.4 merged eligible /branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4 10321-10322 /branches/Async/HeuristicLab.Algorithms.DataAnalysis/3.4 13329-15286 /branches/Benchmarking/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 6917-7005 /branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.4 9070-13099 /branches/CloningRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4 4656-4721 /branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4 5471-5808 /branches/DataAnalysis SolutionEnsembles/HeuristicLab.Algorithms.DataAnalysis/3.4 5815-6180 /branches/DataAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.4 4458-4459,4462,4464 /branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis/3.4 10085-11101 /branches/GP.Grammar.Editor/HeuristicLab.Algorithms.DataAnalysis/3.4 6284-6795 /branches/GP.Symbols (TimeLag, Diff, Integral)/HeuristicLab.Algorithms.DataAnalysis/3.4 5060 /branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 11570-12508 /branches/HeuristicLab.Problems.Orienteering/HeuristicLab.Algorithms.DataAnalysis/3.4 11130-12721 /branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4 13819-14091 /branches/HeuristicLab.TimeSeries/HeuristicLab.Algorithms.DataAnalysis/3.4 8116-8789 /branches/LogResidualEvaluator/HeuristicLab.Algorithms.DataAnalysis/3.4 10202-10483 /branches/NET40/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 5138-5162 /branches/ParallelEngine/HeuristicLab.Algorithms.DataAnalysis/3.4 5175-5192 /branches/ProblemInstancesRegressionAndClassification/HeuristicLab.Algorithms.DataAnalysis/3.4 7773-7810 /branches/QAPAlgorithms/HeuristicLab.Algorithms.DataAnalysis/3.4 6350-6627 /branches/Restructure trunk solution/HeuristicLab.Algorithms.DataAnalysis/3.4 6828 /branches/SpectralKernelForGaussianProcesses/HeuristicLab.Algorithms.DataAnalysis/3.4 10204-10479 /branches/SuccessProgressAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.4 5370-5682 /branches/Trunk/HeuristicLab.Algorithms.DataAnalysis/3.4 6829-6865 /branches/VNS/HeuristicLab.Algorithms.DataAnalysis/3.4 5594-5752 /branches/Weighted TSNE/3.4 15451-15531 /branches/histogram/HeuristicLab.Algorithms.DataAnalysis/3.4 5959-6341 /branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4 14232-14825 /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 13402-15674
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
-
branches/2522_RefactorPluginInfrastructure/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r12509 r15973 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 25 25 using HeuristicLab.Common; 26 26 using HeuristicLab.Core; 27 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 27 28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 29 using HeuristicLab.Problems.DataAnalysis; 30 using HeuristicLab.Problems.DataAnalysis.Symbolic; 29 31 30 32 namespace HeuristicLab.Algorithms.DataAnalysis { … … 34 36 [StorableClass] 35 37 [Item("RandomForestModel", "Represents a random forest for regression and classification.")] 36 public sealed class RandomForestModel : NamedItem, IRandomForestModel {38 public sealed class RandomForestModel : ClassificationModel, IRandomForestModel { 37 39 // not persisted 38 40 private alglib.decisionforest randomForest; … … 45 47 } 46 48 49 public override IEnumerable<string> VariablesUsedForPrediction { 50 get { return originalTrainingData.AllowedInputVariables; } 51 } 52 53 public int NumberOfTrees { 54 get { return nTrees; } 55 } 56 47 57 // instead of storing the data of the model itself 48 58 // we instead only store data necessary to recalculate the same model lazily on demand … … 59 69 [Storable] 60 70 private double m; 61 62 71 63 72 [StorableConstructor] … … 91 100 92 101 // random forest models can only be created through the static factory methods CreateRegressionModel and CreateClassificationModel 93 private RandomForestModel( alglib.decisionforest randomForest,102 private RandomForestModel(string targetVariable, alglib.decisionforest randomForest, 94 103 int seed, IDataAnalysisProblemData originalTrainingData, 95 104 int nTrees, double r, double m, double[] classValues = null) 96 : base( ) {105 : base(targetVariable) { 97 106 this.name = ItemName; 98 107 this.description = ItemDescription; … … 130 139 131 140 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 132 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset,AllowedInputVariables, rows);141 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 133 142 AssertInputMatrix(inputData); 134 143 … … 147 156 } 148 157 149 public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 150 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 158 public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) { 159 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 160 AssertInputMatrix(inputData); 161 162 int n = inputData.GetLength(0); 163 int columns = inputData.GetLength(1); 164 double[] x = new double[columns]; 165 double[] ys = new double[this.RandomForest.innerobj.ntrees]; 166 167 for (int row = 0; row < n; row++) { 168 for (int column = 0; column < columns; column++) { 169 x[column] = inputData[row, column]; 170 } 171 alglib.dforest.dfprocessraw(RandomForest.innerobj, x, ref ys); 172 yield return ys.VariancePop(); 173 } 174 } 175 176 public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 177 double[,] inputData = dataset.ToArray(AllowedInputVariables, rows); 151 178 AssertInputMatrix(inputData); 152 179 … … 174 201 } 175 202 176 public IRandomForestRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 177 return new RandomForestRegressionSolution(new RegressionProblemData(problemData), this); 178 } 179 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 180 return CreateRegressionSolution(problemData); 181 } 182 public IRandomForestClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 183 return new RandomForestClassificationSolution(new ClassificationProblemData(problemData), this); 184 } 185 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { 186 return CreateClassificationSolution(problemData); 203 public ISymbolicExpressionTree ExtractTree(int treeIdx) { 204 var rf = RandomForest; 205 // hoping that the internal representation of alglib is stable 206 207 // TREE FORMAT 208 // W[Offs] - size of sub-array (for the tree) 209 // node info: 210 // W[K+0] - variable number (-1 for leaf mode) 211 // W[K+1] - threshold (class/value for leaf node) 212 // W[K+2] - ">=" branch index (absent for leaf node) 213 214 // skip irrelevant trees 215 int offset = 0; 216 for (int i = 0; i < treeIdx - 1; i++) { 217 offset = offset + (int)Math.Round(rf.innerobj.trees[offset]); 218 } 219 220 var constSy = new Constant(); 221 var varCondSy = new VariableCondition() { IgnoreSlope = true }; 222 223 var node = CreateRegressionTreeRec(rf.innerobj.trees, offset, offset + 1, constSy, varCondSy); 224 225 var startNode = new StartSymbol().CreateTreeNode(); 226 startNode.AddSubtree(node); 227 var root = new ProgramRootSymbol().CreateTreeNode(); 228 root.AddSubtree(startNode); 229 return new SymbolicExpressionTree(root); 230 } 231 232 private ISymbolicExpressionTreeNode CreateRegressionTreeRec(double[] trees, int offset, int k, Constant constSy, VariableCondition varCondSy) { 233 234 // alglib source for evaluation of one tree (dfprocessinternal) 235 // offs = 0 236 // 237 // Set pointer to the root 238 // 239 // k = offs + 1; 240 // 241 // // 242 // // Navigate through the tree 243 // // 244 // while (true) { 245 // if ((double)(df.trees[k]) == (double)(-1)) { 246 // if (df.nclasses == 1) { 247 // y[0] = y[0] + df.trees[k + 1]; 248 // } else { 249 // idx = (int)Math.Round(df.trees[k + 1]); 250 // y[idx] = y[idx] + 1; 251 // } 252 // break; 253 // } 254 // if ((double)(x[(int)Math.Round(df.trees[k])]) < (double)(df.trees[k + 1])) { 255 // k = k + innernodewidth; 256 // } else { 257 // k = offs + (int)Math.Round(df.trees[k + 2]); 258 // } 259 // } 260 261 if ((double)(trees[k]) == (double)(-1)) { 262 var constNode = (ConstantTreeNode)constSy.CreateTreeNode(); 263 constNode.Value = trees[k + 1]; 264 return constNode; 265 } else { 266 var condNode = (VariableConditionTreeNode)varCondSy.CreateTreeNode(); 267 condNode.VariableName = AllowedInputVariables[(int)Math.Round(trees[k])]; 268 condNode.Threshold = trees[k + 1]; 269 condNode.Slope = double.PositiveInfinity; 270 271 var left = CreateRegressionTreeRec(trees, offset, k + 3, constSy, varCondSy); 272 var right = CreateRegressionTreeRec(trees, offset, offset + (int)Math.Round(trees[k + 2]), constSy, varCondSy); 273 274 condNode.AddSubtree(left); // not 100% correct because interpreter uses: if(x <= thres) left() else right() and RF uses if(x < thres) left() else right() (see above) 275 condNode.AddSubtree(right); 276 return condNode; 277 } 278 } 279 280 281 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 282 return new RandomForestRegressionSolution(this, new RegressionProblemData(problemData)); 283 } 284 public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 285 return new RandomForestClassificationSolution(this, new ClassificationProblemData(problemData)); 187 286 } 188 287 189 288 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 190 289 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { 191 return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError); 290 return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, 291 rmsError: out rmsError, outOfBagRmsError: out outOfBagRmsError, avgRelError: out avgRelError, outOfBagAvgRelError: out outOfBagAvgRelError); 192 292 } 193 293 … … 195 295 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { 196 296 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 197 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset,variables, trainingIndices);297 double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices); 198 298 199 299 alglib.dfreport rep; … … 201 301 202 302 rmsError = rep.rmserror; 303 outOfBagRmsError = rep.oobrmserror; 203 304 avgRelError = rep.avgrelerror; 204 305 outOfBagAvgRelError = rep.oobavgrelerror; 205 outOfBagRmsError = rep.oobrmserror; 206 207 return new RandomForestModel(dForest, seed, problemData, nTrees, r, m); 306 307 return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m); 208 308 } 209 309 210 310 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 211 311 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 212 return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError); 312 return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, 313 out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError); 213 314 } 214 315 … … 217 318 218 319 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 219 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset,variables, trainingIndices);320 double[,] inputMatrix = problemData.Dataset.ToArray(variables, trainingIndices); 220 321 221 322 var classValues = problemData.ClassValues.ToArray(); … … 242 343 outOfBagRelClassificationError = rep.oobrelclserror; 243 344 244 return new RandomForestModel( dForest, seed, problemData, nTrees, r, m, classValues);345 return new RandomForestModel(problemData.TargetVariable, dForest, seed, problemData, nTrees, r, m, classValues); 245 346 } 246 347 … … 269 370 270 371 private static void AssertInputMatrix(double[,] inputMatrix) { 271 if (inputMatrix.C ast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x)))372 if (inputMatrix.ContainsNanOrInfinity()) 272 373 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); 273 374 }
Note: See TracChangeset
for help on using the changeset viewer.