- Timestamp:
- 09/04/14 14:16:37 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs
r11337 r11339 84 84 } 85 85 86 /// <summary> 87 /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation) 88 /// </summary> 89 /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks> 90 /// <param name="problemData">The problem data</param> 91 /// <param name="numberOfFolds">The number of folds to generate</param> 92 /// <returns>A sequence of folds representing each a sequence of row numbers</returns> 93 public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) { 94 int size = problemData.TrainingPartition.Size; 95 int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder 96 int start = 0, end = f; 97 for (int i = 0; i < numberOfFolds; ++i) { 98 if (r > 0) { ++end; --r; } 99 yield return problemData.TrainingIndices.Skip(start).Take(end - start); 100 start = end; 101 end += f; 86 public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) { 87 var partitions = GenerateSvmPartitions(problemData, numberOfFolds); 88 CalculateCrossValidationPartitions(partitions, parameters, out avgTestMse); 89 } 90 91 public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) { 92 DoubleValue mse = new DoubleValue(Double.MaxValue); 93 var bestParam = DefaultParameters(); 94 var crossProduct = parameterRanges.Values.CartesianProduct(); 95 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 96 var partitions = GenerateSvmPartitions(problemData, numberOfFolds); 97 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 98 var parameters = DefaultParameters(); 99 var parameterValues = parameterCombination.ToList(); 100 for (int i = 0; i < parameterValues.Count; ++i) { 101 setters[i](parameters, parameterValues[i]); 102 } 103 double testMse; 104 CalculateCrossValidationPartitions(partitions, parameters, out testMse); 105 if (testMse < mse.Value) { 106 lock (mse) { mse.Value = testMse; } 107 lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); } 108 } 109 }); 110 return bestParam; 111 } 112 113 private static void CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters, out double avgTestMse) { 114 avgTestMse = 0; 115 var calc = new OnlineMeanSquaredErrorCalculator(); 116 foreach (Tuple<svm_problem, svm_problem> tuple in partitions) { 117 var trainingSvmProblem = tuple.Item1; 118 var testSvmProblem = tuple.Item2; 119 var model = svm.svm_train(trainingSvmProblem, parameters); 120 calc.Reset(); 121 for (int i = 0; i < testSvmProblem.l; ++i) 122 calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i])); 123 avgTestMse += calc.MeanSquaredError; 102 124 } 125 avgTestMse /= partitions.Length; 103 126 } 127 104 128 105 129 private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) { … … 118 142 } 119 143 120 public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {121 var partitions = GenerateSvmPartitions(problemData, numberOfFolds);122 CrossValidate(problemData, parameters, partitions, out avgTestMse);123 }124 125 public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, Tuple<svm_problem, svm_problem>[] partitions, out double avgTestMse) {126 avgTestMse = 0;127 var calc = new OnlineMeanSquaredErrorCalculator();128 foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {129 var trainingSvmProblem = tuple.Item1;130 var testSvmProblem = tuple.Item2;131 var model = svm.svm_train(trainingSvmProblem, parameters);132 calc.Reset();133 for (int i = 0; i < testSvmProblem.l; ++i)134 calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));135 avgTestMse += calc.MeanSquaredError;144 /// <summary> 145 /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation) 146 /// </summary> 147 /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks> 148 /// <param name="problemData">The problem data</param> 149 /// <param name="numberOfFolds">The number of folds to generate</param> 150 /// <returns>A sequence of folds representing each a sequence of row numbers</returns> 151 private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) { 152 int size = problemData.TrainingPartition.Size; 153 int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder 154 int start = 0, end = f; 155 for (int i = 0; i < numberOfFolds; ++i) { 156 if (r > 0) { ++end; --r; } 157 yield return problemData.TrainingIndices.Skip(start).Take(end - start); 158 start = end; 159 end += f; 136 160 } 137 avgTestMse /= partitions.Length;138 161 } 139 162 … … 145 168 var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile(); 146 169 return setter; 147 }148 149 public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {150 DoubleValue mse = new DoubleValue(Double.MaxValue);151 var bestParam = DefaultParameters();152 var pNames = parameterRanges.Keys.ToList();153 var pRanges = pNames.Select(x => parameterRanges[x]);154 var crossProduct = pRanges.CartesianProduct();155 var setters = pNames.Select(GenerateSetter).ToList();156 var partitions = GenerateSvmPartitions(problemData, numberOfFolds);157 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {158 var list = nuple.ToList();159 var parameters = DefaultParameters();160 for (int i = 0; i < pNames.Count; ++i) {161 var s = setters[i];162 s(parameters, list[i]);163 }164 double testMse;165 CrossValidate(problemData, parameters, partitions, out testMse);166 if (testMse < mse.Value) {167 lock (mse) { mse.Value = testMse; }168 lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); }169 }170 });171 return bestParam;172 170 } 173 171 … … 183 181 throw new ArgumentException("Problem data is neither regression or classification problem data."); 184 182 } 183 185 184 } 186 185 }
Note: See TracChangeset
for help on using the changeset viewer.