Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorRegression.cs @ 13846

Last change on this file since 13846 was 13297, checked in by gkronber, 9 years ago

#2454: merged r13238:13239 from trunk to stable

File size: 13.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using LibSVM;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// Support vector machine regression data analysis algorithm.
37  /// </summary>
38  [Item("Support Vector Regression (SVM)", "Support vector machine regression data analysis algorithm (wrapper for libSVM).")]
39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 110)]
40  [StorableClass]
41  public sealed class SupportVectorRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
42    private const string SvmTypeParameterName = "SvmType";
43    private const string KernelTypeParameterName = "KernelType";
44    private const string CostParameterName = "Cost";
45    private const string NuParameterName = "Nu";
46    private const string GammaParameterName = "Gamma";
47    private const string EpsilonParameterName = "Epsilon";
48    private const string DegreeParameterName = "Degree";
49    private const string CreateSolutionParameterName = "CreateSolution";
50
51    #region parameter properties
52    public IConstrainedValueParameter<StringValue> SvmTypeParameter {
53      get { return (IConstrainedValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
54    }
55    public IConstrainedValueParameter<StringValue> KernelTypeParameter {
56      get { return (IConstrainedValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
57    }
58    public IValueParameter<DoubleValue> NuParameter {
59      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
60    }
61    public IValueParameter<DoubleValue> CostParameter {
62      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
63    }
64    public IValueParameter<DoubleValue> GammaParameter {
65      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
66    }
67    public IValueParameter<DoubleValue> EpsilonParameter {
68      get { return (IValueParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
69    }
70    public IValueParameter<IntValue> DegreeParameter {
71      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
72    }
73    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
74      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
75    }
76    #endregion
77    #region properties
78    public StringValue SvmType {
79      get { return SvmTypeParameter.Value; }
80      set { SvmTypeParameter.Value = value; }
81    }
82    public StringValue KernelType {
83      get { return KernelTypeParameter.Value; }
84      set { KernelTypeParameter.Value = value; }
85    }
86    public DoubleValue Nu {
87      get { return NuParameter.Value; }
88    }
89    public DoubleValue Cost {
90      get { return CostParameter.Value; }
91    }
92    public DoubleValue Gamma {
93      get { return GammaParameter.Value; }
94    }
95    public DoubleValue Epsilon {
96      get { return EpsilonParameter.Value; }
97    }
98    public IntValue Degree {
99      get { return DegreeParameter.Value; }
100    }
101    public bool CreateSolution {
102      get { return CreateSolutionParameter.Value.Value; }
103      set { CreateSolutionParameter.Value.Value = value; }
104    }
105    #endregion
106    [StorableConstructor]
107    private SupportVectorRegression(bool deserializing) : base(deserializing) { }
108    private SupportVectorRegression(SupportVectorRegression original, Cloner cloner)
109      : base(original, cloner) {
110    }
111    public SupportVectorRegression()
112      : base() {
113      Problem = new RegressionProblem();
114
115      List<StringValue> svrTypes = (from type in new List<string> { "NU_SVR", "EPSILON_SVR" }
116                                    select new StringValue(type).AsReadOnly())
117                                   .ToList();
118      ItemSet<StringValue> svrTypeSet = new ItemSet<StringValue>(svrTypes);
119      List<StringValue> kernelTypes = (from type in new List<string> { "LINEAR", "POLY", "SIGMOID", "RBF" }
120                                       select new StringValue(type).AsReadOnly())
121                                   .ToList();
122      ItemSet<StringValue> kernelTypeSet = new ItemSet<StringValue>(kernelTypes);
123      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0]));
124      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3]));
125      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter of the nu-SVR.", new DoubleValue(0.5)));
126      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of epsilon-SVR and nu-SVR.", new DoubleValue(1.0)));
127      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
128      Parameters.Add(new ValueParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR.", new DoubleValue(0.1)));
129      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
130      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
131      Parameters[CreateSolutionParameterName].Hidden = true;
132    }
133    [StorableHook(HookType.AfterDeserialization)]
134    private void AfterDeserialization() {
135      #region backwards compatibility (change with 3.4)
136
137      if (!Parameters.ContainsKey(DegreeParameterName)) {
138        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
139          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
140      }
141      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
142        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
143        Parameters[CreateSolutionParameterName].Hidden = true;
144      }
145      #endregion
146    }
147
148    public override IDeepCloneable Clone(Cloner cloner) {
149      return new SupportVectorRegression(this, cloner);
150    }
151
152    #region support vector regression
153    protected override void Run() {
154      IRegressionProblemData problemData = Problem.ProblemData;
155      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
156      int nSv;
157      ISupportVectorMachineModel model;
158      Run(problemData, selectedInputVariables, SvmType.Value, KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, Degree.Value, out model, out nSv);
159
160      if (CreateSolution) {
161        var solution = new SupportVectorRegressionSolution((SupportVectorMachineModel)model, (IRegressionProblemData)problemData.Clone());
162        Results.Add(new Result("Support vector regression solution", "The support vector regression solution.", solution));
163      }
164
165      Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", new IntValue(nSv)));
166
167
168      {
169        // calculate regression model metrics
170        var ds = problemData.Dataset;
171        var trainRows = problemData.TrainingIndices;
172        var testRows = problemData.TestIndices;
173        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
174        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
175        var yPredTrain = model.GetEstimatedValues(ds, trainRows).ToArray();
176        var yPredTest = model.GetEstimatedValues(ds, testRows).ToArray();
177
178        OnlineCalculatorError error;
179        var trainMse = OnlineMeanSquaredErrorCalculator.Calculate(yPredTrain, yTrain, out error);
180        if (error != OnlineCalculatorError.None) trainMse = double.MaxValue;
181        var testMse = OnlineMeanSquaredErrorCalculator.Calculate(yPredTest, yTest, out error);
182        if (error != OnlineCalculatorError.None) testMse = double.MaxValue;
183
184        Results.Add(new Result("Mean squared error (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainMse)));
185        Results.Add(new Result("Mean squared error (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testMse)));
186
187
188        var trainMae = OnlineMeanAbsoluteErrorCalculator.Calculate(yPredTrain, yTrain, out error);
189        if (error != OnlineCalculatorError.None) trainMae = double.MaxValue;
190        var testMae = OnlineMeanAbsoluteErrorCalculator.Calculate(yPredTest, yTest, out error);
191        if (error != OnlineCalculatorError.None) testMae = double.MaxValue;
192
193        Results.Add(new Result("Mean absolute error (training)", "The mean of absolute errors of the SVR solution on the training partition.", new DoubleValue(trainMae)));
194        Results.Add(new Result("Mean absolute error (test)", "The mean of absolute errors of the SVR solution on the test partition.", new DoubleValue(testMae)));
195
196
197        var trainRelErr = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(yPredTrain, yTrain, out error);
198        if (error != OnlineCalculatorError.None) trainRelErr = double.MaxValue;
199        var testRelErr = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(yPredTest, yTest, out error);
200        if (error != OnlineCalculatorError.None) testRelErr = double.MaxValue;
201
202        Results.Add(new Result("Average relative error (training)", "The mean of relative errors of the SVR solution on the training partition.", new DoubleValue(trainRelErr)));
203        Results.Add(new Result("Average relative error (test)", "The mean of relative errors of the SVR solution on the test partition.", new DoubleValue(testRelErr)));
204      }
205    }
206
207    // BackwardsCompatibility3.4
208    #region Backwards compatible code, remove with 3.5
209    // for compatibility with old API
210    public static SupportVectorRegressionSolution CreateSupportVectorRegressionSolution(
211      IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
212      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon, int degree,
213      out double trainingR2, out double testR2, out int nSv) {
214      ISupportVectorMachineModel model;
215      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, epsilon, degree, out model, out nSv);
216
217      var solution = new SupportVectorRegressionSolution((SupportVectorMachineModel)model, (IRegressionProblemData)problemData.Clone());
218      trainingR2 = solution.TrainingRSquared;
219      testR2 = solution.TestRSquared;
220      return solution;
221    }
222    #endregion
223
224    public static void Run(IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
225      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon, int degree,
226      out ISupportVectorMachineModel model, out int nSv) {
227      var dataset = problemData.Dataset;
228      string targetVariable = problemData.TargetVariable;
229      IEnumerable<int> rows = problemData.TrainingIndices;
230
231      svm_parameter parameter = new svm_parameter {
232        svm_type = GetSvmType(svmType),
233        kernel_type = GetKernelType(kernelType),
234        C = cost,
235        nu = nu,
236        gamma = gamma,
237        p = epsilon,
238        cache_size = 500,
239        probability = 0,
240        eps = 0.001,
241        degree = degree,
242        shrinking = 1,
243        coef0 = 0
244      };
245
246      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
247      RangeTransform rangeTransform = RangeTransform.Compute(problem);
248      svm_problem scaledProblem = rangeTransform.Scale(problem);
249      var svmModel = svm.svm_train(scaledProblem, parameter);
250      nSv = svmModel.SV.Length;
251
252      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables);
253    }
254
255    private static int GetSvmType(string svmType) {
256      if (svmType == "NU_SVR") return svm_parameter.NU_SVR;
257      if (svmType == "EPSILON_SVR") return svm_parameter.EPSILON_SVR;
258      throw new ArgumentException("Unknown SVM type");
259    }
260
261    private static int GetKernelType(string kernelType) {
262      if (kernelType == "LINEAR") return svm_parameter.LINEAR;
263      if (kernelType == "POLY") return svm_parameter.POLY;
264      if (kernelType == "SIGMOID") return svm_parameter.SIGMOID;
265      if (kernelType == "RBF") return svm_parameter.RBF;
266      throw new ArgumentException("Unknown kernel type");
267    }
268    #endregion
269  }
270}
Note: See TracBrowser for help on using the repository browser.