Changeset 4068 for trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/Training.cs
- Timestamp:
- 07/22/10 00:44:01 (14 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/Training.cs
r3884 r4068 19 19 20 20 using System; 21 using System.Collections.Generic; 22 23 namespace SVM 24 { 25 /// <summary> 26 /// Class containing the routines to train SVM models. 27 /// </summary> 28 public static class Training 29 { 30 /// <summary> 31 /// Whether the system will output information to the console during the training process. 32 /// </summary> 33 public static bool IsVerbose 34 { 35 get 36 { 37 return Procedures.IsVerbose; 21 22 namespace SVM { 23 /// <summary> 24 /// Class containing the routines to train SVM models. 25 /// </summary> 26 public static class Training { 27 /// <summary> 28 /// Whether the system will output information to the console during the training process. 29 /// </summary> 30 public static bool IsVerbose { 31 get { 32 return Procedures.IsVerbose; 33 } 34 set { 35 Procedures.IsVerbose = value; 36 } 37 } 38 39 private static double doCrossValidation(Problem problem, Parameter parameters, int nr_fold, bool shuffleTraining) { 40 int i; 41 double[] target = new double[problem.Count]; 42 Procedures.svm_cross_validation(problem, parameters, nr_fold, target, shuffleTraining); 43 int total_correct = 0; 44 double total_error = 0; 45 //double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; 46 if (parameters.SvmType == SvmType.EPSILON_SVR || parameters.SvmType == SvmType.NU_SVR) { 47 for (i = 0; i < problem.Count; i++) { 48 double y = problem.Y[i]; 49 double v = target[i]; 50 total_error += (v - y) * (v - y); 51 //sumv += v; 52 //sumy += y; 53 //sumvv += v * v; 54 //sumyy += y * y; 55 //sumvy += v * y; 56 } 57 return total_error / problem.Count; // return MSE 58 // (problem.Count * sumvy - sumv * sumy) / (Math.Sqrt(problem.Count * sumvv - sumv * sumv) * Math.Sqrt(problem.Count * sumyy - sumy * sumy)); 59 } else 60 for (i = 0; i < problem.Count; i++) 61 if (target[i] == problem.Y[i]) 62 ++total_correct; 63 return (double)total_correct / problem.Count; 64 } 65 /// <summary> 66 /// Legacy. Allows use as if this was svm_train. See libsvm documentation for details on which arguments to pass. 67 /// </summary> 68 /// <param name="args"></param> 69 [Obsolete("Provided only for legacy compatibility, use the other Train() methods")] 70 public static void Train(params string[] args) { 71 Parameter parameters; 72 Problem problem; 73 bool crossValidation; 74 int nrfold; 75 string modelFilename; 76 parseCommandLine(args, out parameters, out problem, out crossValidation, out nrfold, out modelFilename); 77 if (crossValidation) 78 PerformCrossValidation(problem, parameters, nrfold, true); 79 else Model.Write(modelFilename, Train(problem, parameters)); 80 } 81 82 /// <summary> 83 /// Performs cross validation. 84 /// </summary> 85 /// <param name="problem">The training data</param> 86 /// <param name="parameters">The parameters to test</param> 87 /// <param name="nrfold">The number of cross validations to use</param> 88 /// <returns>The cross validation score</returns> 89 public static double PerformCrossValidation(Problem problem, Parameter parameters, int nrfold, bool shuffleTraining) { 90 string error = Procedures.svm_check_parameter(problem, parameters); 91 if (error == null) 92 return doCrossValidation(problem, parameters, nrfold, shuffleTraining); 93 else throw new Exception(error); 94 } 95 96 /// <summary> 97 /// Trains a model using the provided training data and parameters. 98 /// </summary> 99 /// <param name="problem">The training data</param> 100 /// <param name="parameters">The parameters to use</param> 101 /// <returns>A trained SVM Model</returns> 102 public static Model Train(Problem problem, Parameter parameters) { 103 string error = Procedures.svm_check_parameter(problem, parameters); 104 105 if (error == null) 106 return Procedures.svm_train(problem, parameters); 107 else throw new Exception(error); 108 } 109 110 private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename) { 111 int i; 112 113 parameters = new Parameter(); 114 // default values 115 116 crossValidation = false; 117 nrfold = 0; 118 119 // parse options 120 for (i = 0; i < args.Length; i++) { 121 if (args[i][0] != '-') 122 break; 123 ++i; 124 switch (args[i - 1][1]) { 125 126 case 's': 127 parameters.SvmType = (SvmType)int.Parse(args[i]); 128 break; 129 130 case 't': 131 parameters.KernelType = (KernelType)int.Parse(args[i]); 132 break; 133 134 case 'd': 135 parameters.Degree = int.Parse(args[i]); 136 break; 137 138 case 'g': 139 parameters.Gamma = double.Parse(args[i]); 140 break; 141 142 case 'r': 143 parameters.Coefficient0 = double.Parse(args[i]); 144 break; 145 146 case 'n': 147 parameters.Nu = double.Parse(args[i]); 148 break; 149 150 case 'm': 151 parameters.CacheSize = double.Parse(args[i]); 152 break; 153 154 case 'c': 155 parameters.C = double.Parse(args[i]); 156 break; 157 158 case 'e': 159 parameters.EPS = double.Parse(args[i]); 160 break; 161 162 case 'p': 163 parameters.P = double.Parse(args[i]); 164 break; 165 166 case 'h': 167 parameters.Shrinking = int.Parse(args[i]) == 1; 168 break; 169 170 case 'b': 171 parameters.Probability = int.Parse(args[i]) == 1; 172 break; 173 174 case 'v': 175 crossValidation = true; 176 nrfold = int.Parse(args[i]); 177 if (nrfold < 2) { 178 throw new ArgumentException("n-fold cross validation: n must >= 2"); 38 179 } 39 set 40 { 41 Procedures.IsVerbose = value; 42 } 180 break; 181 182 case 'w': 183 parameters.Weights[int.Parse(args[i - 1].Substring(2))] = double.Parse(args[1]); 184 break; 185 186 default: 187 throw new ArgumentException("Unknown Parameter"); 43 188 } 44 45 private static double doCrossValidation(Problem problem, Parameter parameters, int nr_fold, bool shuffleTraining) 46 { 47 int i; 48 double[] target = new double[problem.Count]; 49 Procedures.svm_cross_validation(problem, parameters, nr_fold, target, shuffleTraining); 50 int total_correct = 0; 51 double total_error = 0; 52 //double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; 53 if (parameters.SvmType == SvmType.EPSILON_SVR || parameters.SvmType == SvmType.NU_SVR) 54 { 55 for (i = 0; i < problem.Count; i++) 56 { 57 double y = problem.Y[i]; 58 double v = target[i]; 59 total_error += (v - y) * (v - y); 60 //sumv += v; 61 //sumy += y; 62 //sumvv += v * v; 63 //sumyy += y * y; 64 //sumvy += v * y; 65 } 66 return total_error / problem.Count; // return MSE 67 // (problem.Count * sumvy - sumv * sumy) / (Math.Sqrt(problem.Count * sumvv - sumv * sumv) * Math.Sqrt(problem.Count * sumyy - sumy * sumy)); 68 } 69 else 70 for (i = 0; i < problem.Count; i++) 71 if (target[i] == problem.Y[i]) 72 ++total_correct; 73 return (double)total_correct / problem.Count; 74 } 75 /// <summary> 76 /// Legacy. Allows use as if this was svm_train. See libsvm documentation for details on which arguments to pass. 77 /// </summary> 78 /// <param name="args"></param> 79 [Obsolete("Provided only for legacy compatibility, use the other Train() methods")] 80 public static void Train(params string[] args) 81 { 82 Parameter parameters; 83 Problem problem; 84 bool crossValidation; 85 int nrfold; 86 string modelFilename; 87 parseCommandLine(args, out parameters, out problem, out crossValidation, out nrfold, out modelFilename); 88 if (crossValidation) 89 PerformCrossValidation(problem, parameters, nrfold, true); 90 else Model.Write(modelFilename, Train(problem, parameters)); 91 } 92 93 /// <summary> 94 /// Performs cross validation. 95 /// </summary> 96 /// <param name="problem">The training data</param> 97 /// <param name="parameters">The parameters to test</param> 98 /// <param name="nrfold">The number of cross validations to use</param> 99 /// <returns>The cross validation score</returns> 100 public static double PerformCrossValidation(Problem problem, Parameter parameters, int nrfold, bool shuffleTraining) 101 { 102 string error = Procedures.svm_check_parameter(problem, parameters); 103 if (error == null) 104 return doCrossValidation(problem, parameters, nrfold, shuffleTraining); 105 else throw new Exception(error); 106 } 107 108 /// <summary> 109 /// Trains a model using the provided training data and parameters. 110 /// </summary> 111 /// <param name="problem">The training data</param> 112 /// <param name="parameters">The parameters to use</param> 113 /// <returns>A trained SVM Model</returns> 114 public static Model Train(Problem problem, Parameter parameters) 115 { 116 string error = Procedures.svm_check_parameter(problem, parameters); 117 118 if (error == null) 119 return Procedures.svm_train(problem, parameters); 120 else throw new Exception(error); 121 } 122 123 private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename) 124 { 125 int i; 126 127 parameters = new Parameter(); 128 // default values 129 130 crossValidation = false; 131 nrfold = 0; 132 133 // parse options 134 for (i = 0; i < args.Length; i++) 135 { 136 if (args[i][0] != '-') 137 break; 138 ++i; 139 switch (args[i - 1][1]) 140 { 141 142 case 's': 143 parameters.SvmType = (SvmType)int.Parse(args[i]); 144 break; 145 146 case 't': 147 parameters.KernelType = (KernelType)int.Parse(args[i]); 148 break; 149 150 case 'd': 151 parameters.Degree = int.Parse(args[i]); 152 break; 153 154 case 'g': 155 parameters.Gamma = double.Parse(args[i]); 156 break; 157 158 case 'r': 159 parameters.Coefficient0 = double.Parse(args[i]); 160 break; 161 162 case 'n': 163 parameters.Nu = double.Parse(args[i]); 164 break; 165 166 case 'm': 167 parameters.CacheSize = double.Parse(args[i]); 168 break; 169 170 case 'c': 171 parameters.C = double.Parse(args[i]); 172 break; 173 174 case 'e': 175 parameters.EPS = double.Parse(args[i]); 176 break; 177 178 case 'p': 179 parameters.P = double.Parse(args[i]); 180 break; 181 182 case 'h': 183 parameters.Shrinking = int.Parse(args[i]) == 1; 184 break; 185 186 case 'b': 187 parameters.Probability = int.Parse(args[i]) == 1; 188 break; 189 190 case 'v': 191 crossValidation = true; 192 nrfold = int.Parse(args[i]); 193 if (nrfold < 2) 194 { 195 throw new ArgumentException("n-fold cross validation: n must >= 2"); 196 } 197 break; 198 199 case 'w': 200 parameters.Weights[int.Parse(args[i - 1].Substring(2))] = double.Parse(args[1]); 201 break; 202 203 default: 204 throw new ArgumentException("Unknown Parameter"); 205 } 206 } 207 208 // determine filenames 209 210 if (i >= args.Length) 211 throw new ArgumentException("No input file specified"); 212 213 problem = Problem.Read(args[i]); 214 215 if (parameters.Gamma == 0) 216 parameters.Gamma = 1.0 / problem.MaxIndex; 217 218 if (i < args.Length - 1) 219 modelFilename = args[i + 1]; 220 else 221 { 222 int p = args[i].LastIndexOf('/') + 1; 223 modelFilename = args[i].Substring(p) + ".model"; 224 } 225 } 226 } 189 } 190 191 // determine filenames 192 193 if (i >= args.Length) 194 throw new ArgumentException("No input file specified"); 195 196 problem = Problem.Read(args[i]); 197 198 if (parameters.Gamma == 0) 199 parameters.Gamma = 1.0 / problem.MaxIndex; 200 201 if (i < args.Length - 1) 202 modelFilename = args[i + 1]; 203 else { 204 int p = args[i].LastIndexOf('/') + 1; 205 modelFilename = args[i].Substring(p) + ".model"; 206 } 207 } 208 } 227 209 }
Note: See TracChangeset
for help on using the changeset viewer.