source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs @ 11464

Last change on this file since 11464 was 11464, checked in by bburlacu, 6 years ago

#2234: Moved lock object inside the GridSearch method. Added scaling for the svm partitions.

File size: 11.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Linq.Expressions;
26using System.Threading.Tasks;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Random;
32using LibSVM;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  public class SupportVectorMachineUtil {
36    /// <summary>
37    /// Transforms <paramref name="problemData"/> into a data structure as needed by libSVM.
38    /// </summary>
39    /// <param name="problemData">The problem data to transform</param>
40    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
41    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
42    public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
43      double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
44      svm_node[][] nodes = new svm_node[targetVector.Length][];
45      int maxNodeIndex = 0;
46      int svmProblemRowIndex = 0;
47      List<string> inputVariablesList = inputVariables.ToList();
48      foreach (int row in rowIndices) {
49        List<svm_node> tempRow = new List<svm_node>();
50        int colIndex = 1; // make sure the smallest node index for SVM = 1
51        foreach (var inputVariable in inputVariablesList) {
52          double value = dataset.GetDoubleValue(inputVariable, row);
53          // SVM also works with missing values
54          // => don't add NaN values in the dataset to the sparse SVM matrix representation
55          if (!double.IsNaN(value)) {
56            tempRow.Add(new svm_node() { index = colIndex, value = value });
57            // nodes must be sorted in ascending ordered by column index
58            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
59          }
60          colIndex++;
61        }
62        nodes[svmProblemRowIndex++] = tempRow.ToArray();
63      }
64      return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes };
65    }
66
67    /// <summary>
68    /// Instantiate and return a svm_parameter object with default values.
69    /// </summary>
70    /// <returns>A svm_parameter object with default values</returns>
71    public static svm_parameter DefaultParameters() {
72      svm_parameter parameter = new svm_parameter();
73      parameter.svm_type = svm_parameter.NU_SVR;
74      parameter.kernel_type = svm_parameter.RBF;
75      parameter.C = 1;
76      parameter.nu = 0.5;
77      parameter.gamma = 1;
78      parameter.p = 1;
79      parameter.cache_size = 500;
80      parameter.probability = 0;
81      parameter.eps = 0.001;
82      parameter.degree = 3;
83      parameter.shrinking = 1;
84      parameter.coef0 = 0;
85
86      return parameter;
87    }
88
89    public static double CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, bool shuffleFolds = true) {
90      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
91      return CalculateCrossValidationPartitions(partitions, parameters);
92    }
93
94    public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int numberOfFolds, bool shuffleFolds = true, int maxDegreeOfParallelism = 1) {
95      DoubleValue mse = new DoubleValue(Double.MaxValue);
96      var bestParam = DefaultParameters();
97      var crossProduct = parameterRanges.Values.CartesianProduct();
98      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
99      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
100
101      var locker = new object(); // for thread synchronization
102      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism },
103      parameterCombination => {
104        var parameters = DefaultParameters();
105        var parameterValues = parameterCombination.ToList();
106        for (int i = 0; i < parameterValues.Count; ++i)
107          setters[i](parameters, parameterValues[i]);
108
109        double testMse = CalculateCrossValidationPartitions(partitions, parameters);
110        lock (locker) {
111          if (testMse < mse.Value) {
112            mse.Value = testMse;
113            bestParam = (svm_parameter)parameters.Clone();
114          }
115        }
116      });
117      return bestParam;
118    }
119
120    private static double CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters) {
121      double avgTestMse = 0;
122      var calc = new OnlineMeanSquaredErrorCalculator();
123      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
124        var trainingSvmProblem = tuple.Item1;
125        var testSvmProblem = tuple.Item2;
126        var model = svm.svm_train(trainingSvmProblem, parameters);
127        calc.Reset();
128        for (int i = 0; i < testSvmProblem.l; ++i)
129          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
130        avgTestMse += calc.MeanSquaredError;
131      }
132      avgTestMse /= partitions.Length;
133      return avgTestMse;
134    }
135
136    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
137      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
138      var targetVariable = GetTargetVariableName(problemData);
139      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
140      for (int i = 0; i < numberOfFolds; ++i) {
141        int p = i; // avoid "access to modified closure" warning below
142        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
143        var testRows = folds[i];
144        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
145        var rangeTransform = RangeTransform.Compute(trainingSvmProblem);
146        var testSvmProblem = rangeTransform.Scale(CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows));
147        partitions[i] = new Tuple<svm_problem, svm_problem>(rangeTransform.Scale(trainingSvmProblem), testSvmProblem);
148      }
149      return partitions;
150    }
151
152    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
153      var random = new MersenneTwister((uint)Environment.TickCount);
154      if (problemData is IRegressionProblemData) {
155        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
156        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
157      }
158      if (problemData is IClassificationProblemData) {
159        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
160        // otherwise, generate folds normally
161        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
162      }
163      throw new ArgumentException("Problem data is neither regression or classification problem data.");
164    }
165
166    /// <summary>
167    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
168    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
169    /// the corresponding parts from each class label.
170    /// </summary>
171    /// <param name="problemData">The classification problem data.</param>
172    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
173    /// <param name="random">The random generator used to shuffle the folds.</param>
174    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
175    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
176      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
177      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
178      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
179      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
180      while (enumerators.All(e => e.MoveNext())) {
181        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
182      }
183    }
184
185    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
186      // if number of folds is greater than the number of values, some empty folds will be returned
187      if (valuesCount < numberOfFolds) {
188        for (int i = 0; i < numberOfFolds; ++i)
189          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
190      } else {
191        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
192        int start = 0, end = f;
193        for (int i = 0; i < numberOfFolds; ++i) {
194          if (r > 0) {
195            ++end;
196            --r;
197          }
198          yield return values.Skip(start).Take(end - start);
199          start = end;
200          end += f;
201        }
202      }
203    }
204
205    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
206      var targetExp = Expression.Parameter(typeof(svm_parameter));
207      var valueExp = Expression.Parameter(typeof(double));
208      var fieldExp = Expression.Field(targetExp, fieldName);
209      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
210      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
211      return setter;
212    }
213
214    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
215      var regressionProblemData = problemData as IRegressionProblemData;
216      var classificationProblemData = problemData as IClassificationProblemData;
217
218      if (regressionProblemData != null)
219        return regressionProblemData.TargetVariable;
220      if (classificationProblemData != null)
221        return classificationProblemData.TargetVariable;
222
223      throw new ArgumentException("Problem data is neither regression or classification problem data.");
224    }
225  }
226}
Note: See TracBrowser for help on using the repository browser.