Changeset 11361
 Timestamp:
 09/14/14 13:30:30 (8 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs
r11342 r11361 26 26 using System.Threading.Tasks; 27 27 using HeuristicLab.Common; 28 using HeuristicLab.Core; 28 29 using HeuristicLab.Data; 29 30 using HeuristicLab.Problems.DataAnalysis; 31 using HeuristicLab.Random; 30 32 using LibSVM; 31 33 … … 52 54 // => don't add NaN values in the dataset to the sparse SVM matrix representation 53 55 if (!double.IsNaN(value)) { 54 tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index 56 tempRow.Add(new svm_node() { index = colIndex, value = value }); 57 // nodes must be sorted in ascending ordered by column index 55 58 if (colIndex > maxNodeIndex) maxNodeIndex = colIndex; 56 59 } … … 84 87 } 85 88 86 public static void CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, out double avgTestMse) {87 var partitions = GenerateSvmPartitions(problemData, numberOfFolds );88 CalculateCrossValidationPartitions(partitions, parameters, out avgTestMse);89 } 90 91 public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, int numberOfFolds, Dictionary<string, IEnumerable<double>> parameterRanges, int maxDegreeOfParallelism = 1) {89 public static double CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, bool shuffleFolds = true) { 90 var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds); 91 return CalculateCrossValidationPartitions(partitions, parameters); 92 } 93 94 public static svm_parameter GridSearch(IDataAnalysisProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int numberOfFolds, bool shuffleFolds = true, int maxDegreeOfParallelism = 1) { 92 95 DoubleValue mse = new DoubleValue(Double.MaxValue); 93 96 var bestParam = DefaultParameters(); 94 97 var crossProduct = parameterRanges.Values.CartesianProduct(); 95 98 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 96 var partitions = GenerateSvmPartitions(problemData, numberOfFolds); 97 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, parameterCombination => { 99 var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds); 100 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, 101 parameterCombination => { 98 102 var parameters = DefaultParameters(); 99 103 var parameterValues = parameterCombination.ToList(); 100 for (int i = 0; i < parameterValues.Count; ++i) {104 for (int i = 0; i < parameterValues.Count; ++i) 101 105 setters[i](parameters, parameterValues[i]); 102 } 103 double testMse; 104 CalculateCrossValidationPartitions(partitions, parameters, out testMse); 106 107 double testMse = CalculateCrossValidationPartitions(partitions, parameters); 105 108 if (testMse < mse.Value) { 106 109 lock (mse) { … … 113 116 } 114 117 115 private static void CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters, out double avgTestMse) {116 avgTestMse = 0;118 private static double CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters) { 119 double avgTestMse = 0; 117 120 var calc = new OnlineMeanSquaredErrorCalculator(); 118 121 foreach (Tuple<svm_problem, svm_problem> tuple in partitions) { … … 126 129 } 127 130 avgTestMse /= partitions.Length; 128 }129 130 131 private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds ) {132 var folds = GenerateFolds(problemData, numberOfFolds ).ToList();131 return avgTestMse; 132 } 133 134 private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) { 135 var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList(); 133 136 var targetVariable = GetTargetVariableName(problemData); 134 137 var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds]; … … 144 147 } 145 148 149 public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) { 150 var random = new MersenneTwister((uint)Environment.TickCount); 151 if (problemData is IRegressionProblemData) { 152 var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices; 153 return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds); 154 } 155 if (problemData is IClassificationProblemData) { 156 // when shuffle is enabled do stratified folds generation, some folds may have zero elements 157 // otherwise, generate folds normally 158 return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds); 159 } 160 throw new ArgumentException("Problem data is neither regression or classification problem data."); 161 } 162 146 163 /// <summary> 147 /// Generate a collection of row sequences corresponding to folds in the data (used for crossvalidation) 164 /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold. 165 /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of 166 /// the corresponding parts from each class label. 148 167 /// </summary> 149 /// <remarks>This method is aimed to be lightweight and as such does not clone the dataset.</remarks> 150 /// <param name="problemData">The problem data</param> 151 /// <param name="numberOfFolds">The number of folds to generate</param> 152 /// <returns>A sequence of folds representing each a sequence of row numbers</returns> 153 private static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds) { 154 int size = problemData.TrainingPartition.Size; 155 int f = size / numberOfFolds, r = size % numberOfFolds; // number of folds rounded to integer and remainder 156 int start = 0, end = f; 157 for (int i = 0; i < numberOfFolds; ++i) { 158 if (r > 0) { ++end; r; } 159 yield return problemData.TrainingIndices.Skip(start).Take(end  start); 160 start = end; 161 end += f; 168 /// <param name="problemData">The classification problem data.</param> 169 /// <param name="numberOfFolds">The number of folds in which to split the data.</param> 170 /// <param name="random">The random generator used to shuffle the folds.</param> 171 /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns> 172 private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) { 173 var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 174 var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList(); 175 IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds)); 176 var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList(); 177 while (enumerators.All(e => e.MoveNext())) { 178 yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList(); 179 } 180 } 181 182 private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) { 183 // if number of folds is greater than the number of values, some empty folds will be returned 184 if (valuesCount < numberOfFolds) { 185 for (int i = 0; i < numberOfFolds; ++i) 186 yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>(); 187 } else { 188 int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder 189 int start = 0, end = f; 190 for (int i = 0; i < numberOfFolds; ++i) { 191 if (r > 0) { 192 ++end; 193 r; 194 } 195 yield return values.Skip(start).Take(end  start); 196 start = end; 197 end += f; 198 } 162 199 } 163 200 } … … 183 220 throw new ArgumentException("Problem data is neither regression or classification problem data."); 184 221 } 185 186 222 } 187 223 }
Note: See TracChangeset
for help on using the changeset viewer.