Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs

Last change on this file was 17180, checked in by swagner, 5 years ago

#2875: Removed years in copyrights

File size: 13.0 KB
RevLine 
[5624]1#region License Information
2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5624]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[11308]22using System;
[5624]23using System.Collections.Generic;
24using System.Linq;
[11308]25using System.Linq.Expressions;
26using System.Threading.Tasks;
27using HeuristicLab.Common;
[11361]28using HeuristicLab.Core;
[11308]29using HeuristicLab.Data;
[5624]30using HeuristicLab.Problems.DataAnalysis;
[11361]31using HeuristicLab.Random;
[8609]32using LibSVM;
[5624]33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  public class SupportVectorMachineUtil {
36    /// <summary>
[15854]37    /// Transforms <paramref name="dataset"/> into a data structure as needed by libSVM.
[5624]38    /// </summary>
[15854]39    /// <param name="dataset">The source dataset</param>
40    /// <param name="targetVariable">The target variable</param>
41    /// <param name="inputVariables">The selected input variables to include in the svm_problem.</param>
[5624]42    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
43    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
[12509]44    public static svm_problem CreateSvmProblem(IDataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
[15854]45      double[] targetVector ;
46      var nRows = rowIndices.Count();
47      if (string.IsNullOrEmpty(targetVariable)) {
48        // if the target variable is not set (e.g. for prediction of a trained model) we just use a zero vector
49        targetVector = new double[nRows];
50      } else {
51        targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
52      }
53      svm_node[][] nodes = new svm_node[nRows][];
[5624]54      int maxNodeIndex = 0;
55      int svmProblemRowIndex = 0;
[6002]56      List<string> inputVariablesList = inputVariables.ToList();
[5624]57      foreach (int row in rowIndices) {
[11337]58        List<svm_node> tempRow = new List<svm_node>();
[6002]59        int colIndex = 1; // make sure the smallest node index for SVM = 1
60        foreach (var inputVariable in inputVariablesList) {
[6740]61          double value = dataset.GetDoubleValue(inputVariable, row);
[6002]62          // SVM also works with missing values
63          // => don't add NaN values in the dataset to the sparse SVM matrix representation
[5624]64          if (!double.IsNaN(value)) {
[11361]65            tempRow.Add(new svm_node() { index = colIndex, value = value });
66            // nodes must be sorted in ascending ordered by column index
[6002]67            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
[5624]68          }
[6002]69          colIndex++;
[5624]70        }
[6002]71        nodes[svmProblemRowIndex++] = tempRow.ToArray();
[5624]72      }
[11337]73      return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes };
[5624]74    }
[11308]75
76    /// <summary>
[15854]77    /// Transforms <paramref name="dataset"/> into a data structure as needed by libSVM for prediction.
78    /// </summary>
79    /// <param name="dataset">The problem data to transform</param>
80    /// <param name="inputVariables">The selected input variables to include in the svm_problem.</param>
81    /// <param name="rowIndices">The rows of the dataset that should be contained in the resulting SVM-problem</param>
82    /// <returns>A problem data type that can be used for prediction with a trained support vector machine.</returns>
83    public static svm_problem CreateSvmProblem(IDataset dataset, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
84      // for prediction we don't need a target variable
85      return CreateSvmProblem(dataset, string.Empty, inputVariables, rowIndices);
86    }
87
88    /// <summary>
[11308]89    /// Instantiate and return a svm_parameter object with default values.
90    /// </summary>
91    /// <returns>A svm_parameter object with default values</returns>
92    public static svm_parameter DefaultParameters() {
93      svm_parameter parameter = new svm_parameter();
94      parameter.svm_type = svm_parameter.NU_SVR;
95      parameter.kernel_type = svm_parameter.RBF;
96      parameter.C = 1;
97      parameter.nu = 0.5;
98      parameter.gamma = 1;
99      parameter.p = 1;
100      parameter.cache_size = 500;
101      parameter.probability = 0;
102      parameter.eps = 0.001;
103      parameter.degree = 3;
104      parameter.shrinking = 1;
105      parameter.coef0 = 0;
106
107      return parameter;
108    }
109
[11361]110    public static double CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, bool shuffleFolds = true) {
111      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
112      return CalculateCrossValidationPartitions(partitions, parameters);
[11339]113    }
114
[11542]115    public static svm_parameter GridSearch(out double cvMse, IDataAnalysisProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int numberOfFolds, bool shuffleFolds = true, int maxDegreeOfParallelism = 1) {
[11339]116      DoubleValue mse = new DoubleValue(Double.MaxValue);
117      var bestParam = DefaultParameters();
118      var crossProduct = parameterRanges.Values.CartesianProduct();
119      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
[11361]120      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
[11464]121
122      var locker = new object(); // for thread synchronization
[11361]123      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism },
124      parameterCombination => {
[11339]125        var parameters = DefaultParameters();
126        var parameterValues = parameterCombination.ToList();
[11361]127        for (int i = 0; i < parameterValues.Count; ++i)
[11339]128          setters[i](parameters, parameterValues[i]);
[11361]129
130        double testMse = CalculateCrossValidationPartitions(partitions, parameters);
[11542]131        if (!double.IsNaN(testMse)) {
132          lock (locker) {
133            if (testMse < mse.Value) {
134              mse.Value = testMse;
135              bestParam = (svm_parameter)parameters.Clone();
136            }
[11342]137          }
[11339]138        }
139      });
[11542]140      cvMse = mse.Value;
[11339]141      return bestParam;
142    }
143
[11361]144    private static double CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters) {
145      double avgTestMse = 0;
[11339]146      var calc = new OnlineMeanSquaredErrorCalculator();
147      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
148        var trainingSvmProblem = tuple.Item1;
149        var testSvmProblem = tuple.Item2;
150        var model = svm.svm_train(trainingSvmProblem, parameters);
151        calc.Reset();
152        for (int i = 0; i < testSvmProblem.l; ++i)
153          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
[11542]154        double mse = calc.ErrorState == OnlineCalculatorError.None ? calc.MeanSquaredError : double.NaN;
155        avgTestMse += mse;
[11308]156      }
[11339]157      avgTestMse /= partitions.Length;
[11361]158      return avgTestMse;
[11308]159    }
160
[11361]161    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
162      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
[11326]163      var targetVariable = GetTargetVariableName(problemData);
[11337]164      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
165      for (int i = 0; i < numberOfFolds; ++i) {
[11326]166        int p = i; // avoid "access to modified closure" warning below
[11337]167        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
[11326]168        var testRows = folds[i];
[11337]169        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
[11464]170        var rangeTransform = RangeTransform.Compute(trainingSvmProblem);
171        var testSvmProblem = rangeTransform.Scale(CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows));
172        partitions[i] = new Tuple<svm_problem, svm_problem>(rangeTransform.Scale(trainingSvmProblem), testSvmProblem);
[11326]173      }
[11337]174      return partitions;
[11326]175    }
[11308]176
[11361]177    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
178      var random = new MersenneTwister((uint)Environment.TickCount);
179      if (problemData is IRegressionProblemData) {
180        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
181        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
182      }
183      if (problemData is IClassificationProblemData) {
184        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
185        // otherwise, generate folds normally
186        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
187      }
188      throw new ArgumentException("Problem data is neither regression or classification problem data.");
189    }
190
[11339]191    /// <summary>
[11361]192    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
193    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
194    /// the corresponding parts from each class label.
[11339]195    /// </summary>
[11361]196    /// <param name="problemData">The classification problem data.</param>
197    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
198    /// <param name="random">The random generator used to shuffle the folds.</param>
199    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
200    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
201      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
202      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
203      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
204      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
205      while (enumerators.All(e => e.MoveNext())) {
206        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
[11308]207      }
208    }
209
[11361]210    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
211      // if number of folds is greater than the number of values, some empty folds will be returned
212      if (valuesCount < numberOfFolds) {
213        for (int i = 0; i < numberOfFolds; ++i)
214          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
215      } else {
216        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
217        int start = 0, end = f;
218        for (int i = 0; i < numberOfFolds; ++i) {
219          if (r > 0) {
220            ++end;
221            --r;
222          }
223          yield return values.Skip(start).Take(end - start);
224          start = end;
225          end += f;
226        }
227      }
228    }
229
[11308]230    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
231      var targetExp = Expression.Parameter(typeof(svm_parameter));
232      var valueExp = Expression.Parameter(typeof(double));
233      var fieldExp = Expression.Field(targetExp, fieldName);
234      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
235      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
236      return setter;
237    }
238
[11326]239    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
240      var regressionProblemData = problemData as IRegressionProblemData;
241      var classificationProblemData = problemData as IClassificationProblemData;
242
243      if (regressionProblemData != null)
244        return regressionProblemData.TargetVariable;
245      if (classificationProblemData != null)
246        return classificationProblemData.TargetVariable;
247
248      throw new ArgumentException("Problem data is neither regression or classification problem data.");
249    }
[5624]250  }
251}
Note: See TracBrowser for help on using the repository browser.