- Timestamp:
- 09/14/14 16:15:28 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r11343 r11362 28 28 using System.Threading.Tasks; 29 29 using HeuristicLab.Common; 30 using HeuristicLab.Core; 30 31 using HeuristicLab.Data; 31 32 using HeuristicLab.Problems.DataAnalysis; 33 using HeuristicLab.Random; 32 34 33 35 namespace HeuristicLab.Algorithms.DataAnalysis { … … 41 43 42 44 public static class RandomForestUtil { 43 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {44 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);45 }46 45 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { 47 46 avgTestMse = 0; … … 63 62 avgTestMse /= partitions.Length; 64 63 } 65 66 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) {67 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse);68 }69 64 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) { 70 65 avgTestAccuracy = 0; … … 87 82 } 88 83 89 private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 84 // grid search without cross-validation since in the case of random forests, the out-of-bag estimate is unbiased 85 public static RFParameter GridSearch(IRegressionProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 86 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 87 var crossProduct = parameterRanges.Values.CartesianProduct(); 88 double bestOutOfBagRmsError = double.MaxValue; 89 RFParameter bestParameters = new RFParameter(); 90 91 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 92 var parameterValues = parameterCombination.ToList(); 93 double testMSE; 94 var parameters = new RFParameter(); 95 for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } 96 double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; 97 var model = RandomForestModel.CreateRegressionModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed, 98 out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 99 if (bestOutOfBagRmsError > outOfBagRmsError) { 100 lock (bestParameters) { 101 bestOutOfBagRmsError = outOfBagRmsError; 102 bestParameters = (RFParameter)parameters.Clone(); 103 } 104 } 105 }); 106 return bestParameters; 107 } 108 109 public static RFParameter GridSearch(IClassificationProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 110 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 111 var crossProduct = parameterRanges.Values.CartesianProduct(); 112 113 double bestOutOfBagRmsError = double.MaxValue; 114 RFParameter bestParameters = new RFParameter(); 115 116 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 117 var parameterValues = parameterCombination.ToList(); 118 var parameters = new RFParameter(); 119 for (int i = 0; i < setters.Count; ++i) { setters[i](parameters, parameterValues[i]); } 120 double rmsError, outOfBagRmsError, avgRelError, outOfBagAvgRelError; 121 var model = RandomForestModel.CreateClassificationModel(problemData, problemData.TrainingIndices, (int)parameters.n, parameters.r, parameters.m, seed, 122 out rmsError, out outOfBagRmsError, out avgRelError, out outOfBagAvgRelError); 123 if (bestOutOfBagRmsError > outOfBagRmsError) { 124 lock (bestParameters) { 125 bestOutOfBagRmsError = outOfBagRmsError; 126 bestParameters = (RFParameter)parameters.Clone(); 127 } 128 } 129 }); 130 return bestParameters; 131 } 132 133 public static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 90 134 DoubleValue mse = new DoubleValue(Double.MaxValue); 91 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults135 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; 92 136 93 137 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); … … 102 146 setters[i](parameters, parameterValues[i]); 103 147 } 104 CrossValidate(problemData, partitions, parameters, seed, out testMSE);148 CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testMSE); 105 149 if (testMSE < mse.Value) { 106 150 lock (mse) { … … 113 157 } 114 158 115 p rivate static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) {159 public static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, bool shuffleFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 116 160 DoubleValue accuracy = new DoubleValue(0); 117 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults118 119 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 120 var crossProduct = parameterRanges.Values.CartesianProduct(); 121 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds );161 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; 162 163 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 164 var crossProduct = parameterRanges.Values.CartesianProduct(); 165 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds, shuffleFolds); 122 166 123 167 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { … … 128 172 setters[i](parameters, parameterValues[i]); 129 173 } 130 CrossValidate(problemData, partitions, parameters, seed, out testAccuracy);174 CrossValidate(problemData, partitions, (int)parameters.n, parameters.r, parameters.m, seed, out testAccuracy); 131 175 if (testAccuracy > accuracy.Value) { 132 176 lock (accuracy) { … … 139 183 } 140 184 141 /// <summary> 142 /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation) 143 /// </summary> 144 /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks> 145 /// <param name="problemData">The problem data</param> 146 /// <param name="numberOfFolds">The number of folds to generate</param> 147 /// <returns>A sequence of folds representing each a sequence of row numbers</returns> 148 private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) { 149 int size = problemData.TrainingPartition.Size; 150 int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder 151 int start = 0, end = f; 152 for (int i = 0; i < numberOfFolds; ++i) { 153 if (r > 0) { ++end; --r; } 154 yield return problemData.TrainingIndices.Skip(start).Take(end - start); 155 start = end; 156 end += f; 157 } 158 } 159 160 private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) { 161 var folds = GenerateFolds(problemData, numberOfFolds).ToList(); 185 private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) { 186 var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList(); 162 187 var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds]; 163 188 … … 171 196 } 172 197 198 public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = false) { 199 var random = new MersenneTwister((uint)Environment.TickCount); 200 if (problemData is IRegressionProblemData) { 201 var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices; 202 return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds); 203 } 204 if (problemData is IClassificationProblemData) { 205 // when shuffle is enabled do stratified folds generation, some folds may have zero elements 206 // otherwise, generate folds normally 207 return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds); 208 } 209 throw new ArgumentException("Problem data is neither regression or classification problem data."); 210 } 211 212 /// <summary> 213 /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold. 214 /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of 215 /// the corresponding parts from each class label. 216 /// </summary> 217 /// <param name="problemData">The classification problem data.</param> 218 /// <param name="numberOfFolds">The number of folds in which to split the data.</param> 219 /// <param name="random">The random generator used to shuffle the folds.</param> 220 /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns> 221 private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) { 222 var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 223 var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList(); 224 IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds)); 225 var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList(); 226 while (enumerators.All(e => e.MoveNext())) { 227 yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList(); 228 } 229 } 230 231 private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) { 232 // if number of folds is greater than the number of values, some empty folds will be returned 233 if (valuesCount < numberOfFolds) { 234 for (int i = 0; i < numberOfFolds; ++i) 235 yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>(); 236 } else { 237 int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder 238 int start = 0, end = f; 239 for (int i = 0; i < numberOfFolds; ++i) { 240 if (r > 0) { 241 ++end; 242 --r; 243 } 244 yield return values.Skip(start).Take(end - start); 245 start = end; 246 end += f; 247 } 248 } 249 } 173 250 174 251 private static Action<RFParameter, double> GenerateSetter(string field) {
Note: See TracChangeset
for help on using the changeset viewer.