source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs @ 11326

Last change on this file since 11326 was 11326, checked in by bburlacu, 6 years ago

#2234: Refactored CrossValidate and GridSearch methods.

File size: 9.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Linq.Expressions;
26using System.Threading.Tasks;
27using HeuristicLab.Common;
28using HeuristicLab.Data;
29using HeuristicLab.Problems.DataAnalysis;
30using LibSVM;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  public class SupportVectorMachineUtil {
34    /// <summary>
35    /// Transforms <paramref name="problemData"/> into a data structure as needed by libSVM.
36    /// </summary>
37    /// <param name="problemData">The problem data to transform</param>
38    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
39    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
40    public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
41      double[] targetVector =
42        dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
43
44      svm_node[][] nodes = new svm_node[targetVector.Length][];
45      List<svm_node> tempRow;
46      int maxNodeIndex = 0;
47      int svmProblemRowIndex = 0;
48      List<string> inputVariablesList = inputVariables.ToList();
49      foreach (int row in rowIndices) {
50        tempRow = new List<svm_node>();
51        int colIndex = 1; // make sure the smallest node index for SVM = 1
52        foreach (var inputVariable in inputVariablesList) {
53          double value = dataset.GetDoubleValue(inputVariable, row);
54          // SVM also works with missing values
55          // => don't add NaN values in the dataset to the sparse SVM matrix representation
56          if (!double.IsNaN(value)) {
57            tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index
58            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
59          }
60          colIndex++;
61        }
62        nodes[svmProblemRowIndex++] = tempRow.ToArray();
63      }
64
65      return new svm_problem() { l = targetVector.Length, y = targetVector, x = nodes };
66    }
67
68    /// <summary>
69    /// Instantiate and return a svm_parameter object with default values.
70    /// </summary>
71    /// <returns>A svm_parameter object with default values</returns>
72    public static svm_parameter DefaultParameters() {
73      svm_parameter parameter = new svm_parameter();
74      parameter.svm_type = svm_parameter.NU_SVR;
75      parameter.kernel_type = svm_parameter.RBF;
76      parameter.C = 1;
77      parameter.nu = 0.5;
78      parameter.gamma = 1;
79      parameter.p = 1;
80      parameter.cache_size = 500;
81      parameter.probability = 0;
82      parameter.eps = 0.001;
83      parameter.degree = 3;
84      parameter.shrinking = 1;
85      parameter.coef0 = 0;
86
87      return parameter;
88    }
89
90    /// <summary>
91    /// Generate a collection of training indices corresponding to folds in the data (used for crossvalidation)
92    /// </summary>
93    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
94    /// <param name="problemData">The problem data</param>
95    /// <param name="nFolds">The number of folds to generate</param>
96    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
97    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int nFolds) {
98      int size = problemData.TrainingPartition.Size;
99
100      int foldSize = size / nFolds; // rounding to integer
101      var trainingIndices = problemData.TrainingIndices;
102
103      for (int i = 0; i < nFolds; ++i) {
104        int n = i * foldSize;
105        int s = n + 2 * foldSize > size ? foldSize + size % foldSize : foldSize;
106        yield return trainingIndices.Skip(n).Take(s);
107      }
108    }
109
110    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numFolds, out double avgTestMse) {
111      avgTestMse = 0;
112      var folds = GenerateFolds(problemData, numFolds).ToList();
113      var calc = new OnlineMeanSquaredErrorCalculator();
114      var targetVariable = GetTargetVariableName(problemData);
115      for (int i = 0; i < numFolds; ++i) {
116        int p = i; // avoid "access to modified closure" warning below
117        var training = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
118        var testRows = folds[i];
119        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, training);
120        var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
121
122        var model = svm.svm_train(trainingSvmProblem, parameters);
123        calc.Reset();
124        for (int j = 0; j < testSvmProblem.l; ++j)
125          calc.Add(testSvmProblem.y[j], svm.svm_predict(model, testSvmProblem.x[j]));
126        avgTestMse += calc.MeanSquaredError;
127      }
128      avgTestMse /= numFolds;
129    }
130
131    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, Tuple<svm_problem, svm_problem>[] partitions, out double avgTestMse) {
132      avgTestMse = 0;
133      var calc = new OnlineMeanSquaredErrorCalculator();
134      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
135        var trainingSvmProblem = tuple.Item1;
136        var testSvmProblem = tuple.Item2;
137        var model = svm.svm_train(trainingSvmProblem, parameters);
138        calc.Reset();
139        for (int i = 0; i < testSvmProblem.l; ++i)
140          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
141        avgTestMse += calc.MeanSquaredError;
142      }
143      avgTestMse /= partitions.Length;
144    }
145
146    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
147      var targetExp = Expression.Parameter(typeof(svm_parameter));
148      var valueExp = Expression.Parameter(typeof(double));
149      var fieldExp = Expression.Field(targetExp, fieldName);
150      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
151      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
152      return setter;
153    }
154
155    public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
156      DoubleValue mse = new DoubleValue(Double.MaxValue);
157      var bestParam = DefaultParameters();
158
159      // search for C, gamma and epsilon parameter combinations
160      var pNames = parameterRanges.Keys.ToList();
161      var pRanges = pNames.Select(x => parameterRanges[x]);
162
163      var crossProduct = pRanges.CartesianProduct();
164      var setters = pNames.Select(GenerateSetter).ToList();
165      var folds = GenerateFolds(problemData, numberOfFolds).ToList();
166
167      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
168      var targetVariable = GetTargetVariableName(problemData);
169
170      for (int i = 0; i < numberOfFolds; ++i) {
171        int p = i; // avoid "access to modified closure" warning below
172        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
173        var testRows = folds[i];
174        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
175        var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
176        partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem);
177      }
178
179      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
180        //  foreach (var nuple in crossProduct) {
181        var list = nuple.ToList();
182        var parameters = DefaultParameters();
183        for (int i = 0; i < pNames.Count; ++i) {
184          var s = setters[i];
185          s(parameters, list[i]);
186        }
187        double testMse;
188        CrossValidate(problemData, parameters, partitions, out testMse);
189        if (testMse < mse.Value) {
190          lock (mse) { mse.Value = testMse; }
191          lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); } // set best parameter values to the best found so far
192        }
193      });
194      return bestParam;
195    }
196
197    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
198      var regressionProblemData = problemData as IRegressionProblemData;
199      var classificationProblemData = problemData as IClassificationProblemData;
200
201      if (regressionProblemData != null)
202        return regressionProblemData.TargetVariable;
203      if (classificationProblemData != null)
204        return classificationProblemData.TargetVariable;
205
206      throw new ArgumentException("Problem data is neither regression or classification problem data.");
207    }
208  }
209}
Note: See TracBrowser for help on using the repository browser.