Changeset 12934 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs
 Timestamp:
 09/02/15 17:08:29 (7 years ago)
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs
r12509 r12934 46 46 private const string GammaParameterName = "Gamma"; 47 47 private const string DegreeParameterName = "Degree"; 48 private const string CreateSolutionParameterName = "CreateSolution"; 48 49 49 50 #region parameter properties … … 65 66 public IValueParameter<IntValue> DegreeParameter { 66 67 get { return (IValueParameter<IntValue>)Parameters[DegreeParameterName]; } 68 } 69 public IFixedValueParameter<BoolValue> CreateSolutionParameter { 70 get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; } 67 71 } 68 72 #endregion … … 87 91 public IntValue Degree { 88 92 get { return DegreeParameter.Value; } 93 } 94 public bool CreateSolution { 95 get { return CreateSolutionParameter.Value.Value; } 96 set { CreateSolutionParameter.Value.Value = value; } 89 97 } 90 98 #endregion … … 112 120 Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0))); 113 121 Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3))); 122 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 123 Parameters[CreateSolutionParameterName].Hidden = true; 114 124 } 115 125 [StorableHook(HookType.AfterDeserialization)] 116 126 private void AfterDeserialization() { 117 127 #region backwards compatibility (change with 3.4) 118 if (!Parameters.ContainsKey(DegreeParameterName)) 119 Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, "The degree parameter for the polynomial kernel function.", new IntValue(3))); 128 if (!Parameters.ContainsKey(DegreeParameterName)) { 129 Parameters.Add(new ValueParameter<IntValue>(DegreeParameterName, 130 "The degree parameter for the polynomial kernel function.", new IntValue(3))); 131 } 132 if (!Parameters.ContainsKey(CreateSolutionParameterName)) { 133 Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, 134 "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 135 Parameters[CreateSolutionParameterName].Hidden = true; 136 } 120 137 #endregion 121 138 } … … 129 146 IClassificationProblemData problemData = Problem.ProblemData; 130 147 IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables; 131 double trainingAccuracy, testAccuracy;132 148 int nSv; 133 var solution = CreateSupportVectorClassificationSolution(problemData, selectedInputVariables, 134 SvmType.Value, KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Degree.Value, 135 out trainingAccuracy, out testAccuracy, out nSv); 136 137 Results.Add(new Result("Support vector classification solution", "The support vector classification solution.", solution)); 138 Results.Add(new Result("Training accuracy", "The accuracy of the SVR solution on the training partition.", new DoubleValue(trainingAccuracy))); 139 Results.Add(new Result("Test accuracy", "The accuracy of the SVR solution on the test partition.", new DoubleValue(testAccuracy))); 140 Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", new IntValue(nSv))); 149 ISupportVectorMachineModel model; 150 151 Run(problemData, selectedInputVariables, GetSvmType(SvmType.Value), GetKernelType(KernelType.Value), Cost.Value, Nu.Value, Gamma.Value, Degree.Value, out model, out nSv); 152 153 if (CreateSolution) { 154 var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone()); 155 Results.Add(new Result("Support vector classification solution", "The support vector classification solution.", 156 solution)); 157 } 158 159 { 160 // calculate classification metrics 161 // calculate regression model metrics 162 var ds = problemData.Dataset; 163 var trainRows = problemData.TrainingIndices; 164 var testRows = problemData.TestIndices; 165 var yTrain = ds.GetDoubleValues(problemData.TargetVariable, trainRows); 166 var yTest = ds.GetDoubleValues(problemData.TargetVariable, testRows); 167 var yPredTrain = model.GetEstimatedClassValues(ds, trainRows); 168 var yPredTest = model.GetEstimatedClassValues(ds, testRows); 169 170 OnlineCalculatorError error; 171 var trainAccuracy = OnlineAccuracyCalculator.Calculate(yPredTrain, yTrain, out error); 172 if (error != OnlineCalculatorError.None) trainAccuracy = double.MaxValue; 173 var testAccuracy = OnlineAccuracyCalculator.Calculate(yPredTest, yTest, out error); 174 if (error != OnlineCalculatorError.None) testAccuracy = double.MaxValue; 175 176 Results.Add(new Result("Accuracy (training)", "The mean of squared errors of the SVR solution on the training partition.", new DoubleValue(trainAccuracy))); 177 Results.Add(new Result("Accuracy (test)", "The mean of squared errors of the SVR solution on the test partition.", new DoubleValue(testAccuracy))); 178 179 Results.Add(new Result("Number of support vectors", "The number of support vectors of the SVR solution.", 180 new IntValue(nSv))); 181 } 141 182 } 142 183 … … 147 188 } 148 189 190 // BackwardsCompatibility3.4 191 #region Backwards compatible code, remove with 3.5 149 192 public static SupportVectorClassificationSolution CreateSupportVectorClassificationSolution(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables, 150 193 int svmType, int kernelType, double cost, double nu, double gamma, int degree, out double trainingAccuracy, out double testAccuracy, out int nSv) { 194 195 ISupportVectorMachineModel model; 196 Run(problemData, allowedInputVariables, svmType, kernelType, cost, nu, gamma, degree, out model, out nSv); 197 var solution = new SupportVectorClassificationSolution((SupportVectorMachineModel)model, (IClassificationProblemData)problemData.Clone()); 198 199 trainingAccuracy = solution.TrainingAccuracy; 200 testAccuracy = solution.TestAccuracy; 201 202 return solution; 203 } 204 205 #endregion 206 207 public static void Run(IClassificationProblemData problemData, IEnumerable<string> allowedInputVariables, 208 int svmType, int kernelType, double cost, double nu, double gamma, int degree, 209 out ISupportVectorMachineModel model, out int nSv) { 151 210 var dataset = problemData.Dataset; 152 211 string targetVariable = problemData.TargetVariable; … … 154 213 155 214 //extract SVM parameters from scope and set them 156 svm_parameter parameter = new svm_parameter(); 157 parameter.svm_type = svmType; 158 parameter.kernel_type = kernelType; 159 parameter.C = cost; 160 parameter.nu = nu; 161 parameter.gamma = gamma; 162 parameter.cache_size = 500; 163 parameter.probability = 0; 164 parameter.eps = 0.001; 165 parameter.degree = degree; 166 parameter.shrinking = 1; 167 parameter.coef0 = 0; 215 svm_parameter parameter = new svm_parameter { 216 svm_type = svmType, 217 kernel_type = kernelType, 218 C = cost, 219 nu = nu, 220 gamma = gamma, 221 cache_size = 500, 222 probability = 0, 223 eps = 0.001, 224 degree = degree, 225 shrinking = 1, 226 coef0 = 0 227 }; 168 228 169 229 var weightLabels = new List<int>(); … … 182 242 parameter.weight = weights.ToArray(); 183 243 184 185 244 svm_problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows); 186 245 RangeTransform rangeTransform = RangeTransform.Compute(problem); 187 246 svm_problem scaledProblem = rangeTransform.Scale(problem); 188 247 var svmModel = svm.svm_train(scaledProblem, parameter); 189 var model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues);190 var solution = new SupportVectorClassificationSolution(model, (IClassificationProblemData)problemData.Clone());191 192 248 nSv = svmModel.SV.Length; 193 trainingAccuracy = solution.TrainingAccuracy; 194 testAccuracy = solution.TestAccuracy; 195 196 return solution; 249 250 model = new SupportVectorMachineModel(svmModel, rangeTransform, targetVariable, allowedInputVariables, problemData.ClassValues); 197 251 } 198 252
Note: See TracChangeset
for help on using the changeset viewer.