Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Tests/HeuristicLab.Scripting-3.3/Script Sources/GridSearchSVMRegressionScriptSource.cs @ 16147

Last change on this file since 16147 was 12740, checked in by abeham, 10 years ago

#2301: merged 12292,12293 to stable

File size: 5.8 KB
RevLine 
[11514]1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Linq.Expressions;
5
6using HeuristicLab.Algorithms.DataAnalysis;
[12740]7using HeuristicLab.Common;
[11514]8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Parameters;
11using HeuristicLab.Problems.DataAnalysis;
12using HeuristicLab.Scripting;
13
14using LibSVM;
15
16public class SVMRegressionCrossValidationScript : HeuristicLab.Scripting.CSharpScriptBase {
17  /* Maximum degree of parallelism (specifies whether or not the grid search should be parallelized) */
18  const int maximumDegreeOfParallelism = 4;
19
20  /* Number of crossvalidation folds: */
21  const int numberOfFolds = 5;
22
23  /* Specify whether the folds should be shuffled */
24  const bool shuffleFolds = false;
25
26  /* The tunable SVM parameters:
27     - "C" (penalty factor) effects the trade-off between complexity and proportion of nonseparable samples and must be selected by the user. Can have any positive value.
28     - "nu" is an upper bound on the fraction of margin errors and a lower bound of the fraction of support vectors relative to the total number of training examples.
29     - "degree" represents the polynomial kernel degree
30     - "eps" (epsilon) determines the level of accuracy of the approximated function. It controls the width of the epsilon-insensitive zone used to fit the training data.
31       With optimal values of epsilon, the parameter C has negligible effect.
32     - "degree" represents the degree of the polynomial kernel
33     - "kernel_type" specifies the kernel to be used: linear, polynomial, radial basis or sigmoidal.
34       Valid values: svm_parameter.LINEAR, svm_parameter.POLY, svm_parameter.RBF, svm_parameter.SIGMOID
35     Comment or uncomment the parameter ranges below as needed.  */
36
37  static Dictionary<string, IEnumerable<double>> svmParameterRanges = new Dictionary<string, IEnumerable<double>> {
38        { "svm_type", new List<double> {svm_parameter.NU_SVR } },
39        { "kernel_type", new List<double> { svm_parameter.RBF }},
[12740]40        { "C", SequenceGenerator.GenerateSteps(-1m, 12, 1).Select(x => Math.Pow(2, (double)x)) },
41        { "gamma", SequenceGenerator.GenerateSteps(-4m, -1, 1).Select(x => Math.Pow(2, (double)x)) },
42//        { "eps", SequenceGenerator.GenerateSteps(-8m, -1, 1).Select(x => Math.Pow(2, (double)x)) },
43        { "nu" , SequenceGenerator.GenerateSteps(-10m, 0, 1m).Select(x => Math.Pow(2, (double)x)) },
44//        { "degree", SequenceGenerator.GenerateSteps(1m, 4, 1).Select(x => (double)x) }
[11514]45  };
46
47  static Dictionary<int, string> svmTypes = new Dictionary<int, string> {
48    { svm_parameter.NU_SVR, "NU_SVR" },
49    { svm_parameter.EPSILON_SVR, "EPSILON_SVR" }
50  };
51
52  static Dictionary<int, string> kernelTypes = new Dictionary<int, string> {
53    { svm_parameter.LINEAR, "LINEAR" },
54    { svm_parameter.POLY, "POLY" },
55    { svm_parameter.RBF, "RBF" },
56    { svm_parameter.SIGMOID, "SIGMOID" }
57  };
58
[11545]59  private static SupportVectorRegressionSolution SvmGridSearch(IRegressionProblemData problemData, out svm_parameter bestParameters, out int nSv, out double cvMse) {
60    bestParameters = SupportVectorMachineUtil.GridSearch(out cvMse, problemData, svmParameterRanges, numberOfFolds, shuffleFolds, maximumDegreeOfParallelism);
[11514]61    double trainingError, testError;
62    string svmType = svmTypes[bestParameters.svm_type];
63    string kernelType = kernelTypes[bestParameters.kernel_type];
64    var svm_solution = SupportVectorRegression.CreateSupportVectorRegressionSolution(problemData, problemData.AllowedInputVariables, svmType, kernelType,
65                       bestParameters.C, bestParameters.nu, bestParameters.gamma, bestParameters.eps, bestParameters.degree, out trainingError, out testError, out nSv);
66    return svm_solution;
67  }
68
69  public override void Main() {
70    var variables = (Variables)vars;
71    var item = variables.SingleOrDefault(x => x.Value is IRegressionProblem || x.Value is IRegressionProblemData);
72    if (item.Equals(default(KeyValuePair<string, object>)))
73      throw new ArgumentException("Could not find a suitable problem or problem data.");
74
75    string name = item.Key;
76    IRegressionProblemData problemData;
77    if (item.Value is IRegressionProblem)
78      problemData = ((IRegressionProblem)item.Value).ProblemData;
79    else
80      problemData = (IRegressionProblemData)item.Value;
81
82    int nSv; // number of support vectors
[11545]83    double cvMse;
[11514]84    svm_parameter bestParameters;
[11545]85    var bestSolution = SvmGridSearch(problemData, out bestParameters, out nSv, out cvMse);
[11514]86
87    vars["bestSolution"] = bestSolution;
88    Console.WriteLine(name + " parameters: C = {0}, g = {1:0.000}, eps = {2:0.000}, nu = {3:0.000}, degree = {4}", bestParameters.C, bestParameters.gamma, bestParameters.eps, bestParameters.nu, bestParameters.degree);
89    Console.WriteLine(name + " best solution mse (training): " + bestSolution.TrainingMeanSquaredError + ", mse (test): " + bestSolution.TestMeanSquaredError);
90    Console.WriteLine(name + " best solution R2 (training): " + bestSolution.TrainingRSquared + ", R2 (test): " + bestSolution.TestRSquared);
91
92    var bestParametersCollection = new ParameterCollection();
93    foreach (var p in svmParameterRanges.Keys) {
94      var getter = GenerateGetter(p);
95      bestParametersCollection.Add(new FixedValueParameter<DoubleValue>(p, new DoubleValue(getter(bestParameters))));
96    }
97    vars["bestParameters"] = bestParametersCollection;
98  }
99
100  private static Func<svm_parameter, double> GenerateGetter(string field) {
101    var paramExpr = Expression.Parameter(typeof(svm_parameter));
102    var getterExpr = Expression.Convert(Expression.Field(paramExpr, field), typeof(double)); // cast to double
103    Func<svm_parameter, double> f = Expression.Lambda<Func<svm_parameter, double>>(getterExpr, paramExpr).Compile();
104    return f;
105  }
106}
[11545]107
Note: See TracBrowser for help on using the repository browser.