Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HiveHiveEngine/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/Prediction.cs @ 7317

Last change on this file since 7317 was 4068, checked in by swagner, 14 years ago

Sorted usings and removed unused usings in entire solution (#1094)

File size: 6.2 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.Diagnostics;
22using System.IO;
23
24namespace SVM {
25  /// <summary>
26  /// Class containing the routines to perform class membership prediction using a trained SVM.
27  /// </summary>
28  public static class Prediction {
29    /// <summary>
30    /// Predicts the class memberships of all the vectors in the problem.
31    /// </summary>
32    /// <param name="problem">The SVM Problem to solve</param>
33    /// <param name="outputFile">File for result output</param>
34    /// <param name="model">The Model to use</param>
35    /// <param name="predict_probability">Whether to output a distribution over the classes</param>
36    /// <returns>Percentage correctly labelled</returns>
37    public static double Predict(
38        Problem problem,
39        string outputFile,
40        Model model,
41        bool predict_probability) {
42      int correct = 0;
43      int total = 0;
44      double error = 0;
45      double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
46      StreamWriter output = outputFile != null ? new StreamWriter(outputFile) : null;
47
48      SvmType svm_type = Procedures.svm_get_svm_type(model);
49      int nr_class = Procedures.svm_get_nr_class(model);
50      int[] labels = new int[nr_class];
51      double[] prob_estimates = null;
52
53      if (predict_probability) {
54        if (svm_type == SvmType.EPSILON_SVR || svm_type == SvmType.NU_SVR) {
55          Console.WriteLine("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=" + Procedures.svm_get_svr_probability(model));
56        } else {
57          Procedures.svm_get_labels(model, labels);
58          prob_estimates = new double[nr_class];
59          if (output != null) {
60            output.Write("labels");
61            for (int j = 0; j < nr_class; j++) {
62              output.Write(" " + labels[j]);
63            }
64            output.Write("\n");
65          }
66        }
67      }
68      for (int i = 0; i < problem.Count; i++) {
69        double target = problem.Y[i];
70        Node[] x = problem.X[i];
71
72        double v;
73        if (predict_probability && (svm_type == SvmType.C_SVC || svm_type == SvmType.NU_SVC)) {
74          v = Procedures.svm_predict_probability(model, x, prob_estimates);
75          if (output != null) {
76            output.Write(v + " ");
77            for (int j = 0; j < nr_class; j++) {
78              output.Write(prob_estimates[j] + " ");
79            }
80            output.Write("\n");
81          }
82        } else {
83          v = Procedures.svm_predict(model, x);
84          if (output != null)
85            output.Write(v + "\n");
86        }
87
88        if (v == target)
89          ++correct;
90        error += (v - target) * (v - target);
91        sumv += v;
92        sumy += target;
93        sumvv += v * v;
94        sumyy += target * target;
95        sumvy += v * target;
96        ++total;
97      }
98      if (output != null)
99        output.Close();
100      return (double)correct / total;
101    }
102
103    /// <summary>
104    /// Predict the class for a single input vector.
105    /// </summary>
106    /// <param name="model">The Model to use for prediction</param>
107    /// <param name="x">The vector for which to predict class</param>
108    /// <returns>The result</returns>
109    public static double Predict(Model model, Node[] x) {
110      return Procedures.svm_predict(model, x);
111    }
112
113    /// <summary>
114    /// Predicts a class distribution for the single input vector.
115    /// </summary>
116    /// <param name="model">Model to use for prediction</param>
117    /// <param name="x">The vector for which to predict the class distribution</param>
118    /// <returns>A probability distribtion over classes</returns>
119    public static double[] PredictProbability(Model model, Node[] x) {
120      SvmType svm_type = Procedures.svm_get_svm_type(model);
121      if (svm_type != SvmType.C_SVC && svm_type != SvmType.NU_SVC)
122        throw new Exception("Model type " + svm_type + " unable to predict probabilities.");
123      int nr_class = Procedures.svm_get_nr_class(model);
124      double[] probEstimates = new double[nr_class];
125      Procedures.svm_predict_probability(model, x, probEstimates);
126      return probEstimates;
127    }
128
129    private static void exit_with_help() {
130      Debug.Write("usage: svm_predict [options] test_file model_file output_file\n" + "options:\n" + "-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n");
131      Environment.Exit(1);
132    }
133
134    /// <summary>
135    /// Legacy method, provided to allow usage as though this were the command line version of libsvm.
136    /// </summary>
137    /// <param name="args">Standard arguments passed to the svm_predict exectutable.  See libsvm documentation for details.</param>
138    [Obsolete("Use the other version of Predict() instead")]
139    public static void Predict(params string[] args) {
140      int i = 0;
141      bool predictProbability = false;
142
143      // parse options
144      for (i = 0; i < args.Length; i++) {
145        if (args[i][0] != '-')
146          break;
147        ++i;
148        switch (args[i - 1][1]) {
149
150          case 'b':
151            predictProbability = int.Parse(args[i]) == 1;
152            break;
153
154          default:
155            throw new ArgumentException("Unknown option");
156
157        }
158      }
159      if (i >= args.Length)
160        throw new ArgumentException("No input, model and output files provided");
161
162      Problem problem = Problem.Read(args[i]);
163      Model model = Model.Read(args[i + 1]);
164      Predict(problem, args[i + 2], model, predictProbability);
165    }
166  }
167}
Note: See TracBrowser for help on using the repository browser.