source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs @ 11308

Last change on this file since 11308 was 11308, checked in by bburlacu, 6 years ago

#2234: Implemented SVM grid search in SupportVectorMachineUtil.cs.

File size: 8.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Linq.Expressions;
26using System.Threading.Tasks;
27using HeuristicLab.Common;
28using HeuristicLab.Data;
29using HeuristicLab.Problems.DataAnalysis;
30using LibSVM;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  public class SupportVectorMachineUtil {
34    /// <summary>
35    /// Transforms <paramref name="problemData"/> into a data structure as needed by libSVM.
36    /// </summary>
37    /// <param name="problemData">The problem data to transform</param>
38    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
39    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
40    public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
41      double[] targetVector =
42        dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
43
44      svm_node[][] nodes = new svm_node[targetVector.Length][];
45      List<svm_node> tempRow;
46      int maxNodeIndex = 0;
47      int svmProblemRowIndex = 0;
48      List<string> inputVariablesList = inputVariables.ToList();
49      foreach (int row in rowIndices) {
50        tempRow = new List<svm_node>();
51        int colIndex = 1; // make sure the smallest node index for SVM = 1
52        foreach (var inputVariable in inputVariablesList) {
53          double value = dataset.GetDoubleValue(inputVariable, row);
54          // SVM also works with missing values
55          // => don't add NaN values in the dataset to the sparse SVM matrix representation
56          if (!double.IsNaN(value)) {
57            tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index
58            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
59          }
60          colIndex++;
61        }
62        nodes[svmProblemRowIndex++] = tempRow.ToArray();
63      }
64
65      return new svm_problem() { l = targetVector.Length, y = targetVector, x = nodes };
66    }
67
68    /// <summary>
69    /// Instantiate and return a svm_parameter object with default values.
70    /// </summary>
71    /// <returns>A svm_parameter object with default values</returns>
72    public static svm_parameter DefaultParameters() {
73      svm_parameter parameter = new svm_parameter();
74      parameter.svm_type = svm_parameter.NU_SVR;
75      parameter.kernel_type = svm_parameter.RBF;
76      parameter.C = 1;
77      parameter.nu = 0.5;
78      parameter.gamma = 1;
79      parameter.p = 1;
80      parameter.cache_size = 500;
81      parameter.probability = 0;
82      parameter.eps = 0.001;
83      parameter.degree = 3;
84      parameter.shrinking = 1;
85      parameter.coef0 = 0;
86
87      return parameter;
88    }
89
90    /// <summary>
91    /// Generate a collection of training indices corresponding to folds in the data (used for crossvalidation)
92    /// </summary>
93    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
94    /// <param name="problemData">The problem data</param>
95    /// <param name="nFolds">The number of folds to generate</param>
96    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
97    public static IEnumerable<IEnumerable<int>> GenerateFolds(IRegressionProblemData problemData, int nFolds) {
98      int size = problemData.TrainingPartition.Size;
99
100      int foldSize = size / nFolds; // rounding to integer
101      var trainingIndices = problemData.TrainingIndices;
102
103      for (int i = 0; i < nFolds; ++i) {
104        int n = i * foldSize;
105        int s = n + 2 * foldSize > size ? foldSize + size % foldSize : foldSize;
106        yield return trainingIndices.Skip(n).Take(s);
107      }
108    }
109
110    /// <summary>
111    /// Performs crossvalidation
112    /// </summary>
113    /// <param name="problemData">The problem data</param>
114    /// <param name="parameters">The svm parameters</param>
115    /// <param name="folds">The svm_problem instances for each fold</param>
116    /// <param name="avgTestMSE">The average test mean squared error (not used atm)</param>
117    public static void CrossValidate(IRegressionProblemData problemData, svm_parameter parameters, IEnumerable<IEnumerable<int>> folds, out double avgTestMSE) {
118      avgTestMSE = 0;
119
120      var calc = new OnlineMeanSquaredErrorCalculator();
121      var ds = problemData.Dataset;
122      var targetVariable = problemData.TargetVariable;
123      var inputVariables = problemData.AllowedInputVariables;
124
125      var svmProblem = CreateSvmProblem(ds, targetVariable, inputVariables, problemData.TrainingIndices);
126      var partitions = folds.ToList();
127
128      for (int i = 0; i < partitions.Count; ++i) {
129        var test = partitions[i];
130        var training = new List<int>();
131        for (int j = 0; j < i; ++j)
132          training.AddRange(partitions[j]);
133
134        for (int j = i + 1; j < partitions.Count; ++j)
135          training.AddRange(partitions[j]);
136
137        var p = CreateSvmProblem(ds, targetVariable, inputVariables, training);
138        var model = svm.svm_train(p, parameters);
139        calc.Reset();
140        foreach (var row in test) {
141          calc.Add(svmProblem.y[row], svm.svm_predict(model, svmProblem.x[row]));
142        }
143        double error = calc.MeanSquaredError;
144        avgTestMSE += error;
145      }
146
147      avgTestMSE /= partitions.Count;
148    }
149
150    /// <summary>
151    /// Dynamically generate a setter for svm_parameter fields
152    /// </summary>
153    /// <param name="parameters"></param>
154    /// <param name="fieldName"></param>
155    /// <returns></returns>
156    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
157      var targetExp = Expression.Parameter(typeof(svm_parameter));
158      var valueExp = Expression.Parameter(typeof(double));
159
160      // Expression.Property can be used here as well
161      var fieldExp = Expression.Field(targetExp, fieldName);
162      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
163      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
164      return setter;
165    }
166
167    public static svm_parameter GridSearch(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
168      DoubleValue mse = new DoubleValue(Double.MaxValue);
169      var bestParam = DefaultParameters();
170
171      // search for C, gamma and epsilon parameter combinations
172
173      var pNames = parameterRanges.Keys.ToList();
174      var pRanges = pNames.Select(x => parameterRanges[x]);
175
176      var crossProduct = pRanges.CartesianProduct();
177      var setters = pNames.Select(GenerateSetter).ToList();
178      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
179        //  foreach (var nuple in crossProduct) {
180        var list = nuple.ToList();
181        var parameters = DefaultParameters();
182        for (int i = 0; i < pNames.Count; ++i) {
183          var s = setters[i];
184          s(parameters, list[i]);
185        }
186        double testMSE;
187        CrossValidate(problemData, parameters, folds, out testMSE);
188        if (testMSE < mse.Value) {
189          lock (mse) { mse.Value = testMSE; }
190          lock (bestParam) { // set best parameter values to the best found so far
191            bestParam = (svm_parameter)parameters.Clone();
192          }
193        }
194      });
195      return bestParam;
196    }
197  }
198}
Note: See TracBrowser for help on using the repository browser.