- Timestamp:
- 09/03/14 15:12:04 (10 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
TabularUnified trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs ¶
r11171 r11337 142 142 143 143 public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables, 144 string svmType, string kernelType, double cost, double nu, double gamma, int degree, 145 out double trainingAccuracy, out double testAccuracy, out int nSv) { 144 string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) { 145 return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetKernelType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree, 146 out trainingAccuracy, out testAccuracy, out nSv); 147 } 148 149 public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables, 150 int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) { 146 151 Dataset dataset = problemData.Dataset; 147 152 string targetVariable = problemData.TargetVariable; … … 150 155 //extract SVM parameters from scope and set them 151 156 svm_parameter parameter = new svm_parameter(); 152 parameter.svm_type = GetSvmType(svmType);153 parameter.kernel_type = GetKernelType(kernelType);157 parameter.svm_type = svmType; 158 parameter.kernel_type = kernelType; 154 159 parameter.C = cost; 155 160 parameter.nu = nu; … … 161 166 parameter.shrinking = 1; 162 167 parameter.coef0 = 0; 163 164 168 165 169 var weightLabels = new List<int>(); -
TabularUnified trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs ¶
r11326 r11337 39 39 /// <returns>A problem data type that can be used to train a support vector machine.</returns> 40 40 public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) { 41 double[] targetVector = 42 dataset.GetDoubleValues(targetVariable, rowIndices).ToArray(); 43 41 double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray(); 44 42 svm_node[][] nodes = new svm_node[targetVector.Length][]; 45 List<svm_node> tempRow;46 43 int maxNodeIndex = 0; 47 44 int svmProblemRowIndex = 0; 48 45 List<string> inputVariablesList = inputVariables.ToList(); 49 46 foreach (int row in rowIndices) { 50 tempRow = new List<svm_node>();47 List<svm_node> tempRow = new List<svm_node>(); 51 48 int colIndex = 1; // make sure the smallest node index for SVM = 1 52 49 foreach (var inputVariable in inputVariablesList) { … … 62 59 nodes[svmProblemRowIndex++] = tempRow.ToArray(); 63 60 } 64 65 return new svm_problem() { l = targetVector.Length, y = targetVector, x = nodes }; 61 return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes }; 66 62 } 67 63 … … 89 85 90 86 /// <summary> 91 /// Generate a collection of training indices corresponding to folds in the data (used for crossvalidation)87 /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation) 92 88 /// </summary> 93 89 /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks> 94 90 /// <param name="problemData">The problem data</param> 95 /// <param name="n Folds">The number of folds to generate</param>91 /// <param name="numberOfFolds">The number of folds to generate</param> 96 92 /// <returns>A sequence of folds representing each a sequence of row numbers</returns> 97 public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int n Folds) {93 public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) { 98 94 int size = problemData.TrainingPartition.Size; 99 100 int foldSize = size / nFolds; // rounding to integer 101 var trainingIndices = problemData.TrainingIndices; 102 103 for (int i = 0; i < nFolds; ++i) { 104 int n = i * foldSize; 105 int s = n + 2 * foldSize > size ? foldSize + size % foldSize : foldSize; 106 yield return trainingIndices.Skip(n).Take(s); 95 int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder 96 int start = 0, end = f; 97 for (int i = 0; i < numberOfFolds; ++i) { 98 if (r > 0) { ++end; --r; } 99 yield return problemData.TrainingIndices.Skip(start).Take(end - start); 100 start = end; 101 end += f; 107 102 } 108 103 } 109 104 110 public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numFolds, out double avgTestMse) { 111 avgTestMse = 0; 112 var folds = GenerateFolds(problemData, numFolds).ToList(); 113 var calc = new OnlineMeanSquaredErrorCalculator(); 105 private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) { 106 var folds = GenerateFolds(problemData, numberOfFolds).ToList(); 114 107 var targetVariable = GetTargetVariableName(problemData); 115 for (int i = 0; i < numFolds; ++i) { 108 var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds]; 109 for (int i = 0; i < numberOfFolds; ++i) { 116 110 int p = i; // avoid "access to modified closure" warning below 117 var training = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());111 var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>()); 118 112 var testRows = folds[i]; 119 var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, training );113 var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows); 120 114 var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows); 115 partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem); 116 } 117 return partitions; 118 } 121 119 122 var model = svm.svm_train(trainingSvmProblem, parameters); 123 calc.Reset(); 124 for (int j = 0; j < testSvmProblem.l; ++j) 125 calc.Add(testSvmProblem.y[j], svm.svm_predict(model, testSvmProblem.x[j])); 126 avgTestMse += calc.MeanSquaredError; 127 } 128 avgTestMse /= numFolds; 120 public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) { 121 var partitions = GenerateSvmPartitions(problemData, numberOfFolds); 122 CrossValidate(problemData, parameters, partitions, out avgTestMse); 129 123 } 130 124 … … 156 150 DoubleValue mse = new DoubleValue(Double.MaxValue); 157 151 var bestParam = DefaultParameters(); 158 159 // search for C, gamma and epsilon parameter combinations160 152 var pNames = parameterRanges.Keys.ToList(); 161 153 var pRanges = pNames.Select(x => parameterRanges[x]); 162 163 154 var crossProduct = pRanges.CartesianProduct(); 164 155 var setters = pNames.Select(GenerateSetter).ToList(); 165 var folds = GenerateFolds(problemData, numberOfFolds).ToList(); 166 167 var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds]; 168 var targetVariable = GetTargetVariableName(problemData); 169 170 for (int i = 0; i < numberOfFolds; ++i) { 171 int p = i; // avoid "access to modified closure" warning below 172 var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>()); 173 var testRows = folds[i]; 174 var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows); 175 var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows); 176 partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem); 177 } 178 156 var partitions = GenerateSvmPartitions(problemData, numberOfFolds); 179 157 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => { 180 // foreach (var nuple in crossProduct) {181 158 var list = nuple.ToList(); 182 159 var parameters = DefaultParameters(); … … 189 166 if (testMse < mse.Value) { 190 167 lock (mse) { mse.Value = testMse; } 191 lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); } // set best parameter values to the best found so far168 lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); } 192 169 } 193 170 });
Note: See TracChangeset
for help on using the changeset viewer.