Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/LibSVM/Model.cs @ 2411

Last change on this file since 2411 was 2411, checked in by gkronber, 15 years ago

Implemented #772 (Text export of SVM models)

File size: 12.1 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20
21using System;
22using System.IO;
23
24namespace SVM {
25  /// <remarks>
26  /// Encapsulates an SVM Model.
27  /// </remarks>
28  [Serializable]
29  public class Model {
30    private Parameter _parameter;
31    private int _numberOfClasses;
32    private int _supportVectorCount;
33    private Node[][] _supportVectors;
34    private double[][] _supportVectorCoefficients;
35    private double[] _rho;
36    private double[] _pairwiseProbabilityA;
37    private double[] _pairwiseProbabilityB;
38
39    private int[] _classLabels;
40    private int[] _numberOfSVPerClass;
41
42    internal Model() {
43    }
44
45    /// <summary>
46    /// Parameter object.
47    /// </summary>
48    public Parameter Parameter {
49      get {
50        return _parameter;
51      }
52      set {
53        _parameter = value;
54      }
55    }
56
57    /// <summary>
58    /// Number of classes in the model.
59    /// </summary>
60    public int NumberOfClasses {
61      get {
62        return _numberOfClasses;
63      }
64      set {
65        _numberOfClasses = value;
66      }
67    }
68
69    /// <summary>
70    /// Total number of support vectors.
71    /// </summary>
72    public int SupportVectorCount {
73      get {
74        return _supportVectorCount;
75      }
76      set {
77        _supportVectorCount = value;
78      }
79    }
80
81    /// <summary>
82    /// The support vectors.
83    /// </summary>
84    public Node[][] SupportVectors {
85      get {
86        return _supportVectors;
87      }
88      set {
89        _supportVectors = value;
90      }
91    }
92
93    /// <summary>
94    /// The coefficients for the support vectors.
95    /// </summary>
96    public double[][] SupportVectorCoefficients {
97      get {
98        return _supportVectorCoefficients;
99      }
100      set {
101        _supportVectorCoefficients = value;
102      }
103    }
104
105    /// <summary>
106    /// Rho values.
107    /// </summary>
108    public double[] Rho {
109      get {
110        return _rho;
111      }
112      set {
113        _rho = value;
114      }
115    }
116
117    /// <summary>
118    /// First pairwise probability.
119    /// </summary>
120    public double[] PairwiseProbabilityA {
121      get {
122        return _pairwiseProbabilityA;
123      }
124      set {
125        _pairwiseProbabilityA = value;
126      }
127    }
128
129    /// <summary>
130    /// Second pairwise probability.
131    /// </summary>
132    public double[] PairwiseProbabilityB {
133      get {
134        return _pairwiseProbabilityB;
135      }
136      set {
137        _pairwiseProbabilityB = value;
138      }
139    }
140
141    // for classification only
142
143    /// <summary>
144    /// Class labels.
145    /// </summary>
146    public int[] ClassLabels {
147      get {
148        return _classLabels;
149      }
150      set {
151        _classLabels = value;
152      }
153    }
154
155    /// <summary>
156    /// Number of support vectors per class.
157    /// </summary>
158    public int[] NumberOfSVPerClass {
159      get {
160        return _numberOfSVPerClass;
161      }
162      set {
163        _numberOfSVPerClass = value;
164      }
165    }
166
167    /// <summary>
168    /// Reads a Model from the provided file.
169    /// </summary>
170    /// <param name="filename">The name of the file containing the Model</param>
171    /// <returns>the Model</returns>
172    public static Model Read(string filename) {
173      FileStream input = File.OpenRead(filename);
174      try {
175        return Read(input);
176      }
177      finally {
178        input.Close();
179      }
180    }
181
182    /// <summary>
183    /// Reads a Model from the provided stream.
184    /// </summary>
185    /// <param name="stream">The stream from which to read the Model.</param>
186    /// <returns>the Model</returns>
187    public static Model Read(Stream stream) {
188      StreamReader input = new StreamReader(stream);
189
190      // read parameters
191
192      Model model = new Model();
193      Parameter param = new Parameter();
194      model.Parameter = param;
195      model.Rho = null;
196      model.PairwiseProbabilityA = null;
197      model.PairwiseProbabilityB = null;
198      model.ClassLabels = null;
199      model.NumberOfSVPerClass = null;
200
201      bool headerFinished = false;
202      while (!headerFinished) {
203        string line = input.ReadLine();
204        string cmd, arg;
205        int splitIndex = line.IndexOf(' ');
206        if (splitIndex >= 0) {
207          cmd = line.Substring(0, splitIndex);
208          arg = line.Substring(splitIndex + 1);
209        } else {
210          cmd = line;
211          arg = "";
212        }
213        arg = arg.ToLower();
214
215        int i, n;
216        switch (cmd) {
217          case "svm_type":
218            param.SvmType = (SvmType)Enum.Parse(typeof(SvmType), arg.ToUpper());
219            break;
220
221          case "kernel_type":
222            param.KernelType = (KernelType)Enum.Parse(typeof(KernelType), arg.ToUpper());
223            break;
224
225          case "degree":
226            param.Degree = int.Parse(arg);
227            break;
228
229          case "gamma":
230            param.Gamma = double.Parse(arg);
231            break;
232
233          case "coef0":
234            param.Coefficient0 = double.Parse(arg);
235            break;
236
237          case "nr_class":
238            model.NumberOfClasses = int.Parse(arg);
239            break;
240
241          case "total_sv":
242            model.SupportVectorCount = int.Parse(arg);
243            break;
244
245          case "rho":
246            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
247            model.Rho = new double[n];
248            string[] rhoParts = arg.Split();
249            for (i = 0; i < n; i++)
250              model.Rho[i] = double.Parse(rhoParts[i]);
251            break;
252
253          case "label":
254            n = model.NumberOfClasses;
255            model.ClassLabels = new int[n];
256            string[] labelParts = arg.Split();
257            for (i = 0; i < n; i++)
258              model.ClassLabels[i] = int.Parse(labelParts[i]);
259            break;
260
261          case "probA":
262            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
263            model.PairwiseProbabilityA = new double[n];
264            string[] probAParts = arg.Split();
265            for (i = 0; i < n; i++)
266              model.PairwiseProbabilityA[i] = double.Parse(probAParts[i]);
267            break;
268
269          case "probB":
270            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
271            model.PairwiseProbabilityB = new double[n];
272            string[] probBParts = arg.Split();
273            for (i = 0; i < n; i++)
274              model.PairwiseProbabilityB[i] = double.Parse(probBParts[i]);
275            break;
276
277          case "nr_sv":
278            n = model.NumberOfClasses;
279            model.NumberOfSVPerClass = new int[n];
280            string[] nrsvParts = arg.Split();
281            for (i = 0; i < n; i++)
282              model.NumberOfSVPerClass[i] = int.Parse(nrsvParts[i]);
283            break;
284
285          case "SV":
286            headerFinished = true;
287            break;
288
289          default:
290            throw new Exception("Unknown text in model file");
291        }
292      }
293
294      // read sv_coef and SV
295
296      int m = model.NumberOfClasses - 1;
297      int l = model.SupportVectorCount;
298      model.SupportVectorCoefficients = new double[m][];
299      for (int i = 0; i < m; i++) {
300        model.SupportVectorCoefficients[i] = new double[l];
301      }
302      model.SupportVectors = new Node[l][];
303
304      for (int i = 0; i < l; i++) {
305        string[] parts = input.ReadLine().Trim().Split();
306
307        for (int k = 0; k < m; k++)
308          model.SupportVectorCoefficients[k][i] = double.Parse(parts[k]);
309        int n = parts.Length - m;
310        model.SupportVectors[i] = new Node[n];
311        for (int j = 0; j < n; j++) {
312          string[] nodeParts = parts[m + j].Split(':');
313          model.SupportVectors[i][j] = new Node();
314          model.SupportVectors[i][j].Index = int.Parse(nodeParts[0]);
315          model.SupportVectors[i][j].Value = double.Parse(nodeParts[1]);
316        }
317      }
318
319      return model;
320    }
321
322    /// <summary>
323    /// Writes a model to the provided filename.  This will overwrite any previous data in the file.
324    /// </summary>
325    /// <param name="filename">The desired file</param>
326    /// <param name="model">The Model to write</param>
327    public static void Write(string filename, Model model) {
328      FileStream stream = File.Open(filename, FileMode.Create);
329      try {
330        Write(stream, model);
331      }
332      finally {
333        stream.Close();
334      }
335    }
336
337    /// <summary>
338    /// Writes a model to the provided stream.
339    /// </summary>
340    /// <param name="stream">The output stream</param>
341    /// <param name="model">The model to write</param>
342    public static void Write(Stream stream, Model model) {
343      StreamWriter output = new StreamWriter(stream);
344
345      Parameter param = model.Parameter;
346
347      output.Write("svm_type " + param.SvmType + Environment.NewLine);
348      output.Write("kernel_type " + param.KernelType + Environment.NewLine);
349
350      if (param.KernelType == KernelType.POLY)
351        output.Write("degree " + param.Degree + Environment.NewLine);
352
353      if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.RBF || param.KernelType == KernelType.SIGMOID)
354        output.Write("gamma " + param.Gamma + Environment.NewLine);
355
356      if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.SIGMOID)
357        output.Write("coef0 " + param.Coefficient0 + Environment.NewLine);
358
359      int nr_class = model.NumberOfClasses;
360      int l = model.SupportVectorCount;
361      output.Write("nr_class " + nr_class + Environment.NewLine);
362      output.Write("total_sv " + l + Environment.NewLine);
363
364      {
365        output.Write("rho");
366        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
367          output.Write(" " + model.Rho[i]);
368        output.Write(Environment.NewLine);
369      }
370
371      if (model.ClassLabels != null) {
372        output.Write("label");
373        for (int i = 0; i < nr_class; i++)
374          output.Write(" " + model.ClassLabels[i]);
375        output.Write(Environment.NewLine);
376      }
377
378      if (model.PairwiseProbabilityA != null)
379      // regression has probA only
380            {
381        output.Write("probA");
382        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
383          output.Write(" " + model.PairwiseProbabilityA[i]);
384        output.Write(Environment.NewLine);
385      }
386      if (model.PairwiseProbabilityB != null) {
387        output.Write("probB");
388        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
389          output.Write(" " + model.PairwiseProbabilityB[i]);
390        output.Write(Environment.NewLine);
391      }
392
393      if (model.NumberOfSVPerClass != null) {
394        output.Write("nr_sv");
395        for (int i = 0; i < nr_class; i++)
396          output.Write(" " + model.NumberOfSVPerClass[i]);
397        output.Write(Environment.NewLine);
398      }
399
400      output.Write("SV\n");
401      double[][] sv_coef = model.SupportVectorCoefficients;
402      Node[][] SV = model.SupportVectors;
403
404      for (int i = 0; i < l; i++) {
405        for (int j = 0; j < nr_class - 1; j++)
406          output.Write(sv_coef[j][i] + " ");
407
408        Node[] p = SV[i];
409        if (p.Length == 0) {
410          output.WriteLine();
411          continue;
412        }
413        if (param.KernelType == KernelType.PRECOMPUTED)
414          output.Write("0:{0}", (int)p[0].Value);
415        else {
416          output.Write("{0}:{1}", p[0].Index, p[0].Value);
417          for (int j = 1; j < p.Length; j++)
418            output.Write(" {0}:{1}", p[j].Index, p[j].Value);
419        }
420        output.WriteLine();
421      }
422
423      output.Flush();
424    }
425  }
426}
Note: See TracBrowser for help on using the repository browser.