Changeset 12515 for branches/HiveStatistics/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
- Timestamp:
- 06/25/15 18:21:19 (9 years ago)
- Location:
- branches/HiveStatistics/sources/HeuristicLab.Algorithms.DataAnalysis
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HiveStatistics/sources/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Algorithms.DataAnalysis (added) merged: 12504,12509
- Property svn:mergeinfo changed
-
branches/HiveStatistics/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestClassification.cs
r12012 r12515 20 20 #endregion 21 21 22 using System;23 using System.Collections.Generic;24 using System.Linq;25 22 using HeuristicLab.Common; 26 23 using HeuristicLab.Core; … … 36 33 /// </summary> 37 34 [Item("Random Forest Classification", "Random forest classification data analysis algorithm (wrapper for ALGLIB).")] 38 [Creatable( "Data Analysis")]35 [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 120)] 39 36 [StorableClass] 40 37 public sealed class RandomForestClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> { -
branches/HiveStatistics/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r12012 r12515 129 129 } 130 130 131 public IEnumerable<double> GetEstimatedValues( Dataset dataset, IEnumerable<int> rows) {131 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 132 132 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 133 133 AssertInputMatrix(inputData); … … 147 147 } 148 148 149 public IEnumerable<double> GetEstimatedClassValues( Dataset dataset, IEnumerable<int> rows) {149 public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 150 150 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, AllowedInputVariables, rows); 151 151 AssertInputMatrix(inputData); … … 205 205 outOfBagRmsError = rep.oobrmserror; 206 206 207 return new RandomForestModel(dForest, seed, problemData,nTrees, r, m);207 return new RandomForestModel(dForest, seed, problemData, nTrees, r, m); 208 208 } 209 209 … … 242 242 outOfBagRelClassificationError = rep.oobrelclserror; 243 243 244 return new RandomForestModel(dForest, seed, problemData,nTrees, r, m, classValues);244 return new RandomForestModel(dForest, seed, problemData, nTrees, r, m, classValues); 245 245 } 246 246 -
branches/HiveStatistics/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestRegression.cs
r12012 r12515 20 20 #endregion 21 21 22 using System;23 using System.Collections.Generic;24 using System.Linq;25 22 using HeuristicLab.Common; 26 23 using HeuristicLab.Core; … … 36 33 /// </summary> 37 34 [Item("Random Forest Regression", "Random forest regression data analysis algorithm (wrapper for ALGLIB).")] 38 [Creatable( "Data Analysis")]35 [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)] 39 36 [StorableClass] 40 37 public sealed class RandomForestRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> { -
branches/HiveStatistics/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r12012 r12515 90 90 91 91 public static class RandomForestUtil { 92 private static readonly object locker = new object();93 94 92 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { 95 93 avgTestMse = 0; … … 132 130 } 133 131 134 // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased 132 /// <summary> 133 /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased) 134 /// </summary> 135 /// <param name="problemData">The regression problem data</param> 136 /// <param name="parameterRanges">The ranges for each parameter in the grid search</param> 137 /// <param name="seed">The random seed (required by the random forest model)</param> 138 /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param> 135 139 public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 136 140 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); … … 139 143 RFParameter bestParameters = new RFParameter(); 140 144 145 var locker = new object(); 141 146 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 142 147 var parameterValues = parameterCombination.ToList(); … … 156 161 } 157 162 163 /// <summary> 164 /// Grid search without crossvalidation (since for random forests the out-of-bag estimate is unbiased) 165 /// </summary> 166 /// <param name="problemData">The classification problem data</param> 167 /// <param name="parameterRanges">The ranges for each parameter in the grid search</param> 168 /// <param name="seed">The random seed (required by the random forest model)</param> 169 /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param> 158 170 public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 159 171 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); … … 163 175 RFParameter bestParameters = new RFParameter(); 164 176 177 var locker = new object(); 165 178 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 166 179 var parameterValues = parameterCombination.ToList(); … … 181 194 } 182 195 196 /// <summary> 197 /// Grid search with crossvalidation 198 /// </summary> 199 /// <param name="problemData">The regression problem data</param> 200 /// <param name="numberOfFolds">The number of folds for crossvalidation</param> 201 /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param> 202 /// <param name="parameterRanges">The ranges for each parameter in the grid search</param> 203 /// <param name="seed">The random seed (required by the random forest model)</param> 204 /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param> 205 /// <returns>The best parameter values found by the grid search</returns> 183 206 public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 184 207 DoubleValue mse = new DoubleValue(Double.MaxValue); … … 189 212 var crossProduct = parameterRanges.Values.CartesianProduct(); 190 213 214 var locker = new object(); 191 215 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 192 216 var parameterValues = parameterCombination.ToList(); … … 208 232 } 209 233 234 /// <summary> 235 /// Grid search with crossvalidation 236 /// </summary> 237 /// <param name="problemData">The classification problem data</param> 238 /// <param name="numberOfFolds">The number of folds for crossvalidation</param> 239 /// <param name="shuffleFolds">Specifies whether the folds should be shuffled</param> 240 /// <param name="parameterRanges">The ranges for each parameter in the grid search</param> 241 /// <param name="seed">The random seed (for shuffling)</param> 242 /// <param name="maxDegreeOfParallelism">The maximum allowed number of threads (to parallelize the grid search)</param> 210 243 public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 211 244 DoubleValue accuracy = new DoubleValue(0); … … 216 249 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds); 217 250 251 var locker = new object(); 218 252 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 219 253 var parameterValues = parameterCombination.ToList();
Note: See TracChangeset
for help on using the changeset viewer.