Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Operator Architecture Refactoring/LibSVM/Model.cs @ 2215

Last change on this file since 2215 was 1819, checked in by mkommend, 16 years ago

created new project for LibSVM source files (ticket #619)

File size: 15.0 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20
21using System;
22using System.IO;
23
24namespace SVM
25{
26    /// <remarks>
27    /// Encapsulates an SVM Model.
28    /// </remarks>
29  [Serializable]
30  public class Model
31  {
32        private Parameter _parameter;
33        private int _numberOfClasses;
34        private int _supportVectorCount;
35        private Node[][] _supportVectors;
36        private double[][] _supportVectorCoefficients;
37        private double[] _rho;
38        private double[] _pairwiseProbabilityA;
39        private double[] _pairwiseProbabilityB;
40
41        private int[] _classLabels;
42        private int[] _numberOfSVPerClass;
43
44        internal Model()
45        {
46        }
47
48        /// <summary>
49        /// Parameter object.
50        /// </summary>
51        public Parameter Parameter
52        {
53            get
54            {
55                return _parameter;
56            }
57            set
58            {
59                _parameter = value;
60            }
61        }
62
63        /// <summary>
64        /// Number of classes in the model.
65        /// </summary>
66        public int NumberOfClasses
67        {
68            get
69            {
70                return _numberOfClasses;
71            }
72            set
73            {
74                _numberOfClasses = value;
75            }
76        }
77
78        /// <summary>
79        /// Total number of support vectors.
80        /// </summary>
81        public int SupportVectorCount
82        {
83            get
84            {
85                return _supportVectorCount;
86            }
87            set
88            {
89                _supportVectorCount = value;
90            }
91        }
92
93        /// <summary>
94        /// The support vectors.
95        /// </summary>
96        public Node[][] SupportVectors
97        {
98            get
99            {
100                return _supportVectors;
101            }
102            set
103            {
104                _supportVectors = value;
105            }
106        }
107
108        /// <summary>
109        /// The coefficients for the support vectors.
110        /// </summary>
111        public double[][] SupportVectorCoefficients
112        {
113            get
114            {
115                return _supportVectorCoefficients;
116            }
117            set
118            {
119                _supportVectorCoefficients = value;
120            }
121        }
122
123        /// <summary>
124        /// Rho values.
125        /// </summary>
126        public double[] Rho
127        {
128            get
129            {
130                return _rho;
131            }
132            set
133            {
134                _rho = value;
135            }
136        }
137
138        /// <summary>
139        /// First pairwise probability.
140        /// </summary>
141        public double[] PairwiseProbabilityA
142        {
143            get
144            {
145                return _pairwiseProbabilityA;
146            }
147            set
148            {
149                _pairwiseProbabilityA = value;
150            }
151        }
152
153        /// <summary>
154        /// Second pairwise probability.
155        /// </summary>
156        public double[] PairwiseProbabilityB
157        {
158            get
159            {
160                return _pairwiseProbabilityB;
161            }
162            set
163            {
164                _pairwiseProbabilityB = value;
165            }
166        }
167   
168    // for classification only
169
170        /// <summary>
171        /// Class labels.
172        /// </summary>
173        public int[] ClassLabels
174        {
175            get
176            {
177                return _classLabels;
178            }
179            set
180            {
181                _classLabels = value;
182            }
183        }
184
185        /// <summary>
186        /// Number of support vectors per class.
187        /// </summary>
188        public int[] NumberOfSVPerClass
189        {
190            get
191            {
192                return _numberOfSVPerClass;
193            }
194            set
195            {
196                _numberOfSVPerClass = value;
197            }
198        }
199
200        /// <summary>
201        /// Reads a Model from the provided file.
202        /// </summary>
203        /// <param name="filename">The name of the file containing the Model</param>
204        /// <returns>the Model</returns>
205        public static Model Read(string filename)
206        {
207            FileStream input = File.OpenRead(filename);
208            try
209            {
210                return Read(input);
211            }
212            finally
213            {
214                input.Close();
215            }
216        }
217
218        /// <summary>
219        /// Reads a Model from the provided stream.
220        /// </summary>
221        /// <param name="stream">The stream from which to read the Model.</param>
222        /// <returns>the Model</returns>
223        public static Model Read(Stream stream)
224        {
225            StreamReader input = new StreamReader(stream);
226
227            // read parameters
228
229            Model model = new Model();
230            Parameter param = new Parameter();
231            model.Parameter = param;
232            model.Rho = null;
233            model.PairwiseProbabilityA = null;
234            model.PairwiseProbabilityB = null;
235            model.ClassLabels = null;
236            model.NumberOfSVPerClass = null;
237
238            bool headerFinished = false;
239            while (!headerFinished)
240            {
241                string line = input.ReadLine();
242                string cmd, arg;
243                int splitIndex = line.IndexOf(' ');
244                if (splitIndex >= 0)
245                {
246                    cmd = line.Substring(0, splitIndex);
247                    arg = line.Substring(splitIndex + 1);
248                }
249                else
250                {
251                    cmd = line;
252                    arg = "";
253                }
254                arg = arg.ToLower();
255
256                int i,n;
257                switch(cmd){
258                    case "svm_type":
259                        param.SvmType = (SvmType)Enum.Parse(typeof(SvmType), arg.ToUpper());
260                        break;
261                       
262                    case "kernel_type":
263                        param.KernelType = (KernelType)Enum.Parse(typeof(KernelType), arg.ToUpper());
264                        break;
265
266                    case "degree":
267                        param.Degree = int.Parse(arg);
268                        break;
269
270                    case "gamma":
271                        param.Gamma = double.Parse(arg);
272                        break;
273
274                    case "coef0":
275                        param.Coefficient0 = double.Parse(arg);
276                        break;
277
278                    case "nr_class":
279                        model.NumberOfClasses = int.Parse(arg);
280                        break;
281
282                    case "total_sv":
283                        model.SupportVectorCount = int.Parse(arg);
284                        break;
285
286                    case "rho":
287                        n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
288                        model.Rho = new double[n];
289                        string[] rhoParts = arg.Split();
290                        for(i=0; i<n; i++)
291                            model.Rho[i] = double.Parse(rhoParts[i]);
292                        break;
293
294                    case "label":
295                        n = model.NumberOfClasses;
296                        model.ClassLabels = new int[n];
297                        string[] labelParts = arg.Split();
298                        for (i = 0; i < n; i++)
299                            model.ClassLabels[i] = int.Parse(labelParts[i]);
300                        break;
301
302                    case "probA":
303                        n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
304                        model.PairwiseProbabilityA = new double[n];
305                            string[] probAParts = arg.Split();
306                        for (i = 0; i < n; i++)
307                            model.PairwiseProbabilityA[i] = double.Parse(probAParts[i]);
308                        break;
309
310                    case "probB":
311                        n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
312                        model.PairwiseProbabilityB = new double[n];
313                        string[] probBParts = arg.Split();
314                        for (i = 0; i < n; i++)
315                            model.PairwiseProbabilityB[i] = double.Parse(probBParts[i]);
316                        break;
317
318                    case "nr_sv":
319                        n = model.NumberOfClasses;
320                        model.NumberOfSVPerClass = new int[n];
321                        string[] nrsvParts = arg.Split();
322                        for (i = 0; i < n; i++)
323                            model.NumberOfSVPerClass[i] = int.Parse(nrsvParts[i]);
324                        break;
325
326                    case "SV":
327                        headerFinished = true;
328                        break;
329
330                    default:
331                        throw new Exception("Unknown text in model file"); 
332                }
333            }
334
335            // read sv_coef and SV
336
337            int m = model.NumberOfClasses - 1;
338            int l = model.SupportVectorCount;
339            model.SupportVectorCoefficients = new double[m][];
340            for (int i = 0; i < m; i++)
341            {
342                model.SupportVectorCoefficients[i] = new double[l];
343            }
344            model.SupportVectors = new Node[l][];
345
346            for (int i = 0; i < l; i++)
347            {
348                string[] parts = input.ReadLine().Trim().Split();
349
350                for (int k = 0; k < m; k++)
351                    model.SupportVectorCoefficients[k][i] = double.Parse(parts[k]);
352                int n = parts.Length-m;
353                model.SupportVectors[i] = new Node[n];
354                for (int j = 0; j < n; j++)
355                {
356                    string[] nodeParts = parts[m + j].Split(':');
357                    model.SupportVectors[i][j] = new Node();
358                    model.SupportVectors[i][j].Index = int.Parse(nodeParts[0]);
359                    model.SupportVectors[i][j].Value = double.Parse(nodeParts[1]);
360                }
361            }
362
363            return model;
364        }
365
366        /// <summary>
367        /// Writes a model to the provided filename.  This will overwrite any previous data in the file.
368        /// </summary>
369        /// <param name="filename">The desired file</param>
370        /// <param name="model">The Model to write</param>
371        public static void Write(string filename, Model model)
372        {
373            FileStream stream = File.Open(filename, FileMode.Create);
374            try
375            {
376                Write(stream, model);
377            }
378            finally
379            {
380                stream.Close();
381            }
382        }
383
384        /// <summary>
385        /// Writes a model to the provided stream.
386        /// </summary>
387        /// <param name="stream">The output stream</param>
388        /// <param name="model">The model to write</param>
389        public static void Write(Stream stream, Model model)
390        {
391            StreamWriter output = new StreamWriter(stream);
392
393            Parameter param = model.Parameter;
394
395            output.Write("svm_type " + param.SvmType + "\n");
396            output.Write("kernel_type " + param.KernelType + "\n");
397
398            if (param.KernelType == KernelType.POLY)
399                output.Write("degree " + param.Degree + "\n");
400
401            if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.RBF || param.KernelType == KernelType.SIGMOID)
402                output.Write("gamma " + param.Gamma + "\n");
403
404            if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.SIGMOID)
405                output.Write("coef0 " + param.Coefficient0 + "\n");
406
407            int nr_class = model.NumberOfClasses;
408            int l = model.SupportVectorCount;
409            output.Write("nr_class " + nr_class + "\n");
410            output.Write("total_sv " + l + "\n");
411
412            {
413                output.Write("rho");
414                for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
415                    output.Write(" " + model.Rho[i]);
416                output.Write("\n");
417            }
418
419            if (model.ClassLabels != null)
420            {
421                output.Write("label");
422                for (int i = 0; i < nr_class; i++)
423                    output.Write(" " + model.ClassLabels[i]);
424                output.Write("\n");
425            }
426
427            if (model.PairwiseProbabilityA != null)
428            // regression has probA only
429            {
430                output.Write("probA");
431                for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
432                    output.Write(" " + model.PairwiseProbabilityA[i]);
433                output.Write("\n");
434            }
435            if (model.PairwiseProbabilityB != null)
436            {
437                output.Write("probB");
438                for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
439                    output.Write(" " + model.PairwiseProbabilityB[i]);
440                output.Write("\n");
441            }
442
443            if (model.NumberOfSVPerClass != null)
444            {
445                output.Write("nr_sv");
446                for (int i = 0; i < nr_class; i++)
447                    output.Write(" " + model.NumberOfSVPerClass[i]);
448                output.Write("\n");
449            }
450
451            output.Write("SV\n");
452            double[][] sv_coef = model.SupportVectorCoefficients;
453            Node[][] SV = model.SupportVectors;
454
455            for (int i = 0; i < l; i++)
456            {
457                for (int j = 0; j < nr_class - 1; j++)
458                    output.Write(sv_coef[j][i] + " ");
459
460                Node[] p = SV[i];
461                if (p.Length == 0)
462                {
463                    output.WriteLine();
464                    continue;
465                }
466                if (param.KernelType == KernelType.PRECOMPUTED)
467                    output.Write("0:{0}", (int)p[0].Value);
468                else
469                {
470                    output.Write("{0}:{1}", p[0].Index, p[0].Value);
471                    for (int j = 1; j < p.Length; j++)
472                        output.Write(" {0}:{1}", p[j].Index, p[j].Value);
473                }
474                output.WriteLine();
475            }
476
477            output.Flush();
478        }
479  }
480}
Note: See TracBrowser for help on using the repository browser.