/* * SVM.NET Library * Copyright (C) 2008 Matthew Johnson * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ using System; namespace SVM { /// /// Class containing the routines to train SVM models. /// public static class Training { /// /// Whether the system will output information to the console during the training process. /// public static bool IsVerbose { get { return Procedures.IsVerbose; } set { Procedures.IsVerbose = value; } } private static double doCrossValidation(Problem problem, Parameter parameters, int nr_fold, bool shuffleTraining) { int i; double[] target = new double[problem.Count]; Procedures.svm_cross_validation(problem, parameters, nr_fold, target, shuffleTraining); int total_correct = 0; double total_error = 0; //double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; if (parameters.SvmType == SvmType.EPSILON_SVR || parameters.SvmType == SvmType.NU_SVR) { for (i = 0; i < problem.Count; i++) { double y = problem.Y[i]; double v = target[i]; total_error += (v - y) * (v - y); //sumv += v; //sumy += y; //sumvv += v * v; //sumyy += y * y; //sumvy += v * y; } return total_error / problem.Count; // return MSE // (problem.Count * sumvy - sumv * sumy) / (Math.Sqrt(problem.Count * sumvv - sumv * sumv) * Math.Sqrt(problem.Count * sumyy - sumy * sumy)); } else for (i = 0; i < problem.Count; i++) if (target[i] == problem.Y[i]) ++total_correct; return (double)total_correct / problem.Count; } /// /// Legacy. Allows use as if this was svm_train. See libsvm documentation for details on which arguments to pass. /// /// [Obsolete("Provided only for legacy compatibility, use the other Train() methods")] public static void Train(params string[] args) { Parameter parameters; Problem problem; bool crossValidation; int nrfold; string modelFilename; parseCommandLine(args, out parameters, out problem, out crossValidation, out nrfold, out modelFilename); if (crossValidation) PerformCrossValidation(problem, parameters, nrfold, true); else Model.Write(modelFilename, Train(problem, parameters)); } /// /// Performs cross validation. /// /// The training data /// The parameters to test /// The number of cross validations to use /// The cross validation score public static double PerformCrossValidation(Problem problem, Parameter parameters, int nrfold, bool shuffleTraining) { string error = Procedures.svm_check_parameter(problem, parameters); if (error == null) return doCrossValidation(problem, parameters, nrfold, shuffleTraining); else throw new Exception(error); } /// /// Trains a model using the provided training data and parameters. /// /// The training data /// The parameters to use /// A trained SVM Model public static Model Train(Problem problem, Parameter parameters) { string error = Procedures.svm_check_parameter(problem, parameters); if (error == null) return Procedures.svm_train(problem, parameters); else throw new Exception(error); } private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename) { int i; parameters = new Parameter(); // default values crossValidation = false; nrfold = 0; // parse options for (i = 0; i < args.Length; i++) { if (args[i][0] != '-') break; ++i; switch (args[i - 1][1]) { case 's': parameters.SvmType = (SvmType)int.Parse(args[i]); break; case 't': parameters.KernelType = (KernelType)int.Parse(args[i]); break; case 'd': parameters.Degree = int.Parse(args[i]); break; case 'g': parameters.Gamma = double.Parse(args[i]); break; case 'r': parameters.Coefficient0 = double.Parse(args[i]); break; case 'n': parameters.Nu = double.Parse(args[i]); break; case 'm': parameters.CacheSize = double.Parse(args[i]); break; case 'c': parameters.C = double.Parse(args[i]); break; case 'e': parameters.EPS = double.Parse(args[i]); break; case 'p': parameters.P = double.Parse(args[i]); break; case 'h': parameters.Shrinking = int.Parse(args[i]) == 1; break; case 'b': parameters.Probability = int.Parse(args[i]) == 1; break; case 'v': crossValidation = true; nrfold = int.Parse(args[i]); if (nrfold < 2) { throw new ArgumentException("n-fold cross validation: n must >= 2"); } break; case 'w': parameters.Weights[int.Parse(args[i - 1].Substring(2))] = double.Parse(args[1]); break; default: throw new ArgumentException("Unknown Parameter"); } } // determine filenames if (i >= args.Length) throw new ArgumentException("No input file specified"); problem = Problem.Read(args[i]); if (parameters.Gamma == 0) parameters.Gamma = 1.0 / problem.MaxIndex; if (i < args.Length - 1) modelFilename = args[i + 1]; else { int p = args[i].LastIndexOf('/') + 1; modelFilename = args[i].Substring(p) + ".model"; } } } }