Changeset 11343
- Timestamp:
- 09/04/14 17:31:46 (10 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r11338 r11343 189 189 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 190 190 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { 191 return CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError, problemData.TrainingIndices);192 } 193 194 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed,195 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError , IEnumerable<int> trainingIndices) {191 return CreateRegressionModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError); 192 } 193 194 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed, 195 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { 196 196 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 197 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);197 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices); 198 198 199 199 alglib.dfreport rep; … … 205 205 outOfBagRmsError = rep.oobrmserror; 206 206 207 return new RandomForestModel(dForest, 208 seed, problemData, 209 nTrees, r, m); 207 return new RandomForestModel(dForest,seed, problemData,nTrees, r, m); 210 208 } 211 209 212 210 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 213 211 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 214 return CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError, problemData.TrainingIndices);215 } 216 217 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed,218 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError , IEnumerable<int> trainingIndices) {212 return CreateClassificationModel(problemData, problemData.TrainingIndices, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError); 213 } 214 215 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, IEnumerable<int> trainingIndices, int nTrees, double r, double m, int seed, 216 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 219 217 220 218 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); … … 244 242 outOfBagRelClassificationError = rep.oobrelclserror; 245 243 246 return new RandomForestModel(dForest, 247 seed, problemData, 248 nTrees, r, m, classValues); 244 return new RandomForestModel(dForest,seed, problemData,nTrees, r, m, classValues); 249 245 } 250 246 -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r11338 r11343 41 41 42 42 public static class RandomForestUtil { 43 private static Action<RFParameter, double> GenerateSetter(string field) { 44 var targetExp = Expression.Parameter(typeof(RFParameter)); 45 var valueExp = Expression.Parameter(typeof(double)); 46 var fieldExp = Expression.Field(targetExp, field); 47 var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); 48 var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile(); 49 return setter; 43 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) { 44 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse); 45 } 46 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { 47 avgTestMse = 0; 48 var ds = problemData.Dataset; 49 var targetVariable = GetTargetVariableName(problemData); 50 foreach (var tuple in partitions) { 51 double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; 52 var trainingRandomForestPartition = tuple.Item1; 53 var testRandomForestPartition = tuple.Item2; 54 var model = RandomForestModel.CreateRegressionModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 55 var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition); 56 var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); 57 OnlineCalculatorError calculatorError; 58 double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError); 59 if (calculatorError != OnlineCalculatorError.None) 60 mse = double.NaN; 61 avgTestMse += mse; 62 } 63 avgTestMse /= partitions.Length; 64 } 65 66 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) { 67 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse); 68 } 69 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) { 70 avgTestAccuracy = 0; 71 var ds = problemData.Dataset; 72 var targetVariable = GetTargetVariableName(problemData); 73 foreach (var tuple in partitions) { 74 double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; 75 var trainingRandomForestPartition = tuple.Item1; 76 var testRandomForestPartition = tuple.Item2; 77 var model = RandomForestModel.CreateClassificationModel(problemData, trainingRandomForestPartition, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError); 78 var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition); 79 var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); 80 OnlineCalculatorError calculatorError; 81 double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError); 82 if (calculatorError != OnlineCalculatorError.None) 83 accuracy = double.NaN; 84 avgTestAccuracy += accuracy; 85 } 86 avgTestAccuracy /= partitions.Length; 87 } 88 89 private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 90 DoubleValue mse = new DoubleValue(Double.MaxValue); 91 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults 92 93 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 94 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); 95 var crossProduct = parameterRanges.Values.CartesianProduct(); 96 97 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 98 var parameterValues = parameterCombination.ToList(); 99 double testMSE; 100 var parameters = new RFParameter(); 101 for (int i = 0; i < setters.Count; ++i) { 102 setters[i](parameters, parameterValues[i]); 103 } 104 CrossValidate(problemData, partitions, parameters, seed, out testMSE); 105 if (testMSE < mse.Value) { 106 lock (mse) { 107 mse.Value = testMSE; 108 bestParameter = (RFParameter)parameters.Clone(); 109 } 110 } 111 }); 112 return bestParameter; 113 } 114 115 private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 116 DoubleValue accuracy = new DoubleValue(0); 117 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults 118 119 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 120 var crossProduct = parameterRanges.Values.CartesianProduct(); 121 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); 122 123 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 124 var parameterValues = parameterCombination.ToList(); 125 double testAccuracy; 126 var parameters = new RFParameter(); 127 for (int i = 0; i < setters.Count; ++i) { 128 setters[i](parameters, parameterValues[i]); 129 } 130 CrossValidate(problemData, partitions, parameters, seed, out testAccuracy); 131 if (testAccuracy > accuracy.Value) { 132 lock (accuracy) { 133 accuracy.Value = testAccuracy; 134 bestParameter = (RFParameter)parameters.Clone(); 135 } 136 } 137 }); 138 return bestParameter; 50 139 } 51 140 … … 57 146 /// <param name="numberOfFolds">The number of folds to generate</param> 58 147 /// <returns>A sequence of folds representing each a sequence of row numbers</returns> 59 p ublicstatic IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {148 private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) { 60 149 int size = problemData.TrainingPartition.Size; 61 150 int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder … … 82 171 } 83 172 84 public static void CrossValidate(IDataAnalysisProblemData problemData, int numberOfFolds, RFParameter parameters, int seed, out double error) {85 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds);86 CrossValidate(problemData, partitions, parameters, seed, out error);87 }88 173 89 // user should call the more specific CrossValidate methods 90 public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double error) { 91 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out error); 92 } 93 94 public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double error) { 95 var regressionProblemData = problemData as IRegressionProblemData; 96 var classificationProblemData = problemData as IClassificationProblemData; 97 if (regressionProblemData != null) 98 CrossValidate(regressionProblemData, partitions, nTrees, m, r, seed, out error); 99 else if (classificationProblemData != null) 100 CrossValidate(classificationProblemData, partitions, nTrees, m, r, seed, out error); 101 else throw new ArgumentException("Problem data is neither regression or classification problem data."); 102 } 103 104 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) { 105 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse); 106 } 107 108 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) { 109 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse); 110 } 111 112 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { 113 avgTestMse = 0; 114 var ds = problemData.Dataset; 115 var targetVariable = GetTargetVariableName(problemData); 116 foreach (var tuple in partitions) { 117 double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; 118 var trainingRandomForestPartition = tuple.Item1; 119 var testRandomForestPartition = tuple.Item2; 120 var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition); 121 var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition); 122 var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); 123 OnlineCalculatorError calculatorError; 124 double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError); 125 if (calculatorError != OnlineCalculatorError.None) 126 mse = double.NaN; 127 avgTestMse += mse; 128 } 129 avgTestMse /= partitions.Length; 130 } 131 132 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) { 133 avgTestAccuracy = 0; 134 var ds = problemData.Dataset; 135 var targetVariable = GetTargetVariableName(problemData); 136 foreach (var tuple in partitions) { 137 double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; 138 var trainingRandomForestPartition = tuple.Item1; 139 var testRandomForestPartition = tuple.Item2; 140 var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition); 141 var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition); 142 var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); 143 OnlineCalculatorError calculatorError; 144 double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError); 145 if (calculatorError != OnlineCalculatorError.None) 146 accuracy = double.NaN; 147 avgTestAccuracy += accuracy; 148 } 149 avgTestAccuracy /= partitions.Length; 150 } 151 152 public static RFParameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 153 var regressionProblemData = problemData as IRegressionProblemData; 154 var classificationProblemData = problemData as IClassificationProblemData; 155 156 if (regressionProblemData != null) 157 return GridSearch(regressionProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism); 158 if (classificationProblemData != null) 159 return GridSearch(classificationProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism); 160 161 throw new ArgumentException("Problem data is neither regression or classification problem data."); 162 } 163 164 private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 165 DoubleValue mse = new DoubleValue(Double.MaxValue); 166 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults 167 168 var pNames = parameterRanges.Keys.ToList(); 169 var pRanges = pNames.Select(x => parameterRanges[x]); 170 var setters = pNames.Select(GenerateSetter).ToList(); 171 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); 172 var crossProduct = pRanges.CartesianProduct(); 173 174 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => { 175 var list = nuple.ToList(); 176 double testMSE; 177 var parameters = new RFParameter(); 178 for (int i = 0; i < pNames.Count; ++i) { 179 var s = setters[i]; 180 s(parameters, list[i]); 181 } 182 CrossValidate(problemData, partitions, parameters, seed, out testMSE); 183 if (testMSE < mse.Value) { 184 lock (mse) { mse.Value = testMSE; } 185 lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); } 186 } 187 }); 188 return bestParameter; 189 } 190 191 private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 192 DoubleValue accuracy = new DoubleValue(0); 193 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults 194 195 var pNames = parameterRanges.Keys.ToList(); 196 var pRanges = pNames.Select(x => parameterRanges[x]); 197 var setters = pNames.Select(GenerateSetter).ToList(); 198 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); 199 var crossProduct = pRanges.CartesianProduct(); 200 201 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => { 202 var list = nuple.ToList(); 203 double testAccuracy; 204 var parameters = new RFParameter(); 205 for (int i = 0; i < pNames.Count; ++i) { 206 var s = setters[i]; 207 s(parameters, list[i]); 208 } 209 CrossValidate(problemData, partitions, parameters, seed, out testAccuracy); 210 if (testAccuracy > accuracy.Value) { 211 lock (accuracy) { accuracy.Value = testAccuracy; } 212 lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); } 213 } 214 }); 215 return bestParameter; 174 private static Action<RFParameter, double> GenerateSetter(string field) { 175 var targetExp = Expression.Parameter(typeof(RFParameter)); 176 var valueExp = Expression.Parameter(typeof(double)); 177 var fieldExp = Expression.Field(targetExp, field); 178 var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); 179 var setter = Expression.Lambda<Action<RFParameter, double>>(assignExp, targetExp, valueExp).Compile(); 180 return setter; 216 181 } 217 182
Note: See TracChangeset
for help on using the changeset viewer.