- Timestamp:
- 09/02/14 09:16:52 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs
r11308 r11326 95 95 /// <param name="nFolds">The number of folds to generate</param> 96 96 /// <returns>A sequence of folds representing each a sequence of row numbers</returns> 97 public static IEnumerable<IEnumerable<int>> GenerateFolds(I RegressionProblemData problemData, int nFolds) {97 public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int nFolds) { 98 98 int size = problemData.TrainingPartition.Size; 99 99 … … 108 108 } 109 109 110 /// <summary> 111 /// Performs crossvalidation 112 /// </summary> 113 /// <param name="problemData">The problem data</param> 114 /// <param name="parameters">The svm parameters</param> 115 /// <param name="folds">The svm_problem instances for each fold</param> 116 /// <param name="avgTestMSE">The average test mean squared error (not used atm)</param> 117 public static void CrossValidate(IRegressionProblemData problemData, svm_parameter parameters, IEnumerable<IEnumerable<int>> folds, out double avgTestMSE) { 118 avgTestMSE = 0; 119 110 public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numFolds, out double avgTestMse) { 111 avgTestMse = 0; 112 var folds = GenerateFolds(problemData, numFolds).ToList(); 120 113 var calc = new OnlineMeanSquaredErrorCalculator(); 121 var ds = problemData.Dataset; 122 var targetVariable = problemData.TargetVariable; 123 var inputVariables = problemData.AllowedInputVariables; 124 125 var svmProblem = CreateSvmProblem(ds, targetVariable, inputVariables, problemData.TrainingIndices); 126 var partitions = folds.ToList(); 127 128 for (int i = 0; i < partitions.Count; ++i) { 129 var test = partitions[i]; 130 var training = new List<int>(); 131 for (int j = 0; j < i; ++j) 132 training.AddRange(partitions[j]); 133 134 for (int j = i + 1; j < partitions.Count; ++j) 135 training.AddRange(partitions[j]); 136 137 var p = CreateSvmProblem(ds, targetVariable, inputVariables, training); 138 var model = svm.svm_train(p, parameters); 114 var targetVariable = GetTargetVariableName(problemData); 115 for (int i = 0; i < numFolds; ++i) { 116 int p = i; // avoid "access to modified closure" warning below 117 var training = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>()); 118 var testRows = folds[i]; 119 var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, training); 120 var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows); 121 122 var model = svm.svm_train(trainingSvmProblem, parameters); 139 123 calc.Reset(); 140 foreach (var row in test) { 141 calc.Add(svmProblem.y[row], svm.svm_predict(model, svmProblem.x[row])); 142 } 143 double error = calc.MeanSquaredError; 144 avgTestMSE += error; 145 } 146 147 avgTestMSE /= partitions.Count; 148 } 149 150 /// <summary> 151 /// Dynamically generate a setter for svm_parameter fields 152 /// </summary> 153 /// <param name="parameters"></param> 154 /// <param name="fieldName"></param> 155 /// <returns></returns> 124 for (int j = 0; j < testSvmProblem.l; ++j) 125 calc.Add(testSvmProblem.y[j], svm.svm_predict(model, testSvmProblem.x[j])); 126 avgTestMse += calc.MeanSquaredError; 127 } 128 avgTestMse /= numFolds; 129 } 130 131 public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, Tuple<svm_problem, svm_problem>[] partitions, out double avgTestMse) { 132 avgTestMse = 0; 133 var calc = new OnlineMeanSquaredErrorCalculator(); 134 foreach (Tuple<svm_problem, svm_problem> tuple in partitions) { 135 var trainingSvmProblem = tuple.Item1; 136 var testSvmProblem = tuple.Item2; 137 var model = svm.svm_train(trainingSvmProblem, parameters); 138 calc.Reset(); 139 for (int i = 0; i < testSvmProblem.l; ++i) 140 calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i])); 141 avgTestMse += calc.MeanSquaredError; 142 } 143 avgTestMse /= partitions.Length; 144 } 145 156 146 private static Action<svm_parameter, double> GenerateSetter(string fieldName) { 157 147 var targetExp = Expression.Parameter(typeof(svm_parameter)); 158 148 var valueExp = Expression.Parameter(typeof(double)); 159 160 // Expression.Property can be used here as well161 149 var fieldExp = Expression.Field(targetExp, fieldName); 162 150 var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); … … 165 153 } 166 154 167 public static svm_parameter GridSearch(I RegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {155 public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) { 168 156 DoubleValue mse = new DoubleValue(Double.MaxValue); 169 157 var bestParam = DefaultParameters(); 170 158 171 159 // search for C, gamma and epsilon parameter combinations 172 173 160 var pNames = parameterRanges.Keys.ToList(); 174 161 var pRanges = pNames.Select(x => parameterRanges[x]); … … 176 163 var crossProduct = pRanges.CartesianProduct(); 177 164 var setters = pNames.Select(GenerateSetter).ToList(); 165 var folds = GenerateFolds(problemData, numberOfFolds).ToList(); 166 167 var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds]; 168 var targetVariable = GetTargetVariableName(problemData); 169 170 for (int i = 0; i < numberOfFolds; ++i) { 171 int p = i; // avoid "access to modified closure" warning below 172 var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>()); 173 var testRows = folds[i]; 174 var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows); 175 var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows); 176 partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem); 177 } 178 178 179 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => { 179 180 // foreach (var nuple in crossProduct) { … … 184 185 s(parameters, list[i]); 185 186 } 186 double testMSE; 187 CrossValidate(problemData, parameters, folds, out testMSE); 188 if (testMSE < mse.Value) { 189 lock (mse) { mse.Value = testMSE; } 190 lock (bestParam) { // set best parameter values to the best found so far 191 bestParam = (svm_parameter)parameters.Clone(); 192 } 187 double testMse; 188 CrossValidate(problemData, parameters, partitions, out testMse); 189 if (testMse < mse.Value) { 190 lock (mse) { mse.Value = testMse; } 191 lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); } // set best parameter values to the best found so far 193 192 } 194 193 }); 195 194 return bestParam; 196 195 } 196 197 private static string GetTargetVariableName(IDataAnalysisProblemData problemData) { 198 var regressionProblemData = problemData as IRegressionProblemData; 199 var classificationProblemData = problemData as IClassificationProblemData; 200 201 if (regressionProblemData != null) 202 return regressionProblemData.TargetVariable; 203 if (classificationProblemData != null) 204 return classificationProblemData.TargetVariable; 205 206 throw new ArgumentException("Problem data is neither regression or classification problem data."); 207 } 197 208 } 198 209 }
Note: See TracChangeset
for help on using the changeset viewer.