Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.SupportVectorMachines/3.2/Predictor.cs @ 2290

Last change on this file since 2290 was 2290, checked in by gkronber, 15 years ago
  • introduced a variablename to index mapping for SVM models (to make sure we can use the model for prediction in the model analyzer)
  • added support to enable and disable algorithms in the dispatcher and removed DispatcherBase
  • fixed bugs when calculating variable impacts and reading the final model of GP algorithms

#722 (IModel should provide a Predict() method to get predicted values for an input vector)

File size: 4.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Text;
25using System.Xml;
26using HeuristicLab.Core;
27using System.Globalization;
28using System.IO;
29using HeuristicLab.Modeling;
30using SVM;
31using HeuristicLab.DataAnalysis;
32
33namespace HeuristicLab.SupportVectorMachines {
34  public class Predictor : ItemBase, IPredictor {
35    private SVMModel svmModel;
36    private Dictionary<string, int> variableNames = new Dictionary<string, int>();
37    private string targetVariable;
38
39    public Predictor() : base() { } // for persistence
40
41    public Predictor(SVMModel model, string targetVariable, Dictionary<string, int> variableNames)
42      : base() {
43      this.svmModel = model;
44      this.targetVariable = targetVariable;
45      this.variableNames = variableNames;
46    }
47
48    public double[] Predict(Dataset input, int start, int end) {
49      if (start < 0 || end <= start) throw new ArgumentException("start must be larger than zero and strictly smaller than end");
50      if (end > input.Rows) throw new ArgumentOutOfRangeException("number of rows in input is smaller then end");
51      RangeTransform transform = svmModel.RangeTransform;
52      Model model = svmModel.Model;
53      // maps columns of the current input dataset to the columns that were originally used in training
54      Dictionary<int, int> newIndex = new Dictionary<int, int>();
55      foreach (var pair in variableNames) {
56        newIndex[input.GetVariableIndex(pair.Key)] = pair.Value;
57      }
58
59      Problem p = SVMHelper.CreateSVMProblem(input, input.GetVariableIndex(targetVariable), newIndex, start, end);
60      Problem scaledProblem = SVM.Scaling.Scale(p, transform);
61
62      int rows = end - start;
63      int columns = input.Columns;
64      double[] result = new double[rows];
65      for (int row = 0; row < rows; row++) {
66        result[row] = SVM.Prediction.Predict(model, scaledProblem.X[row]);
67      }
68      return result;
69    }
70
71    public override IView CreateView() {
72      return svmModel.CreateView();
73    }
74
75    public override object Clone(IDictionary<Guid, object> clonedObjects) {
76      Predictor clone = (Predictor)base.Clone(clonedObjects);
77      clone.svmModel = (SVMModel)Auxiliary.Clone(svmModel, clonedObjects);
78      clone.targetVariable = targetVariable;
79      clone.variableNames = new Dictionary<string, int>(variableNames);
80      return clone;
81    }
82
83    public override XmlNode GetXmlNode(string name, XmlDocument document, IDictionary<Guid, IStorable> persistedObjects) {
84      XmlNode node = base.GetXmlNode(name, document, persistedObjects);
85      XmlAttribute targetVarAttr = document.CreateAttribute("TargetVariable");
86      targetVarAttr.Value = targetVariable;
87      node.Attributes.Append(targetVarAttr);
88      node.AppendChild(PersistenceManager.Persist(svmModel, document, persistedObjects));
89      XmlNode variablesNode = document.CreateElement("Variables");
90      foreach (var pair in variableNames) {
91        XmlNode pairNode = document.CreateElement("Variable");
92        XmlAttribute nameAttr = document.CreateAttribute("Name");
93        XmlAttribute indexAttr = document.CreateAttribute("Index");
94        nameAttr.Value = pair.Key;
95        indexAttr.Value = XmlConvert.ToString(pair.Value);
96        pairNode.Attributes.Append(nameAttr);
97        pairNode.Attributes.Append(indexAttr);
98        variablesNode.AppendChild(pairNode);
99      }
100      node.AppendChild(variablesNode);
101      return node;
102    }
103
104    public override void Populate(XmlNode node, IDictionary<Guid, IStorable> restoredObjects) {
105      base.Populate(node, restoredObjects);
106      targetVariable = node.Attributes["TargetVariable"].Value;
107      svmModel = (SVMModel)PersistenceManager.Restore(node.ChildNodes[0], restoredObjects);
108
109      variableNames = new Dictionary<string, int>();
110      XmlNode variablesNode = node.ChildNodes[1];
111      foreach (XmlNode pairNode in variablesNode.ChildNodes) {
112        variableNames[pairNode.Attributes["Name"].Value] = XmlConvert.ToInt32(pairNode.Attributes["Index"].Value);
113      }
114    }
115  }
116}
Note: See TracBrowser for help on using the repository browser.