Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/Training.cs @ 3884

Last change on this file since 3884 was 3884, checked in by gkronber, 14 years ago

Worked on support vector regression operators and views. #1009

File size: 8.3 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.Collections.Generic;
22
23namespace SVM
24{
25    /// <summary>
26    /// Class containing the routines to train SVM models.
27    /// </summary>
28    public static class Training
29    {
30        /// <summary>
31        /// Whether the system will output information to the console during the training process.
32        /// </summary>
33        public static bool IsVerbose
34        {
35            get
36            {
37                return Procedures.IsVerbose;
38            }
39            set
40            {
41                Procedures.IsVerbose = value;
42            }
43        }
44
45        private static double doCrossValidation(Problem problem, Parameter parameters, int nr_fold, bool shuffleTraining)
46        {
47            int i;
48            double[] target = new double[problem.Count];
49            Procedures.svm_cross_validation(problem, parameters, nr_fold, target, shuffleTraining);
50            int total_correct = 0;
51            double total_error = 0;
52            //double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
53            if (parameters.SvmType == SvmType.EPSILON_SVR || parameters.SvmType == SvmType.NU_SVR)
54            {
55                for (i = 0; i < problem.Count; i++)
56                {
57                    double y = problem.Y[i];
58                    double v = target[i];
59                    total_error += (v - y) * (v - y);
60                    //sumv += v;
61                    //sumy += y;
62                    //sumvv += v * v;
63                    //sumyy += y * y;
64                    //sumvy += v * y;
65                }
66                return total_error / problem.Count; // return MSE
67              // (problem.Count * sumvy - sumv * sumy) / (Math.Sqrt(problem.Count * sumvv - sumv * sumv) * Math.Sqrt(problem.Count * sumyy - sumy * sumy));
68            }
69            else
70                for (i = 0; i < problem.Count; i++)
71                    if (target[i] == problem.Y[i])
72                        ++total_correct;
73            return (double)total_correct / problem.Count;
74        }
75        /// <summary>
76        /// Legacy.  Allows use as if this was svm_train.  See libsvm documentation for details on which arguments to pass.
77        /// </summary>
78        /// <param name="args"></param>
79        [Obsolete("Provided only for legacy compatibility, use the other Train() methods")]
80        public static void Train(params string[] args)
81        {
82            Parameter parameters;
83            Problem problem;
84            bool crossValidation;
85            int nrfold;
86            string modelFilename;
87            parseCommandLine(args, out parameters, out problem, out crossValidation, out nrfold, out modelFilename);
88            if (crossValidation)
89                PerformCrossValidation(problem, parameters, nrfold, true);
90            else Model.Write(modelFilename, Train(problem, parameters));
91        }
92
93        /// <summary>
94        /// Performs cross validation.
95        /// </summary>
96        /// <param name="problem">The training data</param>
97        /// <param name="parameters">The parameters to test</param>
98        /// <param name="nrfold">The number of cross validations to use</param>
99        /// <returns>The cross validation score</returns>
100        public static double PerformCrossValidation(Problem problem, Parameter parameters, int nrfold, bool shuffleTraining)
101        {
102            string error = Procedures.svm_check_parameter(problem, parameters);
103            if (error == null)
104                return doCrossValidation(problem, parameters, nrfold, shuffleTraining);
105            else throw new Exception(error);
106        }
107
108        /// <summary>
109        /// Trains a model using the provided training data and parameters.
110        /// </summary>
111        /// <param name="problem">The training data</param>
112        /// <param name="parameters">The parameters to use</param>
113        /// <returns>A trained SVM Model</returns>
114        public static Model Train(Problem problem, Parameter parameters)
115        {
116            string error = Procedures.svm_check_parameter(problem, parameters);
117
118            if (error == null)
119                return Procedures.svm_train(problem, parameters);
120            else throw new Exception(error);
121        }
122
123        private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename)
124        {
125            int i;
126
127            parameters = new Parameter();
128            // default values
129
130            crossValidation = false;
131            nrfold = 0;
132
133            // parse options
134            for (i = 0; i < args.Length; i++)
135            {
136                if (args[i][0] != '-')
137                    break;
138                ++i;
139                switch (args[i - 1][1])
140                {
141
142                    case 's':
143                        parameters.SvmType = (SvmType)int.Parse(args[i]);
144                        break;
145
146                    case 't':
147                        parameters.KernelType = (KernelType)int.Parse(args[i]);
148                        break;
149
150                    case 'd':
151                        parameters.Degree = int.Parse(args[i]);
152                        break;
153
154                    case 'g':
155                        parameters.Gamma = double.Parse(args[i]);
156                        break;
157
158                    case 'r':
159                        parameters.Coefficient0 = double.Parse(args[i]);
160                        break;
161
162                    case 'n':
163                        parameters.Nu = double.Parse(args[i]);
164                        break;
165
166                    case 'm':
167                        parameters.CacheSize = double.Parse(args[i]);
168                        break;
169
170                    case 'c':
171                        parameters.C = double.Parse(args[i]);
172                        break;
173
174                    case 'e':
175                        parameters.EPS = double.Parse(args[i]);
176                        break;
177
178                    case 'p':
179                        parameters.P = double.Parse(args[i]);
180                        break;
181
182                    case 'h':
183                        parameters.Shrinking = int.Parse(args[i]) == 1;
184                        break;
185
186                    case 'b':
187                        parameters.Probability = int.Parse(args[i]) == 1;
188                        break;
189
190                    case 'v':
191                        crossValidation = true;
192                        nrfold = int.Parse(args[i]);
193                        if (nrfold < 2)
194                        {
195                            throw new ArgumentException("n-fold cross validation: n must >= 2");
196                        }
197                        break;
198
199                    case 'w':
200                        parameters.Weights[int.Parse(args[i - 1].Substring(2))] = double.Parse(args[1]);
201                        break;
202
203                    default:
204                        throw new ArgumentException("Unknown Parameter");
205                }
206            }
207
208            // determine filenames
209
210            if (i >= args.Length)
211                throw new ArgumentException("No input file specified");
212
213            problem = Problem.Read(args[i]);
214
215            if (parameters.Gamma == 0)
216                parameters.Gamma = 1.0 / problem.MaxIndex;
217
218            if (i < args.Length - 1)
219                modelFilename = args[i + 1];
220            else
221            {
222                int p = args[i].LastIndexOf('/') + 1;
223                modelFilename = args[i].Substring(p) + ".model";
224            }
225        }
226    }
227}
Note: See TracBrowser for help on using the repository browser.