Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/Training.cs @ 4893

Last change on this file since 4893 was 4068, checked in by swagner, 15 years ago

Sorted usings and removed unused usings in entire solution (#1094)

File size: 6.9 KB
RevLine 
[2645]1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21
[4068]22namespace SVM {
23  /// <summary>
24  /// Class containing the routines to train SVM models.
25  /// </summary>
26  public static class Training {
[2645]27    /// <summary>
[4068]28    /// Whether the system will output information to the console during the training process.
[2645]29    /// </summary>
[4068]30    public static bool IsVerbose {
31      get {
32        return Procedures.IsVerbose;
33      }
34      set {
35        Procedures.IsVerbose = value;
36      }
37    }
[2645]38
[4068]39    private static double doCrossValidation(Problem problem, Parameter parameters, int nr_fold, bool shuffleTraining) {
40      int i;
41      double[] target = new double[problem.Count];
42      Procedures.svm_cross_validation(problem, parameters, nr_fold, target, shuffleTraining);
43      int total_correct = 0;
44      double total_error = 0;
45      //double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
46      if (parameters.SvmType == SvmType.EPSILON_SVR || parameters.SvmType == SvmType.NU_SVR) {
47        for (i = 0; i < problem.Count; i++) {
48          double y = problem.Y[i];
49          double v = target[i];
50          total_error += (v - y) * (v - y);
51          //sumv += v;
52          //sumy += y;
53          //sumvv += v * v;
54          //sumyy += y * y;
55          //sumvy += v * y;
[2645]56        }
[4068]57        return total_error / problem.Count; // return MSE
58        // (problem.Count * sumvy - sumv * sumy) / (Math.Sqrt(problem.Count * sumvv - sumv * sumv) * Math.Sqrt(problem.Count * sumyy - sumy * sumy));
59      } else
60        for (i = 0; i < problem.Count; i++)
61          if (target[i] == problem.Y[i])
62            ++total_correct;
63      return (double)total_correct / problem.Count;
64    }
65    /// <summary>
66    /// Legacy.  Allows use as if this was svm_train.  See libsvm documentation for details on which arguments to pass.
67    /// </summary>
68    /// <param name="args"></param>
69    [Obsolete("Provided only for legacy compatibility, use the other Train() methods")]
70    public static void Train(params string[] args) {
71      Parameter parameters;
72      Problem problem;
73      bool crossValidation;
74      int nrfold;
75      string modelFilename;
76      parseCommandLine(args, out parameters, out problem, out crossValidation, out nrfold, out modelFilename);
77      if (crossValidation)
78        PerformCrossValidation(problem, parameters, nrfold, true);
79      else Model.Write(modelFilename, Train(problem, parameters));
80    }
[2645]81
[4068]82    /// <summary>
83    /// Performs cross validation.
84    /// </summary>
85    /// <param name="problem">The training data</param>
86    /// <param name="parameters">The parameters to test</param>
87    /// <param name="nrfold">The number of cross validations to use</param>
88    /// <returns>The cross validation score</returns>
89    public static double PerformCrossValidation(Problem problem, Parameter parameters, int nrfold, bool shuffleTraining) {
90      string error = Procedures.svm_check_parameter(problem, parameters);
91      if (error == null)
92        return doCrossValidation(problem, parameters, nrfold, shuffleTraining);
93      else throw new Exception(error);
94    }
[2645]95
[4068]96    /// <summary>
97    /// Trains a model using the provided training data and parameters.
98    /// </summary>
99    /// <param name="problem">The training data</param>
100    /// <param name="parameters">The parameters to use</param>
101    /// <returns>A trained SVM Model</returns>
102    public static Model Train(Problem problem, Parameter parameters) {
103      string error = Procedures.svm_check_parameter(problem, parameters);
[2645]104
[4068]105      if (error == null)
106        return Procedures.svm_train(problem, parameters);
107      else throw new Exception(error);
108    }
[2645]109
[4068]110    private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename) {
111      int i;
[2645]112
[4068]113      parameters = new Parameter();
114      // default values
[2645]115
[4068]116      crossValidation = false;
117      nrfold = 0;
[2645]118
[4068]119      // parse options
120      for (i = 0; i < args.Length; i++) {
121        if (args[i][0] != '-')
122          break;
123        ++i;
124        switch (args[i - 1][1]) {
[2645]125
[4068]126          case 's':
127            parameters.SvmType = (SvmType)int.Parse(args[i]);
128            break;
[2645]129
[4068]130          case 't':
131            parameters.KernelType = (KernelType)int.Parse(args[i]);
132            break;
[2645]133
[4068]134          case 'd':
135            parameters.Degree = int.Parse(args[i]);
136            break;
[2645]137
[4068]138          case 'g':
139            parameters.Gamma = double.Parse(args[i]);
140            break;
[2645]141
[4068]142          case 'r':
143            parameters.Coefficient0 = double.Parse(args[i]);
144            break;
[2645]145
[4068]146          case 'n':
147            parameters.Nu = double.Parse(args[i]);
148            break;
[2645]149
[4068]150          case 'm':
151            parameters.CacheSize = double.Parse(args[i]);
152            break;
[2645]153
[4068]154          case 'c':
155            parameters.C = double.Parse(args[i]);
156            break;
[2645]157
[4068]158          case 'e':
159            parameters.EPS = double.Parse(args[i]);
160            break;
[2645]161
[4068]162          case 'p':
163            parameters.P = double.Parse(args[i]);
164            break;
[2645]165
[4068]166          case 'h':
167            parameters.Shrinking = int.Parse(args[i]) == 1;
168            break;
[2645]169
[4068]170          case 'b':
171            parameters.Probability = int.Parse(args[i]) == 1;
172            break;
[2645]173
[4068]174          case 'v':
175            crossValidation = true;
176            nrfold = int.Parse(args[i]);
177            if (nrfold < 2) {
178              throw new ArgumentException("n-fold cross validation: n must >= 2");
179            }
180            break;
[2645]181
[4068]182          case 'w':
183            parameters.Weights[int.Parse(args[i - 1].Substring(2))] = double.Parse(args[1]);
184            break;
[2645]185
[4068]186          default:
187            throw new ArgumentException("Unknown Parameter");
188        }
189      }
[2645]190
[4068]191      // determine filenames
[2645]192
[4068]193      if (i >= args.Length)
194        throw new ArgumentException("No input file specified");
[2645]195
[4068]196      problem = Problem.Read(args[i]);
[2645]197
[4068]198      if (parameters.Gamma == 0)
199        parameters.Gamma = 1.0 / problem.MaxIndex;
[2645]200
[4068]201      if (i < args.Length - 1)
202        modelFilename = args[i + 1];
203      else {
204        int p = args[i].LastIndexOf('/') + 1;
205        modelFilename = args[i].Substring(p) + ".model";
206      }
[2645]207    }
[4068]208  }
[2645]209}
Note: See TracBrowser for help on using the repository browser.