Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs @ 11342

Last change on this file since 11342 was 11342, checked in by mkommend, 10 years ago

#2234: Corrected locking in SVM Util class.

File size: 8.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Linq.Expressions;
26using System.Threading.Tasks;
27using HeuristicLab.Common;
28using HeuristicLab.Data;
29using HeuristicLab.Problems.DataAnalysis;
30using LibSVM;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  public class SupportVectorMachineUtil {
34    /// <summary>
35    /// Transforms <paramref name="problemData"/> into a data structure as needed by libSVM.
36    /// </summary>
37    /// <param name="problemData">The problem data to transform</param>
38    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
39    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
40    public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
41      double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
42      svm_node[][] nodes = new svm_node[targetVector.Length][];
43      int maxNodeIndex = 0;
44      int svmProblemRowIndex = 0;
45      List<string> inputVariablesList = inputVariables.ToList();
46      foreach (int row in rowIndices) {
47        List<svm_node> tempRow = new List<svm_node>();
48        int colIndex = 1; // make sure the smallest node index for SVM = 1
49        foreach (var inputVariable in inputVariablesList) {
50          double value = dataset.GetDoubleValue(inputVariable, row);
51          // SVM also works with missing values
52          // => don't add NaN values in the dataset to the sparse SVM matrix representation
53          if (!double.IsNaN(value)) {
54            tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index
55            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
56          }
57          colIndex++;
58        }
59        nodes[svmProblemRowIndex++] = tempRow.ToArray();
60      }
61      return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes };
62    }
63
64    /// <summary>
65    /// Instantiate and return a svm_parameter object with default values.
66    /// </summary>
67    /// <returns>A svm_parameter object with default values</returns>
68    public static svm_parameter DefaultParameters() {
69      svm_parameter parameter = new svm_parameter();
70      parameter.svm_type = svm_parameter.NU_SVR;
71      parameter.kernel_type = svm_parameter.RBF;
72      parameter.C = 1;
73      parameter.nu = 0.5;
74      parameter.gamma = 1;
75      parameter.p = 1;
76      parameter.cache_size = 500;
77      parameter.probability = 0;
78      parameter.eps = 0.001;
79      parameter.degree = 3;
80      parameter.shrinking = 1;
81      parameter.coef0 = 0;
82
83      return parameter;
84    }
85
86    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
87      var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
88      CalculateCrossValidationPartitions(partitions, parameters, out avgTestMse);
89    }
90
91    public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
92      DoubleValue mse = new DoubleValue(Double.MaxValue);
93      var bestParam = DefaultParameters();
94      var crossProduct = parameterRanges.Values.CartesianProduct();
95      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
96      var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
97      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
98        var parameters = DefaultParameters();
99        var parameterValues = parameterCombination.ToList();
100        for (int i = 0; i < parameterValues.Count; ++i) {
101          setters[i](parameters, parameterValues[i]);
102        }
103        double testMse;
104        CalculateCrossValidationPartitions(partitions, parameters, out testMse);
105        if (testMse < mse.Value) {
106          lock (mse) {
107            mse.Value = testMse;
108            bestParam = (svm_parameter)parameters.Clone();
109          }
110        }
111      });
112      return bestParam;
113    }
114
115    private static void CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters, out double avgTestMse) {
116      avgTestMse = 0;
117      var calc = new OnlineMeanSquaredErrorCalculator();
118      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
119        var trainingSvmProblem = tuple.Item1;
120        var testSvmProblem = tuple.Item2;
121        var model = svm.svm_train(trainingSvmProblem, parameters);
122        calc.Reset();
123        for (int i = 0; i < testSvmProblem.l; ++i)
124          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
125        avgTestMse += calc.MeanSquaredError;
126      }
127      avgTestMse /= partitions.Length;
128    }
129
130
131    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
132      var folds = GenerateFolds(problemData, numberOfFolds).ToList();
133      var targetVariable = GetTargetVariableName(problemData);
134      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
135      for (int i = 0; i < numberOfFolds; ++i) {
136        int p = i; // avoid "access to modified closure" warning below
137        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
138        var testRows = folds[i];
139        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
140        var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
141        partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem);
142      }
143      return partitions;
144    }
145
146    /// <summary>
147    /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
148    /// </summary>
149    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
150    /// <param name="problemData">The problem data</param>
151    /// <param name="numberOfFolds">The number of folds to generate</param>
152    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
153    private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
154      int size = problemData.TrainingPartition.Size;
155      int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
156      int start = 0, end = f;
157      for (int i = 0; i < numberOfFolds; ++i) {
158        if (r > 0) { ++end; --r; }
159        yield return problemData.TrainingIndices.Skip(start).Take(end - start);
160        start = end;
161        end += f;
162      }
163    }
164
165    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
166      var targetExp = Expression.Parameter(typeof(svm_parameter));
167      var valueExp = Expression.Parameter(typeof(double));
168      var fieldExp = Expression.Field(targetExp, fieldName);
169      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
170      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
171      return setter;
172    }
173
174    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
175      var regressionProblemData = problemData as IRegressionProblemData;
176      var classificationProblemData = problemData as IClassificationProblemData;
177
178      if (regressionProblemData != null)
179        return regressionProblemData.TargetVariable;
180      if (classificationProblemData != null)
181        return classificationProblemData.TargetVariable;
182
183      throw new ArgumentException("Problem data is neither regression or classification problem data.");
184    }
185
186  }
187}
Note: See TracBrowser for help on using the repository browser.