Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorRegression.cs @ 18079

Last change on this file since 18079 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 13.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HEAL.Attic;
32using HeuristicLab.Problems.DataAnalysis;
33using LibSVM;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  /// <summary>
37  /// Support vector machine regression data analysis algorithm.
38  /// </summary>
39  [Item("Support Vector Regression (SVM)", "Support vector machine regression data analysis algorithm (wrapper for libSVM).")]
40  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 110)]
41  [StorableType("645A21E5-EF07-46BF-AA04-A616165F0EF4")]
42  public sealed class SupportVectorRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
43    private const string SvmTypeParameterName = "SvmType";
44    private const string KernelTypeParameterName = "KernelType";
45    private const string CostParameterName = "Cost";
46    private const string NuParameterName = "Nu";
47    private const string GammaParameterName = "Gamma";
48    private const string EpsilonParameterName = "Epsilon";
49    private const string DegreeParameterName = "Degree";
50    private const string CreateSolutionParameterName = "CreateSolution";
51
52    #region parameter properties
53    public IConstrainedValueParameter<StringValue> SvmTypeParameter {
54      get { return (IConstrainedValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
55    }
56    public IConstrainedValueParameter<StringValue> KernelTypeParameter {
57      get { return (IConstrainedValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
58    }
59    public IValueParameter<DoubleValue> NuParameter {
60      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
61    }
62    public IValueParameter<DoubleValue> CostParameter {
63      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
64    }
65    public IValueParameter<DoubleValue> GammaParameter {
66      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
67    }
68    public IValueParameter<DoubleValue> EpsilonParameter {
69      get { return (IValueParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
70    }
71    public IValueParameter<IntValue> DegreeParameter {
72      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
73    }
74    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
75      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
76    }
77    #endregion
78    #region properties
79    public StringValue SvmType {
80      get { return SvmTypeParameter.Value; }
81      set { SvmTypeParameter.Value = value; }
82    }
83    public StringValue KernelType {
84      get { return KernelTypeParameter.Value; }
85      set { KernelTypeParameter.Value = value; }
86    }
87    public DoubleValue Nu {
88      get { return NuParameter.Value; }
89    }
90    public DoubleValue Cost {
91      get { return CostParameter.Value; }
92    }
93    public DoubleValue Gamma {
94      get { return GammaParameter.Value; }
95    }
96    public DoubleValue Epsilon {
97      get { return EpsilonParameter.Value; }
98    }
99    public IntValue Degree {
100      get { return DegreeParameter.Value; }
101    }
102    public bool CreateSolution {
103      get { return CreateSolutionParameter.Value.Value; }
104      set { CreateSolutionParameter.Value.Value = value; }
105    }
106    #endregion
107    [StorableConstructor]
108    private SupportVectorRegression(StorableConstructorFlag _) : base(_) { }
109    private SupportVectorRegression(SupportVectorRegression original, Cloner cloner)
110      : base(original, cloner) {
111    }
112    public SupportVectorRegression()
113      : base() {
114      Problem = new RegressionProblem();
115
116      List<StringValue> svrTypes = (from type in new List<string> { "NU_SVR", "EPSILON_SVR" }
117                                    select new StringValue(type).AsReadOnly())
118                                   .ToList();
119      ItemSet<StringValue> svrTypeSet = new ItemSet<StringValue>(svrTypes);
120      List<StringValue> kernelTypes = (from type in new List<string> { "LINEAR", "POLY", "SIGMOID", "RBF" }
121                                       select new StringValue(type).AsReadOnly())
122                                   .ToList();
123      ItemSet<StringValue> kernelTypeSet = new ItemSet<StringValue>(kernelTypes);
124      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0]));
125      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3]));
126      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter of the nu-SVR.", new DoubleValue(0.5)));
127      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of epsilon-SVR and nu-SVR.", new DoubleValue(1.0)));
128      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
129      Parameters.Add(new ValueParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR.", new DoubleValue(0.1)));
130      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
131      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
132      Parameters[CreateSolutionParameterName].Hidden = true;
133    }
134    [StorableHook(HookType.AfterDeserialization)]
135    private void AfterDeserialization() {
136      #region backwards compatibility (change with 3.4)
137
138      if (!Parameters.ContainsKey(DegreeParameterName)) {
139        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
140          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
141      }
142      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
143        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
144        Parameters[CreateSolutionParameterName].Hidden = true;
145      }
146      #endregion
147    }
148
149    public override IDeepCloneable Clone(Cloner cloner) {
150      return new SupportVectorRegression(this, cloner);
151    }
152
153    #region support vector regression
154    protected override void Run(CancellationToken cancellationToken) {
155      IRegressionProblemData problemData = Problem.ProblemData;
156      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
157      int nSv;
158      ISupportVectorMachineModel model;
159      Run(problemData, selectedInputVariables, SvmType.Value, KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, Degree.Value, out model, out nSv);
160
161      if (CreateSolution) {
162        var solution = new SupportVectorRegressionSolution((SupportVectorMachineModel)model, (IRegressionProblemData)problemData.Clone());
163        Results.Add(new Result("Support vector regression solution", "The support vector regression solution.", solution));
164      }
165
166      Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", new IntValue(nSv)));
167
168
169      {
170        // calculate regression model metrics
171        var ds = problemData.Dataset;
172        var trainRows = problemData.TrainingIndices;
173        var testRows = problemData.TestIndices;
174        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
175        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
176        var yPredTrain = model.GetEstimatedValues(ds, trainRows).ToArray();
177        var yPredTest = model.GetEstimatedValues(ds, testRows).ToArray();
178
179        OnlineCalculatorError error;
180        var trainMse = OnlineMeanSquaredErrorCalculator.Calculate(yPredTrain, yTrain, out error);
181        if (error != OnlineCalculatorError.None) trainMse = double.MaxValue;
182        var testMse = OnlineMeanSquaredErrorCalculator.Calculate(yPredTest, yTest, out error);
183        if (error != OnlineCalculatorError.None) testMse = double.MaxValue;
184
185        Results.Add(new Result("Mean squared error (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainMse)));
186        Results.Add(new Result("Mean squared error (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testMse)));
187
188
189        var trainMae = OnlineMeanAbsoluteErrorCalculator.Calculate(yPredTrain, yTrain, out error);
190        if (error != OnlineCalculatorError.None) trainMae = double.MaxValue;
191        var testMae = OnlineMeanAbsoluteErrorCalculator.Calculate(yPredTest, yTest, out error);
192        if (error != OnlineCalculatorError.None) testMae = double.MaxValue;
193
194        Results.Add(new Result("Mean absolute error (training)", "The mean of absolute errors of the SVR solution on the training partition.", new DoubleValue(trainMae)));
195        Results.Add(new Result("Mean absolute error (test)", "The mean of absolute errors of the SVR solution on the test partition.", new DoubleValue(testMae)));
196
197
198        var trainRelErr = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(yPredTrain, yTrain, out error);
199        if (error != OnlineCalculatorError.None) trainRelErr = double.MaxValue;
200        var testRelErr = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(yPredTest, yTest, out error);
201        if (error != OnlineCalculatorError.None) testRelErr = double.MaxValue;
202
203        Results.Add(new Result("Average relative error (training)", "The mean of relative errors of the SVR solution on the training partition.", new DoubleValue(trainRelErr)));
204        Results.Add(new Result("Average relative error (test)", "The mean of relative errors of the SVR solution on the test partition.", new DoubleValue(testRelErr)));
205      }
206    }
207
208    // BackwardsCompatibility3.4
209    #region Backwards compatible code, remove with 3.5
210    // for compatibility with old API
211    public static SupportVectorRegressionSolution CreateSupportVectorRegressionSolution(
212      IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
213      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon, int degree,
214      out double trainingR2, out double testR2, out int nSv) {
215      ISupportVectorMachineModel model;
216      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, epsilon, degree, out model, out nSv);
217
218      var solution = new SupportVectorRegressionSolution((SupportVectorMachineModel)model, (IRegressionProblemData)problemData.Clone());
219      trainingR2 = solution.TrainingRSquared;
220      testR2 = solution.TestRSquared;
221      return solution;
222    }
223    #endregion
224
225    public static void Run(IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
226      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon, int degree,
227      out ISupportVectorMachineModel model, out int nSv) {
228      var dataset = problemData.Dataset;
229      string targetVariable = problemData.TargetVariable;
230      IEnumerable<int> rows = problemData.TrainingIndices;
231
232      svm_parameter parameter = new svm_parameter {
233        svm_type = GetSvmType(svmType),
234        kernel_type = GetKernelType(kernelType),
235        C = cost,
236        nu = nu,
237        gamma = gamma,
238        p = epsilon,
239        cache_size = 500,
240        probability = 0,
241        eps = 0.001,
242        degree = degree,
243        shrinking = 1,
244        coef0 = 0
245      };
246
247      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
248      RangeTransform rangeTransform = RangeTransform.Compute(problem);
249      svm_problem scaledProblem = rangeTransform.Scale(problem);
250      var svmModel = svm.svm_train(scaledProblem, parameter);
251      nSv = svmModel.SV.Length;
252
253      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables);
254    }
255
256    private static int GetSvmType(string svmType) {
257      if (svmType == "NU_SVR") return svm_parameter.NU_SVR;
258      if (svmType == "EPSILON_SVR") return svm_parameter.EPSILON_SVR;
259      throw new ArgumentException("Unknown SVM type");
260    }
261
262    private static int GetKernelType(string kernelType) {
263      if (kernelType == "LINEAR") return svm_parameter.LINEAR;
264      if (kernelType == "POLY") return svm_parameter.POLY;
265      if (kernelType == "SIGMOID") return svm_parameter.SIGMOID;
266      if (kernelType == "RBF") return svm_parameter.RBF;
267      throw new ArgumentException("Unknown kernel type");
268    }
269    #endregion
270  }
271}
Note: See TracBrowser for help on using the repository browser.