- Timestamp:
- 02/05/15 10:51:29 (10 years ago)
- Location:
- stable
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
stable
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Algorithms.DataAnalysis merged: 11308,11326,11337,11339-11340,11342,11361,11427,11464,11542
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs
r11170 r11907 142 142 143 143 public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables, 144 string svmType, string kernelType, double cost, double nu, double gamma, int degree, 145 out double trainingAccuracy, out double testAccuracy, out int nSv) { 144 string svmType, string kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) { 145 return CreateSupportVectorClassificationSolution(problemData, allowedInputVariables, GetSvmType(svmType), GetKernelType(kernelType), cost, nu, gamma, degree, 146 out trainingAccuracy, out testAccuracy, out nSv); 147 } 148 149 public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables, 150 int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) { 146 151 Dataset dataset = problemData.Dataset; 147 152 string targetVariable = problemData.TargetVariable; … … 150 155 //extract SVM parameters from scope and set them 151 156 svm_parameter parameter = new svm_parameter(); 152 parameter.svm_type = GetSvmType(svmType);153 parameter.kernel_type = GetKernelType(kernelType);157 parameter.svm_type = svmType; 158 parameter.kernel_type = kernelType; 154 159 parameter.C = cost; 155 160 parameter.nu = nu; … … 161 166 parameter.shrinking = 1; 162 167 parameter.coef0 = 0; 163 164 168 165 169 var weightLabels = new List<int>(); -
stable/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs
r11170 r11907 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 25 using System.Linq.Expressions; 26 using System.Threading.Tasks; 27 using HeuristicLab.Common; 28 using HeuristicLab.Core; 29 using HeuristicLab.Data; 24 30 using HeuristicLab.Problems.DataAnalysis; 31 using HeuristicLab.Random; 25 32 using LibSVM; 26 33 … … 34 41 /// <returns>A problem data type that can be used to train a support vector machine.</returns> 35 42 public static svm_problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) { 36 double[] targetVector = 37 dataset.GetDoubleValues(targetVariable, rowIndices).ToArray(); 38 43 double[] targetVector = dataset.GetDoubleValues(targetVariable, rowIndices).ToArray(); 39 44 svm_node[][] nodes = new svm_node[targetVector.Length][]; 40 List<svm_node> tempRow;41 45 int maxNodeIndex = 0; 42 46 int svmProblemRowIndex = 0; 43 47 List<string> inputVariablesList = inputVariables.ToList(); 44 48 foreach (int row in rowIndices) { 45 tempRow = new List<svm_node>();49 List<svm_node> tempRow = new List<svm_node>(); 46 50 int colIndex = 1; // make sure the smallest node index for SVM = 1 47 51 foreach (var inputVariable in inputVariablesList) { … … 50 54 // => don't add NaN values in the dataset to the sparse SVM matrix representation 51 55 if (!double.IsNaN(value)) { 52 tempRow.Add(new svm_node() { index = colIndex, value = value }); // nodes must be sorted in ascending ordered by column index 56 tempRow.Add(new svm_node() { index = colIndex, value = value }); 57 // nodes must be sorted in ascending ordered by column index 53 58 if (colIndex > maxNodeIndex) maxNodeIndex = colIndex; 54 59 } … … 57 62 nodes[svmProblemRowIndex++] = tempRow.ToArray(); 58 63 } 59 60 return new svm_problem() { l = targetVector.Length, y = targetVector, x = nodes }; 64 return new svm_problem { l = targetVector.Length, y = targetVector, x = nodes }; 65 } 66 67 /// <summary> 68 /// Instantiate and return a svm_parameter object with default values. 69 /// </summary> 70 /// <returns>A svm_parameter object with default values</returns> 71 public static svm_parameter DefaultParameters() { 72 svm_parameter parameter = new svm_parameter(); 73 parameter.svm_type = svm_parameter.NU_SVR; 74 parameter.kernel_type = svm_parameter.RBF; 75 parameter.C = 1; 76 parameter.nu = 0.5; 77 parameter.gamma = 1; 78 parameter.p = 1; 79 parameter.cache_size = 500; 80 parameter.probability = 0; 81 parameter.eps = 0.001; 82 parameter.degree = 3; 83 parameter.shrinking = 1; 84 parameter.coef0 = 0; 85 86 return parameter; 87 } 88 89 public static double CrossValidate(IDataAnalysisProblemData problemData, svm_parameter parameters, int numberOfFolds, bool shuffleFolds = true) { 90 var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds); 91 return CalculateCrossValidationPartitions(partitions, parameters); 92 } 93 94 public static svm_parameter GridSearch(out double cvMse, IDataAnalysisProblemData problemData, Dictionary<string, IEnumerable<double>> parameterRanges, int numberOfFolds, bool shuffleFolds = true, int maxDegreeOfParallelism = 1) { 95 DoubleValue mse = new DoubleValue(Double.MaxValue); 96 var bestParam = DefaultParameters(); 97 var crossProduct = parameterRanges.Values.CartesianProduct(); 98 var setters = parameterRanges.Keys.Select(GenerateSetter).ToList(); 99 var partitions = GenerateSvmPartitions(problemData, numberOfFolds, shuffleFolds); 100 101 var locker = new object(); // for thread synchronization 102 Parallel.ForEach(crossProduct, new ParallelOptions { MaxDegreeOfParallelism = maxDegreeOfParallelism }, 103 parameterCombination => { 104 var parameters = DefaultParameters(); 105 var parameterValues = parameterCombination.ToList(); 106 for (int i = 0; i < parameterValues.Count; ++i) 107 setters[i](parameters, parameterValues[i]); 108 109 double testMse = CalculateCrossValidationPartitions(partitions, parameters); 110 if (!double.IsNaN(testMse)) { 111 lock (locker) { 112 if (testMse < mse.Value) { 113 mse.Value = testMse; 114 bestParam = (svm_parameter)parameters.Clone(); 115 } 116 } 117 } 118 }); 119 cvMse = mse.Value; 120 return bestParam; 121 } 122 123 private static double CalculateCrossValidationPartitions(Tuple<svm_problem, svm_problem>[] partitions, svm_parameter parameters) { 124 double avgTestMse = 0; 125 var calc = new OnlineMeanSquaredErrorCalculator(); 126 foreach (Tuple<svm_problem, svm_problem> tuple in partitions) { 127 var trainingSvmProblem = tuple.Item1; 128 var testSvmProblem = tuple.Item2; 129 var model = svm.svm_train(trainingSvmProblem, parameters); 130 calc.Reset(); 131 for (int i = 0; i < testSvmProblem.l; ++i) 132 calc.Add(testSvmProblem.y[i], svm.svm_predict(model, testSvmProblem.x[i])); 133 double mse = calc.ErrorState == OnlineCalculatorError.None ? calc.MeanSquaredError : double.NaN; 134 avgTestMse += mse; 135 } 136 avgTestMse /= partitions.Length; 137 return avgTestMse; 138 } 139 140 private static Tuple<svm_problem, svm_problem>[] GenerateSvmPartitions(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) { 141 var folds = GenerateFolds(problemData, numberOfFolds, shuffleFolds).ToList(); 142 var targetVariable = GetTargetVariableName(problemData); 143 var partitions = new Tuple<svm_problem, svm_problem>[numberOfFolds]; 144 for (int i = 0; i < numberOfFolds; ++i) { 145 int p = i; // avoid "access to modified closure" warning below 146 var trainingRows = folds.SelectMany((par, j) => j != p ? par : Enumerable.Empty<int>()); 147 var testRows = folds[i]; 148 var trainingSvmProblem = CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, trainingRows); 149 var rangeTransform = RangeTransform.Compute(trainingSvmProblem); 150 var testSvmProblem = rangeTransform.Scale(CreateSvmProblem(problemData.Dataset, targetVariable, problemData.AllowedInputVariables, testRows)); 151 partitions[i] = new Tuple<svm_problem, svm_problem>(rangeTransform.Scale(trainingSvmProblem), testSvmProblem); 152 } 153 return partitions; 154 } 155 156 public static IEnumerable<IEnumerable<int>> GenerateFolds(IDataAnalysisProblemData problemData, int numberOfFolds, bool shuffleFolds = true) { 157 var random = new MersenneTwister((uint)Environment.TickCount); 158 if (problemData is IRegressionProblemData) { 159 var trainingIndices = shuffleFolds ? problemData.TrainingIndices.OrderBy(x => random.Next()) : problemData.TrainingIndices; 160 return GenerateFolds(trainingIndices, problemData.TrainingPartition.Size, numberOfFolds); 161 } 162 if (problemData is IClassificationProblemData) { 163 // when shuffle is enabled do stratified folds generation, some folds may have zero elements 164 // otherwise, generate folds normally 165 return shuffleFolds ? GenerateFoldsStratified(problemData as IClassificationProblemData, numberOfFolds, random) : GenerateFolds(problemData.TrainingIndices, problemData.TrainingPartition.Size, numberOfFolds); 166 } 167 throw new ArgumentException("Problem data is neither regression or classification problem data."); 168 } 169 170 /// <summary> 171 /// Stratified fold generation from classification data. Stratification means that we ensure the same distribution of class labels for each fold. 172 /// The samples are grouped by class label and each group is split into @numberOfFolds parts. The final folds are formed from the joining of 173 /// the corresponding parts from each class label. 174 /// </summary> 175 /// <param name="problemData">The classification problem data.</param> 176 /// <param name="numberOfFolds">The number of folds in which to split the data.</param> 177 /// <param name="random">The random generator used to shuffle the folds.</param> 178 /// <returns>An enumerable sequece of folds, where a fold is represented by a sequence of row indices.</returns> 179 private static IEnumerable<IEnumerable<int>> GenerateFoldsStratified(IClassificationProblemData problemData, int numberOfFolds, IRandom random) { 180 var values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 181 var valuesIndices = problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }).ToList(); 182 IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = valuesIndices.GroupBy(x => x.Value, x => x.Index).Select(g => GenerateFolds(g, g.Count(), numberOfFolds)); 183 var enumerators = foldsByClass.Select(f => f.GetEnumerator()).ToList(); 184 while (enumerators.All(e => e.MoveNext())) { 185 yield return enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()).ToList(); 186 } 187 } 188 189 private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(IEnumerable<T> values, int valuesCount, int numberOfFolds) { 190 // if number of folds is greater than the number of values, some empty folds will be returned 191 if (valuesCount < numberOfFolds) { 192 for (int i = 0; i < numberOfFolds; ++i) 193 yield return i < valuesCount ? values.Skip(i).Take(1) : Enumerable.Empty<T>(); 194 } else { 195 int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; // number of folds rounded to integer and remainder 196 int start = 0, end = f; 197 for (int i = 0; i < numberOfFolds; ++i) { 198 if (r > 0) { 199 ++end; 200 --r; 201 } 202 yield return values.Skip(start).Take(end - start); 203 start = end; 204 end += f; 205 } 206 } 207 } 208 209 private static Action<svm_parameter, double> GenerateSetter(string fieldName) { 210 var targetExp = Expression.Parameter(typeof(svm_parameter)); 211 var valueExp = Expression.Parameter(typeof(double)); 212 var fieldExp = Expression.Field(targetExp, fieldName); 213 var assignExp = Expression.Assign(fieldExp, Expression.Convert(valueExp, fieldExp.Type)); 214 var setter = Expression.Lambda<Action<svm_parameter, double>>(assignExp, targetExp, valueExp).Compile(); 215 return setter; 216 } 217 218 private static string GetTargetVariableName(IDataAnalysisProblemData problemData) { 219 var regressionProblemData = problemData as IRegressionProblemData; 220 var classificationProblemData = problemData as IClassificationProblemData; 221 222 if (regressionProblemData != null) 223 return regressionProblemData.TargetVariable; 224 if (classificationProblemData != null) 225 return classificationProblemData.TargetVariable; 226 227 throw new ArgumentException("Problem data is neither regression or classification problem data."); 61 228 } 62 229 }
Note: See TracChangeset
for help on using the changeset viewer.