Changeset 11361


Ignore:
Timestamp:
09/14/14 13:30:30 (8 years ago)
Author:
bburlacu
Message:

#2234: Added the option to shuffle the crossvalidation folds (this option is on by default since libsvm does it too). Implemented stratified fold generation for classification data (ensures similar label distribution in each fold).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs

    r11342 r11361  
    2626using System.Threading.Tasks;
    2727using HeuristicLab.Common;
     28using HeuristicLab.Core;
    2829using HeuristicLab.Data;
    2930using HeuristicLab.Problems.DataAnalysis;
     31using HeuristicLab.Random;
    3032using LibSVM;
    3133
     
    5254          // => don't add NaN values in the dataset to the sparse SVM matrix representation
    5355          if (!double.IsNaN(value)) {
    54             tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index
     56            tempRow.Add(new svm_node() { index = colIndex, value = value });
     57            // nodes must be sorted in ascending ordered by column index
    5558            if (colIndex > maxNodeIndex) maxNodeIndex = colIndex;
    5659          }
     
    8487    }
    8588
    86     public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
    87       var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
    88       CalculateCrossValidationPartitions(partitions, parameters, out avgTestMse);
    89     }
    90 
    91     public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
     89    public static double CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, bool shuffleFolds = true) {
     90      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
     91      return CalculateCrossValidationPartitions(partitions, parameters);
     92    }
     93
     94    public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int numberOfFolds, bool shuffleFolds = true, int maxDegreeOfParallelism = 1) {
    9295      DoubleValue mse = new DoubleValue(Double.MaxValue);
    9396      var bestParam = DefaultParameters();
    9497      var crossProduct = parameterRanges.Values.CartesianProduct();
    9598      var setters = parameterRanges.Keys.Select(GenerateSetter).ToList();
    96       var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
    97       Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => {
     99      var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds);
     100      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism },
     101      parameterCombination => {
    98102        var parameters = DefaultParameters();
    99103        var parameterValues = parameterCombination.ToList();
    100         for (int i = 0; i < parameterValues.Count; ++i) {
     104        for (int i = 0; i < parameterValues.Count; ++i)
    101105          setters[i](parameters, parameterValues[i]);
    102         }
    103         double testMse;
    104         CalculateCrossValidationPartitions(partitions, parameters, out testMse);
     106
     107        double testMse = CalculateCrossValidationPartitions(partitions, parameters);
    105108        if (testMse < mse.Value) {
    106109          lock (mse) {
     
    113116    }
    114117
    115     private static void CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters, out double avgTestMse) {
    116       avgTestMse = 0;
     118    private static double CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters) {
     119      double avgTestMse = 0;
    117120      var calc = new OnlineMeanSquaredErrorCalculator();
    118121      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
     
    126129      }
    127130      avgTestMse /= partitions.Length;
    128     }
    129 
    130 
    131     private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
    132       var folds = GenerateFolds(problemData, numberOfFolds).ToList();
     131      return avgTestMse;
     132    }
     133
     134    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
     135      var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList();
    133136      var targetVariable = GetTargetVariableName(problemData);
    134137      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
     
    144147    }
    145148
     149    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) {
     150      var random = new MersenneTwister((uint)Environment.TickCount);
     151      if (problemData is IRegressionProblemData) {
     152        var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices;
     153        return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
     154      }
     155      if (problemData is IClassificationProblemData) {
     156        // when shuffle is enabled do stratified folds generation, some folds may have zero elements
     157        // otherwise, generate folds normally
     158        return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds);
     159      }
     160      throw new ArgumentException("Problem data is neither regression or classification problem data.");
     161    }
     162
    146163    /// <summary>
    147     /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
     164    /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold.
     165    /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of
     166    /// the corresponding parts from each class label.
    148167    /// </summary>
    149     /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
    150     /// <param name="problemData">The problem data</param>
    151     /// <param name="numberOfFolds">The number of folds to generate</param>
    152     /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
    153     private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
    154       int size = problemData.TrainingPartition.Size;
    155       int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
    156       int start = 0, end = f;
    157       for (int i = 0; i < numberOfFolds; ++i) {
    158         if (r > 0) { ++end; --r; }
    159         yield return problemData.TrainingIndices.Skip(start).Take(end - start);
    160         start = end;
    161         end += f;
     168    /// <param name="problemData">The classification problem data.</param>
     169    /// <param name="numberOfFolds">The number of folds in which to split the data.</param>
     170    /// <param name="random">The random generator used to shuffle the folds.</param>
     171    /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns>
     172    private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) {
     173      var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
     174      var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList();
     175      IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds));
     176      var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList();
     177      while (enumerators.All(e => e.MoveNext())) {
     178        yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList();
     179      }
     180    }
     181
     182    private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) {
     183      // if number of folds is greater than the number of values, some empty folds will be returned
     184      if (valuesCount < numberOfFolds) {
     185        for (int i = 0; i < numberOfFolds; ++i)
     186          yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>();
     187      } else {
     188        int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder
     189        int start = 0, end = f;
     190        for (int i = 0; i < numberOfFolds; ++i) {
     191          if (r > 0) {
     192            ++end;
     193            --r;
     194          }
     195          yield return values.Skip(start).Take(end - start);
     196          start = end;
     197          end += f;
     198        }
    162199      }
    163200    }
     
    183220      throw new ArgumentException("Problem data is neither regression or classification problem data.");
    184221    }
    185 
    186222  }
    187223}
Note: See TracChangeset for help on using the changeset viewer.