Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/Model.cs @ 3884

Last change on this file since 3884 was 3884, checked in by gkronber, 14 years ago

Worked on support vector regression operators and views. #1009

File size: 12.9 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20
21using System;
22using System.IO;
23using System.Threading;
24using System.Globalization;
25
26namespace SVM {
27  /// <summary>
28  /// Encapsulates an SVM Model.
29  /// </summary>
30  [Serializable]
31  public class Model {
32    private Parameter _parameter;
33    private int _numberOfClasses;
34    private int _supportVectorCount;
35    private int[] _supportVectorIndizes;
36    private Node[][] _supportVectors;
37    private double[][] _supportVectorCoefficients;
38    private double[] _rho;
39    private double[] _pairwiseProbabilityA;
40    private double[] _pairwiseProbabilityB;
41
42    private int[] _classLabels;
43    private int[] _numberOfSVPerClass;
44
45    internal Model() {
46    }
47
48    /// <summary>
49    /// Parameter object.
50    /// </summary>
51    public Parameter Parameter {
52      get {
53        return _parameter;
54      }
55      set {
56        _parameter = value;
57      }
58    }
59
60    /// <summary>
61    /// Number of classes in the model.
62    /// </summary>
63    public int NumberOfClasses {
64      get {
65        return _numberOfClasses;
66      }
67      set {
68        _numberOfClasses = value;
69      }
70    }
71
72    /// <summary>
73    /// Total number of support vectors.
74    /// </summary>
75    public int SupportVectorCount {
76      get {
77        return _supportVectorCount;
78      }
79      set {
80        _supportVectorCount = value;
81      }
82    }
83
84    /// <summary>
85    /// Indizes of support vectors identified in the training.
86    /// </summary>
87    public int[] SupportVectorIndizes {
88      get {
89        return _supportVectorIndizes;
90      }
91      set {
92        _supportVectorIndizes = value;
93      }
94    }
95
96    /// <summary>
97    /// The support vectors.
98    /// </summary>
99    public Node[][] SupportVectors {
100      get {
101        return _supportVectors;
102      }
103      set {
104        _supportVectors = value;
105      }
106    }
107
108    /// <summary>
109    /// The coefficients for the support vectors.
110    /// </summary>
111    public double[][] SupportVectorCoefficients {
112      get {
113        return _supportVectorCoefficients;
114      }
115      set {
116        _supportVectorCoefficients = value;
117      }
118    }
119
120    /// <summary>
121    /// Rho values.
122    /// </summary>
123    public double[] Rho {
124      get {
125        return _rho;
126      }
127      set {
128        _rho = value;
129      }
130    }
131
132    /// <summary>
133    /// First pairwise probability.
134    /// </summary>
135    public double[] PairwiseProbabilityA {
136      get {
137        return _pairwiseProbabilityA;
138      }
139      set {
140        _pairwiseProbabilityA = value;
141      }
142    }
143
144    /// <summary>
145    /// Second pairwise probability.
146    /// </summary>
147    public double[] PairwiseProbabilityB {
148      get {
149        return _pairwiseProbabilityB;
150      }
151      set {
152        _pairwiseProbabilityB = value;
153      }
154    }
155
156    // for classification only
157
158    /// <summary>
159    /// Class labels.
160    /// </summary>
161    public int[] ClassLabels {
162      get {
163        return _classLabels;
164      }
165      set {
166        _classLabels = value;
167      }
168    }
169
170    /// <summary>
171    /// Number of support vectors per class.
172    /// </summary>
173    public int[] NumberOfSVPerClass {
174      get {
175        return _numberOfSVPerClass;
176      }
177      set {
178        _numberOfSVPerClass = value;
179      }
180    }
181
182    /// <summary>
183    /// Reads a Model from the provided file.
184    /// </summary>
185    /// <param name="filename">The name of the file containing the Model</param>
186    /// <returns>the Model</returns>
187    public static Model Read(string filename) {
188      FileStream input = File.OpenRead(filename);
189      try {
190        return Read(input);
191      }
192      finally {
193        input.Close();
194      }
195    }
196
197    public static Model Read(Stream stream) {
198      return Read(new StreamReader(stream));
199    }
200
201    /// <summary>
202    /// Reads a Model from the provided stream.
203    /// </summary>
204    /// <param name="stream">The stream from which to read the Model.</param>
205    /// <returns>the Model</returns>
206    public static Model Read(TextReader input) {
207      TemporaryCulture.Start();
208
209      // read parameters
210
211      Model model = new Model();
212      Parameter param = new Parameter();
213      model.Parameter = param;
214      model.Rho = null;
215      model.PairwiseProbabilityA = null;
216      model.PairwiseProbabilityB = null;
217      model.ClassLabels = null;
218      model.NumberOfSVPerClass = null;
219      model.SupportVectorIndizes = new int[0];
220
221      bool headerFinished = false;
222      while (!headerFinished) {
223        string line = input.ReadLine();
224        string cmd, arg;
225        int splitIndex = line.IndexOf(' ');
226        if (splitIndex >= 0) {
227          cmd = line.Substring(0, splitIndex);
228          arg = line.Substring(splitIndex + 1);
229        } else {
230          cmd = line;
231          arg = "";
232        }
233        // arg = arg.ToLower(); (transforms double NaN or Infinity values to incorrect format [gkronber])
234
235        int i, n;
236        switch (cmd) {
237          case "svm_type":
238            param.SvmType = (SvmType)Enum.Parse(typeof(SvmType), arg.ToUpper());
239            break;
240
241          case "kernel_type":
242            param.KernelType = (KernelType)Enum.Parse(typeof(KernelType), arg.ToUpper());
243            break;
244
245          case "degree":
246            param.Degree = int.Parse(arg);
247            break;
248
249          case "gamma":
250            param.Gamma = double.Parse(arg);
251            break;
252
253          case "coef0":
254            param.Coefficient0 = double.Parse(arg);
255            break;
256
257          case "nr_class":
258            model.NumberOfClasses = int.Parse(arg);
259            break;
260
261          case "total_sv":
262            model.SupportVectorCount = int.Parse(arg);
263            break;
264
265          case "rho":
266            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
267            model.Rho = new double[n];
268            string[] rhoParts = arg.Split();
269            for (i = 0; i < n; i++)
270              model.Rho[i] = double.Parse(rhoParts[i]);
271            break;
272
273          case "label":
274            n = model.NumberOfClasses;
275            model.ClassLabels = new int[n];
276            string[] labelParts = arg.Split();
277            for (i = 0; i < n; i++)
278              model.ClassLabels[i] = int.Parse(labelParts[i]);
279            break;
280
281          case "probA":
282            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
283            model.PairwiseProbabilityA = new double[n];
284            string[] probAParts = arg.Split();
285            for (i = 0; i < n; i++)
286              model.PairwiseProbabilityA[i] = double.Parse(probAParts[i]);
287            break;
288
289          case "probB":
290            n = model.NumberOfClasses * (model.NumberOfClasses - 1) / 2;
291            model.PairwiseProbabilityB = new double[n];
292            string[] probBParts = arg.Split();
293            for (i = 0; i < n; i++)
294              model.PairwiseProbabilityB[i] = double.Parse(probBParts[i]);
295            break;
296
297          case "nr_sv":
298            n = model.NumberOfClasses;
299            model.NumberOfSVPerClass = new int[n];
300            string[] nrsvParts = arg.Split();
301            for (i = 0; i < n; i++)
302              model.NumberOfSVPerClass[i] = int.Parse(nrsvParts[i]);
303            break;
304
305          case "SV":
306            headerFinished = true;
307            break;
308
309          default:
310            throw new Exception("Unknown text in model file");
311        }
312      }
313
314      // read sv_coef and SV
315
316      int m = model.NumberOfClasses - 1;
317      int l = model.SupportVectorCount;
318      model.SupportVectorCoefficients = new double[m][];
319      for (int i = 0; i < m; i++) {
320        model.SupportVectorCoefficients[i] = new double[l];
321      }
322      model.SupportVectors = new Node[l][];
323
324      for (int i = 0; i < l; i++) {
325        string[] parts = input.ReadLine().Trim().Split();
326
327        for (int k = 0; k < m; k++)
328          model.SupportVectorCoefficients[k][i] = double.Parse(parts[k]);
329        int n = parts.Length - m;
330        model.SupportVectors[i] = new Node[n];
331        for (int j = 0; j < n; j++) {
332          string[] nodeParts = parts[m + j].Split(':');
333          model.SupportVectors[i][j] = new Node();
334          model.SupportVectors[i][j].Index = int.Parse(nodeParts[0]);
335          model.SupportVectors[i][j].Value = double.Parse(nodeParts[1]);
336        }
337      }
338
339      TemporaryCulture.Stop();
340
341      return model;
342    }
343
344    /// <summary>
345    /// Writes a model to the provided filename.  This will overwrite any previous data in the file.
346    /// </summary>
347    /// <param name="filename">The desired file</param>
348    /// <param name="model">The Model to write</param>
349    public static void Write(string filename, Model model) {
350      FileStream stream = File.Open(filename, FileMode.Create);
351      try {
352        Write(stream, model);
353      }
354      finally {
355        stream.Close();
356      }
357    }
358
359    /// <summary>
360    /// Writes a model to the provided stream.
361    /// </summary>
362    /// <param name="stream">The output stream</param>
363    /// <param name="model">The model to write</param>
364    public static void Write(Stream stream, Model model) {
365      TemporaryCulture.Start();
366
367      StreamWriter output = new StreamWriter(stream);
368
369      Parameter param = model.Parameter;
370
371      output.Write("svm_type " + param.SvmType + Environment.NewLine);
372      output.Write("kernel_type " + param.KernelType + Environment.NewLine);
373
374      if (param.KernelType == KernelType.POLY)
375        output.Write("degree " + param.Degree.ToString("r") + Environment.NewLine);
376
377      if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.RBF || param.KernelType == KernelType.SIGMOID)
378        output.Write("gamma " + param.Gamma.ToString("r") + Environment.NewLine);
379
380      if (param.KernelType == KernelType.POLY || param.KernelType == KernelType.SIGMOID)
381        output.Write("coef0 " + param.Coefficient0.ToString("r") + Environment.NewLine);
382
383      int nr_class = model.NumberOfClasses;
384      int l = model.SupportVectorCount;
385      output.Write("nr_class " + nr_class + Environment.NewLine);
386      output.Write("total_sv " + l + Environment.NewLine);
387
388      {
389        output.Write("rho");
390        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
391          output.Write(" " + model.Rho[i].ToString("r"));
392        output.Write(Environment.NewLine);
393      }
394
395      if (model.ClassLabels != null) {
396        output.Write("label");
397        for (int i = 0; i < nr_class; i++)
398          output.Write(" " + model.ClassLabels[i]);
399        output.Write(Environment.NewLine);
400      }
401
402      if (model.PairwiseProbabilityA != null)
403      // regression has probA only
404            {
405        output.Write("probA");
406        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
407          output.Write(" " + model.PairwiseProbabilityA[i].ToString("r"));
408        output.Write(Environment.NewLine);
409      }
410      if (model.PairwiseProbabilityB != null) {
411        output.Write("probB");
412        for (int i = 0; i < nr_class * (nr_class - 1) / 2; i++)
413          output.Write(" " + model.PairwiseProbabilityB[i].ToString("r"));
414        output.Write(Environment.NewLine);
415      }
416
417      if (model.NumberOfSVPerClass != null) {
418        output.Write("nr_sv");
419        for (int i = 0; i < nr_class; i++)
420          output.Write(" " + model.NumberOfSVPerClass[i].ToString("r"));
421        output.Write(Environment.NewLine);
422      }
423
424      output.WriteLine("SV");
425      double[][] sv_coef = model.SupportVectorCoefficients;
426      Node[][] SV = model.SupportVectors;
427
428      for (int i = 0; i < l; i++) {
429        for (int j = 0; j < nr_class - 1; j++)
430          output.Write(sv_coef[j][i].ToString("r") + " ");
431
432        Node[] p = SV[i];
433        if (p.Length == 0) {
434          output.WriteLine();
435          continue;
436        }
437        if (param.KernelType == KernelType.PRECOMPUTED)
438          output.Write("0:{0}", (int)p[0].Value);
439        else {
440          output.Write("{0}:{1}", p[0].Index, p[0].Value.ToString("r"));
441          for (int j = 1; j < p.Length; j++)
442            output.Write(" {0}:{1}", p[j].Index, p[j].Value.ToString("r"));
443        }
444        output.WriteLine();
445      }
446
447      output.Flush();
448
449      TemporaryCulture.Stop();
450    }
451  }
452}
Note: See TracBrowser for help on using the repository browser.