Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs

Last change on this file was 17180, checked in by swagner, 5 years ago

#2875: Removed years in copyrights

File size: 12.9 KB
RevLine 
[5626]1#region License Information
2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5626]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[5759]23using System.Collections.Generic;
[5626]24using System.Linq;
[14523]25using System.Threading;
[5626]26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
[16565]31using HEAL.Attic;
[5626]32using HeuristicLab.Problems.DataAnalysis;
[8609]33using LibSVM;
[5626]34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  /// <summary>
37  /// Support vector machine classification data analysis algorithm.
38  /// </summary>
[13238]39  [Item("Support Vector Classification (SVM)", "Support vector machine classification data analysis algorithm (wrapper for libSVM).")]
[12504]40  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 110)]
[16565]41  [StorableType("F15289E4-B648-4A92-AB01-14D769A33967")]
[5626]42  public sealed class SupportVectorClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
43    private const string SvmTypeParameterName = "SvmType";
44    private const string KernelTypeParameterName = "KernelType";
45    private const string CostParameterName = "Cost";
46    private const string NuParameterName = "Nu";
47    private const string GammaParameterName = "Gamma";
[8613]48    private const string DegreeParameterName = "Degree";
[12934]49    private const string CreateSolutionParameterName = "CreateSolution";
[5626]50
51    #region parameter properties
[8121]52    public IConstrainedValueParameter<StringValue> SvmTypeParameter {
53      get { return (IConstrainedValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
[5626]54    }
[8121]55    public IConstrainedValueParameter<StringValue> KernelTypeParameter {
56      get { return (IConstrainedValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
[5626]57    }
58    public IValueParameter<DoubleValue> NuParameter {
59      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
60    }
61    public IValueParameter<DoubleValue> CostParameter {
62      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
63    }
64    public IValueParameter<DoubleValue> GammaParameter {
65      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
66    }
[8613]67    public IValueParameter<IntValue> DegreeParameter {
68      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
69    }
[12934]70    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
71      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
72    }
[5626]73    #endregion
74    #region properties
75    public StringValue SvmType {
76      get { return SvmTypeParameter.Value; }
[8121]77      set { SvmTypeParameter.Value = value; }
[5626]78    }
79    public StringValue KernelType {
80      get { return KernelTypeParameter.Value; }
[8121]81      set { KernelTypeParameter.Value = value; }
[5626]82    }
83    public DoubleValue Nu {
84      get { return NuParameter.Value; }
85    }
86    public DoubleValue Cost {
87      get { return CostParameter.Value; }
88    }
89    public DoubleValue Gamma {
90      get { return GammaParameter.Value; }
91    }
[8613]92    public IntValue Degree {
93      get { return DegreeParameter.Value; }
94    }
[12934]95    public bool CreateSolution {
96      get { return CreateSolutionParameter.Value.Value; }
97      set { CreateSolutionParameter.Value.Value = value; }
98    }
[5626]99    #endregion
100    [StorableConstructor]
[16565]101    private SupportVectorClassification(StorableConstructorFlag _) : base(_) { }
[5626]102    private SupportVectorClassification(SupportVectorClassification original, Cloner cloner)
103      : base(original, cloner) {
104    }
105    public SupportVectorClassification()
106      : base() {
[5649]107      Problem = new ClassificationProblem();
108
[6812]109      List<StringValue> svrTypes = (from type in new List<string> { "NU_SVC", "C_SVC" }
[5626]110                                    select new StringValue(type).AsReadOnly())
111                                   .ToList();
112      ItemSet<StringValue> svrTypeSet = new ItemSet<StringValue>(svrTypes);
113      List<StringValue> kernelTypes = (from type in new List<string> { "LINEAR", "POLY", "SIGMOID", "RBF" }
114                                       select new StringValue(type).AsReadOnly())
115                                   .ToList();
[5649]116      ItemSet<StringValue> kernelTypeSet = new ItemSet<StringValue>(kernelTypes);
[5626]117      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0]));
118      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3]));
119      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC.", new DoubleValue(0.5)));
120      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC.", new DoubleValue(1.0)));
121      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
[8613]122      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
[12934]123      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
124      Parameters[CreateSolutionParameterName].Hidden = true;
[5626]125    }
126    [StorableHook(HookType.AfterDeserialization)]
[8613]127    private void AfterDeserialization() {
128      #region backwards compatibility (change with 3.4)
[12934]129      if (!Parameters.ContainsKey(DegreeParameterName)) {
130        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
131          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
132      }
133      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
134        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,
135          "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
136        Parameters[CreateSolutionParameterName].Hidden = true;
137      }
[8613]138      #endregion
139    }
[5626]140
141    public override IDeepCloneable Clone(Cloner cloner) {
142      return new SupportVectorClassification(this, cloner);
143    }
144
145    #region support vector classification
[14523]146    protected override void Run(CancellationToken cancellationToken) {
[5626]147      IClassificationProblemData problemData = Problem.ProblemData;
[5649]148      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
[7430]149      int nSv;
[12934]150      ISupportVectorMachineModel model;
[5626]151
[12934]152      Run(problemData, selectedInputVariables, GetSvmType(SvmType.Value), GetKernelType(KernelType.Value), Cost.Value, Nu.Value, Gamma.Value, Degree.Value, out model, out nSv);
153
154      if (CreateSolution) {
155        var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
156        Results.Add(new Result("Support vector classification solution", "The support vector classification solution.",
157          solution));
158      }
159
160      {
161        // calculate classification metrics
162        // calculate regression model metrics
163        var ds = problemData.Dataset;
164        var trainRows = problemData.TrainingIndices;
165        var testRows = problemData.TestIndices;
166        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
167        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
168        var yPredTrain = model.GetEstimatedClassValues(ds, trainRows);
169        var yPredTest = model.GetEstimatedClassValues(ds, testRows);
170
171        OnlineCalculatorError error;
172        var trainAccuracy = OnlineAccuracyCalculator.Calculate(yPredTrain, yTrain, out error);
173        if (error != OnlineCalculatorError.None) trainAccuracy = double.MaxValue;
174        var testAccuracy = OnlineAccuracyCalculator.Calculate(yPredTest, yTest, out error);
175        if (error != OnlineCalculatorError.None) testAccuracy = double.MaxValue;
176
177        Results.Add(new Result("Accuracy (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainAccuracy)));
178        Results.Add(new Result("Accuracy (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testAccuracy)));
179
180        Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.",
181          new IntValue(nSv)));
182      }
[5626]183    }
184
185    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
[11337]186      string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
[11340]187      return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetSvmType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree,
[11337]188        out trainingAccuracy, out testAccuracy, out nSv);
189    }
190
[12934]191    // BackwardsCompatibility3.4
192    #region Backwards compatible code, remove with 3.5
[11337]193    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
194      int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
[12934]195
196      ISupportVectorMachineModel model;
197      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, degree, out model, out nSv);
198      var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
199
200      trainingAccuracy = solution.TrainingAccuracy;
201      testAccuracy = solution.TestAccuracy;
202
203      return solution;
204    }
205
206    #endregion
207
208    public static void Run(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
209      int svmType, int kernelType, double cost, double nu, double gamma, int degree,
210      out ISupportVectorMachineModel model, out int nSv) {
[12509]211      var dataset = problemData.Dataset;
[5626]212      string targetVariable = problemData.TargetVariable;
[8139]213      IEnumerable<int> rows = problemData.TrainingIndices;
[5626]214
[12934]215      svm_parameter parameter = new svm_parameter {
216        svm_type = svmType,
217        kernel_type = kernelType,
218        C = cost,
219        nu = nu,
220        gamma = gamma,
221        cache_size = 500,
222        probability = 0,
223        eps = 0.001,
224        degree = degree,
225        shrinking = 1,
226        coef0 = 0
227      };
[5626]228
[8609]229      var weightLabels = new List<int>();
230      var weights = new List<double>();
[6812]231      foreach (double c in problemData.ClassValues) {
232        double wSum = 0.0;
233        foreach (double otherClass in problemData.ClassValues) {
234          if (!c.IsAlmost(otherClass)) {
235            wSum += problemData.GetClassificationPenalty(c, otherClass);
236          }
237        }
[8609]238        weightLabels.Add((int)c);
239        weights.Add(wSum);
[6812]240      }
[8609]241      parameter.weight_label = weightLabels.ToArray();
242      parameter.weight = weights.ToArray();
[5626]243
[8609]244      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
245      RangeTransform rangeTransform = RangeTransform.Compute(problem);
246      svm_problem scaledProblem = rangeTransform.Scale(problem);
247      var svmModel = svm.svm_train(scaledProblem, parameter);
248      nSv = svmModel.SV.Length;
[7430]249
[12934]250      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues);
[5626]251    }
[8609]252
253    private static int GetSvmType(string svmType) {
254      if (svmType == "NU_SVC") return svm_parameter.NU_SVC;
255      if (svmType == "C_SVC") return svm_parameter.C_SVC;
256      throw new ArgumentException("Unknown SVM type");
257    }
258
259    private static int GetKernelType(string kernelType) {
260      if (kernelType == "LINEAR") return svm_parameter.LINEAR;
261      if (kernelType == "POLY") return svm_parameter.POLY;
262      if (kernelType == "SIGMOID") return svm_parameter.SIGMOID;
263      if (kernelType == "RBF") return svm_parameter.RBF;
264      throw new ArgumentException("Unknown kernel type");
265    }
[5626]266    #endregion
267  }
268}
Note: See TracBrowser for help on using the repository browser.