Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs @ 12934

Last change on this file since 12934 was 12934, checked in by gkronber, 9 years ago

#2385 added a boolean "CreateSolution" parameter for support vector machine algorithms and added model error/accuracy metrics as algorithm results (to allow grid search without creating solutions)

File size: 13.0 KB
RevLine 
[5626]1#region License Information
2/* HeuristicLab
[12012]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5626]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[5759]23using System.Collections.Generic;
[5626]24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
[8609]32using LibSVM;
[5626]33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// Support vector machine classification data analysis algorithm.
37  /// </summary>
[6240]38  [Item("Support Vector Classification", "Support vector machine classification data analysis algorithm (wrapper for libSVM).")]
[12504]39  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 110)]
[5626]40  [StorableClass]
41  public sealed class SupportVectorClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
42    private const string SvmTypeParameterName = "SvmType";
43    private const string KernelTypeParameterName = "KernelType";
44    private const string CostParameterName = "Cost";
45    private const string NuParameterName = "Nu";
46    private const string GammaParameterName = "Gamma";
[8613]47    private const string DegreeParameterName = "Degree";
[12934]48    private const string CreateSolutionParameterName = "CreateSolution";
[5626]49
50    #region parameter properties
[8121]51    public IConstrainedValueParameter<StringValue> SvmTypeParameter {
52      get { return (IConstrainedValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
[5626]53    }
[8121]54    public IConstrainedValueParameter<StringValue> KernelTypeParameter {
55      get { return (IConstrainedValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
[5626]56    }
57    public IValueParameter<DoubleValue> NuParameter {
58      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
59    }
60    public IValueParameter<DoubleValue> CostParameter {
61      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
62    }
63    public IValueParameter<DoubleValue> GammaParameter {
64      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
65    }
[8613]66    public IValueParameter<IntValue> DegreeParameter {
67      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
68    }
[12934]69    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
70      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
71    }
[5626]72    #endregion
73    #region properties
74    public StringValue SvmType {
75      get { return SvmTypeParameter.Value; }
[8121]76      set { SvmTypeParameter.Value = value; }
[5626]77    }
78    public StringValue KernelType {
79      get { return KernelTypeParameter.Value; }
[8121]80      set { KernelTypeParameter.Value = value; }
[5626]81    }
82    public DoubleValue Nu {
83      get { return NuParameter.Value; }
84    }
85    public DoubleValue Cost {
86      get { return CostParameter.Value; }
87    }
88    public DoubleValue Gamma {
89      get { return GammaParameter.Value; }
90    }
[8613]91    public IntValue Degree {
92      get { return DegreeParameter.Value; }
93    }
[12934]94    public bool CreateSolution {
95      get { return CreateSolutionParameter.Value.Value; }
96      set { CreateSolutionParameter.Value.Value = value; }
97    }
[5626]98    #endregion
99    [StorableConstructor]
100    private SupportVectorClassification(bool deserializing) : base(deserializing) { }
101    private SupportVectorClassification(SupportVectorClassification original, Cloner cloner)
102      : base(original, cloner) {
103    }
104    public SupportVectorClassification()
105      : base() {
[5649]106      Problem = new ClassificationProblem();
107
[6812]108      List<StringValue> svrTypes = (from type in new List<string> { "NU_SVC", "C_SVC" }
[5626]109                                    select new StringValue(type).AsReadOnly())
110                                   .ToList();
111      ItemSet<StringValue> svrTypeSet = new ItemSet<StringValue>(svrTypes);
112      List<StringValue> kernelTypes = (from type in new List<string> { "LINEAR", "POLY", "SIGMOID", "RBF" }
113                                       select new StringValue(type).AsReadOnly())
114                                   .ToList();
[5649]115      ItemSet<StringValue> kernelTypeSet = new ItemSet<StringValue>(kernelTypes);
[5626]116      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0]));
117      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3]));
118      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC.", new DoubleValue(0.5)));
119      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC.", new DoubleValue(1.0)));
120      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
[8613]121      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
[12934]122      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
123      Parameters[CreateSolutionParameterName].Hidden = true;
[5626]124    }
125    [StorableHook(HookType.AfterDeserialization)]
[8613]126    private void AfterDeserialization() {
127      #region backwards compatibility (change with 3.4)
[12934]128      if (!Parameters.ContainsKey(DegreeParameterName)) {
129        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
130          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
131      }
132      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
133        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,
134          "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
135        Parameters[CreateSolutionParameterName].Hidden = true;
136      }
[8613]137      #endregion
138    }
[5626]139
140    public override IDeepCloneable Clone(Cloner cloner) {
141      return new SupportVectorClassification(this, cloner);
142    }
143
144    #region support vector classification
145    protected override void Run() {
146      IClassificationProblemData problemData = Problem.ProblemData;
[5649]147      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
[7430]148      int nSv;
[12934]149      ISupportVectorMachineModel model;
[5626]150
[12934]151      Run(problemData, selectedInputVariables, GetSvmType(SvmType.Value), GetKernelType(KernelType.Value), Cost.Value, Nu.Value, Gamma.Value, Degree.Value, out model, out nSv);
152
153      if (CreateSolution) {
154        var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
155        Results.Add(new Result("Support vector classification solution", "The support vector classification solution.",
156          solution));
157      }
158
159      {
160        // calculate classification metrics
161        // calculate regression model metrics
162        var ds = problemData.Dataset;
163        var trainRows = problemData.TrainingIndices;
164        var testRows = problemData.TestIndices;
165        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
166        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
167        var yPredTrain = model.GetEstimatedClassValues(ds, trainRows);
168        var yPredTest = model.GetEstimatedClassValues(ds, testRows);
169
170        OnlineCalculatorError error;
171        var trainAccuracy = OnlineAccuracyCalculator.Calculate(yPredTrain, yTrain, out error);
172        if (error != OnlineCalculatorError.None) trainAccuracy = double.MaxValue;
173        var testAccuracy = OnlineAccuracyCalculator.Calculate(yPredTest, yTest, out error);
174        if (error != OnlineCalculatorError.None) testAccuracy = double.MaxValue;
175
176        Results.Add(new Result("Accuracy (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainAccuracy)));
177        Results.Add(new Result("Accuracy (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testAccuracy)));
178
179        Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.",
180          new IntValue(nSv)));
181      }
[5626]182    }
183
184    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
[11337]185      string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
[11340]186      return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetSvmType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree,
[11337]187        out trainingAccuracy, out testAccuracy, out nSv);
188    }
189
[12934]190    // BackwardsCompatibility3.4
191    #region Backwards compatible code, remove with 3.5
[11337]192    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
193      int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
[12934]194
195      ISupportVectorMachineModel model;
196      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, degree, out model, out nSv);
197      var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
198
199      trainingAccuracy = solution.TrainingAccuracy;
200      testAccuracy = solution.TestAccuracy;
201
202      return solution;
203    }
204
205    #endregion
206
207    public static void Run(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
208      int svmType, int kernelType, double cost, double nu, double gamma, int degree,
209      out ISupportVectorMachineModel model, out int nSv) {
[12509]210      var dataset = problemData.Dataset;
[5626]211      string targetVariable = problemData.TargetVariable;
[8139]212      IEnumerable<int> rows = problemData.TrainingIndices;
[5626]213
214      //extract SVM parameters from scope and set them
[12934]215      svm_parameter parameter = new svm_parameter {
216        svm_type = svmType,
217        kernel_type = kernelType,
218        C = cost,
219        nu = nu,
220        gamma = gamma,
221        cache_size = 500,
222        probability = 0,
223        eps = 0.001,
224        degree = degree,
225        shrinking = 1,
226        coef0 = 0
227      };
[5626]228
[8609]229      var weightLabels = new List<int>();
230      var weights = new List<double>();
[6812]231      foreach (double c in problemData.ClassValues) {
232        double wSum = 0.0;
233        foreach (double otherClass in problemData.ClassValues) {
234          if (!c.IsAlmost(otherClass)) {
235            wSum += problemData.GetClassificationPenalty(c, otherClass);
236          }
237        }
[8609]238        weightLabels.Add((int)c);
239        weights.Add(wSum);
[6812]240      }
[8609]241      parameter.weight_label = weightLabels.ToArray();
242      parameter.weight = weights.ToArray();
[5626]243
[8609]244      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
245      RangeTransform rangeTransform = RangeTransform.Compute(problem);
246      svm_problem scaledProblem = rangeTransform.Scale(problem);
247      var svmModel = svm.svm_train(scaledProblem, parameter);
248      nSv = svmModel.SV.Length;
[7430]249
[12934]250      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues);
[5626]251    }
[8609]252
253    private static int GetSvmType(string svmType) {
254      if (svmType == "NU_SVC") return svm_parameter.NU_SVC;
255      if (svmType == "C_SVC") return svm_parameter.C_SVC;
256      throw new ArgumentException("Unknown SVM type");
257    }
258
259    private static int GetKernelType(string kernelType) {
260      if (kernelType == "LINEAR") return svm_parameter.LINEAR;
261      if (kernelType == "POLY") return svm_parameter.POLY;
262      if (kernelType == "SIGMOID") return svm_parameter.SIGMOID;
263      if (kernelType == "RBF") return svm_parameter.RBF;
264      throw new ArgumentException("Unknown kernel type");
265    }
[5626]266    #endregion
267  }
268}
Note: See TracBrowser for help on using the repository browser.