Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Persistence Test/LibSVM/Training.cs @ 4120

Last change on this file since 4120 was 2415, checked in by gkronber, 15 years ago

Updated LibSVM project to latest version. #774

File size: 8.1 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.Collections.Generic;
22
23namespace SVM
24{
25    /// <summary>
26    /// Class containing the routines to train SVM models.
27    /// </summary>
28    public static class Training
29    {
30        /// <summary>
31        /// Whether the system will output information to the console during the training process.
32        /// </summary>
33        public static bool IsVerbose
34        {
35            get
36            {
37                return Procedures.IsVerbose;
38            }
39            set
40            {
41                Procedures.IsVerbose = value;
42            }
43        }
44
45        private static double doCrossValidation(Problem problem, Parameter parameters, int nr_fold)
46        {
47            int i;
48            double[] target = new double[problem.Count];
49            Procedures.svm_cross_validation(problem, parameters, nr_fold, target);
50            int total_correct = 0;
51            double total_error = 0;
52            double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
53            if (parameters.SvmType == SvmType.EPSILON_SVR || parameters.SvmType == SvmType.NU_SVR)
54            {
55                for (i = 0; i < problem.Count; i++)
56                {
57                    double y = problem.Y[i];
58                    double v = target[i];
59                    total_error += (v - y) * (v - y);
60                    sumv += v;
61                    sumy += y;
62                    sumvv += v * v;
63                    sumyy += y * y;
64                    sumvy += v * y;
65                }
66                return(problem.Count * sumvy - sumv * sumy) / (Math.Sqrt(problem.Count * sumvv - sumv * sumv) * Math.Sqrt(problem.Count * sumyy - sumy * sumy));
67            }
68            else
69                for (i = 0; i < problem.Count; i++)
70                    if (target[i] == problem.Y[i])
71                        ++total_correct;
72            return (double)total_correct / problem.Count;
73        }
74        /// <summary>
75        /// Legacy.  Allows use as if this was svm_train.  See libsvm documentation for details on which arguments to pass.
76        /// </summary>
77        /// <param name="args"></param>
78        [Obsolete("Provided only for legacy compatibility, use the other Train() methods")]
79        public static void Train(params string[] args)
80        {
81            Parameter parameters;
82            Problem problem;
83            bool crossValidation;
84            int nrfold;
85            string modelFilename;
86            parseCommandLine(args, out parameters, out problem, out crossValidation, out nrfold, out modelFilename);
87            if (crossValidation)
88                PerformCrossValidation(problem, parameters, nrfold);
89            else Model.Write(modelFilename, Train(problem, parameters));
90        }
91
92        /// <summary>
93        /// Performs cross validation.
94        /// </summary>
95        /// <param name="problem">The training data</param>
96        /// <param name="parameters">The parameters to test</param>
97        /// <param name="nrfold">The number of cross validations to use</param>
98        /// <returns>The cross validation score</returns>
99        public static double PerformCrossValidation(Problem problem, Parameter parameters, int nrfold)
100        {
101            string error = Procedures.svm_check_parameter(problem, parameters);
102            if (error == null)
103                return doCrossValidation(problem, parameters, nrfold);
104            else throw new Exception(error);
105        }
106
107        /// <summary>
108        /// Trains a model using the provided training data and parameters.
109        /// </summary>
110        /// <param name="problem">The training data</param>
111        /// <param name="parameters">The parameters to use</param>
112        /// <returns>A trained SVM Model</returns>
113        public static Model Train(Problem problem, Parameter parameters)
114        {
115            string error = Procedures.svm_check_parameter(problem, parameters);
116
117            if (error == null)
118                return Procedures.svm_train(problem, parameters);
119            else throw new Exception(error);
120        }
121
122        private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename)
123        {
124            int i;
125
126            parameters = new Parameter();
127            // default values
128
129            crossValidation = false;
130            nrfold = 0;
131
132            // parse options
133            for (i = 0; i < args.Length; i++)
134            {
135                if (args[i][0] != '-')
136                    break;
137                ++i;
138                switch (args[i - 1][1])
139                {
140
141                    case 's':
142                        parameters.SvmType = (SvmType)int.Parse(args[i]);
143                        break;
144
145                    case 't':
146                        parameters.KernelType = (KernelType)int.Parse(args[i]);
147                        break;
148
149                    case 'd':
150                        parameters.Degree = int.Parse(args[i]);
151                        break;
152
153                    case 'g':
154                        parameters.Gamma = double.Parse(args[i]);
155                        break;
156
157                    case 'r':
158                        parameters.Coefficient0 = double.Parse(args[i]);
159                        break;
160
161                    case 'n':
162                        parameters.Nu = double.Parse(args[i]);
163                        break;
164
165                    case 'm':
166                        parameters.CacheSize = double.Parse(args[i]);
167                        break;
168
169                    case 'c':
170                        parameters.C = double.Parse(args[i]);
171                        break;
172
173                    case 'e':
174                        parameters.EPS = double.Parse(args[i]);
175                        break;
176
177                    case 'p':
178                        parameters.P = double.Parse(args[i]);
179                        break;
180
181                    case 'h':
182                        parameters.Shrinking = int.Parse(args[i]) == 1;
183                        break;
184
185                    case 'b':
186                        parameters.Probability = int.Parse(args[i]) == 1;
187                        break;
188
189                    case 'v':
190                        crossValidation = true;
191                        nrfold = int.Parse(args[i]);
192                        if (nrfold < 2)
193                        {
194                            throw new ArgumentException("n-fold cross validation: n must >= 2");
195                        }
196                        break;
197
198                    case 'w':
199                        parameters.Weights[int.Parse(args[i - 1].Substring(2))] = double.Parse(args[1]);
200                        break;
201
202                    default:
203                        throw new ArgumentException("Unknown Parameter");
204                }
205            }
206
207            // determine filenames
208
209            if (i >= args.Length)
210                throw new ArgumentException("No input file specified");
211
212            problem = Problem.Read(args[i]);
213
214            if (parameters.Gamma == 0)
215                parameters.Gamma = 1.0 / problem.MaxIndex;
216
217            if (i < args.Length - 1)
218                modelFilename = args[i + 1];
219            else
220            {
221                int p = args[i].LastIndexOf('/') + 1;
222                modelFilename = args[i].Substring(p) + ".model";
223            }
224        }
225    }
226}
Note: See TracBrowser for help on using the repository browser.