Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/26/14 16:33:53 (10 years ago)
Author:
bburlacu
Message:

#2234: Implemented SVM grid search in SupportVectorMachineUtil.cs.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs

    r11171 r11308  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using System.Linq.Expressions;
     26using System.Threading.Tasks;
     27using HeuristicLab.Common;
     28using HeuristicLab.Data;
    2429using HeuristicLab.Problems.DataAnalysis;
    2530using LibSVM;
     
    6065      return new svm_problem() { l = targetVector.Length, y = targetVector, x = nodes };
    6166    }
     67
     68    /// <summary>
     69    /// Instantiate and return a svm_parameter object with default values.
     70    /// </summary>
     71    /// <returns>A svm_parameter object with default values</returns>
     72    public static svm_parameter DefaultParameters() {
     73      svm_parameter parameter = new svm_parameter();
     74      parameter.svm_type = svm_parameter.NU_SVR;
     75      parameter.kernel_type = svm_parameter.RBF;
     76      parameter.C = 1;
     77      parameter.nu = 0.5;
     78      parameter.gamma = 1;
     79      parameter.p = 1;
     80      parameter.cache_size = 500;
     81      parameter.probability = 0;
     82      parameter.eps = 0.001;
     83      parameter.degree = 3;
     84      parameter.shrinking = 1;
     85      parameter.coef0 = 0;
     86
     87      return parameter;
     88    }
     89
     90    /// <summary>
     91    /// Generate a collection of training indices corresponding to folds in the data (used for crossvalidation)
     92    /// </summary>
     93    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
     94    /// <param name="problemData">The problem data</param>
     95    /// <param name="nFolds">The number of folds to generate</param>
     96    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
     97    public static IEnumerable<IEnumerable<int>> GenerateFolds(IRegressionProblemData problemData, int nFolds) {
     98      int size = problemData.TrainingPartition.Size;
     99
     100      int foldSize = size / nFolds; // rounding to integer
     101      var trainingIndices = problemData.TrainingIndices;
     102
     103      for (int i = 0; i < nFolds; ++i) {
     104        int n = i * foldSize;
     105        int s = n + 2 * foldSize > size ? foldSize + size % foldSize : foldSize;
     106        yield return trainingIndices.Skip(n).Take(s);
     107      }
     108    }
     109
     110    /// <summary>
     111    /// Performs crossvalidation
     112    /// </summary>
     113    /// <param name="problemData">The problem data</param>
     114    /// <param name="parameters">The svm parameters</param>
     115    /// <param name="folds">The svm_problem instances for each fold</param>
     116    /// <param name="avgTestMSE">The average test mean squared error (not used atm)</param>
     117    public static void CrossValidate(IRegressionProblemData problemData, svm_parameter parameters, IEnumerable<IEnumerable<int>> folds, out double avgTestMSE) {
     118      avgTestMSE = 0;
     119
     120      var calc = new OnlineMeanSquaredErrorCalculator();
     121      var ds = problemData.Dataset;
     122      var targetVariable = problemData.TargetVariable;
     123      var inputVariables = problemData.AllowedInputVariables;
     124
     125      var svmProblem = CreateSvmProblem(ds, targetVariable, inputVariables, problemData.TrainingIndices);
     126      var partitions = folds.ToList();
     127
     128      for (int i = 0; i < partitions.Count; ++i) {
     129        var test = partitions[i];
     130        var training = new List<int>();
     131        for (int j = 0; j < i; ++j)
     132          training.AddRange(partitions[j]);
     133
     134        for (int j = i + 1; j < partitions.Count; ++j)
     135          training.AddRange(partitions[j]);
     136
     137        var p = CreateSvmProblem(ds, targetVariable, inputVariables, training);
     138        var model = svm.svm_train(p, parameters);
     139        calc.Reset();
     140        foreach (var row in test) {
     141          calc.Add(svmProblem.y[row], svm.svm_predict(model, svmProblem.x[row]));
     142        }
     143        double error = calc.MeanSquaredError;
     144        avgTestMSE += error;
     145      }
     146
     147      avgTestMSE /= partitions.Count;
     148    }
     149
     150    /// <summary>
     151    /// Dynamically generate a setter for svm_parameter fields
     152    /// </summary>
     153    /// <param name="parameters"></param>
     154    /// <param name="fieldName"></param>
     155    /// <returns></returns>
     156    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
     157      var targetExp = Expression.Parameter(typeof(svm_parameter));
     158      var valueExp = Expression.Parameter(typeof(double));
     159
     160      // Expression.Property can be used here as well
     161      var fieldExp = Expression.Field(targetExp, fieldName);
     162      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
     163      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
     164      return setter;
     165    }
     166
     167    public static svm_parameter GridSearch(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
     168      DoubleValue mse = new DoubleValue(Double.MaxValue);
     169      var bestParam = DefaultParameters();
     170
     171      // search for C, gamma and epsilon parameter combinations
     172
     173      var pNames = parameterRanges.Keys.ToList();
     174      var pRanges = pNames.Select(x => parameterRanges[x]);
     175
     176      var crossProduct = pRanges.CartesianProduct();
     177      var setters = pNames.Select(GenerateSetter).ToList();
     178      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
     179        //  foreach (var nuple in crossProduct) {
     180        var list = nuple.ToList();
     181        var parameters = DefaultParameters();
     182        for (int i = 0; i < pNames.Count; ++i) {
     183          var s = setters[i];
     184          s(parameters, list[i]);
     185        }
     186        double testMSE;
     187        CrossValidate(problemData, parameters, folds, out testMSE);
     188        if (testMSE < mse.Value) {
     189          lock (mse) { mse.Value = testMSE; }
     190          lock (bestParam) { // set best parameter values to the best found so far
     191            bestParam = (svm_parameter)parameters.Clone();
     192          }
     193        }
     194      });
     195      return bestParam;
     196    }
    62197  }
    63198}
Note: See TracChangeset for help on using the changeset viewer.