Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs @ 11361

Last change on this file since 11361 was 11361, checked in by bburlacu, 10 years ago

#2234: Added the option to shuffle the crossvalidation folds (this option is on by default since libsvm does it too). Implemented stratified fold generation for classification data (ensures similar label distribution in each fold).

File size: 11.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Linq.Expressions;
26using System.Threading.Tasks;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Random;
32using LibSVM;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  public class SupportVectorMachineUtil {
36    /// <summary>
37    /// Transforms <paramref name="problemData"/> into a data structure as needed by libSVM.
38    /// </summary>
39    /// <param name="problemData">The problem data to transform</param>
40    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
41    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
42    public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
43      double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
44      svm_node[][] nodes = new svm_node[targetVector.Length][];
45      int maxNodeIndex = 0;
46      int svmProblemRowIndex = 0;
47      List<string> inputVariablesList = inputVariables.ToList();
48      foreach (int row in rowIndices) {
49        List<svm_node> tempRow = new List<svm_node>();
50        int colIndex = 1; // make sure the smallest node index for SVM = 1
51        foreach (var inputVariable in inputVariablesList) {
52          double value = dataset.GetDoubleValue(inputVariable, row);
53          // SVM also works with missing values
54          // => don't add NaN values in the dataset to the sparse SVM matrix representation
55          if (!double.IsNaN(value)) {
56            tempRow.Add(new svm_node() { index = colIndex, value = value });
57            // nodes must be sorted in ascending ordered by column index
58            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
59          }
60          colIndex++;
61        }
62        nodes[svmProblemRowIndex++] = tempRow.ToArray();
63      }
64      return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes };
65    }
66
67    /// <summary>
68    /// Instantiate and return a svm_parameter object with default values.
69    /// </summary>
70    /// <returns>A svm_parameter object with default values</returns>
71    public static svm_parameter DefaultParameters() {
72      svm_parameter parameter = new svm_parameter();
73      parameter.svm_type = svm_parameter.NU_SVR;
74      parameter.kernel_type = svm_parameter.RBF;
75      parameter.C = 1;
76      parameter.nu = 0.5;
77      parameter.gamma = 1;
78      parameter.p = 1;
79      parameter.cache_size = 500;
80      parameter.probability = 0;
81      parameter.eps = 0.001;
82      parameter.degree = 3;
83      parameter.shrinking = 1;
84      parameter.coef0 = 0;
85
86      return parameter;
87    }
88
89    public static double CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, bool shuffleFolds = true) {
90      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
91      return CalculateCrossValidationPartitions(partitions, parameters);
92    }
93
94    public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int numberOfFolds, bool shuffleFolds = true, int maxDegreeOfParallelism = 1) {
95      DoubleValue mse = new DoubleValue(Double.MaxValue);
96      var bestParam = DefaultParameters();
97      var crossProduct = parameterRanges.Values.CartesianProduct();
98      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
99      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
100      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism },
101      parameterCombination => {
102        var parameters = DefaultParameters();
103        var parameterValues = parameterCombination.ToList();
104        for (int i = 0; i < parameterValues.Count; ++i)
105          setters[i](parameters, parameterValues[i]);
106
107        double testMse = CalculateCrossValidationPartitions(partitions, parameters);
108        if (testMse < mse.Value) {
109          lock (mse) {
110            mse.Value = testMse;
111            bestParam = (svm_parameter)parameters.Clone();
112          }
113        }
114      });
115      return bestParam;
116    }
117
118    private static double CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters) {
119      double avgTestMse = 0;
120      var calc = new OnlineMeanSquaredErrorCalculator();
121      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
122        var trainingSvmProblem = tuple.Item1;
123        var testSvmProblem = tuple.Item2;
124        var model = svm.svm_train(trainingSvmProblem, parameters);
125        calc.Reset();
126        for (int i = 0; i < testSvmProblem.l; ++i)
127          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
128        avgTestMse += calc.MeanSquaredError;
129      }
130      avgTestMse /= partitions.Length;
131      return avgTestMse;
132    }
133
134    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
135      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
136      var targetVariable = GetTargetVariableName(problemData);
137      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
138      for (int i = 0; i < numberOfFolds; ++i) {
139        int p = i; // avoid "access to modified closure" warning below
140        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
141        var testRows = folds[i];
142        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
143        var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
144        partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem);
145      }
146      return partitions;
147    }
148
149    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
150      var random = new MersenneTwister((uint)Environment.TickCount);
151      if (problemData is IRegressionProblemData) {
152        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
153        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
154      }
155      if (problemData is IClassificationProblemData) {
156        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
157        // otherwise, generate folds normally
158        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
159      }
160      throw new ArgumentException("Problem data is neither regression or classification problem data.");
161    }
162
163    /// <summary>
164    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
165    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
166    /// the corresponding parts from each class label.
167    /// </summary>
168    /// <param name="problemData">The classification problem data.</param>
169    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
170    /// <param name="random">The random generator used to shuffle the folds.</param>
171    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
172    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
173      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
174      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
175      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
176      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
177      while (enumerators.All(e => e.MoveNext())) {
178        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
179      }
180    }
181
182    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
183      // if number of folds is greater than the number of values, some empty folds will be returned
184      if (valuesCount < numberOfFolds) {
185        for (int i = 0; i < numberOfFolds; ++i)
186          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
187      } else {
188        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
189        int start = 0, end = f;
190        for (int i = 0; i < numberOfFolds; ++i) {
191          if (r > 0) {
192            ++end;
193            --r;
194          }
195          yield return values.Skip(start).Take(end - start);
196          start = end;
197          end += f;
198        }
199      }
200    }
201
202    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
203      var targetExp = Expression.Parameter(typeof(svm_parameter));
204      var valueExp = Expression.Parameter(typeof(double));
205      var fieldExp = Expression.Field(targetExp, fieldName);
206      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
207      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
208      return setter;
209    }
210
211    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
212      var regressionProblemData = problemData as IRegressionProblemData;
213      var classificationProblemData = problemData as IClassificationProblemData;
214
215      if (regressionProblemData != null)
216        return regressionProblemData.TargetVariable;
217      if (classificationProblemData != null)
218        return classificationProblemData.TargetVariable;
219
220      throw new ArgumentException("Problem data is neither regression or classification problem data.");
221    }
222  }
223}
Note: See TracBrowser for help on using the repository browser.