Free cookie consent management tool by TermsFeed Policy Generator

source: branches/CEDMA-Exporter-715/sources/LibSVM/Prediction.cs @ 3026

Last change on this file since 3026 was 1819, checked in by mkommend, 16 years ago

created new project for LibSVM source files (ticket #619)

File size: 7.4 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.IO;
22using System.Diagnostics;
23
24namespace SVM
25{
26    /// <remarks>
27    /// Class containing the routines to perform class membership prediction using a trained SVM.
28    /// </remarks>
29    public static class Prediction
30    {
31        /// <summary>
32        /// Predicts the class memberships of all the vectors in the problem.
33        /// </summary>
34        /// <param name="problem">The SVM Problem to solve</param>
35        /// <param name="outputFile">File for result output</param>
36        /// <param name="model">The Model to use</param>
37        /// <param name="predict_probability">Whether to output a distribution over the classes</param>
38        /// <returns>Percentage correctly labelled</returns>
39        public static double Predict(
40            Problem problem,
41            string outputFile,
42            Model model,
43            bool predict_probability)
44        {
45            int correct = 0;
46            int total = 0;
47            double error = 0;
48            double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
49            StreamWriter output = outputFile != null ? new StreamWriter(outputFile) : null;
50
51            SvmType svm_type = Procedures.svm_get_svm_type(model);
52            int nr_class = Procedures.svm_get_nr_class(model);
53            int[] labels = new int[nr_class];
54            double[] prob_estimates = null;
55
56            if (predict_probability)
57            {
58                if (svm_type == SvmType.EPSILON_SVR || svm_type == SvmType.NU_SVR)
59                {
60                    Console.WriteLine("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=" + Procedures.svm_get_svr_probability(model));
61                }
62                else
63                {
64                    Procedures.svm_get_labels(model, labels);
65                    prob_estimates = new double[nr_class];
66                    if (output != null)
67                    {
68                        output.Write("labels");
69                        for (int j = 0; j < nr_class; j++)
70                        {
71                            output.Write(" " + labels[j]);
72                        }
73                        output.Write("\n");
74                    }
75                }
76            }
77            for (int i = 0; i < problem.Count; i++)
78            {
79                double target = problem.Y[i];
80                Node[] x = problem.X[i];
81
82                double v;
83                if (predict_probability && (svm_type == SvmType.C_SVC || svm_type == SvmType.NU_SVC))
84                {
85                    v = Procedures.svm_predict_probability(model, x, prob_estimates);
86                    if (output != null)
87                    {
88                        output.Write(v + " ");
89                        for (int j = 0; j < nr_class; j++)
90                        {
91                            output.Write(prob_estimates[j] + " ");
92                        }
93                        output.Write("\n");
94                    }
95                }
96                else
97                {
98                    v = Procedures.svm_predict(model, x);
99                    if(output != null)
100                        output.Write(v + "\n");
101                }
102
103                if (v == target)
104                    ++correct;
105                error += (v - target) * (v - target);
106                sumv += v;
107                sumy += target;
108                sumvv += v * v;
109                sumyy += target * target;
110                sumvy += v * target;
111                ++total;
112            }
113            if(output != null)
114                output.Close();
115            return (double)correct / total;
116        }
117
118        /// <summary>
119        /// Predict the class for a single input vector.
120        /// </summary>
121        /// <param name="model">The Model to use for prediction</param>
122        /// <param name="x">The vector for which to predict class</param>
123        /// <returns>The result</returns>
124        public static double Predict(Model model, Node[] x)
125        {
126            return Procedures.svm_predict(model, x);
127        }
128
129        /// <summary>
130        /// Predicts a class distribution for the single input vector.
131        /// </summary>
132        /// <param name="model">Model to use for prediction</param>
133        /// <param name="x">The vector for which to predict the class distribution</param>
134        /// <returns>A probability distribtion over classes</returns>
135        public static double[] PredictProbability(Model model, Node[] x)
136        {
137            SvmType svm_type = Procedures.svm_get_svm_type(model);
138            if (svm_type != SvmType.C_SVC && svm_type != SvmType.NU_SVC)
139                throw new Exception("Model type " + svm_type + " unable to predict probabilities.");
140            int nr_class = Procedures.svm_get_nr_class(model);
141            double[] probEstimates = new double[nr_class];
142            Procedures.svm_predict_probability(model, x, probEstimates);
143            return probEstimates;
144        }
145
146        private static void exit_with_help()
147        {
148            Debug.Write("usage: svm_predict [options] test_file model_file output_file\n" + "options:\n" + "-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n");
149            Environment.Exit(1);
150        }
151
152        /// <summary>
153        /// Legacy method, provided to allow usage as though this were the command line version of libsvm.
154        /// </summary>
155        /// <param name="args">Standard arguments passed to the svm_predict exectutable.  See libsvm documentation for details.</param>
156        [Obsolete("Use the other version of Predict() instead")]
157        public static void Predict(params string[] args)
158        {
159            int i = 0;
160            bool predictProbability = false;
161
162            // parse options
163            for (i = 0; i < args.Length; i++)
164            {
165                if (args[i][0] != '-')
166                    break;
167                ++i;
168                switch (args[i - 1][1])
169                {
170
171                    case 'b':
172                        predictProbability = int.Parse(args[i]) == 1;
173                        break;
174
175                    default:
176                        throw new ArgumentException("Unknown option");
177
178                }
179            }
180            if (i >= args.Length)
181                throw new ArgumentException("No input, model and output files provided");
182
183            Problem problem = Problem.Read(args[i]);
184            Model model = Model.Read(args[i + 1]);
185            Predict(problem, args[i + 2], model, predictProbability);
186        }
187    }
188}
Note: See TracBrowser for help on using the repository browser.