Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/29/15 18:39:53 (8 years ago)
Author:
gkronber
Message:

#1998: merged changesets r10551:13084 (only on HeuristicLab.Algorithms.DataAnalysis) from trunk to branch

Location:
branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis

  • branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorRegression.cs

    r10556 r13085  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    3737  /// </summary>
    3838  [Item("Support Vector Regression", "Support vector machine regression data analysis algorithm (wrapper for libSVM).")]
    39   [Creatable("Data Analysis")]
     39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 110)]
    4040  [StorableClass]
    4141  public sealed class SupportVectorRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     
    4747    private const string EpsilonParameterName = "Epsilon";
    4848    private const string DegreeParameterName = "Degree";
     49    private const string CreateSolutionParameterName = "CreateSolution";
    4950
    5051    #region parameter properties
     
    6970    public IValueParameter<IntValue> DegreeParameter {
    7071      get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; }
     72    }
     73    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
     74      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
    7175    }
    7276    #endregion
     
    9498    public IntValue Degree {
    9599      get { return DegreeParameter.Value; }
     100    }
     101    public bool CreateSolution {
     102      get { return CreateSolutionParameter.Value.Value; }
     103      set { CreateSolutionParameter.Value.Value = value; }
    96104    }
    97105    #endregion
     
    120128      Parameters.Add(new ValueParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR.", new DoubleValue(0.1)));
    121129      Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     130      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     131      Parameters[CreateSolutionParameterName].Hidden = true;
    122132    }
    123133    [StorableHook(HookType.AfterDeserialization)]
    124134    private void AfterDeserialization() {
    125135      #region backwards compatibility (change with 3.4)
    126       if (!Parameters.ContainsKey(DegreeParameterName))
    127         Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     136
     137      if (!Parameters.ContainsKey(DegreeParameterName)) {
     138        Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName,
     139          "The degree parameter for the polynomial kernel function.", new IntValue(3)));
     140      }
     141      if (!Parameters.ContainsKey(CreateSolutionParameterName)) {
     142        Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
     143        Parameters[CreateSolutionParameterName].Hidden = true;
     144      }
    128145      #endregion
    129146    }
     
    137154      IRegressionProblemData problemData = Problem.ProblemData;
    138155      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
    139       double trainR2, testR2;
    140156      int nSv;
    141       var solution = CreateSupportVectorRegressionSolution(problemData, selectedInputVariables, SvmType.Value,
    142         KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, Degree.Value,
    143         out trainR2, out testR2, out nSv);
    144 
    145       Results.Add(new Result("Support vector regression solution", "The support vector regression solution.", solution));
    146       Results.Add(new Result("Training R²", "The Pearson's R² of the SVR solution on the training partition.", new DoubleValue(trainR2)));
    147       Results.Add(new Result("Test R²", "The Pearson's R² of the SVR solution on the test partition.", new DoubleValue(testR2)));
     157      ISupportVectorMachineModel model;
     158      Run(problemData, selectedInputVariables, SvmType.Value, KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, Degree.Value, out model, out nSv);
     159
     160      if (CreateSolution) {
     161        var solution = new SupportVectorRegressionSolution((SupportVectorMachineModel)model, (IRegressionProblemData)problemData.Clone());
     162        Results.Add(new Result("Support vector regression solution", "The support vector regression solution.", solution));
     163      }
     164
    148165      Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", new IntValue(nSv)));
    149     }
    150 
    151     public static SupportVectorRegressionSolution CreateSupportVectorRegressionSolution(IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
     166
     167
     168      {
     169        // calculate regression model metrics
     170        var ds = problemData.Dataset;
     171        var trainRows = problemData.TrainingIndices;
     172        var testRows = problemData.TestIndices;
     173        var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows);
     174        var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows);
     175        var yPredTrain = model.GetEstimatedValues(ds, trainRows).ToArray();
     176        var yPredTest = model.GetEstimatedValues(ds, testRows).ToArray();
     177
     178        OnlineCalculatorError error;
     179        var trainMse = OnlineMeanSquaredErrorCalculator.Calculate(yPredTrain, yTrain, out error);
     180        if (error != OnlineCalculatorError.None) trainMse = double.MaxValue;
     181        var testMse = OnlineMeanSquaredErrorCalculator.Calculate(yPredTest, yTest, out error);
     182        if (error != OnlineCalculatorError.None) testMse = double.MaxValue;
     183
     184        Results.Add(new Result("Mean squared error (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainMse)));
     185        Results.Add(new Result("Mean squared error (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testMse)));
     186
     187
     188        var trainMae = OnlineMeanAbsoluteErrorCalculator.Calculate(yPredTrain, yTrain, out error);
     189        if (error != OnlineCalculatorError.None) trainMae = double.MaxValue;
     190        var testMae = OnlineMeanAbsoluteErrorCalculator.Calculate(yPredTest, yTest, out error);
     191        if (error != OnlineCalculatorError.None) testMae = double.MaxValue;
     192
     193        Results.Add(new Result("Mean absolute error (training)", "The mean of absolute errors of the SVR solution on the training partition.", new DoubleValue(trainMae)));
     194        Results.Add(new Result("Mean absolute error (test)", "The mean of absolute errors of the SVR solution on the test partition.", new DoubleValue(testMae)));
     195
     196
     197        var trainRelErr = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(yPredTrain, yTrain, out error);
     198        if (error != OnlineCalculatorError.None) trainRelErr = double.MaxValue;
     199        var testRelErr = OnlineMeanAbsolutePercentageErrorCalculator.Calculate(yPredTest, yTest, out error);
     200        if (error != OnlineCalculatorError.None) testRelErr = double.MaxValue;
     201
     202        Results.Add(new Result("Average relative error (training)", "The mean of relative errors of the SVR solution on the training partition.", new DoubleValue(trainRelErr)));
     203        Results.Add(new Result("Average relative error (test)", "The mean of relative errors of the SVR solution on the test partition.", new DoubleValue(testRelErr)));
     204      }
     205    }
     206
     207    // BackwardsCompatibility3.4
     208    #region Backwards compatible code, remove with 3.5
     209    // for compatibility with old API
     210    public static SupportVectorRegressionSolution CreateSupportVectorRegressionSolution(
     211      IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
    152212      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon, int degree,
    153213      out double trainingR2, out double testR2, out int nSv) {
    154       Dataset dataset = problemData.Dataset;
     214      ISupportVectorMachineModel model;
     215      Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, epsilon, degree, out model, out nSv);
     216
     217      var solution = new SupportVectorRegressionSolution((SupportVectorMachineModel)model, (IRegressionProblemData)problemData.Clone());
     218      trainingR2 = solution.TrainingRSquared;
     219      testR2 = solution.TestRSquared;
     220      return solution;
     221    }
     222    #endregion
     223
     224    public static void Run(IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
     225      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon, int degree,
     226      out ISupportVectorMachineModel model, out int nSv) {
     227      var dataset = problemData.Dataset;
    155228      string targetVariable = problemData.TargetVariable;
    156229      IEnumerable<int> rows = problemData.TrainingIndices;
    157230
    158231      //extract SVM parameters from scope and set them
    159       svm_parameter parameter = new svm_parameter();
    160       parameter.svm_type = GetSvmType(svmType);
    161       parameter.kernel_type = GetKernelType(kernelType);
    162       parameter.C = cost;
    163       parameter.nu = nu;
    164       parameter.gamma = gamma;
    165       parameter.p = epsilon;
    166       parameter.cache_size = 500;
    167       parameter.probability = 0;
    168       parameter.eps = 0.001;
    169       parameter.degree = degree;
    170       parameter.shrinking = 1;
    171       parameter.coef0 = 0;
    172 
    173 
     232      svm_parameter parameter = new svm_parameter {
     233        svm_type = GetSvmType(svmType),
     234        kernel_type = GetKernelType(kernelType),
     235        C = cost,
     236        nu = nu,
     237        gamma = gamma,
     238        p = epsilon,
     239        cache_size = 500,
     240        probability = 0,
     241        eps = 0.001,
     242        degree = degree,
     243        shrinking = 1,
     244        coef0 = 0
     245      };
    174246
    175247      svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
     
    178250      var svmModel = svm.svm_train(scaledProblem, parameter);
    179251      nSv = svmModel.SV.Length;
    180       var model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables);
    181       var solution = new SupportVectorRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
    182       trainingR2 = solution.TrainingRSquared;
    183       testR2 = solution.TestRSquared;
    184       return solution;
     252
     253      model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables);
    185254    }
    186255
Note: See TracChangeset for help on using the changeset viewer.