Free cookie consent management tool by TermsFeed Policy Generator

Changeset 11326


Ignore:
Timestamp:
09/02/14 09:16:52 (10 years ago)
Author:
bburlacu
Message:

#2234: Refactored CrossValidate and GridSearch methods.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs

    r11308 r11326  
    9595    /// <param name="nFolds">The number of folds to generate</param>
    9696    /// <returns>A sequence of folds representing each a sequence of row numbers</returns>
    97     public static IEnumerable<IEnumerable<int>> GenerateFolds(IRegressionProblemData problemData, int nFolds) {
     97    public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int nFolds) {
    9898      int size = problemData.TrainingPartition.Size;
    9999
     
    108108    }
    109109
    110     /// <summary>
    111     /// Performs crossvalidation
    112     /// </summary>
    113     /// <param name="problemData">The problem data</param>
    114     /// <param name="parameters">The svm parameters</param>
    115     /// <param name="folds">The svm_problem instances for each fold</param>
    116     /// <param name="avgTestMSE">The average test mean squared error (not used atm)</param>
    117     public static void CrossValidate(IRegressionProblemData problemData, svm_parameter parameters, IEnumerable<IEnumerable<int>> folds, out double avgTestMSE) {
    118       avgTestMSE = 0;
    119 
     110    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numFolds, out double avgTestMse) {
     111      avgTestMse = 0;
     112      var folds = GenerateFolds(problemData, numFolds).ToList();
    120113      var calc = new OnlineMeanSquaredErrorCalculator();
    121       var ds = problemData.Dataset;
    122       var targetVariable = problemData.TargetVariable;
    123       var inputVariables = problemData.AllowedInputVariables;
    124 
    125       var svmProblem = CreateSvmProblem(ds, targetVariable, inputVariables, problemData.TrainingIndices);
    126       var partitions = folds.ToList();
    127 
    128       for (int i = 0; i < partitions.Count; ++i) {
    129         var test = partitions[i];
    130         var training = new List<int>();
    131         for (int j = 0; j < i; ++j)
    132           training.AddRange(partitions[j]);
    133 
    134         for (int j = i + 1; j < partitions.Count; ++j)
    135           training.AddRange(partitions[j]);
    136 
    137         var p = CreateSvmProblem(ds, targetVariable, inputVariables, training);
    138         var model = svm.svm_train(p, parameters);
     114      var targetVariable = GetTargetVariableName(problemData);
     115      for (int i = 0; i < numFolds; ++i) {
     116        int p = i; // avoid "access to modified closure" warning below
     117        var training = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
     118        var testRows = folds[i];
     119        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, training);
     120        var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
     121
     122        var model = svm.svm_train(trainingSvmProblem, parameters);
    139123        calc.Reset();
    140         foreach (var row in test) {
    141           calc.Add(svmProblem.y[row], svm.svm_predict(model, svmProblem.x[row]));
    142         }
    143         double error = calc.MeanSquaredError;
    144         avgTestMSE += error;
    145       }
    146 
    147       avgTestMSE /= partitions.Count;
    148     }
    149 
    150     /// <summary>
    151     /// Dynamically generate a setter for svm_parameter fields
    152     /// </summary>
    153     /// <param name="parameters"></param>
    154     /// <param name="fieldName"></param>
    155     /// <returns></returns>
     124        for (int j = 0; j < testSvmProblem.l; ++j)
     125          calc.Add(testSvmProblem.y[j], svm.svm_predict(model, testSvmProblem.x[j]));
     126        avgTestMse += calc.MeanSquaredError;
     127      }
     128      avgTestMse /= numFolds;
     129    }
     130
     131    public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, Tuple<svm_problem, svm_problem>[] partitions, out double avgTestMse) {
     132      avgTestMse = 0;
     133      var calc = new OnlineMeanSquaredErrorCalculator();
     134      foreach (Tuple<svm_problem, svm_problem> tuple in partitions) {
     135        var trainingSvmProblem = tuple.Item1;
     136        var testSvmProblem = tuple.Item2;
     137        var model = svm.svm_train(trainingSvmProblem, parameters);
     138        calc.Reset();
     139        for (int i = 0; i < testSvmProblem.l; ++i)
     140          calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i]));
     141        avgTestMse += calc.MeanSquaredError;
     142      }
     143      avgTestMse /= partitions.Length;
     144    }
     145
    156146    private static Action<svm_parameter, double> GenerateSetter(string fieldName) {
    157147      var targetExp = Expression.Parameter(typeof(svm_parameter));
    158148      var valueExp = Expression.Parameter(typeof(double));
    159 
    160       // Expression.Property can be used here as well
    161149      var fieldExp = Expression.Field(targetExp, fieldName);
    162150      var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type));
     
    165153    }
    166154
    167     public static svm_parameter GridSearch(IRegressionProblemData problemData, IEnumerable<IEnumerable<int>> folds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
     155    public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {
    168156      DoubleValue mse = new DoubleValue(Double.MaxValue);
    169157      var bestParam = DefaultParameters();
    170158
    171159      // search for C, gamma and epsilon parameter combinations
    172 
    173160      var pNames = parameterRanges.Keys.ToList();
    174161      var pRanges = pNames.Select(x => parameterRanges[x]);
     
    176163      var crossProduct = pRanges.CartesianProduct();
    177164      var setters = pNames.Select(GenerateSetter).ToList();
     165      var folds = GenerateFolds(problemData, numberOfFolds).ToList();
     166
     167      var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds];
     168      var targetVariable = GetTargetVariableName(problemData);
     169
     170      for (int i = 0; i < numberOfFolds; ++i) {
     171        int p = i; // avoid "access to modified closure" warning below
     172        var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>());
     173        var testRows = folds[i];
     174        var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows);
     175        var testSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows);
     176        partitions[i] = new Tuple<svm_problem, svm_problem>(trainingSvmProblem, testSvmProblem);
     177      }
     178
    178179      Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, nuple => {
    179180        //  foreach (var nuple in crossProduct) {
     
    184185          s(parameters, list[i]);
    185186        }
    186         double testMSE;
    187         CrossValidate(problemData, parameters, folds, out testMSE);
    188         if (testMSE < mse.Value) {
    189           lock (mse) { mse.Value = testMSE; }
    190           lock (bestParam) { // set best parameter values to the best found so far
    191             bestParam = (svm_parameter)parameters.Clone();
    192           }
     187        double testMse;
     188        CrossValidate(problemData, parameters, partitions, out testMse);
     189        if (testMse < mse.Value) {
     190          lock (mse) { mse.Value = testMse; }
     191          lock (bestParam) { bestParam = (svm_parameter)parameters.Clone(); } // set best parameter values to the best found so far
    193192        }
    194193      });
    195194      return bestParam;
    196195    }
     196
     197    private static string GetTargetVariableName(IDataAnalysisProblemData problemData) {
     198      var regressionProblemData = problemData as IRegressionProblemData;
     199      var classificationProblemData = problemData as IClassificationProblemData;
     200
     201      if (regressionProblemData != null)
     202        return regressionProblemData.TargetVariable;
     203      if (classificationProblemData != null)
     204        return classificationProblemData.TargetVariable;
     205
     206      throw new ArgumentException("Problem data is neither regression or classification problem data.");
     207    }
    197208  }
    198209}
Note: See TracChangeset for help on using the changeset viewer.