source: branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs @ 14387

Last change on this file since 14387 was 14185, checked in by swagner, 3 years ago

#2526: Updated year of copyrights in license headers

File size: 12.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using LibSVM;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// Support vector machine classification data analysis algorithm.
37  /// </summary>
38  [Item("Support Vector Classification (SVM)", "Support vector machine classification data analysis algorithm (wrapper for libSVM).")]
39  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 110)]
40  [StorableClass]
41  public sealed class SupportVectorClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
42    private const string SvmTypeParameterName = "SvmType";
43    private const string KernelTypeParameterName = "KernelType";
44    private const string CostParameterName = "Cost";
45    private const string NuParameterName = "Nu";
46    private const string GammaParameterName = "Gamma";
47    private const string DegreeParameterName = "Degree";
48    private const string CreateSolutionParameterName = "CreateSolution";
49
50    #region parameter properties
51    public IConstrainedValueParameter<StringValue> SvmTypeParameter {
52      get { return (IConstrainedValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
53    }
54    public IConstrainedValueParameter<StringValue> KernelTypeParameter {
55      get { return (IConstrainedValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
56    }
57    public IValueParameter<DoubleValue> NuParameter {
58      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
59    }
60    public IValueParameter<DoubleValue> CostParameter {
61      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
62    }
63    public IValueParameter<DoubleValue> GammaParameter {
64      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
65    }
66    public IValueParameter<IntValue> DegreeParameter {
67      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
68    }
69    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
70      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
71    }
72    #endregion
73    #region properties
74    public StringValue SvmType {
75      get { return SvmTypeParameter.Value; }
76      set { SvmTypeParameter.Value = value; }
77    }
78    public StringValue KernelType {
79      get { return KernelTypeParameter.Value; }
80      set { KernelTypeParameter.Value = value; }
81    }
82    public DoubleValue Nu {
83      get { return NuParameter.Value; }
84    }
85    public DoubleValue Cost {
86      get { return CostParameter.Value; }
87    }
88    public DoubleValue Gamma {
89      get { return GammaParameter.Value; }
90    }
91    public IntValue Degree {
92      get { return DegreeParameter.Value; }
93    }
94    public bool CreateSolution {
95      get { return CreateSolutionParameter.Value.Value; }
96      set { CreateSolutionParameter.Value.Value = value; }
97    }
98    #endregion
99    [StorableConstructor]
100    private SupportVectorClassification(bool deserializing) : base(deserializing) { }
101    private SupportVectorClassification(SupportVectorClassification original, Cloner cloner)
102      : base(original, cloner) {
103    }
104    public SupportVectorClassification()
105      : base() {
106      Problem = new ClassificationProblem();
107
108      List<StringValue> svrTypes = (from type in new List<string> { "NU_SVC", "C_SVC" }
109                                    select new StringValue(type).AsReadOnly())
110                                   .ToList();
111      ItemSet<StringValue> svrTypeSet = new ItemSet<StringValue>(svrTypes);
112      List<StringValue> kernelTypes = (from type in new List<string> { "LINEAR", "POLY", "SIGMOID", "RBF" }
113                                       select new StringValue(type).AsReadOnly())
114                                   .ToList();
115      ItemSet<StringValue> kernelTypeSet = new ItemSet<StringValue>(kernelTypes);
116      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0]));
117      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3]));
118      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC.", new DoubleValue(0.5)));
119      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC.", new DoubleValue(1.0)));
120      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
121      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
122      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
123      Parameters[CreateSolutionParameterName].Hidden = true;
124    }
125    [StorableHook(HookType.AfterDeserialization)]
126    private void AfterDeserialization() {
127      #region backwards compatibility (change with 3.4)
128      if (!Parameters.ContainsKey(DegreeParameterName)) {
129        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
130          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
131      }
132      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
133        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,
134          "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
135        Parameters[CreateSolutionParameterName].Hidden = true;
136      }
137      #endregion
138    }
139
140    public override IDeepCloneable Clone(Cloner cloner) {
141      return new SupportVectorClassification(this, cloner);
142    }
143
144    #region support vector classification
145    protected override void Run() {
146      IClassificationProblemData problemData = Problem.ProblemData;
147      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
148      int nSv;
149      ISupportVectorMachineModel model;
150
151      Run(problemData, selectedInputVariables, GetSvmType(SvmType.Value), GetKernelType(KernelType.Value), Cost.Value, Nu.Value, Gamma.Value, Degree.Value, out model, out nSv);
152
153      if (CreateSolution) {
154        var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
155        Results.Add(new Result("Support vector classification solution", "The support vector classification solution.",
156          solution));
157      }
158
159      {
160        // calculate classification metrics
161        // calculate regression model metrics
162        var ds = problemData.Dataset;
163        var trainRows = problemData.TrainingIndices;
164        var testRows = problemData.TestIndices;
165        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
166        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
167        var yPredTrain = model.GetEstimatedClassValues(ds, trainRows);
168        var yPredTest = model.GetEstimatedClassValues(ds, testRows);
169
170        OnlineCalculatorError error;
171        var trainAccuracy = OnlineAccuracyCalculator.Calculate(yPredTrain, yTrain, out error);
172        if (error != OnlineCalculatorError.None) trainAccuracy = double.MaxValue;
173        var testAccuracy = OnlineAccuracyCalculator.Calculate(yPredTest, yTest, out error);
174        if (error != OnlineCalculatorError.None) testAccuracy = double.MaxValue;
175
176        Results.Add(new Result("Accuracy (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainAccuracy)));
177        Results.Add(new Result("Accuracy (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testAccuracy)));
178
179        Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.",
180          new IntValue(nSv)));
181      }
182    }
183
184    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
185      string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
186      return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetSvmType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree,
187        out trainingAccuracy, out testAccuracy, out nSv);
188    }
189
190    // BackwardsCompatibility3.4
191    #region Backwards compatible code, remove with 3.5
192    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
193      int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
194
195      ISupportVectorMachineModel model;
196      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, degree, out model, out nSv);
197      var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
198
199      trainingAccuracy = solution.TrainingAccuracy;
200      testAccuracy = solution.TestAccuracy;
201
202      return solution;
203    }
204
205    #endregion
206
207    public static void Run(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
208      int svmType, int kernelType, double cost, double nu, double gamma, int degree,
209      out ISupportVectorMachineModel model, out int nSv) {
210      var dataset = problemData.Dataset;
211      string targetVariable = problemData.TargetVariable;
212      IEnumerable<int> rows = problemData.TrainingIndices;
213
214      svm_parameter parameter = new svm_parameter {
215        svm_type = svmType,
216        kernel_type = kernelType,
217        C = cost,
218        nu = nu,
219        gamma = gamma,
220        cache_size = 500,
221        probability = 0,
222        eps = 0.001,
223        degree = degree,
224        shrinking = 1,
225        coef0 = 0
226      };
227
228      var weightLabels = new List<int>();
229      var weights = new List<double>();
230      foreach (double c in problemData.ClassValues) {
231        double wSum = 0.0;
232        foreach (double otherClass in problemData.ClassValues) {
233          if (!c.IsAlmost(otherClass)) {
234            wSum += problemData.GetClassificationPenalty(c, otherClass);
235          }
236        }
237        weightLabels.Add((int)c);
238        weights.Add(wSum);
239      }
240      parameter.weight_label = weightLabels.ToArray();
241      parameter.weight = weights.ToArray();
242
243      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
244      RangeTransform rangeTransform = RangeTransform.Compute(problem);
245      svm_problem scaledProblem = rangeTransform.Scale(problem);
246      var svmModel = svm.svm_train(scaledProblem, parameter);
247      nSv = svmModel.SV.Length;
248
249      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues);
250    }
251
252    private static int GetSvmType(string svmType) {
253      if (svmType == "NU_SVC") return svm_parameter.NU_SVC;
254      if (svmType == "C_SVC") return svm_parameter.C_SVC;
255      throw new ArgumentException("Unknown SVM type");
256    }
257
258    private static int GetKernelType(string kernelType) {
259      if (kernelType == "LINEAR") return svm_parameter.LINEAR;
260      if (kernelType == "POLY") return svm_parameter.POLY;
261      if (kernelType == "SIGMOID") return svm_parameter.SIGMOID;
262      if (kernelType == "RBF") return svm_parameter.RBF;
263      throw new ArgumentException("Unknown kernel type");
264    }
265    #endregion
266  }
267}
Note: See TracBrowser for help on using the repository browser.