Changeset 11337


Ignore:
Timestamp:
09/03/14 15:12:04 (7 years ago)
Author:
bburlacu
Message:

#2234: Refactored SVM grid search, added support for symbolic classification.

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs

    r11171 r11337  
    142142
    143143    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
    144       string svmType, string kernelType, double cost, double nu, double gamma, int degree,
    145       out double trainingAccuracy, out double testAccuracy, out int nSv) {
     144      string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
     145      return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetKernelType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree,
     146        out trainingAccuracy, out testAccuracy, out nSv);
     147    }
     148
     149    public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables,
     150      int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) {
    146151      Dataset dataset = problemData.Dataset;
    147152      string targetVariable = problemData.TargetVariable;
     
    150155      //extract SVM parameters from scope and set them
    151156      svm_parameter parameter = new svm_parameter();
    152       parameter.svm_type = GetSvmType(svmType);
    153       parameter.kernel_type = GetKernelType(kernelType);
     157      parameter.svm_type = svmType;
     158      parameter.kernel_type = kernelType;
    154159      parameter.C = cost;
    155160      parameter.nu = nu;
     
    161166      parameter.shrinking = 1;
    162167      parameter.coef0 = 0;
    163 
    164168
    165169      var weightLabels = new List<int>();
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs

    r11326 r11337  
    3939    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
    4040    public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) {
    41       double[] targetVector =
    42         dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
    43 
     41      double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray();
    4442      svm_node[][] nodes = new svm_node[targetVector.Length][];
    45       List<svm_node> tempRow;
    4643      int maxNodeIndex = 0;
    4744      int svmProblemRowIndex = 0;
    4845      List<string> inputVariablesList = inputVariables.ToList();
    4946      foreach (int row in rowIndices) {
    50         tempRow = new List<svm_node>();
     47        List<svm_node> tempRow = new List<svm_node>();
    5148        int colIndex = 1; // make sure the smallest node index for SVM = 1
    5249        foreach (var inputVariable in inputVariablesList) {
     
    6259        nodes[svmProblemRowIndex++] = tempRow.ToArray();
    6360      }
    64 
    65       return new svm_problem() { l = targetVector.Length, y = targetVector, x = nodes };
     61      return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes };
    6662    }
    6763
     
    8985
    9086    /// <summary>
    91     /// Generate a collection of training indices corresponding to folds in the data (used for crossvalidation)
     87    /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation)
    9288    /// </summary>
    9389    /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks>
    9490    /// <param name="problemData">The problem data</param>
    95     /// <param name="nFolds">The number of folds to generate</param>
     91    /// <param name="numberOfFolds">The number of folds to generate</param>
    9692    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
    97     public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int nFolds) {
     93    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) {
    9894      int size = problemData.TrainingPartition.Size;
    99 
    100       int foldSize = size / nFolds; // rounding to integer
    101       var trainingIndices = problemData.TrainingIndices;
    102 
    103       for (int i = 0; i < nFolds; ++i) {
    104         int n = i * foldSize;
    105         int s = n + 2 * foldSize > size ? foldSize + size % foldSize : foldSize;
    106         yield return trainingIndices.Skip(n).Take(s);
     95      int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder
     96      int start = 0, end = f;
     97      for (int i = 0; i < numberOfFolds; ++i) {
     98        if (r > 0) { ++end; --r; }
     99        yield return problemData.TrainingIndices.Skip(start).Take(end - start);
     100        start = end;
     101        end += f;
    107102      }
    108103    }
    109104
    110     public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numFolds, out double avgTestMse) {
    111       avgTestMse = 0;
    112       var folds = GenerateFolds(problemData, numFolds).ToList();
    113       var calc = new OnlineMeanSquaredErrorCalculator();
     105    private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds) {
     106      var folds = GenerateFolds(problemData, numberOfFolds).ToList();
    114107      var targetVariable = GetTargetVariableName(problemData);
    115       for (int i = 0; i < numFolds; ++i) {
     108      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
     109      for (int i = 0; i < numberOfFolds; ++i) {
    116110        int p = i; // avoid "access to modified closure" warning below
    117         var training = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
     111        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
    118112        var testRows = folds[i];
    119         var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, training);
     113        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
    120114        var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
     115        partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem);
     116      }
     117      return partitions;
     118    }
    121119
    122         var model = svm.svm_train(trainingSvmProblem, parameters);
    123         calc.Reset();
    124         for (int j = 0; j < testSvmProblem.l; ++j)
    125           calc.Add(testSvmProblem.y[j], svm.svm_predict(model, testSvmProblem.x[j]));
    126         avgTestMse += calc.MeanSquaredError;
    127       }
    128       avgTestMse /= numFolds;
     120    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {
     121      var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
     122      CrossValidate(problemData, parameters, partitions, out avgTestMse);
    129123    }
    130124
     
    156150      DoubleValue mse = new DoubleValue(Double.MaxValue);
    157151      var bestParam = DefaultParameters();
    158 
    159       // search for C, gamma and epsilon parameter combinations
    160152      var pNames = parameterRanges.Keys.ToList();
    161153      var pRanges = pNames.Select(x => parameterRanges[x]);
    162 
    163154      var crossProduct = pRanges.CartesianProduct();
    164155      var setters = pNames.Select(GenerateSetter).ToList();
    165       var folds = GenerateFolds(problemData, numberOfFolds).ToList();
    166 
    167       var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
    168       var targetVariable = GetTargetVariableName(problemData);
    169 
    170       for (int i = 0; i < numberOfFolds; ++i) {
    171         int p = i; // avoid "access to modified closure" warning below
    172         var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
    173         var testRows = folds[i];
    174         var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
    175         var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
    176         partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem);
    177       }
    178 
     156      var partitions = GenerateSvmPartitions(problemData, numberOfFolds);
    179157      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
    180         //  foreach (var nuple in crossProduct) {
    181158        var list = nuple.ToList();
    182159        var parameters = DefaultParameters();
     
    189166        if (testMse < mse.Value) {
    190167          lock (mse) { mse.Value = testMse; }
    191           lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); } // set best parameter values to the best found so far
     168          lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); }
    192169        }
    193170      });
Note: See TracChangeset for help on using the changeset viewer.