1  /*


2  * SVM.NET Library


3  * Copyright (C) 2008 Matthew Johnson


4  *


5  * This program is free software: you can redistribute it and/or modify


6  * it under the terms of the GNU General Public License as published by


7  * the Free Software Foundation, either version 3 of the License, or


8  * (at your option) any later version.


9  *


10  * This program is distributed in the hope that it will be useful,


11  * but WITHOUT ANY WARRANTY; without even the implied warranty of


12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


13  * GNU General Public License for more details.


14  *


15  * You should have received a copy of the GNU General Public License


16  * along with this program. If not, see <http://www.gnu.org/licenses/>.


17  */


18 


19 


20 


21  using System;


22  using System.IO;


23 


24  namespace SVM {


25  /// <summary>


26  /// Encapsulates an SVM Model.


27  /// </summary>


28  [Serializable]


29  public class Model {


30  private Parameter _parameter;


31  private int _numberOfClasses;


32  private int _supportVectorCount;


33  private Node[][] _supportVectors;


34  private double[][] _supportVectorCoefficients;


35  private double[] _rho;


36  private double[] _pairwiseProbabilityA;


37  private double[] _pairwiseProbabilityB;


38 


39  private int[] _classLabels;


40  private int[] _numberOfSVPerClass;


41 


42  internal Model() {


43  }


44 


45  /// <summary>


46  /// Parameter object.


47  /// </summary>


48  public Parameter Parameter {


49  get {


50  return _parameter;


51  }


52  set {


53  _parameter = value;


54  }


55  }


56 


57  /// <summary>


58  /// Number of classes in the model.


59  /// </summary>


60  public int NumberOfClasses {


61  get {


62  return _numberOfClasses;


63  }


64  set {


65  _numberOfClasses = value;


66  }


67  }


68 


69  /// <summary>


70  /// Total number of support vectors.


71  /// </summary>


72  public int SupportVectorCount {


73  get {


74  return _supportVectorCount;


75  }


76  set {


77  _supportVectorCount = value;


78  }


79  }


80 


81  /// <summary>


82  /// The support vectors.


83  /// </summary>


84  public Node[][] SupportVectors {


85  get {


86  return _supportVectors;


87  }


88  set {


89  _supportVectors = value;


90  }


91  }


92 


93  /// <summary>


94  /// The coefficients for the support vectors.


95  /// </summary>


96  public double[][] SupportVectorCoefficients {


97  get {


98  return _supportVectorCoefficients;


99  }


100  set {


101  _supportVectorCoefficients = value;


102  }


103  }


104 


105  /// <summary>


106  /// Rho values.


107  /// </summary>


108  public double[] Rho {


109  get {


110  return _rho;


111  }


112  set {


113  _rho = value;


114  }


115  }


116 


117  /// <summary>


118  /// First pairwise probability.


119  /// </summary>


120  public double[] PairwiseProbabilityA {


121  get {


122  return _pairwiseProbabilityA;


123  }


124  set {


125  _pairwiseProbabilityA = value;


126  }


127  }


128 


129  /// <summary>


130  /// Second pairwise probability.


131  /// </summary>


132  public double[] PairwiseProbabilityB {


133  get {


134  return _pairwiseProbabilityB;


135  }


136  set {


137  _pairwiseProbabilityB = value;


138  }


139  }


140 


141  // for classification only


142 


143  /// <summary>


144  /// Class labels.


145  /// </summary>


146  public int[] ClassLabels {


147  get {


148  return _classLabels;


149  }


150  set {


151  _classLabels = value;


152  }


153  }


154 


155  /// <summary>


156  /// Number of support vectors per class.


157  /// </summary>


158  public int[] NumberOfSVPerClass {


159  get {


160  return _numberOfSVPerClass;


161  }


162  set {


163  _numberOfSVPerClass = value;


164  }


165  }


166 


167  /// <summary>


168  /// Reads a Model from the provided file.


169  /// </summary>


170  /// <param name="filename">The name of the file containing the Model</param>


171  /// <returns>the Model</returns>


172  public static Model Read(string filename) {


173  FileStream input = File.OpenRead(filename);


174  try {


175  return Read(input);


176  }


177  finally {


178  input.Close();


179  }


180  }


181 


182  public static Model Read(Stream stream) {


183  return Read(new StreamReader(stream));


184  }


185 


186  /// <summary>


187  /// Reads a Model from the provided stream.


188  /// </summary>


189  /// <param name="input">The stream from which to read the Model.</param>


190  /// <returns>the Model</returns>


191  public static Model Read(TextReader input) {


192  TemporaryCulture.Start();


193 


194  // read parameters


195 


196  Model model = new Model();


197  Parameter param = new Parameter();


198  model.Parameter = param;


199  model.Rho = null;


200  model.PairwiseProbabilityA = null;


201  model.PairwiseProbabilityB = null;


202  model.ClassLabels = null;


203  model.NumberOfSVPerClass = null;


204 


205  bool headerFinished = false;


206  while (!headerFinished) {


207  string line = input.ReadLine();


208  string cmd, arg;


209  int splitIndex = line.IndexOf(' ');


210  if (splitIndex >= 0) {


211  cmd = line.Substring(0, splitIndex);


212  arg = line.Substring(splitIndex + 1);


213  } else {


214  cmd = line;


215  arg = "";


216  }


217  // arg = arg.ToLower(); (transforms double NaN or Infinity values to incorrect format [gkronber])


218 


219  int i, n;


220  switch (cmd) {


221  case "svm_type":


222  param.SvmType = (SvmType)Enum.Parse(typeof(SvmType), arg.ToUpper());


223  break;


224 


225  case "kernel_type":


226  param.KernelType = (KernelType)Enum.Parse(typeof(KernelType), arg.ToUpper());


227  break;


228 


229  case "degree":


230  param.Degree = int.Parse(arg);


231  break;


232 


233  case "gamma":


234  param.Gamma = double.Parse(arg);


235  break;


236 


237  case "coef0":


238  param.Coefficient0 = double.Parse(arg);


239  break;


240 


241  case "nr_class":


242  model.NumberOfClasses = int.Parse(arg);


243  break;


244 


245  case "total_sv":


246  model.SupportVectorCount = int.Parse(arg);


247  break;


248 


249  case "rho":


250  n = model.NumberOfClasses * (model.NumberOfClasses  1) / 2;


251  model.Rho = new double[n];


252  string[] rhoParts = arg.Split();


253  for (i = 0; i < n; i++)


254  model.Rho[i] = double.Parse(rhoParts[i]);


255  break;


256 


257  case "label":


258  n = model.NumberOfClasses;


259  model.ClassLabels = new int[n];


260  string[] labelParts = arg.Split();


261  for (i = 0; i < n; i++)


262  model.ClassLabels[i] = int.Parse(labelParts[i]);


263  break;


264 


265  case "probA":


266  n = model.NumberOfClasses * (model.NumberOfClasses  1) / 2;


267  model.PairwiseProbabilityA = new double[n];


268  string[] probAParts = arg.Split();


269  for (i = 0; i < n; i++)


270  model.PairwiseProbabilityA[i] = double.Parse(probAParts[i]);


271  break;


272 


273  case "probB":


274  n = model.NumberOfClasses * (model.NumberOfClasses  1) / 2;


275  model.PairwiseProbabilityB = new double[n];


276  string[] probBParts = arg.Split();


277  for (i = 0; i < n; i++)


278  model.PairwiseProbabilityB[i] = double.Parse(probBParts[i]);


279  break;


280 


281  case "nr_sv":


282  n = model.NumberOfClasses;


283  model.NumberOfSVPerClass = new int[n];


284  string[] nrsvParts = arg.Split();


285  for (i = 0; i < n; i++)


286  model.NumberOfSVPerClass[i] = int.Parse(nrsvParts[i]);


287  break;


288 


289  case "SV":


290  headerFinished = true;


291  break;


292 


293  default:


294  throw new Exception("Unknown text in model file");


295  }


296  }


297 


298  // read sv_coef and SV


299 


300  int m = model.NumberOfClasses  1;


301  int l = model.SupportVectorCount;


302  model.SupportVectorCoefficients = new double[m][];


303  for (int i = 0; i < m; i++) {


304  model.SupportVectorCoefficients[i] = new double[l];


305  }


306  model.SupportVectors = new Node[l][];


307 


308  for (int i = 0; i < l; i++) {


309  string[] parts = input.ReadLine().Trim().Split();


310 


311  for (int k = 0; k < m; k++)


312  model.SupportVectorCoefficients[k][i] = double.Parse(parts[k]);


313  int n = parts.Length  m;


314  model.SupportVectors[i] = new Node[n];


315  for (int j = 0; j < n; j++) {


316  string[] nodeParts = parts[m + j].Split(':');


317  model.SupportVectors[i][j] = new Node();


318  model.SupportVectors[i][j].Index = int.Parse(nodeParts[0]);


319  model.SupportVectors[i][j].Value = double.Parse(nodeParts[1]);


320  }


321  }


322 


323  TemporaryCulture.Stop();


324 


325  return model;


326  }


327 


328  /// <summary>


329  /// Writes a model to the provided filename. This will overwrite any previous data in the file.


330  /// </summary>


331  /// <param name="filename">The desired file</param>


332  /// <param name="model">The Model to write</param>


333  public static void Write(string filename, Model model) {


334  FileStream stream = File.Open(filename, FileMode.Create);


335  try {


336  Write(stream, model);


337  }


338  finally {


339  stream.Close();


340  }


341  }


342 


343  /// <summary>


344  /// Writes a model to the provided stream.


345  /// </summary>


346  /// <param name="stream">The output stream</param>


347  /// <param name="model">The model to write</param>


348  public static void Write(Stream stream, Model model) {


349  TemporaryCulture.Start();


350 


351  StreamWriter output = new StreamWriter(stream);


352 


353  Parameter param = model.Parameter;


354 


355  output.Write("svm_type " + param.SvmType + Environment.NewLine);


356  output.Write("kernel_type " + param.KernelType + Environment.NewLine);


357 


358  if (param.KernelType == KernelType.POLY)


359  output.Write("degree " + param.Degree.ToString("r") + Environment.NewLine);


360 


361  if (param.KernelType == KernelType.POLY  param.KernelType == KernelType.RBF  param.KernelType == KernelType.SIGMOID)


362  output.Write("gamma " + param.Gamma.ToString("r") + Environment.NewLine);


363 


364  if (param.KernelType == KernelType.POLY  param.KernelType == KernelType.SIGMOID)


365  output.Write("coef0 " + param.Coefficient0.ToString("r") + Environment.NewLine);


366 


367  int nr_class = model.NumberOfClasses;


368  int l = model.SupportVectorCount;


369  output.Write("nr_class " + nr_class + Environment.NewLine);


370  output.Write("total_sv " + l + Environment.NewLine);


371 


372  {


373  output.Write("rho");


374  for (int i = 0; i < nr_class * (nr_class  1) / 2; i++)


375  output.Write(" " + model.Rho[i].ToString("r"));


376  output.Write(Environment.NewLine);


377  }


378 


379  if (model.ClassLabels != null) {


380  output.Write("label");


381  for (int i = 0; i < nr_class; i++)


382  output.Write(" " + model.ClassLabels[i]);


383  output.Write(Environment.NewLine);


384  }


385 


386  if (model.PairwiseProbabilityA != null)


387  // regression has probA only


388  {


389  output.Write("probA");


390  for (int i = 0; i < nr_class * (nr_class  1) / 2; i++)


391  output.Write(" " + model.PairwiseProbabilityA[i].ToString("r"));


392  output.Write(Environment.NewLine);


393  }


394  if (model.PairwiseProbabilityB != null) {


395  output.Write("probB");


396  for (int i = 0; i < nr_class * (nr_class  1) / 2; i++)


397  output.Write(" " + model.PairwiseProbabilityB[i].ToString("r"));


398  output.Write(Environment.NewLine);


399  }


400 


401  if (model.NumberOfSVPerClass != null) {


402  output.Write("nr_sv");


403  for (int i = 0; i < nr_class; i++)


404  output.Write(" " + model.NumberOfSVPerClass[i]);


405  output.Write(Environment.NewLine);


406  }


407 


408  output.WriteLine("SV");


409  double[][] sv_coef = model.SupportVectorCoefficients;


410  Node[][] SV = model.SupportVectors;


411 


412  for (int i = 0; i < l; i++) {


413  for (int j = 0; j < nr_class  1; j++)


414  output.Write(sv_coef[j][i].ToString("r") + " ");


415 


416  Node[] p = SV[i];


417  if (p.Length == 0) {


418  output.WriteLine();


419  continue;


420  }


421  if (param.KernelType == KernelType.PRECOMPUTED)


422  output.Write("0:{0}", (int)p[0].Value);


423  else {


424  output.Write("{0}:{1}", p[0].Index, p[0].Value.ToString("r"));


425  for (int j = 1; j < p.Length; j++)


426  output.Write(" {0}:{1}", p[j].Index, p[j].Value.ToString("r"));


427  }


428  output.WriteLine();


429  }


430 


431  output.Flush();


432 


433  TemporaryCulture.Stop();


434  }


435  }


436  } 
