Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs @ 8609

Last change on this file since 8609 was 8609, checked in by gkronber, 12 years ago

#1944 changed SVR and SVC algorithms in HeuristicLab to use most recent LibSVM version.

File size: 9.1 KB
RevLine 
[5626]1#region License Information
2/* HeuristicLab
[7259]3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5626]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[5759]23using System.Collections.Generic;
[5626]24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
[8609]32using LibSVM;
[5626]33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  /// <summary>
36  /// Support vector machine classification data analysis algorithm.
37  /// </summary>
[6240]38  [Item("Support Vector Classification", "Support vector machine classification data analysis algorithm (wrapper for libSVM).")]
[5626]39  [Creatable("Data Analysis")]
40  [StorableClass]
41  public sealed class SupportVectorClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
42    private const string SvmTypeParameterName = "SvmType";
43    private const string KernelTypeParameterName = "KernelType";
44    private const string CostParameterName = "Cost";
45    private const string NuParameterName = "Nu";
46    private const string GammaParameterName = "Gamma";
47
48    #region parameter properties
[8121]49    public IConstrainedValueParameter<StringValue> SvmTypeParameter {
50      get { return (IConstrainedValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
[5626]51    }
[8121]52    public IConstrainedValueParameter<StringValue> KernelTypeParameter {
53      get { return (IConstrainedValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
[5626]54    }
55    public IValueParameter<DoubleValue> NuParameter {
56      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
57    }
58    public IValueParameter<DoubleValue> CostParameter {
59      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
60    }
61    public IValueParameter<DoubleValue> GammaParameter {
62      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
63    }
64    #endregion
65    #region properties
66    public StringValue SvmType {
67      get { return SvmTypeParameter.Value; }
[8121]68      set { SvmTypeParameter.Value = value; }
[5626]69    }
70    public StringValue KernelType {
71      get { return KernelTypeParameter.Value; }
[8121]72      set { KernelTypeParameter.Value = value; }
[5626]73    }
74    public DoubleValue Nu {
75      get { return NuParameter.Value; }
76    }
77    public DoubleValue Cost {
78      get { return CostParameter.Value; }
79    }
80    public DoubleValue Gamma {
81      get { return GammaParameter.Value; }
82    }
83    #endregion
84    [StorableConstructor]
85    private SupportVectorClassification(bool deserializing) : base(deserializing) { }
86    private SupportVectorClassification(SupportVectorClassification original, Cloner cloner)
87      : base(original, cloner) {
88    }
89    public SupportVectorClassification()
90      : base() {
[5649]91      Problem = new ClassificationProblem();
92
[6812]93      List<StringValue> svrTypes = (from type in new List<string> { "NU_SVC", "C_SVC" }
[5626]94                                    select new StringValue(type).AsReadOnly())
95                                   .ToList();
96      ItemSet<StringValue> svrTypeSet = new ItemSet<StringValue>(svrTypes);
97      List<StringValue> kernelTypes = (from type in new List<string> { "LINEAR", "POLY", "SIGMOID", "RBF" }
98                                       select new StringValue(type).AsReadOnly())
99                                   .ToList();
[5649]100      ItemSet<StringValue> kernelTypeSet = new ItemSet<StringValue>(kernelTypes);
[5626]101      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0]));
102      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3]));
103      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC.", new DoubleValue(0.5)));
104      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC.", new DoubleValue(1.0)));
105      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
106    }
107    [StorableHook(HookType.AfterDeserialization)]
108    private void AfterDeserialization() { }
109
110    public override IDeepCloneable Clone(Cloner cloner) {
111      return new SupportVectorClassification(this, cloner);
112    }
113
114    #region support vector classification
115    protected override void Run() {
116      IClassificationProblemData problemData = Problem.ProblemData;
[5649]117      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
[7430]118      double trainingAccuracy, testAccuracy;
119      int nSv;
120      var solution = CreateSupportVectorClassificationSolution(problemData, selectedInputVariables,
121        SvmType.Value, KernelType.Value, Cost.Value, Nu.Value, Gamma.Value,
122        out trainingAccuracy, out testAccuracy, out nSv);
[5626]123
124      Results.Add(new Result("Support vector classification solution", "The support vector classification solution.", solution));
[7430]125      Results.Add(new Result("Training accuracy", "The accuracy of the SVR solution on the training partition.", new DoubleValue(trainingAccuracy)));
[8609]126      Results.Add(new Result("Test accuracy", "The accuracy of the SVR solution on the test partition.", new DoubleValue(testAccuracy)));
[7430]127      Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", new IntValue(nSv)));
[5626]128    }
129
130    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
[7430]131      string svmType, string kernelType, double cost, double nu, double gamma,
132      out double trainingAccuracy, out double testAccuracy, out int nSv) {
[5626]133      Dataset dataset = problemData.Dataset;
134      string targetVariable = problemData.TargetVariable;
[8139]135      IEnumerable<int> rows = problemData.TrainingIndices;
[5626]136
137      //extract SVM parameters from scope and set them
[8609]138      svm_parameter parameter = new svm_parameter();
139      parameter.svm_type = GetSvmType(svmType);
140      parameter.kernel_type = GetKernelType(kernelType);
[5626]141      parameter.C = cost;
[8609]142      parameter.nu = nu;
143      parameter.gamma = gamma;
144      parameter.cache_size = 500;
145      parameter.probability = 0;
146      parameter.eps = 0.001;
147      parameter.degree = 3;
148      parameter.shrinking = 1;
149      parameter.coef0 = 0;
[5626]150
[8609]151
152      var weightLabels = new List<int>();
153      var weights = new List<double>();
[6812]154      foreach (double c in problemData.ClassValues) {
155        double wSum = 0.0;
156        foreach (double otherClass in problemData.ClassValues) {
157          if (!c.IsAlmost(otherClass)) {
158            wSum += problemData.GetClassificationPenalty(c, otherClass);
159          }
160        }
[8609]161        weightLabels.Add((int)c);
162        weights.Add(wSum);
[6812]163      }
[8609]164      parameter.weight_label = weightLabels.ToArray();
165      parameter.weight = weights.ToArray();
[5626]166
[6812]167
[8609]168      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
169      RangeTransform rangeTransform = RangeTransform.Compute(problem);
170      svm_problem scaledProblem = rangeTransform.Scale(problem);
171      var svmModel = svm.svm_train(scaledProblem, parameter);
[7430]172      var model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues);
173      var solution = new SupportVectorClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
[5626]174
[8609]175      nSv = svmModel.SV.Length;
[7430]176      trainingAccuracy = solution.TrainingAccuracy;
177      testAccuracy = solution.TestAccuracy;
178
179      return solution;
[5626]180    }
[8609]181
182    private static int GetSvmType(string svmType) {
183      if (svmType == "NU_SVC") return svm_parameter.NU_SVC;
184      if (svmType == "C_SVC") return svm_parameter.C_SVC;
185      throw new ArgumentException("Unknown SVM type");
186    }
187
188    private static int GetKernelType(string kernelType) {
189      if (kernelType == "LINEAR") return svm_parameter.LINEAR;
190      if (kernelType == "POLY") return svm_parameter.POLY;
191      if (kernelType == "SIGMOID") return svm_parameter.SIGMOID;
192      if (kernelType == "RBF") return svm_parameter.RBF;
193      throw new ArgumentException("Unknown kernel type");
194    }
[5626]195    #endregion
196  }
197}
Note: See TracBrowser for help on using the repository browser.