Index: /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs
===================================================================
--- /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs (revision 11338)
+++ /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs (revision 11339)
@@ -84,22 +84,46 @@
}
- ///
- /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
- ///
- /// This method is aimed to be lightweight and as such does not clone the dataset.
- /// The problem data
- /// The number of folds to generate
- /// A sequence of folds representing each a sequence of row numbers
- public static IEnumerable> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
- int size = problemData.TrainingPartition.Size;
- int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
- int start = 0, end = f;
- for (int i = 0; i < numberOfFolds; ++i) {
- if (r > 0) { ++end; --r; }
- yield return problemData.TrainingIndices.Skip(start).Take(end - start);
- start = end;
- end += f;
+ public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
+ var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
+ CalculateCrossValidationPartitions(partitions, parameters, out avgTestMse);
+ }
+
+ public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary> parameterRanges, int maxDegreeOfParallelism = 1) {
+ DoubleValue mse = new DoubleValue(Double.MaxValue);
+ var bestParam = DefaultParameters();
+ var crossProduct = parameterRanges.Values.CartesianProduct();
+ var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
+ var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
+ Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
+ var parameters = DefaultParameters();
+ var parameterValues = parameterCombination.ToList();
+ for (int i = 0; i < parameterValues.Count; ++i) {
+ setters[i](parameters, parameterValues[i]);
+ }
+ double testMse;
+ CalculateCrossValidationPartitions(partitions, parameters, out testMse);
+ if (testMse < mse.Value) {
+ lock (mse) { mse.Value = testMse; }
+ lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); }
+ }
+ });
+ return bestParam;
+ }
+
+ private static void CalculateCrossValidationPartitions(Tuple[] partitions, svm_parameter parameters, out double avgTestMse) {
+ avgTestMse = 0;
+ var calc = new OnlineMeanSquaredErrorCalculator();
+ foreach (Tuple tuple in partitions) {
+ var trainingSvmProblem = tuple.Item1;
+ var testSvmProblem = tuple.Item2;
+ var model = svm.svm_train(trainingSvmProblem, parameters);
+ calc.Reset();
+ for (int i = 0; i < testSvmProblem.l; ++i)
+ calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
+ avgTestMse += calc.MeanSquaredError;
}
+ avgTestMse /= partitions.Length;
}
+
private static Tuple[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
@@ -118,22 +142,21 @@
}
- public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
- var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
- CrossValidate(problemData, parameters, partitions, out avgTestMse);
- }
-
- public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, Tuple[] partitions, out double avgTestMse) {
- avgTestMse = 0;
- var calc = new OnlineMeanSquaredErrorCalculator();
- foreach (Tuple tuple in partitions) {
- var trainingSvmProblem = tuple.Item1;
- var testSvmProblem = tuple.Item2;
- var model = svm.svm_train(trainingSvmProblem, parameters);
- calc.Reset();
- for (int i = 0; i < testSvmProblem.l; ++i)
- calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
- avgTestMse += calc.MeanSquaredError;
+ ///
+ /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
+ ///
+ /// This method is aimed to be lightweight and as such does not clone the dataset.
+ /// The problem data
+ /// The number of folds to generate
+ /// A sequence of folds representing each a sequence of row numbers
+ private static IEnumerable> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
+ int size = problemData.TrainingPartition.Size;
+ int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
+ int start = 0, end = f;
+ for (int i = 0; i < numberOfFolds; ++i) {
+ if (r > 0) { ++end; --r; }
+ yield return problemData.TrainingIndices.Skip(start).Take(end - start);
+ start = end;
+ end += f;
}
- avgTestMse /= partitions.Length;
}
@@ -145,29 +168,4 @@
var setter = Expression.Lambda>(assignExp, targetExp, valueExp).Compile();
return setter;
- }
-
- public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary> parameterRanges, int maxDegreeOfParallelism = 1) {
- DoubleValue mse = new DoubleValue(Double.MaxValue);
- var bestParam = DefaultParameters();
- var pNames = parameterRanges.Keys.ToList();
- var pRanges = pNames.Select(x => parameterRanges[x]);
- var crossProduct = pRanges.CartesianProduct();
- var setters = pNames.Select(GenerateSetter).ToList();
- var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
- Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
- var list = nuple.ToList();
- var parameters = DefaultParameters();
- for (int i = 0; i < pNames.Count; ++i) {
- var s = setters[i];
- s(parameters, list[i]);
- }
- double testMse;
- CrossValidate(problemData, parameters, partitions, out testMse);
- if (testMse < mse.Value) {
- lock (mse) { mse.Value = testMse; }
- lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); }
- }
- });
- return bestParam;
}
@@ -183,4 +181,5 @@
throw new ArgumentException("Problem data is neither regression or classification problem data.");
}
+
}
}