#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HEAL.Attic; using HeuristicLab.Problems.DataAnalysis; using LibSVM; namespace HeuristicLab.Algorithms.DataAnalysis { /// /// Support vector machine classification data analysis algorithm. /// [Item("Support Vector Classification (SVM)", "Support vector machine classification data analysis algorithm (wrapper for libSVM).")] [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 110)] [StorableType("F15289E4-B648-4A92-AB01-14D769A33967")] public sealed class SupportVectorClassification : FixedDataAnalysisAlgorithm { private const string SvmTypeParameterName = "SvmType"; private const string KernelTypeParameterName = "KernelType"; private const string CostParameterName = "Cost"; private const string NuParameterName = "Nu"; private const string GammaParameterName = "Gamma"; private const string DegreeParameterName = "Degree"; private const string CreateSolutionParameterName = "CreateSolution"; #region parameter properties public IConstrainedValueParameter SvmTypeParameter { get { return (IConstrainedValueParameter)Parameters[SvmTypeParameterName]; } } public IConstrainedValueParameter KernelTypeParameter { get { return (IConstrainedValueParameter)Parameters[KernelTypeParameterName]; } } public IValueParameter NuParameter { get { return (IValueParameter)Parameters[NuParameterName]; } } public IValueParameter CostParameter { get { return (IValueParameter)Parameters[CostParameterName]; } } public IValueParameter GammaParameter { get { return (IValueParameter)Parameters[GammaParameterName]; } } public IValueParameter DegreeParameter { get { return (IValueParameter)Parameters[DegreeParameterName]; } } public IFixedValueParameter CreateSolutionParameter { get { return (IFixedValueParameter)Parameters[CreateSolutionParameterName]; } } #endregion #region properties public StringValue SvmType { get { return SvmTypeParameter.Value; } set { SvmTypeParameter.Value = value; } } public StringValue KernelType { get { return KernelTypeParameter.Value; } set { KernelTypeParameter.Value = value; } } public DoubleValue Nu { get { return NuParameter.Value; } } public DoubleValue Cost { get { return CostParameter.Value; } } public DoubleValue Gamma { get { return GammaParameter.Value; } } public IntValue Degree { get { return DegreeParameter.Value; } } public bool CreateSolution { get { return CreateSolutionParameter.Value.Value; } set { CreateSolutionParameter.Value.Value = value; } } #endregion [StorableConstructor] private SupportVectorClassification(StorableConstructorFlag _) : base(_) { } private SupportVectorClassification(SupportVectorClassification original, Cloner cloner) : base(original, cloner) { } public SupportVectorClassification() : base() { Problem = new ClassificationProblem(); List svrTypes = (from type in new List { "NU_SVC", "C_SVC" } select new StringValue(type).AsReadOnly()) .ToList(); ItemSet svrTypeSet = new ItemSet(svrTypes); List kernelTypes = (from type in new List { "LINEAR", "POLY", "SIGMOID", "RBF" } select new StringValue(type).AsReadOnly()) .ToList(); ItemSet kernelTypeSet = new ItemSet(kernelTypes); Parameters.Add(new ConstrainedValueParameter(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0])); Parameters.Add(new ConstrainedValueParameter(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3])); Parameters.Add(new ValueParameter(NuParameterName, "The value of the nu parameter nu-SVC.", new DoubleValue(0.5))); Parameters.Add(new ValueParameter(CostParameterName, "The value of the C (cost) parameter of C-SVC.", new DoubleValue(1.0))); Parameters.Add(new ValueParameter(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0))); Parameters.Add(new ValueParameter(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3))); Parameters.Add(new FixedValueParameter(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); Parameters[CreateSolutionParameterName].Hidden = true; } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { #region backwards compatibility (change with 3.4) if (!Parameters.ContainsKey(DegreeParameterName)) { Parameters.Add(new ValueParameter(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3))); } if (!Parameters.ContainsKey(CreateSolutionParameterName)) { Parameters.Add(new FixedValueParameter(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); Parameters[CreateSolutionParameterName].Hidden = true; } #endregion } public override IDeepCloneable Clone(Cloner cloner) { return new SupportVectorClassification(this, cloner); } #region support vector classification protected override void Run(CancellationToken cancellationToken) { IClassificationProblemData problemData = Problem.ProblemData; IEnumerable selectedInputVariables = problemData.AllowedInputVariables; int nSv; ISupportVectorMachineModel model; Run(problemData, selectedInputVariables, GetSvmType(SvmType.Value), GetKernelType(KernelType.Value), Cost.Value, Nu.Value, Gamma.Value, Degree.Value, out model, out nSv); if (CreateSolution) { var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone()); Results.Add(new Result("Support vector classification solution", "The support vector classification solution.", solution)); } { // calculate classification metrics // calculate regression model metrics var ds = problemData.Dataset; var trainRows = problemData.TrainingIndices; var testRows = problemData.TestIndices; var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows); var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows); var yPredTrain = model.GetEstimatedClassValues(ds, trainRows); var yPredTest = model.GetEstimatedClassValues(ds, testRows); OnlineCalculatorError error; var trainAccuracy = OnlineAccuracyCalculator.Calculate(yPredTrain, yTrain, out error); if (error != OnlineCalculatorError.None) trainAccuracy = double.MaxValue; var testAccuracy = OnlineAccuracyCalculator.Calculate(yPredTest, yTest, out error); if (error != OnlineCalculatorError.None) testAccuracy = double.MaxValue; Results.Add(new Result("Accuracy (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainAccuracy))); Results.Add(new Result("Accuracy (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testAccuracy))); Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", new IntValue(nSv))); } } public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable allowedInputVariables, string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) { return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetSvmType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree, out trainingAccuracy, out testAccuracy, out nSv); } // BackwardsCompatibility3.4 #region Backwards compatible code, remove with 3.5 public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable allowedInputVariables, int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) { ISupportVectorMachineModel model; Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, degree, out model, out nSv); var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone()); trainingAccuracy = solution.TrainingAccuracy; testAccuracy = solution.TestAccuracy; return solution; } #endregion public static void Run(IClassificationProblemData problemData, IEnumerable allowedInputVariables, int svmType, int kernelType, double cost, double nu, double gamma, int degree, out ISupportVectorMachineModel model, out int nSv) { var dataset = problemData.Dataset; string targetVariable = problemData.TargetVariable; IEnumerable rows = problemData.TrainingIndices; svm_parameter parameter = new svm_parameter { svm_type = svmType, kernel_type = kernelType, C = cost, nu = nu, gamma = gamma, cache_size = 500, probability = 0, eps = 0.001, degree = degree, shrinking = 1, coef0 = 0 }; var weightLabels = new List(); var weights = new List(); foreach (double c in problemData.ClassValues) { double wSum = 0.0; foreach (double otherClass in problemData.ClassValues) { if (!c.IsAlmost(otherClass)) { wSum += problemData.GetClassificationPenalty(c, otherClass); } } weightLabels.Add((int)c); weights.Add(wSum); } parameter.weight_label = weightLabels.ToArray(); parameter.weight = weights.ToArray(); svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows); RangeTransform rangeTransform = RangeTransform.Compute(problem); svm_problem scaledProblem = rangeTransform.Scale(problem); var svmModel = svm.svm_train(scaledProblem, parameter); nSv = svmModel.SV.Length; model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues); } private static int GetSvmType(string svmType) { if (svmType == "NU_SVC") return svm_parameter.NU_SVC; if (svmType == "C_SVC") return svm_parameter.C_SVC; throw new ArgumentException("Unknown SVM type"); } private static int GetKernelType(string kernelType) { if (kernelType == "LINEAR") return svm_parameter.LINEAR; if (kernelType == "POLY") return svm_parameter.POLY; if (kernelType == "SIGMOID") return svm_parameter.SIGMOID; if (kernelType == "RBF") return svm_parameter.RBF; throw new ArgumentException("Unknown kernel type"); } #endregion } }