Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MPI/HeuristicLab.Algorithms.DataAnalysis/3.3/SupportVectorMachine.cs @ 6357

Last change on this file since 6357 was 5809, checked in by mkommend, 14 years ago

#1418: Reintegrated branch into trunk.

File size: 10.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Optimization;
27using HeuristicLab.Parameters;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.PluginInfrastructure;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Problems.DataAnalysis.Evaluators;
32using HeuristicLab.Problems.DataAnalysis.Regression.SupportVectorRegression;
33using HeuristicLab.Problems.DataAnalysis.SupportVectorMachine;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  /// <summary>
37  /// A support vector machine.
38  /// </summary>
39  [NonDiscoverableType]
40  [Item("Support Vector Machine", "Support vector machine data analysis algorithm.")]
41  [StorableClass]
42  public sealed class SupportVectorMachine : EngineAlgorithm, IStorableContent {
43    private const string TrainingSamplesStartParameterName = "Training start";
44    private const string TrainingSamplesEndParameterName = "Training end";
45    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
46    private const string SvmTypeParameterName = "SvmType";
47    private const string KernelTypeParameterName = "KernelType";
48    private const string CostParameterName = "Cost";
49    private const string NuParameterName = "Nu";
50    private const string GammaParameterName = "Gamma";
51    private const string EpsilonParameterName = "Epsilon";
52    private const string ModelParameterName = "SupportVectorMachineModel";
53
54    public string Filename { get; set; }
55
56    #region Problem Properties
57    public override Type ProblemType {
58      get { return typeof(DataAnalysisProblem); }
59    }
60    public new DataAnalysisProblem Problem {
61      get { return (DataAnalysisProblem)base.Problem; }
62      set { base.Problem = value; }
63    }
64    #endregion
65
66    #region parameter properties
67    public IValueParameter<IntValue> TrainingSamplesStartParameter {
68      get { return (IValueParameter<IntValue>)Parameters[TrainingSamplesStartParameterName]; }
69    }
70    public IValueParameter<IntValue> TrainingSamplesEndParameter {
71      get { return (IValueParameter<IntValue>)Parameters[TrainingSamplesEndParameterName]; }
72    }
73    public IValueParameter<StringValue> SvmTypeParameter {
74      get { return (IValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
75    }
76    public IValueParameter<StringValue> KernelTypeParameter {
77      get { return (IValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
78    }
79    public IValueParameter<DoubleValue> NuParameter {
80      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
81    }
82    public IValueParameter<DoubleValue> CostParameter {
83      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
84    }
85    public IValueParameter<DoubleValue> GammaParameter {
86      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
87    }
88    public IValueParameter<DoubleValue> EpsilonParameter {
89      get { return (IValueParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
90    }
91    #endregion
92
93    [Storable]
94    private SupportVectorMachineModelCreator solutionCreator;
95    [Storable]
96    private SupportVectorMachineModelEvaluator evaluator;
97    [Storable]
98    private SimpleMSEEvaluator mseEvaluator;
99    [Storable]
100    private BestSupportVectorRegressionSolutionAnalyzer analyzer;
101    public SupportVectorMachine()
102      : base() {
103      #region svm types
104      StringValue cSvcType = new StringValue("C_SVC").AsReadOnly();
105      StringValue nuSvcType = new StringValue("NU_SVC").AsReadOnly();
106      StringValue eSvrType = new StringValue("EPSILON_SVR").AsReadOnly();
107      StringValue nuSvrType = new StringValue("NU_SVR").AsReadOnly();
108      ItemSet<StringValue> allowedSvmTypes = new ItemSet<StringValue>();
109      allowedSvmTypes.Add(cSvcType);
110      allowedSvmTypes.Add(nuSvcType);
111      allowedSvmTypes.Add(eSvrType);
112      allowedSvmTypes.Add(nuSvrType);
113      #endregion
114      #region kernel types
115      StringValue rbfKernelType = new StringValue("RBF").AsReadOnly();
116      StringValue linearKernelType = new StringValue("LINEAR").AsReadOnly();
117      StringValue polynomialKernelType = new StringValue("POLY").AsReadOnly();
118      StringValue sigmoidKernelType = new StringValue("SIGMOID").AsReadOnly();
119      ItemSet<StringValue> allowedKernelTypes = new ItemSet<StringValue>();
120      allowedKernelTypes.Add(rbfKernelType);
121      allowedKernelTypes.Add(linearKernelType);
122      allowedKernelTypes.Add(polynomialKernelType);
123      allowedKernelTypes.Add(sigmoidKernelType);
124      #endregion
125      Parameters.Add(new ValueParameter<IntValue>(TrainingSamplesStartParameterName, "The first index of the data set partition to use for training."));
126      Parameters.Add(new ValueParameter<IntValue>(TrainingSamplesEndParameterName, "The last index of the data set partition to use for training."));
127      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", allowedSvmTypes, nuSvrType));
128      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", allowedKernelTypes, rbfKernelType));
129      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR.", new DoubleValue(0.5)));
130      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR.", new DoubleValue(1.0)));
131      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
132      Parameters.Add(new ValueLookupParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter (only for epsilon-SVR).", new DoubleValue(1.0)));
133
134      solutionCreator = new SupportVectorMachineModelCreator();
135      evaluator = new SupportVectorMachineModelEvaluator();
136      mseEvaluator = new SimpleMSEEvaluator();
137      analyzer = new BestSupportVectorRegressionSolutionAnalyzer();
138
139      OperatorGraph.InitialOperator = solutionCreator;
140      solutionCreator.Successor = evaluator;
141      evaluator.Successor = mseEvaluator;
142      mseEvaluator.Successor = analyzer;
143
144      Initialize();
145    }
146    [StorableConstructor]
147    private SupportVectorMachine(bool deserializing) : base(deserializing) { }
148    [StorableHook(HookType.AfterDeserialization)]
149    private void AfterDeserialization() {
150      Initialize();
151    }
152
153    private SupportVectorMachine(SupportVectorMachine original, Cloner cloner)
154      : base(original, cloner) {
155      solutionCreator = cloner.Clone(original.solutionCreator);
156      evaluator = cloner.Clone(original.evaluator);
157      mseEvaluator = cloner.Clone(original.mseEvaluator);
158      analyzer = cloner.Clone(original.analyzer);
159      Initialize();
160    }
161    public override IDeepCloneable Clone(Cloner cloner) {
162      return new SupportVectorMachine(this, cloner);
163    }
164
165    public override void Prepare() {
166      if (Problem != null) base.Prepare();
167    }
168
169    protected override void Problem_Reset(object sender, EventArgs e) {
170      UpdateAlgorithmParameters();
171      base.Problem_Reset(sender, e);
172    }
173
174    #region Events
175    protected override void OnProblemChanged() {
176      solutionCreator.DataAnalysisProblemDataParameter.ActualName = Problem.DataAnalysisProblemDataParameter.Name;
177      evaluator.DataAnalysisProblemDataParameter.ActualName = Problem.DataAnalysisProblemDataParameter.Name;
178      analyzer.ProblemDataParameter.ActualName = Problem.DataAnalysisProblemDataParameter.Name;
179      UpdateAlgorithmParameters();
180      Problem.Reset += new EventHandler(Problem_Reset);
181      base.OnProblemChanged();
182    }
183
184    #endregion
185
186    #region Helpers
187    private void Initialize() {
188      solutionCreator.SvmTypeParameter.ActualName = SvmTypeParameter.Name;
189      solutionCreator.KernelTypeParameter.ActualName = KernelTypeParameter.Name;
190      solutionCreator.CostParameter.ActualName = CostParameter.Name;
191      solutionCreator.GammaParameter.ActualName = GammaParameter.Name;
192      solutionCreator.NuParameter.ActualName = NuParameter.Name;
193      solutionCreator.SamplesStartParameter.ActualName = TrainingSamplesStartParameter.Name;
194      solutionCreator.SamplesEndParameter.ActualName = TrainingSamplesEndParameter.Name;
195
196      evaluator.SamplesStartParameter.ActualName = TrainingSamplesStartParameter.Name;
197      evaluator.SamplesEndParameter.ActualName = TrainingSamplesEndParameter.Name;
198      evaluator.SupportVectorMachineModelParameter.ActualName = solutionCreator.SupportVectorMachineModelParameter.ActualName;
199      evaluator.ValuesParameter.ActualName = "Training values";
200
201      mseEvaluator.ValuesParameter.ActualName = "Training values";
202      mseEvaluator.MeanSquaredErrorParameter.ActualName = "Training MSE";
203
204      analyzer.SupportVectorRegressionModelParameter.ActualName = solutionCreator.SupportVectorMachineModelParameter.ActualName;
205      analyzer.SupportVectorRegressionModelParameter.Depth = 0;
206      analyzer.QualityParameter.ActualName = mseEvaluator.MeanSquaredErrorParameter.ActualName;
207      analyzer.QualityParameter.Depth = 0;
208
209      if (Problem != null) {
210        solutionCreator.DataAnalysisProblemDataParameter.ActualName = Problem.DataAnalysisProblemDataParameter.Name;
211        evaluator.DataAnalysisProblemDataParameter.ActualName = Problem.DataAnalysisProblemDataParameter.Name;
212        analyzer.ProblemDataParameter.ActualName = Problem.DataAnalysisProblemDataParameter.Name;
213        Problem.Reset += new EventHandler(Problem_Reset);
214      }
215    }
216
217    private void UpdateAlgorithmParameters() {
218      TrainingSamplesStartParameter.ActualValue = Problem.DataAnalysisProblemData.TrainingSamplesStart;
219      TrainingSamplesEndParameter.ActualValue = Problem.DataAnalysisProblemData.TrainingSamplesEnd;
220    }
221    #endregion
222  }
223}
Note: See TracBrowser for help on using the repository browser.