/* * SVM.NET Library * Copyright (C) 2008 Matthew Johnson * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ using System; using System.IO; using System.Diagnostics; namespace SVM { /// /// Class containing the routines to perform class membership prediction using a trained SVM. /// public static class Prediction { /// /// Predicts the class memberships of all the vectors in the problem. /// /// The SVM Problem to solve /// File for result output /// The Model to use /// Whether to output a distribution over the classes /// Percentage correctly labelled public static double Predict( Problem problem, string outputFile, Model model, bool predict_probability) { int correct = 0; int total = 0; double error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; StreamWriter output = outputFile != null ? new StreamWriter(outputFile) : null; SvmType svm_type = Procedures.svm_get_svm_type(model); int nr_class = Procedures.svm_get_nr_class(model); int[] labels = new int[nr_class]; double[] prob_estimates = null; if (predict_probability) { if (svm_type == SvmType.EPSILON_SVR || svm_type == SvmType.NU_SVR) { Console.WriteLine("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=" + Procedures.svm_get_svr_probability(model)); } else { Procedures.svm_get_labels(model, labels); prob_estimates = new double[nr_class]; if (output != null) { output.Write("labels"); for (int j = 0; j < nr_class; j++) { output.Write(" " + labels[j]); } output.Write("\n"); } } } for (int i = 0; i < problem.Count; i++) { double target = problem.Y[i]; Node[] x = problem.X[i]; double v; if (predict_probability && (svm_type == SvmType.C_SVC || svm_type == SvmType.NU_SVC)) { v = Procedures.svm_predict_probability(model, x, prob_estimates); if (output != null) { output.Write(v + " "); for (int j = 0; j < nr_class; j++) { output.Write(prob_estimates[j] + " "); } output.Write("\n"); } } else { v = Procedures.svm_predict(model, x); if(output != null) output.Write(v + "\n"); } if (v == target) ++correct; error += (v - target) * (v - target); sumv += v; sumy += target; sumvv += v * v; sumyy += target * target; sumvy += v * target; ++total; } if(output != null) output.Close(); return (double)correct / total; } /// /// Predict the class for a single input vector. /// /// The Model to use for prediction /// The vector for which to predict class /// The result public static double Predict(Model model, Node[] x) { return Procedures.svm_predict(model, x); } /// /// Predicts a class distribution for the single input vector. /// /// Model to use for prediction /// The vector for which to predict the class distribution /// A probability distribtion over classes public static double[] PredictProbability(Model model, Node[] x) { SvmType svm_type = Procedures.svm_get_svm_type(model); if (svm_type != SvmType.C_SVC && svm_type != SvmType.NU_SVC) throw new Exception("Model type " + svm_type + " unable to predict probabilities."); int nr_class = Procedures.svm_get_nr_class(model); double[] probEstimates = new double[nr_class]; Procedures.svm_predict_probability(model, x, probEstimates); return probEstimates; } private static void exit_with_help() { Debug.Write("usage: svm_predict [options] test_file model_file output_file\n" + "options:\n" + "-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"); Environment.Exit(1); } /// /// Legacy method, provided to allow usage as though this were the command line version of libsvm. /// /// Standard arguments passed to the svm_predict exectutable. See libsvm documentation for details. [Obsolete("Use the other version of Predict() instead")] public static void Predict(params string[] args) { int i = 0; bool predictProbability = false; // parse options for (i = 0; i < args.Length; i++) { if (args[i][0] != '-') break; ++i; switch (args[i - 1][1]) { case 'b': predictProbability = int.Parse(args[i]) == 1; break; default: throw new ArgumentException("Unknown option"); } } if (i >= args.Length) throw new ArgumentException("No input, model and output files provided"); Problem problem = Problem.Read(args[i]); Model model = Model.Read(args[i + 1]); Predict(problem, args[i + 2], model, predictProbability); } } }