Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs @ 11337

Last change on this file since 11337 was 11337, checked in by bburlacu, 10 years ago

#2234: Refactored SVM grid search, added support for symbolic classification.

File size: 8.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Linq.Expressions;
26using System.Threading.Tasks;
27using HeuristicLab.Common;
28using HeuristicLab.Data;
29using HeuristicLab.Problems.DataAnalysis;
30using LibSVM;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  public class SupportVectorMachineUtil {
34    /// <summary>
35    /// Transforms <paramref name="problemData"/> into a data structure as needed by libSVM.
36    /// </summary>
37    /// <param name="problemData">The problem data to transform</param>
38    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
39    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
40    public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
41      double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
42      svm_node[][] nodes = new svm_node[targetVector.Length][];
43      int maxNodeIndex = 0;
44      int svmProblemRowIndex = 0;
45      List<string> inputVariablesList = inputVariables.ToList();
46      foreach (int row in rowIndices) {
47        List<svm_node> tempRow = new List<svm_node>();
48        int colIndex = 1; // make sure the smallest node index for SVM = 1
49        foreach (var inputVariable in inputVariablesList) {
50          double value = dataset.GetDoubleValue(inputVariable, row);
51          // SVM also works with missing values
52          // => don't add NaN values in the dataset to the sparse SVM matrix representation
53          if (!double.IsNaN(value)) {
54            tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index
55            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
56          }
57          colIndex++;
58        }
59        nodes[svmProblemRowIndex++] = tempRow.ToArray();
60      }
61      return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes };
62    }
63
64    /// <summary>
65    /// Instantiate and return a svm_parameter object with default values.
66    /// </summary>
67    /// <returns>A svm_parameter object with default values</returns>
68    public static svm_parameter DefaultParameters() {
69      svm_parameter parameter = new svm_parameter();
70      parameter.svm_type = svm_parameter.NU_SVR;
71      parameter.kernel_type = svm_parameter.RBF;
72      parameter.C = 1;
73      parameter.nu = 0.5;
74      parameter.gamma = 1;
75      parameter.p = 1;
76      parameter.cache_size = 500;
77      parameter.probability = 0;
78      parameter.eps = 0.001;
79      parameter.degree = 3;
80      parameter.shrinking = 1;
81      parameter.coef0 = 0;
82
83      return parameter;
84    }
85
86    /// <summary>
87    /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
88    /// </summary>
89    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
90    /// <param name="problemData">The problem data</param>
91    /// <param name="numberOfFolds">The number of folds to generate</param>
92    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
93    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
94      int size = problemData.TrainingPartition.Size;
95      int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
96      int start = 0, end = f;
97      for (int i = 0; i < numberOfFolds; ++i) {
98        if (r > 0) { ++end; --r; }
99        yield return problemData.TrainingIndices.Skip(start).Take(end - start);
100        start = end;
101        end += f;
102      }
103    }
104
105    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
106      var folds = GenerateFolds(problemData, numberOfFolds).ToList();
107      var targetVariable = GetTargetVariableName(problemData);
108      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
109      for (int i = 0; i < numberOfFolds; ++i) {
110        int p = i; // avoid "access to modified closure" warning below
111        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
112        var testRows = folds[i];
113        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
114        var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
115        partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem);
116      }
117      return partitions;
118    }
119
120    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
121      var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
122      CrossValidate(problemData, parameters, partitions, out avgTestMse);
123    }
124
125    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, Tuple<svm_problem, svm_problem>[] partitions, out double avgTestMse) {
126      avgTestMse = 0;
127      var calc = new OnlineMeanSquaredErrorCalculator();
128      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
129        var trainingSvmProblem = tuple.Item1;
130        var testSvmProblem = tuple.Item2;
131        var model = svm.svm_train(trainingSvmProblem, parameters);
132        calc.Reset();
133        for (int i = 0; i < testSvmProblem.l; ++i)
134          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
135        avgTestMse += calc.MeanSquaredError;
136      }
137      avgTestMse /= partitions.Length;
138    }
139
140    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
141      var targetExp = Expression.Parameter(typeof(svm_parameter));
142      var valueExp = Expression.Parameter(typeof(double));
143      var fieldExp = Expression.Field(targetExp, fieldName);
144      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
145      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
146      return setter;
147    }
148
149    public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
150      DoubleValue mse = new DoubleValue(Double.MaxValue);
151      var bestParam = DefaultParameters();
152      var pNames = parameterRanges.Keys.ToList();
153      var pRanges = pNames.Select(x => parameterRanges[x]);
154      var crossProduct = pRanges.CartesianProduct();
155      var setters = pNames.Select(GenerateSetter).ToList();
156      var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
157      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
158        var list = nuple.ToList();
159        var parameters = DefaultParameters();
160        for (int i = 0; i < pNames.Count; ++i) {
161          var s = setters[i];
162          s(parameters, list[i]);
163        }
164        double testMse;
165        CrossValidate(problemData, parameters, partitions, out testMse);
166        if (testMse < mse.Value) {
167          lock (mse) { mse.Value = testMse; }
168          lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); }
169        }
170      });
171      return bestParam;
172    }
173
174    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
175      var regressionProblemData = problemData as IRegressionProblemData;
176      var classificationProblemData = problemData as IClassificationProblemData;
177
178      if (regressionProblemData != null)
179        return regressionProblemData.TargetVariable;
180      if (classificationProblemData != null)
181        return classificationProblemData.TargetVariable;
182
183      throw new ArgumentException("Problem data is neither regression or classification problem data.");
184    }
185  }
186}
Note: See TracBrowser for help on using the repository browser.