Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/LibSVM/Training.cs @ 1824

Last change on this file since 1824 was 1819, checked in by mkommend, 16 years ago

created new project for LibSVM source files (ticket #619)

File size: 8.8 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.Collections.Generic;
22
23namespace SVM
24{
25    /// <remarks>
26    /// Class containing the routines to train SVM models.
27    /// </remarks>
28    public static class Training
29    {
30        private static double doCrossValidation(Problem problem, Parameter parameters, int nr_fold)
31        {
32            int i;
33            double[] target = new double[problem.Count];
34            Dictionary<int, double>[] confidence = new Dictionary<int, double>[problem.Count];
35            Procedures.svm_cross_validation(problem, parameters, nr_fold, target, confidence);
36            if (parameters.Probability)
37            {
38                List<RankPair> ranks = new List<RankPair>();
39                for (i = 0; i < target.Length; i++)
40                    ranks.Add(new RankPair(confidence[i][1], problem.Y[i]));
41                PerformanceEvaluator eval = new PerformanceEvaluator(ranks);
42                return eval.AuC*eval.AP;
43            }
44            else
45            {
46                int total_correct = 0;
47                double total_error = 0;
48                double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
49                if (parameters.SvmType == SvmType.EPSILON_SVR || parameters.SvmType == SvmType.NU_SVR)
50                {
51                    for (i = 0; i < problem.Count; i++)
52                    {
53                        double y = problem.Y[i];
54                        double v = target[i];
55                        total_error += (v - y) * (v - y);
56                        sumv += v;
57                        sumy += y;
58                        sumvv += v * v;
59                        sumyy += y * y;
60                        sumvy += v * y;
61                    }
62                }
63                else
64                    for (i = 0; i < problem.Count; i++)
65                        if (target[i] == problem.Y[i])
66                            ++total_correct;
67                return (double)total_correct / problem.Count;
68            }
69
70        }
71        /// <summary>
72        /// Legacy.  Allows use as if this was svm_train.  See libsvm documentation for details on which arguments to pass.
73        /// </summary>
74        /// <param name="args"></param>
75        [Obsolete("Provided only for legacy compatibility, use the other Train() methods")]
76        public static void Train(params string[] args)
77        {
78            Parameter parameters;
79            Problem problem;
80            bool crossValidation;
81            int nrfold;
82            string modelFilename;
83            parseCommandLine(args, out parameters, out problem, out crossValidation, out nrfold, out modelFilename);
84            if (crossValidation)
85                PerformCrossValidation(problem, parameters, nrfold);
86            else Model.Write(modelFilename, Train(problem, parameters));
87        }
88
89        /// <summary>
90        /// Performs cross validation.
91        /// </summary>
92        /// <param name="problem">The training data</param>
93        /// <param name="parameters">The parameters to test</param>
94        /// <param name="nrfold">The number of cross validations to use</param>
95        /// <returns>The cross validation score</returns>
96        public static double PerformCrossValidation(Problem problem, Parameter parameters, int nrfold)
97        {
98            Procedures.svm_check_parameter(problem, parameters);
99            return doCrossValidation(problem, parameters, nrfold);
100        }
101
102        /// <summary>
103        /// Trains a model using the provided training data and parameters.
104        /// </summary>
105        /// <param name="problem">The training data</param>
106        /// <param name="parameters">The parameters to use</param>
107        /// <returns>A trained SVM Model</returns>
108        public static Model Train(Problem problem, Parameter parameters)
109        {
110            Procedures.svm_check_parameter(problem, parameters);
111
112            return Procedures.svm_train(problem, parameters);
113        }
114
115        private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename)
116        {
117            int i;
118
119            parameters = new Parameter();
120            // default values
121
122            crossValidation = false;
123            nrfold = 0;
124
125            // parse options
126            for (i = 0; i < args.Length; i++)
127            {
128                if (args[i][0] != '-')
129                    break;
130                ++i;
131                switch (args[i - 1][1])
132                {
133
134                    case 's':
135                        parameters.SvmType = (SvmType)int.Parse(args[i]);
136                        break;
137
138                    case 't':
139                        parameters.KernelType = (KernelType)int.Parse(args[i]);
140                        break;
141
142                    case 'd':
143                        parameters.Degree = int.Parse(args[i]);
144                        break;
145
146                    case 'g':
147                        parameters.Gamma = double.Parse(args[i]);
148                        break;
149
150                    case 'r':
151                        parameters.Coefficient0 = double.Parse(args[i]);
152                        break;
153
154                    case 'n':
155                        parameters.Nu = double.Parse(args[i]);
156                        break;
157
158                    case 'm':
159                        parameters.CacheSize = double.Parse(args[i]);
160                        break;
161
162                    case 'c':
163                        parameters.C = double.Parse(args[i]);
164                        break;
165
166                    case 'e':
167                        parameters.EPS = double.Parse(args[i]);
168                        break;
169
170                    case 'p':
171                        parameters.P = double.Parse(args[i]);
172                        break;
173
174                    case 'h':
175                        parameters.Shrinking = int.Parse(args[i]) == 1;
176                        break;
177
178                    case 'b':
179                        parameters.Probability = int.Parse(args[i]) == 1;
180                        break;
181
182                    case 'v':
183                        crossValidation = true;
184                        nrfold = int.Parse(args[i]);
185                        if (nrfold < 2)
186                        {
187                            throw new ArgumentException("n-fold cross validation: n must >= 2");
188                        }
189                        break;
190
191                    case 'w':
192                        ++parameters.WeightCount;
193                        {
194                            int[] old = parameters.WeightLabels;
195                            parameters.WeightLabels = new int[parameters.WeightCount];
196                            Array.Copy(old, 0, parameters.WeightLabels, 0, parameters.WeightCount - 1);
197                        }
198                        {
199                            double[] old = parameters.Weights;
200                            parameters.Weights = new double[parameters.WeightCount];
201                            Array.Copy(old, 0, parameters.Weights, 0, parameters.WeightCount - 1);
202                        }
203
204                        parameters.WeightLabels[parameters.WeightCount - 1] = int.Parse(args[i - 1].Substring(2));
205                        parameters.Weights[parameters.WeightCount - 1] = double.Parse(args[i]);
206                        break;
207
208                    default:
209                        throw new ArgumentException("Unknown Parameter");
210                }
211            }
212
213            // determine filenames
214
215            if (i >= args.Length)
216                throw new ArgumentException("No input file specified");
217
218            problem = Problem.Read(args[i]);
219
220            if (parameters.Gamma == 0)
221                parameters.Gamma = 1.0 / problem.MaxIndex;
222
223            if (i < args.Length - 1)
224                modelFilename = args[i + 1];
225            else
226            {
227                int p = args[i].LastIndexOf('/') + 1;
228                modelFilename = args[i].Substring(p) + ".model";
229            }
230        }
231    }
232}
Note: See TracBrowser for help on using the repository browser.