[11514] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using System.Linq.Expressions;
|
---|
| 5 |
|
---|
| 6 | using HeuristicLab.Algorithms.DataAnalysis;
|
---|
[12292] | 7 | using HeuristicLab.Common;
|
---|
[11514] | 8 | using HeuristicLab.Core;
|
---|
| 9 | using HeuristicLab.Data;
|
---|
| 10 | using HeuristicLab.Parameters;
|
---|
| 11 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 12 | using HeuristicLab.Scripting;
|
---|
| 13 |
|
---|
| 14 | using LibSVM;
|
---|
| 15 |
|
---|
| 16 | public class SVMClassificationCrossValidationScript : HeuristicLab.Scripting.CSharpScriptBase {
|
---|
| 17 | /* Maximum degree of parallelism (specifies whether or not the grid search should be parallelized) */
|
---|
| 18 | const int maximumDegreeOfParallelism = 4;
|
---|
| 19 |
|
---|
| 20 | /* Number of crossvalidation folds: */
|
---|
| 21 | const int numberOfFolds = 5;
|
---|
| 22 |
|
---|
| 23 | /* Specify whether the folds should be shuffled */
|
---|
| 24 | const bool shuffleFolds = false;
|
---|
| 25 |
|
---|
| 26 | /* The tunable SVM parameters:
|
---|
| 27 | - "C" (penalty factor) effects the trade-off between complexity and proportion of nonseparable samples and must be selected by the user. Can have any positive value.
|
---|
| 28 | - "nu" is an upper bound on the fraction of margin errors and a lower bound of the fraction of support vectors relative to the total number of training examples.
|
---|
| 29 | - "degree" represents the polynomial kernel degree
|
---|
| 30 | - "eps" (epsilon) determines the level of accuracy of the approximated function. It controls the width of the epsilon-insensitive zone used to fit the training data.
|
---|
| 31 | With optimal values of epsilon, the parameter C has negligible effect.
|
---|
| 32 | - "degree" represents the degree of the polynomial kernel
|
---|
| 33 | - "kernel_type" specifies the kernel to be used: linear, polynomial, radial basis or sigmoidal.
|
---|
| 34 | Valid values: svm_parameter.LINEAR, svm_parameter.POLY, svm_parameter.RBF, svm_parameter.SIGMOID
|
---|
| 35 | Comment or uncomment the parameter ranges below as needed. */
|
---|
| 36 |
|
---|
| 37 | static Dictionary<string, IEnumerable<double>> svmParameterRanges = new Dictionary<string, IEnumerable<double>> {
|
---|
| 38 | { "svm_type", new List<double> {svm_parameter.NU_SVC } },
|
---|
| 39 | { "kernel_type", new List<double> { svm_parameter.RBF }},
|
---|
[12292] | 40 | { "C", SequenceGenerator.GenerateSteps(-1m, 10, 1).Select(x => Math.Pow(2,(double)x)) },
|
---|
| 41 | { "gamma", SequenceGenerator.GenerateSteps(-4m, -1, 1).Select(x => Math.Pow(2,(double)x)) },
|
---|
| 42 | // { "eps", SequenceGenerator.GenerateSteps(-8m, -1, 1).Select(x => Math.Pow(2, (double)x)) },
|
---|
| 43 | { "nu" , SequenceGenerator.GenerateSteps(-10m, 0, 1m).Select(x => Math.Pow(2, (double)x)) },
|
---|
| 44 | // { "degree", SequenceGenerator.GenerateSteps(1m, 4, 1).Select(x => (double)x) }
|
---|
[11514] | 45 | };
|
---|
| 46 |
|
---|
| 47 | static Dictionary<int, string> svmTypes = new Dictionary<int, string> {
|
---|
| 48 | { svm_parameter.NU_SVC, "NU_SVC" },
|
---|
| 49 | { svm_parameter.C_SVC, "C_SVC" }
|
---|
| 50 | };
|
---|
| 51 |
|
---|
| 52 | static Dictionary<int, string> kernelTypes = new Dictionary<int, string> {
|
---|
| 53 | { svm_parameter.LINEAR, "LINEAR" },
|
---|
| 54 | { svm_parameter.POLY, "POLY" },
|
---|
| 55 | { svm_parameter.RBF, "RBF" },
|
---|
| 56 | { svm_parameter.SIGMOID, "SIGMOID" }
|
---|
| 57 | };
|
---|
| 58 |
|
---|
[11545] | 59 | private static SupportVectorClassificationSolution SvmGridSearch(IClassificationProblemData problemData, out svm_parameter bestParameters, out int nSv, out double cvMse) {
|
---|
| 60 | bestParameters = SupportVectorMachineUtil.GridSearch(out cvMse, problemData, svmParameterRanges, numberOfFolds, shuffleFolds, maximumDegreeOfParallelism);
|
---|
[11514] | 61 | double trainingError, testError;
|
---|
| 62 | string svmType = svmTypes[bestParameters.svm_type];
|
---|
| 63 | string kernelType = kernelTypes[bestParameters.kernel_type];
|
---|
| 64 | var svm_solution = SupportVectorClassification.CreateSupportVectorClassificationSolution(problemData, problemData.AllowedInputVariables, svmType, kernelType,
|
---|
| 65 | bestParameters.C, bestParameters.nu, bestParameters.gamma, bestParameters.degree, out trainingError, out testError, out nSv);
|
---|
| 66 | return svm_solution;
|
---|
| 67 | }
|
---|
| 68 |
|
---|
| 69 | public override void Main() {
|
---|
| 70 | var variables = (Variables)vars;
|
---|
| 71 | var item = variables.SingleOrDefault(x => x.Value is IClassificationProblem || x.Value is IClassificationProblemData);
|
---|
| 72 | if (item.Equals(default(KeyValuePair<string, object>)))
|
---|
| 73 | throw new ArgumentException("Could not find a suitable problem or problem data.");
|
---|
| 74 |
|
---|
| 75 | string name = item.Key;
|
---|
| 76 | IClassificationProblemData problemData;
|
---|
| 77 | if (item.Value is IClassificationProblem)
|
---|
| 78 | problemData = ((IClassificationProblem)item.Value).ProblemData;
|
---|
| 79 | else
|
---|
| 80 | problemData = (IClassificationProblemData)item.Value;
|
---|
| 81 |
|
---|
| 82 | int nSv; // number of support vectors
|
---|
[11545] | 83 | double cvMse;
|
---|
[11514] | 84 | svm_parameter bestParameters;
|
---|
[11545] | 85 | var bestSolution = SvmGridSearch(problemData, out bestParameters, out nSv, out cvMse);
|
---|
[11514] | 86 |
|
---|
| 87 | vars["bestSolution"] = bestSolution;
|
---|
| 88 | Console.WriteLine(name + " parameters: C = {0}, g = {1:0.000}, eps = {2:0.000}, nu = {3:0.000}, degree = {4}", bestParameters.C, bestParameters.gamma, bestParameters.eps, bestParameters.nu, bestParameters.degree);
|
---|
| 89 | Console.WriteLine(name + " best solution accuracy (training): " + bestSolution.TrainingAccuracy + ", accuracy (test): " + bestSolution.TestAccuracy);
|
---|
| 90 |
|
---|
| 91 | var bestParametersCollection = new ParameterCollection();
|
---|
| 92 | foreach (var p in svmParameterRanges.Keys) {
|
---|
| 93 | var getter = GenerateGetter(p);
|
---|
| 94 | bestParametersCollection.Add(new FixedValueParameter<DoubleValue>(p, new DoubleValue(getter(bestParameters))));
|
---|
| 95 | }
|
---|
| 96 | vars["bestParameters"] = bestParametersCollection;
|
---|
| 97 | }
|
---|
| 98 |
|
---|
| 99 | private static Func<svm_parameter, double> GenerateGetter(string field) {
|
---|
| 100 | var paramExpr = Expression.Parameter(typeof(svm_parameter));
|
---|
| 101 | var getterExpr = Expression.Convert(Expression.Field(paramExpr, field), typeof(double)); // cast to double
|
---|
| 102 | Func<svm_parameter, double> f = Expression.Lambda<Func<svm_parameter, double>>(getterExpr, paramExpr).Compile();
|
---|
| 103 | return f;
|
---|
| 104 | }
|
---|
| 105 | }
|
---|
[11545] | 106 |
|
---|