source: branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs @ 14542

Last change on this file since 14542 was 14542, checked in by gkronber, 3 years ago

#2650: merged r14504:14533 from trunk to branch

File size: 13.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis;
33using LibSVM;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  /// <summary>
37  /// Support vector machine classification data analysis algorithm.
38  /// </summary>
39  [Item("Support Vector Classification (SVM)", "Support vector machine classification data analysis algorithm (wrapper for libSVM).")]
40  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 110)]
41  [StorableClass]
42  public sealed class SupportVectorClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
43    private const string SvmTypeParameterName = "SvmType";
44    private const string KernelTypeParameterName = "KernelType";
45    private const string CostParameterName = "Cost";
46    private const string NuParameterName = "Nu";
47    private const string GammaParameterName = "Gamma";
48    private const string DegreeParameterName = "Degree";
49    private const string CreateSolutionParameterName = "CreateSolution";
50
51    #region parameter properties
52    public IConstrainedValueParameter<StringValue> SvmTypeParameter {
53      get { return (IConstrainedValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
54    }
55    public IConstrainedValueParameter<StringValue> KernelTypeParameter {
56      get { return (IConstrainedValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
57    }
58    public IValueParameter<DoubleValue> NuParameter {
59      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
60    }
61    public IValueParameter<DoubleValue> CostParameter {
62      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
63    }
64    public IValueParameter<DoubleValue> GammaParameter {
65      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
66    }
67    public IValueParameter<IntValue> DegreeParameter {
68      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
69    }
70    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
71      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
72    }
73    #endregion
74    #region properties
75    public StringValue SvmType {
76      get { return SvmTypeParameter.Value; }
77      set { SvmTypeParameter.Value = value; }
78    }
79    public StringValue KernelType {
80      get { return KernelTypeParameter.Value; }
81      set { KernelTypeParameter.Value = value; }
82    }
83    public DoubleValue Nu {
84      get { return NuParameter.Value; }
85    }
86    public DoubleValue Cost {
87      get { return CostParameter.Value; }
88    }
89    public DoubleValue Gamma {
90      get { return GammaParameter.Value; }
91    }
92    public IntValue Degree {
93      get { return DegreeParameter.Value; }
94    }
95    public bool CreateSolution {
96      get { return CreateSolutionParameter.Value.Value; }
97      set { CreateSolutionParameter.Value.Value = value; }
98    }
99    #endregion
100    [StorableConstructor]
101    private SupportVectorClassification(bool deserializing) : base(deserializing) { }
102    private SupportVectorClassification(SupportVectorClassification original, Cloner cloner)
103      : base(original, cloner) {
104    }
105    public SupportVectorClassification()
106      : base() {
107      Problem = new ClassificationProblem();
108
109      List<StringValue> svrTypes = (from type in new List<string> { "NU_SVC", "C_SVC" }
110                                    select new StringValue(type).AsReadOnly())
111                                   .ToList();
112      ItemSet<StringValue> svrTypeSet = new ItemSet<StringValue>(svrTypes);
113      List<StringValue> kernelTypes = (from type in new List<string> { "LINEAR", "POLY", "SIGMOID", "RBF" }
114                                       select new StringValue(type).AsReadOnly())
115                                   .ToList();
116      ItemSet<StringValue> kernelTypeSet = new ItemSet<StringValue>(kernelTypes);
117      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0]));
118      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3]));
119      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC.", new DoubleValue(0.5)));
120      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC.", new DoubleValue(1.0)));
121      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
122      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
123      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
124      Parameters[CreateSolutionParameterName].Hidden = true;
125    }
126    [StorableHook(HookType.AfterDeserialization)]
127    private void AfterDeserialization() {
128      #region backwards compatibility (change with 3.4)
129      if (!Parameters.ContainsKey(DegreeParameterName)) {
130        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
131          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
132      }
133      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
134        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,
135          "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
136        Parameters[CreateSolutionParameterName].Hidden = true;
137      }
138      #endregion
139    }
140
141    public override IDeepCloneable Clone(Cloner cloner) {
142      return new SupportVectorClassification(this, cloner);
143    }
144
145    #region support vector classification
146    protected override void Run(CancellationToken cancellationToken) {
147      IClassificationProblemData problemData = Problem.ProblemData;
148      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
149      int nSv;
150      ISupportVectorMachineModel model;
151
152      Run(problemData, selectedInputVariables, GetSvmType(SvmType.Value), GetKernelType(KernelType.Value), Cost.Value, Nu.Value, Gamma.Value, Degree.Value, out model, out nSv);
153
154      if (CreateSolution) {
155        var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
156        Results.Add(new Result("Support vector classification solution", "The support vector classification solution.",
157          solution));
158      }
159
160      {
161        // calculate classification metrics
162        // calculate regression model metrics
163        var ds = problemData.Dataset;
164        var trainRows = problemData.TrainingIndices;
165        var testRows = problemData.TestIndices;
166        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
167        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
168        var yPredTrain = model.GetEstimatedClassValues(ds, trainRows);
169        var yPredTest = model.GetEstimatedClassValues(ds, testRows);
170
171        OnlineCalculatorError error;
172        var trainAccuracy = OnlineAccuracyCalculator.Calculate(yPredTrain, yTrain, out error);
173        if (error != OnlineCalculatorError.None) trainAccuracy = double.MaxValue;
174        var testAccuracy = OnlineAccuracyCalculator.Calculate(yPredTest, yTest, out error);
175        if (error != OnlineCalculatorError.None) testAccuracy = double.MaxValue;
176
177        Results.Add(new Result("Accuracy (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainAccuracy)));
178        Results.Add(new Result("Accuracy (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testAccuracy)));
179
180        Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.",
181          new IntValue(nSv)));
182      }
183    }
184
185    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
186      string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
187      return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetSvmType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree,
188        out trainingAccuracy, out testAccuracy, out nSv);
189    }
190
191    // BackwardsCompatibility3.4
192    #region Backwards compatible code, remove with 3.5
193    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
194      int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
195
196      ISupportVectorMachineModel model;
197      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, degree, out model, out nSv);
198      var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
199
200      trainingAccuracy = solution.TrainingAccuracy;
201      testAccuracy = solution.TestAccuracy;
202
203      return solution;
204    }
205
206    #endregion
207
208    public static void Run(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
209      int svmType, int kernelType, double cost, double nu, double gamma, int degree,
210      out ISupportVectorMachineModel model, out int nSv) {
211      var dataset = problemData.Dataset;
212      string targetVariable = problemData.TargetVariable;
213      IEnumerable<int> rows = problemData.TrainingIndices;
214
215      svm_parameter parameter = new svm_parameter {
216        svm_type = svmType,
217        kernel_type = kernelType,
218        C = cost,
219        nu = nu,
220        gamma = gamma,
221        cache_size = 500,
222        probability = 0,
223        eps = 0.001,
224        degree = degree,
225        shrinking = 1,
226        coef0 = 0
227      };
228
229      var weightLabels = new List<int>();
230      var weights = new List<double>();
231      foreach (double c in problemData.ClassValues) {
232        double wSum = 0.0;
233        foreach (double otherClass in problemData.ClassValues) {
234          if (!c.IsAlmost(otherClass)) {
235            wSum += problemData.GetClassificationPenalty(c, otherClass);
236          }
237        }
238        weightLabels.Add((int)c);
239        weights.Add(wSum);
240      }
241      parameter.weight_label = weightLabels.ToArray();
242      parameter.weight = weights.ToArray();
243
244      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
245      RangeTransform rangeTransform = RangeTransform.Compute(problem);
246      svm_problem scaledProblem = rangeTransform.Scale(problem);
247      var svmModel = svm.svm_train(scaledProblem, parameter);
248      nSv = svmModel.SV.Length;
249
250      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues);
251    }
252
253    private static int GetSvmType(string svmType) {
254      if (svmType == "NU_SVC") return svm_parameter.NU_SVC;
255      if (svmType == "C_SVC") return svm_parameter.C_SVC;
256      throw new ArgumentException("Unknown SVM type");
257    }
258
259    private static int GetKernelType(string kernelType) {
260      if (kernelType == "LINEAR") return svm_parameter.LINEAR;
261      if (kernelType == "POLY") return svm_parameter.POLY;
262      if (kernelType == "SIGMOID") return svm_parameter.SIGMOID;
263      if (kernelType == "RBF") return svm_parameter.RBF;
264      throw new ArgumentException("Unknown kernel type");
265    }
266    #endregion
267  }
268}
Note: See TracBrowser for help on using the repository browser.