Changeset 11339


Ignore:
Timestamp:
09/04/14 14:16:37 (8 years ago)
Author:
mkommend
Message:

#2237: Minor code changes in SVMUtil to perform cross validation (code reorganization, naming).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs

    r11337 r11339  
    8484    }
    8585
    86     /// <summary>
    87     /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
    88     /// </summary>
    89     /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
    90     /// <param name="problemData">The problem data</param>
    91     /// <param name="numberOfFolds">The number of folds to generate</param>
    92     /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
    93     public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
    94       int size = problemData.TrainingPartition.Size;
    95       int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
    96       int start = 0, end = f;
    97       for (int i = 0; i < numberOfFolds; ++i) {
    98         if (r > 0) { ++end; --r; }
    99         yield return problemData.TrainingIndices.Skip(start).Take(end - start);
    100         start = end;
    101         end += f;
     86    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
     87      var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
     88      CalculateCrossValidationPartitions(partitions, parameters, out avgTestMse);
     89    }
     90
     91    public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
     92      DoubleValue mse = new DoubleValue(Double.MaxValue);
     93      var bestParam = DefaultParameters();
     94      var crossProduct = parameterRanges.Values.CartesianProduct();
     95      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
     96      var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
     97      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     98        var parameters = DefaultParameters();
     99        var parameterValues = parameterCombination.ToList();
     100        for (int i = 0; i < parameterValues.Count; ++i) {
     101          setters[i](parameters, parameterValues[i]);
     102        }
     103        double testMse;
     104        CalculateCrossValidationPartitions(partitions, parameters, out testMse);
     105        if (testMse < mse.Value) {
     106          lock (mse) { mse.Value = testMse; }
     107          lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); }
     108        }
     109      });
     110      return bestParam;
     111    }
     112
     113    private static void CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters, out double avgTestMse) {
     114      avgTestMse = 0;
     115      var calc = new OnlineMeanSquaredErrorCalculator();
     116      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
     117        var trainingSvmProblem = tuple.Item1;
     118        var testSvmProblem = tuple.Item2;
     119        var model = svm.svm_train(trainingSvmProblem, parameters);
     120        calc.Reset();
     121        for (int i = 0; i < testSvmProblem.l; ++i)
     122          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
     123        avgTestMse += calc.MeanSquaredError;
    102124      }
     125      avgTestMse /= partitions.Length;
    103126    }
     127
    104128
    105129    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
     
    118142    }
    119143
    120     public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
    121       var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
    122       CrossValidate(problemData, parameters, partitions, out avgTestMse);
    123     }
    124 
    125     public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, Tuple<svm_problem, svm_problem>[] partitions, out double avgTestMse) {
    126       avgTestMse = 0;
    127       var calc = new OnlineMeanSquaredErrorCalculator();
    128       foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
    129         var trainingSvmProblem = tuple.Item1;
    130         var testSvmProblem = tuple.Item2;
    131         var model = svm.svm_train(trainingSvmProblem, parameters);
    132         calc.Reset();
    133         for (int i = 0; i < testSvmProblem.l; ++i)
    134           calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
    135         avgTestMse += calc.MeanSquaredError;
     144    /// <summary>
     145    /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
     146    /// </summary>
     147    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
     148    /// <param name="problemData">The problem data</param>
     149    /// <param name="numberOfFolds">The number of folds to generate</param>
     150    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
     151    private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
     152      int size = problemData.TrainingPartition.Size;
     153      int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
     154      int start = 0, end = f;
     155      for (int i = 0; i < numberOfFolds; ++i) {
     156        if (r > 0) { ++end; --r; }
     157        yield return problemData.TrainingIndices.Skip(start).Take(end - start);
     158        start = end;
     159        end += f;
    136160      }
    137       avgTestMse /= partitions.Length;
    138161    }
    139162
     
    145168      var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile();
    146169      return setter;
    147     }
    148 
    149     public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
    150       DoubleValue mse = new DoubleValue(Double.MaxValue);
    151       var bestParam = DefaultParameters();
    152       var pNames = parameterRanges.Keys.ToList();
    153       var pRanges = pNames.Select(x => parameterRanges[x]);
    154       var crossProduct = pRanges.CartesianProduct();
    155       var setters = pNames.Select(GenerateSetter).ToList();
    156       var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
    157       Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
    158         var list = nuple.ToList();
    159         var parameters = DefaultParameters();
    160         for (int i = 0; i < pNames.Count; ++i) {
    161           var s = setters[i];
    162           s(parameters, list[i]);
    163         }
    164         double testMse;
    165         CrossValidate(problemData, parameters, partitions, out testMse);
    166         if (testMse < mse.Value) {
    167           lock (mse) { mse.Value = testMse; }
    168           lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); }
    169         }
    170       });
    171       return bestParam;
    172170    }
    173171
     
    183181      throw new ArgumentException("Problem data is neither regression or classification problem data.");
    184182    }
     183
    185184  }
    186185}
Note: See TracChangeset for help on using the changeset viewer.