Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/25/15 14:39:59 (9 years ago)
Author:
gkronber
Message:

#2478 merged all changes from trunk to branch before trunk-reintegration

Location:
branches/gteufl
Files:
8 edited

Legend:

Unmodified
Added
Removed
  • branches/gteufl

  • branches/gteufl/HeuristicLab.Algorithms.DataAnalysis

  • branches/gteufl/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs

    r9456 r12969  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    3737  /// </summary>
    3838  [Item("Support Vector Classification", "Support vector machine classification data analysis algorithm (wrapper for libSVM).")]
    39   [Creatable("Data Analysis")]
     39  [Creatable(CreatableAttribute.Categories.DataAnalysisClassification, Priority = 110)]
    4040  [StorableClass]
    4141  public sealed class SupportVectorClassification : FixedDataAnalysisAlgorithm<IClassificationProblem> {
     
    4646    private const string GammaParameterName = "Gamma";
    4747    private const string DegreeParameterName = "Degree";
     48    private const string CreateSolutionParameterName = "CreateSolution";
    4849
    4950    #region parameter properties
     
    6566    public IValueParameter<IntValue> DegreeParameter {
    6667      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
     68    }
     69    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
     70      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    6771    }
    6872    #endregion
     
    8791    public IntValue Degree {
    8892      get { return DegreeParameter.Value; }
     93    }
     94    public bool CreateSolution {
     95      get { return CreateSolutionParameter.Value.Value; }
     96      set { CreateSolutionParameter.Value.Value = value; }
    8997    }
    9098    #endregion
     
    112120      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
    113121      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     122      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     123      Parameters[CreateSolutionParameterName].Hidden = true;
    114124    }
    115125    [StorableHook(HookType.AfterDeserialization)]
    116126    private void AfterDeserialization() {
    117127      #region backwards compatibility (change with 3.4)
    118       if (!Parameters.ContainsKey(DegreeParameterName))
    119         Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     128      if (!Parameters.ContainsKey(DegreeParameterName)) {
     129        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
     130          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     131      }
     132      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
     133        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,
     134          "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     135        Parameters[CreateSolutionParameterName].Hidden = true;
     136      }
    120137      #endregion
    121138    }
     
    129146      IClassificationProblemData problemData = Problem.ProblemData;
    130147      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
    131       double trainingAccuracy, testAccuracy;
    132148      int nSv;
    133       var solution = CreateSupportVectorClassificationSolution(problemData, selectedInputVariables,
    134         SvmType.Value, KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Degree.Value,
     149      ISupportVectorMachineModel model;
     150
     151      Run(problemData, selectedInputVariables, GetSvmType(SvmType.Value), GetKernelType(KernelType.Value), Cost.Value, Nu.Value, Gamma.Value, Degree.Value, out model, out nSv);
     152
     153      if (CreateSolution) {
     154        var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
     155        Results.Add(new Result("Support vector classification solution", "The support vector classification solution.",
     156          solution));
     157      }
     158
     159      {
     160        // calculate classification metrics
     161        // calculate regression model metrics
     162        var ds = problemData.Dataset;
     163        var trainRows = problemData.TrainingIndices;
     164        var testRows = problemData.TestIndices;
     165        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
     166        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
     167        var yPredTrain = model.GetEstimatedClassValues(ds, trainRows);
     168        var yPredTest = model.GetEstimatedClassValues(ds, testRows);
     169
     170        OnlineCalculatorError error;
     171        var trainAccuracy = OnlineAccuracyCalculator.Calculate(yPredTrain, yTrain, out error);
     172        if (error != OnlineCalculatorError.None) trainAccuracy = double.MaxValue;
     173        var testAccuracy = OnlineAccuracyCalculator.Calculate(yPredTest, yTest, out error);
     174        if (error != OnlineCalculatorError.None) testAccuracy = double.MaxValue;
     175
     176        Results.Add(new Result("Accuracy (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainAccuracy)));
     177        Results.Add(new Result("Accuracy (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testAccuracy)));
     178
     179        Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.",
     180          new IntValue(nSv)));
     181      }
     182    }
     183
     184    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
     185      string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
     186      return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetSvmType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree,
    135187        out trainingAccuracy, out testAccuracy, out nSv);
    136 
    137       Results.Add(new Result("Support vector classification solution", "The support vector classification solution.", solution));
    138       Results.Add(new Result("Training accuracy", "The accuracy of the SVR solution on the training partition.", new DoubleValue(trainingAccuracy)));
    139       Results.Add(new Result("Test accuracy", "The accuracy of the SVR solution on the test partition.", new DoubleValue(testAccuracy)));
    140       Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", new IntValue(nSv)));
    141     }
    142 
     188    }
     189
     190    // BackwardsCompatibility3.4
     191    #region Backwards compatible code, remove with 3.5
    143192    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
    144       string svmType, string kernelType, double cost, double nu, double gamma, int degree,
    145       out double trainingAccuracy, out double testAccuracy, out int nSv) {
    146       Dataset dataset = problemData.Dataset;
     193      int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
     194
     195      ISupportVectorMachineModel model;
     196      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, degree, out model, out nSv);
     197      var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone());
     198
     199      trainingAccuracy = solution.TrainingAccuracy;
     200      testAccuracy = solution.TestAccuracy;
     201
     202      return solution;
     203    }
     204
     205    #endregion
     206
     207    public static void Run(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
     208      int svmType, int kernelType, double cost, double nu, double gamma, int degree,
     209      out ISupportVectorMachineModel model, out int nSv) {
     210      var dataset = problemData.Dataset;
    147211      string targetVariable = problemData.TargetVariable;
    148212      IEnumerable<int> rows = problemData.TrainingIndices;
    149213
    150214      //extract SVM parameters from scope and set them
    151       svm_parameter parameter = new svm_parameter();
    152       parameter.svm_type = GetSvmType(svmType);
    153       parameter.kernel_type = GetKernelType(kernelType);
    154       parameter.C = cost;
    155       parameter.nu = nu;
    156       parameter.gamma = gamma;
    157       parameter.cache_size = 500;
    158       parameter.probability = 0;
    159       parameter.eps = 0.001;
    160       parameter.degree = degree;
    161       parameter.shrinking = 1;
    162       parameter.coef0 = 0;
    163 
     215      svm_parameter parameter = new svm_parameter {
     216        svm_type = svmType,
     217        kernel_type = kernelType,
     218        C = cost,
     219        nu = nu,
     220        gamma = gamma,
     221        cache_size = 500,
     222        probability = 0,
     223        eps = 0.001,
     224        degree = degree,
     225        shrinking = 1,
     226        coef0 = 0
     227      };
    164228
    165229      var weightLabels = new List<int>();
     
    178242      parameter.weight = weights.ToArray();
    179243
    180 
    181244      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
    182245      RangeTransform rangeTransform = RangeTransform.Compute(problem);
    183246      svm_problem scaledProblem = rangeTransform.Scale(problem);
    184247      var svmModel = svm.svm_train(scaledProblem, parameter);
    185       var model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues);
    186       var solution = new SupportVectorClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
    187 
    188248      nSv = svmModel.SV.Length;
    189       trainingAccuracy = solution.TrainingAccuracy;
    190       testAccuracy = solution.TestAccuracy;
    191 
    192       return solution;
     249
     250      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues);
    193251    }
    194252
  • branches/gteufl/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassificationSolution.cs

    r9456 r12969  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
  • branches/gteufl/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineModel.cs

    r9456 r12969  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    120120
    121121    #region IRegressionModel Members
    122     public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
     122    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    123123      return GetEstimatedValuesHelper(dataset, rows);
    124124    }
     
    132132
    133133    #region IClassificationModel Members
    134     public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     134    public IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    135135      if (classValues == null) throw new NotSupportedException();
    136136      // return the original class value instead of the predicted value of the model
     
    159159    }
    160160    #endregion
    161     private IEnumerable<double> GetEstimatedValuesHelper(Dataset dataset, IEnumerable<int> rows) {
     161    private IEnumerable<double> GetEstimatedValuesHelper(IDataset dataset, IEnumerable<int> rows) {
    162162      // calculate predictions for the currently requested rows
    163163      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
  • branches/gteufl/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs

    r9456 r12969  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using System.Linq.Expressions;
     26using System.Threading.Tasks;
     27using HeuristicLab.Common;
     28using HeuristicLab.Core;
     29using HeuristicLab.Data;
    2430using HeuristicLab.Problems.DataAnalysis;
     31using HeuristicLab.Random;
    2532using LibSVM;
    2633
     
    3340    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
    3441    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
    35     public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
    36       double[] targetVector =
    37         dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
    38 
     42    public static svm_problem CreateSvmProblem(IDataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
     43      double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
    3944      svm_node[][] nodes = new svm_node[targetVector.Length][];
    40       List<svm_node> tempRow;
    4145      int maxNodeIndex = 0;
    4246      int svmProblemRowIndex = 0;
    4347      List<string> inputVariablesList = inputVariables.ToList();
    4448      foreach (int row in rowIndices) {
    45         tempRow = new List<svm_node>();
     49        List<svm_node> tempRow = new List<svm_node>();
    4650        int colIndex = 1; // make sure the smallest node index for SVM = 1
    4751        foreach (var inputVariable in inputVariablesList) {
     
    5054          // => don't add NaN values in the dataset to the sparse SVM matrix representation
    5155          if (!double.IsNaN(value)) {
    52             tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index
     56            tempRow.Add(new svm_node() { index = colIndex, value = value });
     57            // nodes must be sorted in ascending ordered by column index
    5358            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
    5459          }
     
    5762        nodes[svmProblemRowIndex++] = tempRow.ToArray();
    5863      }
    59 
    60       return new svm_problem() { l = targetVector.Length, y = targetVector, x = nodes };
     64      return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes };
     65    }
     66
     67    /// <summary>
     68    /// Instantiate and return a svm_parameter object with default values.
     69    /// </summary>
     70    /// <returns>A svm_parameter object with default values</returns>
     71    public static svm_parameter DefaultParameters() {
     72      svm_parameter parameter = new svm_parameter();
     73      parameter.svm_type = svm_parameter.NU_SVR;
     74      parameter.kernel_type = svm_parameter.RBF;
     75      parameter.C = 1;
     76      parameter.nu = 0.5;
     77      parameter.gamma = 1;
     78      parameter.p = 1;
     79      parameter.cache_size = 500;
     80      parameter.probability = 0;
     81      parameter.eps = 0.001;
     82      parameter.degree = 3;
     83      parameter.shrinking = 1;
     84      parameter.coef0 = 0;
     85
     86      return parameter;
     87    }
     88
     89    public static double CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, bool shuffleFolds = true) {
     90      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
     91      return CalculateCrossValidationPartitions(partitions, parameters);
     92    }
     93
     94    public static svm_parameter GridSearch(out double cvMse, IDataAnalysisProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int numberOfFolds, bool shuffleFolds = true, int maxDegreeOfParallelism = 1) {
     95      DoubleValue mse = new DoubleValue(Double.MaxValue);
     96      var bestParam = DefaultParameters();
     97      var crossProduct = parameterRanges.Values.CartesianProduct();
     98      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     99      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
     100
     101      var locker = new object(); // for thread synchronization
     102      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism },
     103      parameterCombination => {
     104        var parameters = DefaultParameters();
     105        var parameterValues = parameterCombination.ToList();
     106        for (int i = 0; i < parameterValues.Count; ++i)
     107          setters[i](parameters, parameterValues[i]);
     108
     109        double testMse = CalculateCrossValidationPartitions(partitions, parameters);
     110        if (!double.IsNaN(testMse)) {
     111          lock (locker) {
     112            if (testMse < mse.Value) {
     113              mse.Value = testMse;
     114              bestParam = (svm_parameter)parameters.Clone();
     115            }
     116          }
     117        }
     118      });
     119      cvMse = mse.Value;
     120      return bestParam;
     121    }
     122
     123    private static double CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters) {
     124      double avgTestMse = 0;
     125      var calc = new OnlineMeanSquaredErrorCalculator();
     126      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
     127        var trainingSvmProblem = tuple.Item1;
     128        var testSvmProblem = tuple.Item2;
     129        var model = svm.svm_train(trainingSvmProblem, parameters);
     130        calc.Reset();
     131        for (int i = 0; i < testSvmProblem.l; ++i)
     132          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
     133        double mse = calc.ErrorState == OnlineCalculatorError.None ? calc.MeanSquaredError : double.NaN;
     134        avgTestMse += mse;
     135      }
     136      avgTestMse /= partitions.Length;
     137      return avgTestMse;
     138    }
     139
     140    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
     141      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
     142      var targetVariable = GetTargetVariableName(problemData);
     143      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
     144      for (int i = 0; i < numberOfFolds; ++i) {
     145        int p = i; // avoid "access to modified closure" warning below
     146        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
     147        var testRows = folds[i];
     148        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
     149        var rangeTransform = RangeTransform.Compute(trainingSvmProblem);
     150        var testSvmProblem = rangeTransform.Scale(CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows));
     151        partitions[i] = new Tuple<svm_problem, svm_problem>(rangeTransform.Scale(trainingSvmProblem), testSvmProblem);
     152      }
     153      return partitions;
     154    }
     155
     156    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
     157      var random = new MersenneTwister((uint)Environment.TickCount);
     158      if (problemData is IRegressionProblemData) {
     159        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
     160        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
     161      }
     162      if (problemData is IClassificationProblemData) {
     163        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
     164        // otherwise, generate folds normally
     165        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
     166      }
     167      throw new ArgumentException("Problem data is neither regression or classification problem data.");
     168    }
     169
     170    /// <summary>
     171    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
     172    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
     173    /// the corresponding parts from each class label.
     174    /// </summary>
     175    /// <param name="problemData">The classification problem data.</param>
     176    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
     177    /// <param name="random">The random generator used to shuffle the folds.</param>
     178    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
     179    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
     180      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     181      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
     182      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
     183      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
     184      while (enumerators.All(e => e.MoveNext())) {
     185        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
     186      }
     187    }
     188
     189    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
     190      // if number of folds is greater than the number of values, some empty folds will be returned
     191      if (valuesCount < numberOfFolds) {
     192        for (int i = 0; i < numberOfFolds; ++i)
     193          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
     194      } else {
     195        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
     196        int start = 0, end = f;
     197        for (int i = 0; i < numberOfFolds; ++i) {
     198          if (r > 0) {
     199            ++end;
     200            --r;
     201          }
     202          yield return values.Skip(start).Take(end - start);
     203          start = end;
     204          end += f;
     205        }
     206      }
     207    }
     208
     209    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
     210      var targetExp = Expression.Parameter(typeof(svm_parameter));
     211      var valueExp = Expression.Parameter(typeof(double));
     212      var fieldExp = Expression.Field(targetExp, fieldName);
     213      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
     214      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
     215      return setter;
     216    }
     217
     218    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
     219      var regressionProblemData = problemData as IRegressionProblemData;
     220      var classificationProblemData = problemData as IClassificationProblemData;
     221
     222      if (regressionProblemData != null)
     223        return regressionProblemData.TargetVariable;
     224      if (classificationProblemData != null)
     225        return classificationProblemData.TargetVariable;
     226
     227      throw new ArgumentException("Problem data is neither regression or classification problem data.");
    61228    }
    62229  }
  • branches/gteufl/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorRegression.cs

    r9456 r12969  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    3737  /// </summary>
    3838  [Item("Support Vector Regression", "Support vector machine regression data analysis algorithm (wrapper for libSVM).")]
    39   [Creatable("Data Analysis")]
     39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 110)]
    4040  [StorableClass]
    4141  public sealed class SupportVectorRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     
    4747    private const string EpsilonParameterName = "Epsilon";
    4848    private const string DegreeParameterName = "Degree";
     49    private const string CreateSolutionParameterName = "CreateSolution";
    4950
    5051    #region parameter properties
     
    6970    public IValueParameter<IntValue> DegreeParameter {
    7071      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
     72    }
     73    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
     74      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    7175    }
    7276    #endregion
     
    9498    public IntValue Degree {
    9599      get { return DegreeParameter.Value; }
     100    }
     101    public bool CreateSolution {
     102      get { return CreateSolutionParameter.Value.Value; }
     103      set { CreateSolutionParameter.Value.Value = value; }
    96104    }
    97105    #endregion
     
    120128      Parameters.Add(new ValueParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR.", new DoubleValue(0.1)));
    121129      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     130      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     131      Parameters[CreateSolutionParameterName].Hidden = true;
    122132    }
    123133    [StorableHook(HookType.AfterDeserialization)]
    124134    private void AfterDeserialization() {
    125135      #region backwards compatibility (change with 3.4)
    126       if (!Parameters.ContainsKey(DegreeParameterName))
    127         Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     136
     137      if (!Parameters.ContainsKey(DegreeParameterName)) {
     138        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
     139          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     140      }
     141      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
     142        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     143        Parameters[CreateSolutionParameterName].Hidden = true;
     144      }
    128145      #endregion
    129146    }
     
    137154      IRegressionProblemData problemData = Problem.ProblemData;
    138155      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
    139       double trainR2, testR2;
    140156      int nSv;
    141       var solution = CreateSupportVectorRegressionSolution(problemData, selectedInputVariables, SvmType.Value,
    142         KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, Degree.Value,
    143         out trainR2, out testR2, out nSv);
    144 
    145       Results.Add(new Result("Support vector regression solution", "The support vector regression solution.", solution));
    146       Results.Add(new Result("Training R²", "The Pearson's R² of the SVR solution on the training partition.", new DoubleValue(trainR2)));
    147       Results.Add(new Result("Test R²", "The Pearson's R² of the SVR solution on the test partition.", new DoubleValue(testR2)));
     157      ISupportVectorMachineModel model;
     158      Run(problemData, selectedInputVariables, SvmType.Value, KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, Degree.Value, out model, out nSv);
     159
     160      if (CreateSolution) {
     161        var solution = new SupportVectorRegressionSolution((SupportVectorMachineModel)model, (IRegressionProblemData)problemData.Clone());
     162        Results.Add(new Result("Support vector regression solution", "The support vector regression solution.", solution));
     163      }
     164
    148165      Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", new IntValue(nSv)));
    149     }
    150 
    151     public static SupportVectorRegressionSolution CreateSupportVectorRegressionSolution(IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
     166
     167
     168      {
     169        // calculate regression model metrics
     170        var ds = problemData.Dataset;
     171        var trainRows = problemData.TrainingIndices;
     172        var testRows = problemData.TestIndices;
     173        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
     174        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
     175        var yPredTrain = model.GetEstimatedValues(ds, trainRows).ToArray();
     176        var yPredTest = model.GetEstimatedValues(ds, testRows).ToArray();
     177
     178        OnlineCalculatorError error;
     179        var trainMse = OnlineMeanSquaredErrorCalculator.Calculate(yPredTrain, yTrain, out error);
     180        if (error != OnlineCalculatorError.None) trainMse = double.MaxValue;
     181        var testMse = OnlineMeanSquaredErrorCalculator.Calculate(yPredTest, yTest, out error);
     182        if (error != OnlineCalculatorError.None) testMse = double.MaxValue;
     183
     184        Results.Add(new Result("Mean squared error (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainMse)));
     185        Results.Add(new Result("Mean squared error (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testMse)));
     186
     187
     188        var trainMae = OnlineMeanAbsoluteErrorCalculator.Calculate(yPredTrain, yTrain, out error);
     189        if (error != OnlineCalculatorError.None) trainMae = double.MaxValue;
     190        var testMae = OnlineMeanAbsoluteErrorCalculator.Calculate(yPredTest, yTest, out error);
     191        if (error != OnlineCalculatorError.None) testMae = double.MaxValue;
     192
     193        Results.Add(new Result("Mean absolute error (training)", "The mean of absolute errors of the SVR solution on the training partition.", new DoubleValue(trainMae)));
     194        Results.Add(new Result("Mean absolute error (test)", "The mean of absolute errors of the SVR solution on the test partition.", new DoubleValue(testMae)));
     195
     196
     197        var trainRelErr = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(yPredTrain, yTrain, out error);
     198        if (error != OnlineCalculatorError.None) trainRelErr = double.MaxValue;
     199        var testRelErr = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(yPredTest, yTest, out error);
     200        if (error != OnlineCalculatorError.None) testRelErr = double.MaxValue;
     201
     202        Results.Add(new Result("Average relative error (training)", "The mean of relative errors of the SVR solution on the training partition.", new DoubleValue(trainRelErr)));
     203        Results.Add(new Result("Average relative error (test)", "The mean of relative errors of the SVR solution on the test partition.", new DoubleValue(testRelErr)));
     204      }
     205    }
     206
     207    // BackwardsCompatibility3.4
     208    #region Backwards compatible code, remove with 3.5
     209    // for compatibility with old API
     210    public static SupportVectorRegressionSolution CreateSupportVectorRegressionSolution(
     211      IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
    152212      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon, int degree,
    153213      out double trainingR2, out double testR2, out int nSv) {
    154       Dataset dataset = problemData.Dataset;
     214      ISupportVectorMachineModel model;
     215      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, epsilon, degree, out model, out nSv);
     216
     217      var solution = new SupportVectorRegressionSolution((SupportVectorMachineModel)model, (IRegressionProblemData)problemData.Clone());
     218      trainingR2 = solution.TrainingRSquared;
     219      testR2 = solution.TestRSquared;
     220      return solution;
     221    }
     222    #endregion
     223
     224    public static void Run(IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
     225      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon, int degree,
     226      out ISupportVectorMachineModel model, out int nSv) {
     227      var dataset = problemData.Dataset;
    155228      string targetVariable = problemData.TargetVariable;
    156229      IEnumerable<int> rows = problemData.TrainingIndices;
    157230
    158231      //extract SVM parameters from scope and set them
    159       svm_parameter parameter = new svm_parameter();
    160       parameter.svm_type = GetSvmType(svmType);
    161       parameter.kernel_type = GetKernelType(kernelType);
    162       parameter.C = cost;
    163       parameter.nu = nu;
    164       parameter.gamma = gamma;
    165       parameter.p = epsilon;
    166       parameter.cache_size = 500;
    167       parameter.probability = 0;
    168       parameter.eps = 0.001;
    169       parameter.degree = degree;
    170       parameter.shrinking = 1;
    171       parameter.coef0 = 0;
    172 
    173 
     232      svm_parameter parameter = new svm_parameter {
     233        svm_type = GetSvmType(svmType),
     234        kernel_type = GetKernelType(kernelType),
     235        C = cost,
     236        nu = nu,
     237        gamma = gamma,
     238        p = epsilon,
     239        cache_size = 500,
     240        probability = 0,
     241        eps = 0.001,
     242        degree = degree,
     243        shrinking = 1,
     244        coef0 = 0
     245      };
    174246
    175247      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
     
    178250      var svmModel = svm.svm_train(scaledProblem, parameter);
    179251      nSv = svmModel.SV.Length;
    180       var model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables);
    181       var solution = new SupportVectorRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
    182       trainingR2 = solution.TrainingRSquared;
    183       testR2 = solution.TestRSquared;
    184       return solution;
     252
     253      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables);
    185254    }
    186255
  • branches/gteufl/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorRegressionSolution.cs

    r9456 r12969  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
Note: See TracChangeset for help on using the changeset viewer.