Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Persistence Test/LibSVM/Model.cs @ 4021

Last change on this file since 4021 was 2418, checked in by gkronber, 15 years ago

Fixed bugs in text export/import of SVM models. #772.

File size: 12.6 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20
21using System;
22using System.IO;
23using System.Threading;
24using System.Globalization;
25
26namespace SVM {
27  /// <summary>
28  /// Encapsulates an SVM Model.
29  /// </summary>
30  [Serializable]
31  public class Model {
32    private Parameter _parameter;
33    private int _numberOfClasses;
34    private int _supportVectorCount;
35    private Node[][] _supportVectors;
36    private double[][] _supportVectorCoefficients;
37    private double[] _rho;
38    private double[] _pairwiseProbabilityA;
39    private double[] _pairwiseProbabilityB;
40
41    private int[] _classLabels;
42    private int[] _numberOfSVPerClass;
43
44    internal Model() {
45    }
46
47    /// <summary>
48    /// Parameter object.
49    /// </summary>
50    public Parameter Parameter {
51      get {
52        return _parameter;
53      }
54      set {
55        _parameter = value;
56      }
57    }
58
59    /// <summary>
60    /// Number of classes in the model.
61    /// </summary>
62    public int NumberOfClasses {
63      get {
64        return _numberOfClasses;
65      }
66      set {
67        _numberOfClasses = value;
68      }
69    }
70
71    /// <summary>
72    /// Total number of support vectors.
73    /// </summary>
74    public int SupportVectorCount {
75      get {
76        return _supportVectorCount;
77      }
78      set {
79        _supportVectorCount = value;
80      }
81    }
82
83    /// <summary>
84    /// The support vectors.
85    /// </summary>
86    public Node[][] SupportVectors {
87      get {
88        return _supportVectors;
89      }
90      set {
91        _supportVectors = value;
92      }
93    }
94
95    /// <summary>
96    /// The coefficients for the support vectors.
97    /// </summary>
98    public double[][] SupportVectorCoefficients {
99      get {
100        return _supportVectorCoefficients;
101      }
102      set {
103        _supportVectorCoefficients = value;
104      }
105    }
106
107    /// <summary>
108    /// Rho values.
109    /// </summary>
110    public double[] Rho {
111      get {
112        return _rho;
113      }
114      set {
115        _rho = value;
116      }
117    }
118
119    /// <summary>
120    /// First pairwise probability.
121    /// </summary>
122    public double[] PairwiseProbabilityA {
123      get {
124        return _pairwiseProbabilityA;
125      }
126      set {
127        _pairwiseProbabilityA = value;
128      }
129    }
130
131    /// <summary>
132    /// Second pairwise probability.
133    /// </summary>
134    public double[] PairwiseProbabilityB {
135      get {
136        return _pairwiseProbabilityB;
137      }
138      set {
139        _pairwiseProbabilityB = value;
140      }
141    }
142
143    // for classification only
144
145    /// <summary>
146    /// Class labels.
147    /// </summary>
148    public int[] ClassLabels {
149      get {
150        return _classLabels;
151      }
152      set {
153        _classLabels = value;
154      }
155    }
156
157    /// <summary>
158    /// Number of support vectors per class.
159    /// </summary>
160    public int[] NumberOfSVPerClass {
161      get {
162        return _numberOfSVPerClass;
163      }
164      set {
165        _numberOfSVPerClass = value;
166      }
167    }
168
169    /// <summary>
170    /// Reads a Model from the provided file.
171    /// </summary>
172    /// <param name="filename">The name of the file containing the Model</param>
173    /// <returns>the Model</returns>
174    public static Model Read(string filename) {
175      FileStream input = File.OpenRead(filename);
176      try {
177        return Read(input);
178      }
179      finally {
180        input.Close();
181      }
182    }
183
184    public static Model Read(Stream stream) {
185      return Read(new StreamReader(stream));
186    }
187
188    /// <summary>
189    /// Reads a Model from the provided stream.
190    /// </summary>
191    /// <param name="stream">The stream from which to read the Model.</param>
192    /// <returns>the Model</returns>
193    public static Model Read(TextReader input) {
194      TemporaryCulture.Start();
195
196      // read parameters
197
198      Model model = new Model();
199      Parameter param = new Parameter();
200      model.Parameter = param;
201      model.Rho = null;
202      model.PairwiseProbabilityA = null;
203      model.PairwiseProbabilityB = null;
204      model.ClassLabels = null;
205      model.NumberOfSVPerClass = null;
206
207      bool headerFinished = false;
208      while (!headerFinished) {
209        string line = input.ReadLine();
210        string cmd, arg;
211        int splitIndex = line.IndexOf(' ');
212        if (splitIndex >= 0) {
213          cmd = line.Substring(0, splitIndex);
214          arg = line.Substring(splitIndex + 1);
215        } else {
216          cmd = line;
217          arg = "";
218        }
219        // arg = arg.ToLower(); (transforms double NaN or Infinity values to incorrect format [gkronber])
220
221        int i, n;
222        switch (cmd) {
223          case "svm_type":
224            param.SvmType = (SvmType)Enum.Parse(typeof(SvmType), arg.ToUpper());
225            break;
226
227          case "kernel_type":
228            param.KernelType = (KernelType)Enum.Parse(typeof(KernelType), arg.ToUpper());
229            break;
230
231          case "degree":
232            param.Degree = int.Parse(arg);
233            break;
234
235          case "gamma":
236            param.Gamma = double.Parse(arg);
237            break;
238
239          case "coef0":
240            param.Coefficient0 = double.Parse(arg);
241            break;
242
243          case "nr_class":
244            model.NumberOfClasses = int.Parse(arg);
245            break;
246
247          case "total_sv":
248            model.SupportVectorCount = int.Parse(arg);
249            break;
250
251          case "rho":
252            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
253            model.Rho = new double[n];
254            string[] rhoParts = arg.Split();
255            for (i = 0; i < n; i++)
256              model.Rho[i] = double.Parse(rhoParts[i]);
257            break;
258
259          case "label":
260            n = model.NumberOfClasses;
261            model.ClassLabels = new int[n];
262            string[] labelParts = arg.Split();
263            for (i = 0; i < n; i++)
264              model.ClassLabels[i] = int.Parse(labelParts[i]);
265            break;
266
267          case "probA":
268            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
269            model.PairwiseProbabilityA = new double[n];
270            string[] probAParts = arg.Split();
271            for (i = 0; i < n; i++)
272              model.PairwiseProbabilityA[i] = double.Parse(probAParts[i]);
273            break;
274
275          case "probB":
276            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
277            model.PairwiseProbabilityB = new double[n];
278            string[] probBParts = arg.Split();
279            for (i = 0; i < n; i++)
280              model.PairwiseProbabilityB[i] = double.Parse(probBParts[i]);
281            break;
282
283          case "nr_sv":
284            n = model.NumberOfClasses;
285            model.NumberOfSVPerClass = new int[n];
286            string[] nrsvParts = arg.Split();
287            for (i = 0; i < n; i++)
288              model.NumberOfSVPerClass[i] = int.Parse(nrsvParts[i]);
289            break;
290
291          case "SV":
292            headerFinished = true;
293            break;
294
295          default:
296            throw new Exception("Unknown text in model file");
297        }
298      }
299
300      // read sv_coef and SV
301
302      int m = model.NumberOfClasses - 1;
303      int l = model.SupportVectorCount;
304      model.SupportVectorCoefficients = new double[m][];
305      for (int i = 0; i < m; i++) {
306        model.SupportVectorCoefficients[i] = new double[l];
307      }
308      model.SupportVectors = new Node[l][];
309
310      for (int i = 0; i < l; i++) {
311        string[] parts = input.ReadLine().Trim().Split();
312
313        for (int k = 0; k < m; k++)
314          model.SupportVectorCoefficients[k][i] = double.Parse(parts[k]);
315        int n = parts.Length - m;
316        model.SupportVectors[i] = new Node[n];
317        for (int j = 0; j < n; j++) {
318          string[] nodeParts = parts[m + j].Split(':');
319          model.SupportVectors[i][j] = new Node();
320          model.SupportVectors[i][j].Index = int.Parse(nodeParts[0]);
321          model.SupportVectors[i][j].Value = double.Parse(nodeParts[1]);
322        }
323      }
324
325      TemporaryCulture.Stop();
326
327      return model;
328    }
329
330    /// <summary>
331    /// Writes a model to the provided filename.  This will overwrite any previous data in the file.
332    /// </summary>
333    /// <param name="filename">The desired file</param>
334    /// <param name="model">The Model to write</param>
335    public static void Write(string filename, Model model) {
336      FileStream stream = File.Open(filename, FileMode.Create);
337      try {
338        Write(stream, model);
339      }
340      finally {
341        stream.Close();
342      }
343    }
344
345    /// <summary>
346    /// Writes a model to the provided stream.
347    /// </summary>
348    /// <param name="stream">The output stream</param>
349    /// <param name="model">The model to write</param>
350    public static void Write(Stream stream, Model model) {
351      TemporaryCulture.Start();
352
353      StreamWriter output = new StreamWriter(stream);
354
355      Parameter param = model.Parameter;
356
357      output.Write("svm_type " + param.SvmType + Environment.NewLine);
358      output.Write("kernel_type " + param.KernelType + Environment.NewLine);
359
360      if (param.KernelType == KernelType.POLY)
361        output.Write("degree " + param.Degree.ToString("r") + Environment.NewLine);
362
363      if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.RBF || param.KernelType == KernelType.SIGMOID)
364        output.Write("gamma " + param.Gamma.ToString("r") + Environment.NewLine);
365
366      if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.SIGMOID)
367        output.Write("coef0 " + param.Coefficient0.ToString("r") + Environment.NewLine);
368
369      int nr_class = model.NumberOfClasses;
370      int l = model.SupportVectorCount;
371      output.Write("nr_class " + nr_class + Environment.NewLine);
372      output.Write("total_sv " + l + Environment.NewLine);
373
374      {
375        output.Write("rho");
376        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
377          output.Write(" " + model.Rho[i].ToString("r"));
378        output.Write(Environment.NewLine);
379      }
380
381      if (model.ClassLabels != null) {
382        output.Write("label");
383        for (int i = 0; i < nr_class; i++)
384          output.Write(" " + model.ClassLabels[i]);
385        output.Write(Environment.NewLine);
386      }
387
388      if (model.PairwiseProbabilityA != null)
389      // regression has probA only
390            {
391        output.Write("probA");
392        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
393          output.Write(" " + model.PairwiseProbabilityA[i].ToString("r"));
394        output.Write(Environment.NewLine);
395      }
396      if (model.PairwiseProbabilityB != null) {
397        output.Write("probB");
398        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
399          output.Write(" " + model.PairwiseProbabilityB[i].ToString("r"));
400        output.Write(Environment.NewLine);
401      }
402
403      if (model.NumberOfSVPerClass != null) {
404        output.Write("nr_sv");
405        for (int i = 0; i < nr_class; i++)
406          output.Write(" " + model.NumberOfSVPerClass[i].ToString("r"));
407        output.Write(Environment.NewLine);
408      }
409
410      output.WriteLine("SV");
411      double[][] sv_coef = model.SupportVectorCoefficients;
412      Node[][] SV = model.SupportVectors;
413
414      for (int i = 0; i < l; i++) {
415        for (int j = 0; j < nr_class - 1; j++)
416          output.Write(sv_coef[j][i].ToString("r") + " ");
417
418        Node[] p = SV[i];
419        if (p.Length == 0) {
420          output.WriteLine();
421          continue;
422        }
423        if (param.KernelType == KernelType.PRECOMPUTED)
424          output.Write("0:{0}", (int)p[0].Value);
425        else {
426          output.Write("{0}:{1}", p[0].Index, p[0].Value.ToString("r"));
427          for (int j = 1; j < p.Length; j++)
428            output.Write(" {0}:{1}", p[j].Index, p[j].Value.ToString("r"));
429        }
430        output.WriteLine();
431      }
432
433      output.Flush();
434
435      TemporaryCulture.Stop();
436    }
437  }
438}
Note: See TracBrowser for help on using the repository browser.