#region License Information
/* HeuristicLab
* Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using HeuristicLab.Common;
using HeuristicLab.Data;
using HeuristicLab.Problems.DataAnalysis;
using LibSVM;
namespace HeuristicLab.Algorithms.DataAnalysis {
public class SupportVectorMachineUtil {
///
/// Transforms into a data structure as needed by libSVM.
///
/// The problem data to transform
/// The rows of the dataset that should be contained in the resulting SVM-problem
/// A problem data type that can be used to train a support vector machine.
public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable inputVariables, IEnumerable rowIndices) {
double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
svm_node[][] nodes = new svm_node[targetVector.Length][];
int maxNodeIndex = 0;
int svmProblemRowIndex = 0;
List inputVariablesList = inputVariables.ToList();
foreach (int row in rowIndices) {
List tempRow = new List();
int colIndex = 1; // make sure the smallest node index for SVM = 1
foreach (var inputVariable in inputVariablesList) {
double value = dataset.GetDoubleValue(inputVariable, row);
// SVM also works with missing values
// => don't add NaN values in the dataset to the sparse SVM matrix representation
if (!double.IsNaN(value)) {
tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index
if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
}
colIndex++;
}
nodes[svmProblemRowIndex++] = tempRow.ToArray();
}
return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes };
}
///
/// Instantiate and return a svm_parameter object with default values.
///
/// A svm_parameter object with default values
public static svm_parameter DefaultParameters() {
svm_parameter parameter = new svm_parameter();
parameter.svm_type = svm_parameter.NU_SVR;
parameter.kernel_type = svm_parameter.RBF;
parameter.C = 1;
parameter.nu = 0.5;
parameter.gamma = 1;
parameter.p = 1;
parameter.cache_size = 500;
parameter.probability = 0;
parameter.eps = 0.001;
parameter.degree = 3;
parameter.shrinking = 1;
parameter.coef0 = 0;
return parameter;
}
public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
CalculateCrossValidationPartitions(partitions, parameters, out avgTestMse);
}
public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary> parameterRanges, int maxDegreeOfParallelism = 1) {
DoubleValue mse = new DoubleValue(Double.MaxValue);
var bestParam = DefaultParameters();
var crossProduct = parameterRanges.Values.CartesianProduct();
var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
var parameters = DefaultParameters();
var parameterValues = parameterCombination.ToList();
for (int i = 0; i < parameterValues.Count; ++i) {
setters[i](parameters, parameterValues[i]);
}
double testMse;
CalculateCrossValidationPartitions(partitions, parameters, out testMse);
if (testMse < mse.Value) {
lock (mse) { mse.Value = testMse; }
lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); }
}
});
return bestParam;
}
private static void CalculateCrossValidationPartitions(Tuple[] partitions, svm_parameter parameters, out double avgTestMse) {
avgTestMse = 0;
var calc = new OnlineMeanSquaredErrorCalculator();
foreach (Tuple tuple in partitions) {
var trainingSvmProblem = tuple.Item1;
var testSvmProblem = tuple.Item2;
var model = svm.svm_train(trainingSvmProblem, parameters);
calc.Reset();
for (int i = 0; i < testSvmProblem.l; ++i)
calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
avgTestMse += calc.MeanSquaredError;
}
avgTestMse /= partitions.Length;
}
private static Tuple[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
var folds = GenerateFolds(problemData, numberOfFolds).ToList();
var targetVariable = GetTargetVariableName(problemData);
var partitions = new Tuple[numberOfFolds];
for (int i = 0; i < numberOfFolds; ++i) {
int p = i; // avoid "access to modified closure" warning below
var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty());
var testRows = folds[i];
var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
partitions[i] = new Tuple(trainingSvmProblem, testSvmProblem);
}
return partitions;
}
///
/// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
///
/// This method is aimed to be lightweight and as such does not clone the dataset.
/// The problem data
/// The number of folds to generate
/// A sequence of folds representing each a sequence of row numbers
private static IEnumerable> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
int size = problemData.TrainingPartition.Size;
int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
int start = 0, end = f;
for (int i = 0; i < numberOfFolds; ++i) {
if (r > 0) { ++end; --r; }
yield return problemData.TrainingIndices.Skip(start).Take(end - start);
start = end;
end += f;
}
}
private static Action GenerateSetter(string fieldName) {
var targetExp = Expression.Parameter(typeof(svm_parameter));
var valueExp = Expression.Parameter(typeof(double));
var fieldExp = Expression.Field(targetExp, fieldName);
var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
var setter = Expression.Lambda>(assignExp, targetExp, valueExp).Compile();
return setter;
}
private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
var regressionProblemData = problemData as IRegressionProblemData;
var classificationProblemData = problemData as IClassificationProblemData;
if (regressionProblemData != null)
return regressionProblemData.TargetVariable;
if (classificationProblemData != null)
return classificationProblemData.TargetVariable;
throw new ArgumentException("Problem data is neither regression or classification problem data.");
}
}
}