- Timestamp:
- 09/03/14 15:15:41 (10 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModel.cs
r11315 r11338 188 188 189 189 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 190 out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError) {190 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError) { 191 191 return CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError, problemData.TrainingIndices); 192 192 } 193 193 194 194 public static RandomForestModel CreateRegressionModel(IRegressionProblemData problemData, int nTrees, double r, double m, int seed, 195 out double rmsError, out double avgRelError, out double outOfBagAvgRelError, out double outOfBagRmsError, IEnumerable<int> trainingIndices) {195 out double rmsError, out double outOfBagRmsError, out double avgRelError, out double outOfBagAvgRelError, IEnumerable<int> trainingIndices) { 196 196 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 197 197 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices); … … 212 212 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 213 213 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError) { 214 return CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out outOfBagRmsError, out relClassificationError, out outOfBagRelClassificationError, problemData.TrainingIndices); 215 } 216 217 public static RandomForestModel CreateClassificationModel(IClassificationProblemData problemData, int nTrees, double r, double m, int seed, 218 out double rmsError, out double outOfBagRmsError, out double relClassificationError, out double outOfBagRelClassificationError, IEnumerable<int> trainingIndices) { 214 219 215 220 var variables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 216 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, problemData.TrainingIndices);221 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(problemData.Dataset, variables, trainingIndices); 217 222 218 223 var classValues = problemData.ClassValues.ToArray(); … … 268 273 269 274 private static void AssertInputMatrix(double[,] inputMatrix) { 270 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))275 if (inputMatrix.Cast<double>().Any(x => Double.IsNaN(x) || Double.IsInfinity(x))) 271 276 throw new NotSupportedException("Random forest modeling does not support NaN or infinity values in the input dataset."); 272 277 } -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestUtil.cs
r11315 r11338 44 44 var targetExp = Expression.Parameter(typeof(RFParameter)); 45 45 var valueExp = Expression.Parameter(typeof(double)); 46 47 // Expression.Property can be used here as well48 46 var fieldExp = Expression.Field(targetExp, field); 49 47 var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); … … 53 51 54 52 /// <summary> 55 /// Generate a collection of training indices corresponding to folds in the data (used for crossvalidation)53 /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation) 56 54 /// </summary> 57 55 /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks> 58 56 /// <param name="problemData">The problem data</param> 59 /// <param name="n Folds">The number of folds to generate</param>57 /// <param name="numberOfFolds">The number of folds to generate</param> 60 58 /// <returns>A sequence of folds representing each a sequence of row numbers</returns> 61 public static IEnumerable<IEnumerable<int>> GenerateFolds(I RegressionProblemData problemData, int nFolds) {59 public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) { 62 60 int size = problemData.TrainingPartition.Size; 63 64 int foldSize = size / nFolds; // rounding to integer 65 var trainingIndices = problemData.TrainingIndices; 66 67 for (int i = 0; i < nFolds; ++i) { 68 int n = i * foldSize; 69 int s = n + 2 * foldSize > size ? foldSize + size % foldSize : foldSize; 70 yield return trainingIndices.Skip(n).Take(s); 71 } 72 } 73 74 public static void CrossValidate(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, RFParameter parameter, int seed, out double avgTestMSE) { 75 CrossValidate(problemData, folds, (int)Math.Round(parameter.n), parameter.m, parameter.r, seed, out avgTestMSE); 76 } 77 78 public static void CrossValidate(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, int nTrees, double m, double r, int seed, out double avgTestMSE) { 79 avgTestMSE = 0; 80 61 int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder 62 int start = 0, end = f; 63 for (int i = 0; i < numberOfFolds; ++i) { 64 if (r > 0) { ++end; --r; } 65 yield return problemData.TrainingIndices.Skip(start).Take(end - start); 66 start = end; 67 end += f; 68 } 69 } 70 71 private static Tuple<IEnumerable<int>, IEnumerable<int>>[] GenerateRandomForestPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) { 72 var folds = GenerateFolds(problemData, numberOfFolds).ToList(); 73 var partitions = new Tuple<IEnumerable<int>, IEnumerable<int>>[numberOfFolds]; 74 75 for (int i = 0; i < numberOfFolds; ++i) { 76 int p = i; // avoid "access to modified closure" warning 77 var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>()); 78 var testRows = folds[i]; 79 partitions[i] = new Tuple<IEnumerable<int>, IEnumerable<int>>(trainingRows, testRows); 80 } 81 return partitions; 82 } 83 84 public static void CrossValidate(IDataAnalysisProblemData problemData, int numberOfFolds, RFParameter parameters, int seed, out double error) { 85 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); 86 CrossValidate(problemData, partitions, parameters, seed, out error); 87 } 88 89 // user should call the more specific CrossValidate methods 90 public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double error) { 91 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out error); 92 } 93 94 public static void CrossValidate(IDataAnalysisProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double error) { 95 var regressionProblemData = problemData as IRegressionProblemData; 96 var classificationProblemData = problemData as IClassificationProblemData; 97 if (regressionProblemData != null) 98 CrossValidate(regressionProblemData, partitions, nTrees, m, r, seed, out error); 99 else if (classificationProblemData != null) 100 CrossValidate(classificationProblemData, partitions, nTrees, m, r, seed, out error); 101 else throw new ArgumentException("Problem data is neither regression or classification problem data."); 102 } 103 104 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) { 105 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse); 106 } 107 108 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, RFParameter parameters, int seed, out double avgTestMse) { 109 CrossValidate(problemData, partitions, (int)parameters.n, parameters.m, parameters.r, seed, out avgTestMse); 110 } 111 112 private static void CrossValidate(IRegressionProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestMse) { 113 avgTestMse = 0; 81 114 var ds = problemData.Dataset; 82 var targetVariable = problemData.TargetVariable; 83 84 var partitions = folds.ToList(); 85 86 for (int i = 0; i < partitions.Count; ++i) { 115 var targetVariable = GetTargetVariableName(problemData); 116 foreach (var tuple in partitions) { 87 117 double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; 88 var test = partitions[i]; 89 var training = new List<int>(); 90 for (int j = 0; j < i; ++j) 91 training.AddRange(partitions[j]); 92 93 for (int j = i + 1; j < partitions.Count; ++j) 94 training.AddRange(partitions[j]); 95 96 var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, m, r, seed, out rmsError, out avgRelError, out outOfBagAvgRelError, out outOfBagRmsError, training); 97 var estimatedValues = model.GetEstimatedValues(ds, test); 98 var outputValues = ds.GetDoubleValues(targetVariable, test); 99 118 var trainingRandomForestPartition = tuple.Item1; 119 var testRandomForestPartition = tuple.Item2; 120 var model = RandomForestModel.CreateRegressionModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition); 121 var estimatedValues = model.GetEstimatedValues(ds, testRandomForestPartition); 122 var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); 100 123 OnlineCalculatorError calculatorError; 101 double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, outputValues, out calculatorError);124 double mse = OnlineMeanSquaredErrorCalculator.Calculate(estimatedValues, targetValues, out calculatorError); 102 125 if (calculatorError != OnlineCalculatorError.None) 103 126 mse = double.NaN; 104 avgTestMSE += mse; 105 } 106 107 avgTestMSE /= partitions.Count; 108 } 109 110 public static RFParameter GridSearch(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 127 avgTestMse += mse; 128 } 129 avgTestMse /= partitions.Length; 130 } 131 132 private static void CrossValidate(IClassificationProblemData problemData, Tuple<IEnumerable<int>, IEnumerable<int>>[] partitions, int nTrees, double r, double m, int seed, out double avgTestAccuracy) { 133 avgTestAccuracy = 0; 134 var ds = problemData.Dataset; 135 var targetVariable = GetTargetVariableName(problemData); 136 foreach (var tuple in partitions) { 137 double rmsError, avgRelError, outOfBagAvgRelError, outOfBagRmsError; 138 var trainingRandomForestPartition = tuple.Item1; 139 var testRandomForestPartition = tuple.Item2; 140 var model = RandomForestModel.CreateClassificationModel(problemData, nTrees, r, m, seed, out rmsError, out avgRelError, out outOfBagRmsError, out outOfBagAvgRelError, trainingRandomForestPartition); 141 var estimatedValues = model.GetEstimatedClassValues(ds, testRandomForestPartition); 142 var targetValues = ds.GetDoubleValues(targetVariable, testRandomForestPartition); 143 OnlineCalculatorError calculatorError; 144 double accuracy = OnlineAccuracyCalculator.Calculate(estimatedValues, targetValues, out calculatorError); 145 if (calculatorError != OnlineCalculatorError.None) 146 accuracy = double.NaN; 147 avgTestAccuracy += accuracy; 148 } 149 avgTestAccuracy /= partitions.Length; 150 } 151 152 public static RFParameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 153 var regressionProblemData = problemData as IRegressionProblemData; 154 var classificationProblemData = problemData as IClassificationProblemData; 155 156 if (regressionProblemData != null) 157 return GridSearch(regressionProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism); 158 if (classificationProblemData != null) 159 return GridSearch(classificationProblemData, numberOfFolds, parameterRanges, seed, maxDegreeOfParallelism); 160 161 throw new ArgumentException("Problem data is neither regression or classification problem data."); 162 } 163 164 private static RFParameter GridSearch(IRegressionProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 111 165 DoubleValue mse = new DoubleValue(Double.MaxValue); 112 166 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults … … 115 169 var pRanges = pNames.Select(x => parameterRanges[x]); 116 170 var setters = pNames.Select(GenerateSetter).ToList(); 117 171 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); 118 172 var crossProduct = pRanges.CartesianProduct(); 119 173 … … 126 180 s(parameters, list[i]); 127 181 } 128 CrossValidate(problemData, folds, parameters, seed, out testMSE);182 CrossValidate(problemData, partitions, parameters, seed, out testMSE); 129 183 if (testMSE < mse.Value) { 130 lock (mse) { 131 mse.Value = testMSE; 132 } 133 lock (bestParameter) { 134 bestParameter = (RFParameter)parameters.Clone(); 135 } 184 lock (mse) { mse.Value = testMSE; } 185 lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); } 136 186 } 137 187 }); 138 188 return bestParameter; 139 189 } 190 191 private static RFParameter GridSearch(IClassificationProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int seed = 12345, int maxDegreeOfParallelism = 1) { 192 DoubleValue accuracy = new DoubleValue(0); 193 RFParameter bestParameter = new RFParameter { n = 1, m = 0.1, r = 0.1 }; // some random defaults 194 195 var pNames = parameterRanges.Keys.ToList(); 196 var pRanges = pNames.Select(x => parameterRanges[x]); 197 var setters = pNames.Select(GenerateSetter).ToList(); 198 var partitions = GenerateRandomForestPartitions(problemData, numberOfFolds); 199 var crossProduct = pRanges.CartesianProduct(); 200 201 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => { 202 var list = nuple.ToList(); 203 double testAccuracy; 204 var parameters = new RFParameter(); 205 for (int i = 0; i < pNames.Count; ++i) { 206 var s = setters[i]; 207 s(parameters, list[i]); 208 } 209 CrossValidate(problemData, partitions, parameters, seed, out testAccuracy); 210 if (testAccuracy > accuracy.Value) { 211 lock (accuracy) { accuracy.Value = testAccuracy; } 212 lock (bestParameter) { bestParameter = (RFParameter)parameters.Clone(); } 213 } 214 }); 215 return bestParameter; 216 } 217 218 private static string GetTargetVariableName(IDataAnalysisProblemData problemData) { 219 var regressionProblemData = problemData as IRegressionProblemData; 220 var classificationProblemData = problemData as IClassificationProblemData; 221 222 if (regressionProblemData != null) 223 return regressionProblemData.TargetVariable; 224 if (classificationProblemData != null) 225 return classificationProblemData.TargetVariable; 226 227 throw new ArgumentException("Problem data is neither regression or classification problem data."); 228 } 140 229 } 141 230 }
Note: See TracChangeset
for help on using the changeset viewer.