Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/Model.cs @ 5692

Last change on this file since 5692 was 5692, checked in by gkronber, 13 years ago

#1426 merged r5690 from data analysis refactoring branch (see #1418) into trunk to fix persistence problems of SVMs.

File size: 12.5 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20
21using System;
22using System.IO;
23
24namespace SVM {
25  /// <summary>
26  /// Encapsulates an SVM Model.
27  /// </summary>
28  [Serializable]
29  public class Model {
30    private Parameter _parameter;
31    private int _numberOfClasses;
32    private int _supportVectorCount;
33    private Node[][] _supportVectors;
34    private double[][] _supportVectorCoefficients;
35    private double[] _rho;
36    private double[] _pairwiseProbabilityA;
37    private double[] _pairwiseProbabilityB;
38
39    private int[] _classLabels;
40    private int[] _numberOfSVPerClass;
41
42    internal Model() {
43    }
44
45    /// <summary>
46    /// Parameter object.
47    /// </summary>
48    public Parameter Parameter {
49      get {
50        return _parameter;
51      }
52      set {
53        _parameter = value;
54      }
55    }
56
57    /// <summary>
58    /// Number of classes in the model.
59    /// </summary>
60    public int NumberOfClasses {
61      get {
62        return _numberOfClasses;
63      }
64      set {
65        _numberOfClasses = value;
66      }
67    }
68
69    /// <summary>
70    /// Total number of support vectors.
71    /// </summary>
72    public int SupportVectorCount {
73      get {
74        return _supportVectorCount;
75      }
76      set {
77        _supportVectorCount = value;
78      }
79    }
80
81    /// <summary>
82    /// The support vectors.
83    /// </summary>
84    public Node[][] SupportVectors {
85      get {
86        return _supportVectors;
87      }
88      set {
89        _supportVectors = value;
90      }
91    }
92
93    /// <summary>
94    /// The coefficients for the support vectors.
95    /// </summary>
96    public double[][] SupportVectorCoefficients {
97      get {
98        return _supportVectorCoefficients;
99      }
100      set {
101        _supportVectorCoefficients = value;
102      }
103    }
104
105    /// <summary>
106    /// Rho values.
107    /// </summary>
108    public double[] Rho {
109      get {
110        return _rho;
111      }
112      set {
113        _rho = value;
114      }
115    }
116
117    /// <summary>
118    /// First pairwise probability.
119    /// </summary>
120    public double[] PairwiseProbabilityA {
121      get {
122        return _pairwiseProbabilityA;
123      }
124      set {
125        _pairwiseProbabilityA = value;
126      }
127    }
128
129    /// <summary>
130    /// Second pairwise probability.
131    /// </summary>
132    public double[] PairwiseProbabilityB {
133      get {
134        return _pairwiseProbabilityB;
135      }
136      set {
137        _pairwiseProbabilityB = value;
138      }
139    }
140
141    // for classification only
142
143    /// <summary>
144    /// Class labels.
145    /// </summary>
146    public int[] ClassLabels {
147      get {
148        return _classLabels;
149      }
150      set {
151        _classLabels = value;
152      }
153    }
154
155    /// <summary>
156    /// Number of support vectors per class.
157    /// </summary>
158    public int[] NumberOfSVPerClass {
159      get {
160        return _numberOfSVPerClass;
161      }
162      set {
163        _numberOfSVPerClass = value;
164      }
165    }
166
167    /// <summary>
168    /// Reads a Model from the provided file.
169    /// </summary>
170    /// <param name="filename">The name of the file containing the Model</param>
171    /// <returns>the Model</returns>
172    public static Model Read(string filename) {
173      FileStream input = File.OpenRead(filename);
174      try {
175        return Read(input);
176      }
177      finally {
178        input.Close();
179      }
180    }
181
182    public static Model Read(Stream stream) {
183      return Read(new StreamReader(stream));
184    }
185
186    /// <summary>
187    /// Reads a Model from the provided stream.
188    /// </summary>
189    /// <param name="input">The stream from which to read the Model.</param>
190    /// <returns>the Model</returns>
191    public static Model Read(TextReader input) {
192      TemporaryCulture.Start();
193
194      // read parameters
195
196      Model model = new Model();
197      Parameter param = new Parameter();
198      model.Parameter = param;
199      model.Rho = null;
200      model.PairwiseProbabilityA = null;
201      model.PairwiseProbabilityB = null;
202      model.ClassLabels = null;
203      model.NumberOfSVPerClass = null;
204
205      bool headerFinished = false;
206      while (!headerFinished) {
207        string line = input.ReadLine();
208        string cmd, arg;
209        int splitIndex = line.IndexOf(' ');
210        if (splitIndex >= 0) {
211          cmd = line.Substring(0, splitIndex);
212          arg = line.Substring(splitIndex + 1);
213        } else {
214          cmd = line;
215          arg = "";
216        }
217        // arg = arg.ToLower(); (transforms double NaN or Infinity values to incorrect format [gkronber])
218
219        int i, n;
220        switch (cmd) {
221          case "svm_type":
222            param.SvmType = (SvmType)Enum.Parse(typeof(SvmType), arg.ToUpper());
223            break;
224
225          case "kernel_type":
226            param.KernelType = (KernelType)Enum.Parse(typeof(KernelType), arg.ToUpper());
227            break;
228
229          case "degree":
230            param.Degree = int.Parse(arg);
231            break;
232
233          case "gamma":
234            param.Gamma = double.Parse(arg);
235            break;
236
237          case "coef0":
238            param.Coefficient0 = double.Parse(arg);
239            break;
240
241          case "nr_class":
242            model.NumberOfClasses = int.Parse(arg);
243            break;
244
245          case "total_sv":
246            model.SupportVectorCount = int.Parse(arg);
247            break;
248
249          case "rho":
250            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
251            model.Rho = new double[n];
252            string[] rhoParts = arg.Split();
253            for (i = 0; i < n; i++)
254              model.Rho[i] = double.Parse(rhoParts[i]);
255            break;
256
257          case "label":
258            n = model.NumberOfClasses;
259            model.ClassLabels = new int[n];
260            string[] labelParts = arg.Split();
261            for (i = 0; i < n; i++)
262              model.ClassLabels[i] = int.Parse(labelParts[i]);
263            break;
264
265          case "probA":
266            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
267            model.PairwiseProbabilityA = new double[n];
268            string[] probAParts = arg.Split();
269            for (i = 0; i < n; i++)
270              model.PairwiseProbabilityA[i] = double.Parse(probAParts[i]);
271            break;
272
273          case "probB":
274            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
275            model.PairwiseProbabilityB = new double[n];
276            string[] probBParts = arg.Split();
277            for (i = 0; i < n; i++)
278              model.PairwiseProbabilityB[i] = double.Parse(probBParts[i]);
279            break;
280
281          case "nr_sv":
282            n = model.NumberOfClasses;
283            model.NumberOfSVPerClass = new int[n];
284            string[] nrsvParts = arg.Split();
285            for (i = 0; i < n; i++)
286              model.NumberOfSVPerClass[i] = int.Parse(nrsvParts[i]);
287            break;
288
289          case "SV":
290            headerFinished = true;
291            break;
292
293          default:
294            throw new Exception("Unknown text in model file");
295        }
296      }
297
298      // read sv_coef and SV
299
300      int m = model.NumberOfClasses - 1;
301      int l = model.SupportVectorCount;
302      model.SupportVectorCoefficients = new double[m][];
303      for (int i = 0; i < m; i++) {
304        model.SupportVectorCoefficients[i] = new double[l];
305      }
306      model.SupportVectors = new Node[l][];
307
308      for (int i = 0; i < l; i++) {
309        string[] parts = input.ReadLine().Trim().Split();
310
311        for (int k = 0; k < m; k++)
312          model.SupportVectorCoefficients[k][i] = double.Parse(parts[k]);
313        int n = parts.Length - m;
314        model.SupportVectors[i] = new Node[n];
315        for (int j = 0; j < n; j++) {
316          string[] nodeParts = parts[m + j].Split(':');
317          model.SupportVectors[i][j] = new Node();
318          model.SupportVectors[i][j].Index = int.Parse(nodeParts[0]);
319          model.SupportVectors[i][j].Value = double.Parse(nodeParts[1]);
320        }
321      }
322
323      TemporaryCulture.Stop();
324
325      return model;
326    }
327
328    /// <summary>
329    /// Writes a model to the provided filename.  This will overwrite any previous data in the file.
330    /// </summary>
331    /// <param name="filename">The desired file</param>
332    /// <param name="model">The Model to write</param>
333    public static void Write(string filename, Model model) {
334      FileStream stream = File.Open(filename, FileMode.Create);
335      try {
336        Write(stream, model);
337      }
338      finally {
339        stream.Close();
340      }
341    }
342
343    /// <summary>
344    /// Writes a model to the provided stream.
345    /// </summary>
346    /// <param name="stream">The output stream</param>
347    /// <param name="model">The model to write</param>
348    public static void Write(Stream stream, Model model) {
349      TemporaryCulture.Start();
350
351      StreamWriter output = new StreamWriter(stream);
352
353      Parameter param = model.Parameter;
354
355      output.Write("svm_type " + param.SvmType + Environment.NewLine);
356      output.Write("kernel_type " + param.KernelType + Environment.NewLine);
357
358      if (param.KernelType == KernelType.POLY)
359        output.Write("degree " + param.Degree.ToString("r") + Environment.NewLine);
360
361      if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.RBF || param.KernelType == KernelType.SIGMOID)
362        output.Write("gamma " + param.Gamma.ToString("r") + Environment.NewLine);
363
364      if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.SIGMOID)
365        output.Write("coef0 " + param.Coefficient0.ToString("r") + Environment.NewLine);
366
367      int nr_class = model.NumberOfClasses;
368      int l = model.SupportVectorCount;
369      output.Write("nr_class " + nr_class + Environment.NewLine);
370      output.Write("total_sv " + l + Environment.NewLine);
371
372      {
373        output.Write("rho");
374        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
375          output.Write(" " + model.Rho[i].ToString("r"));
376        output.Write(Environment.NewLine);
377      }
378
379      if (model.ClassLabels != null) {
380        output.Write("label");
381        for (int i = 0; i < nr_class; i++)
382          output.Write(" " + model.ClassLabels[i]);
383        output.Write(Environment.NewLine);
384      }
385
386      if (model.PairwiseProbabilityA != null)
387      // regression has probA only
388            {
389        output.Write("probA");
390        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
391          output.Write(" " + model.PairwiseProbabilityA[i].ToString("r"));
392        output.Write(Environment.NewLine);
393      }
394      if (model.PairwiseProbabilityB != null) {
395        output.Write("probB");
396        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
397          output.Write(" " + model.PairwiseProbabilityB[i].ToString("r"));
398        output.Write(Environment.NewLine);
399      }
400
401      if (model.NumberOfSVPerClass != null) {
402        output.Write("nr_sv");
403        for (int i = 0; i < nr_class; i++)
404          output.Write(" " + model.NumberOfSVPerClass[i]);
405        output.Write(Environment.NewLine);
406      }
407
408      output.WriteLine("SV");
409      double[][] sv_coef = model.SupportVectorCoefficients;
410      Node[][] SV = model.SupportVectors;
411
412      for (int i = 0; i < l; i++) {
413        for (int j = 0; j < nr_class - 1; j++)
414          output.Write(sv_coef[j][i].ToString("r") + " ");
415
416        Node[] p = SV[i];
417        if (p.Length == 0) {
418          output.WriteLine();
419          continue;
420        }
421        if (param.KernelType == KernelType.PRECOMPUTED)
422          output.Write("0:{0}", (int)p[0].Value);
423        else {
424          output.Write("{0}:{1}", p[0].Index, p[0].Value.ToString("r"));
425          for (int j = 1; j < p.Length; j++)
426            output.Write(" {0}:{1}", p[j].Index, p[j].Value.ToString("r"));
427        }
428        output.WriteLine();
429      }
430
431      output.Flush();
432
433      TemporaryCulture.Stop();
434    }
435  }
436}
Note: See TracBrowser for help on using the repository browser.