Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModelCreator.cs @ 11987

Last change on this file since 11987 was 5275, checked in by gkronber, 13 years ago

Merged changes from trunk to data analysis exploration branch and added fractional distance metric evaluator. #1142

File size: 9.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Operators;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using SVM;
32
33namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
34  /// <summary>
35  /// Represents an operator that creates a support vector machine model.
36  /// </summary>
37  [StorableClass]
38  [Item("SupportVectorMachineModelCreator", "Represents an operator that creates a support vector machine model.")]
39  public sealed class SupportVectorMachineModelCreator : SingleSuccessorOperator {
40    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
41    private const string SvmTypeParameterName = "SvmType";
42    private const string KernelTypeParameterName = "KernelType";
43    private const string CostParameterName = "Cost";
44    private const string NuParameterName = "Nu";
45    private const string GammaParameterName = "Gamma";
46    private const string EpsilonParameterName = "Epsilon";
47    private const string SamplesStartParameterName = "SamplesStart";
48    private const string SamplesEndParameterName = "SamplesEnd";
49    private const string ModelParameterName = "SupportVectorMachineModel";
50
51    #region parameter properties
52    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
53      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
54    }
55    public IValueLookupParameter<StringValue> SvmTypeParameter {
56      get { return (IValueLookupParameter<StringValue>)Parameters[SvmTypeParameterName]; }
57    }
58    public IValueLookupParameter<StringValue> KernelTypeParameter {
59      get { return (IValueLookupParameter<StringValue>)Parameters[KernelTypeParameterName]; }
60    }
61    public IValueLookupParameter<DoubleValue> NuParameter {
62      get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; }
63    }
64    public IValueLookupParameter<DoubleValue> CostParameter {
65      get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; }
66    }
67    public IValueLookupParameter<DoubleValue> GammaParameter {
68      get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; }
69    }
70    public IValueLookupParameter<DoubleValue> EpsilonParameter {
71      get { return (IValueLookupParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
72    }
73    public IValueLookupParameter<IntValue> SamplesStartParameter {
74      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
75    }
76    public IValueLookupParameter<IntValue> SamplesEndParameter {
77      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
78    }
79    public ILookupParameter<SupportVectorMachineModel> SupportVectorMachineModelParameter {
80      get { return (ILookupParameter<SupportVectorMachineModel>)Parameters[ModelParameterName]; }
81    }
82    #endregion
83    #region properties
84    public DataAnalysisProblemData DataAnalysisProblemData {
85      get { return DataAnalysisProblemDataParameter.ActualValue; }
86    }
87    public StringValue SvmType {
88      get { return SvmTypeParameter.Value; }
89    }
90    public StringValue KernelType {
91      get { return KernelTypeParameter.Value; }
92    }
93    public DoubleValue Nu {
94      get { return NuParameter.ActualValue; }
95    }
96    public DoubleValue Cost {
97      get { return CostParameter.ActualValue; }
98    }
99    public DoubleValue Gamma {
100      get { return GammaParameter.ActualValue; }
101    }
102    public DoubleValue Epsilon {
103      get { return EpsilonParameter.ActualValue; }
104    }
105    public IntValue SamplesStart {
106      get { return SamplesStartParameter.ActualValue; }
107    }
108    public IntValue SamplesEnd {
109      get { return SamplesEndParameter.ActualValue; }
110    }
111    #endregion
112
113    [StorableConstructor]
114    private SupportVectorMachineModelCreator(bool deserializing) : base(deserializing) { }
115    private SupportVectorMachineModelCreator(SupportVectorMachineModelCreator original, Cloner cloner) : base(original, cloner) { }
116    public override IDeepCloneable Clone(Cloner cloner) {
117      return new SupportVectorMachineModelCreator(this, cloner);
118    }
119    public SupportVectorMachineModelCreator()
120      : base() {
121      StringValue nuSvrType = new StringValue("NU_SVR").AsReadOnly();
122      StringValue rbfKernelType = new StringValue("RBF").AsReadOnly();
123      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
124      Parameters.Add(new ValueLookupParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", nuSvrType));
125      Parameters.Add(new ValueLookupParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", rbfKernelType));
126      Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR."));
127      Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR."));
128      Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function."));
129      Parameters.Add(new ValueLookupParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR."));
130      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first index of the data set partition the support vector machine should use for training."));
131      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last index of the data set partition the support vector machine should use for training."));
132      Parameters.Add(new LookupParameter<SupportVectorMachineModel>(ModelParameterName, "The result model generated by the SVM."));
133    }
134
135    public override IOperation Apply() {
136      int start = SamplesStart.Value;
137      int end = SamplesEnd.Value;
138      IEnumerable<int> rows =
139        Enumerable.Range(start, end - start)
140        .Where(i => i < DataAnalysisProblemData.TestSamplesStart.Value || DataAnalysisProblemData.TestSamplesEnd.Value <= i);
141
142      SupportVectorMachineModel model = TrainModel(DataAnalysisProblemData,
143                             rows,
144                             SvmType.Value, KernelType.Value,
145                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value);
146      SupportVectorMachineModelParameter.ActualValue = model;
147
148      return base.Apply();
149    }
150
151    private static SupportVectorMachineModel TrainModel(
152      DataAnalysisProblemData problemData,
153      string svmType, string kernelType,
154      double cost, double nu, double gamma, double epsilon) {
155      return TrainModel(problemData, problemData.TrainingIndizes, svmType, kernelType, cost, nu, gamma, epsilon);
156    }
157
158    public static SupportVectorMachineModel TrainModel(
159      DataAnalysisProblemData problemData,
160      IEnumerable<int> trainingIndizes,
161      string svmType, string kernelType,
162      double cost, double nu, double gamma, double epsilon) {
163      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
164
165      //extract SVM parameters from scope and set them
166      SVM.Parameter parameter = new SVM.Parameter();
167      parameter.SvmType = (SVM.SvmType)Enum.Parse(typeof(SVM.SvmType), svmType, true);
168      parameter.KernelType = (SVM.KernelType)Enum.Parse(typeof(SVM.KernelType), kernelType, true);
169      parameter.C = cost;
170      parameter.Nu = nu;
171      parameter.Gamma = gamma;
172      parameter.P = epsilon;
173      parameter.CacheSize = 500;
174      parameter.Probability = false;
175
176
177      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, trainingIndizes);
178      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
179      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
180      var model = new SupportVectorMachineModel();
181      model.Model = SVM.Training.Train(scaledProblem, parameter);
182      model.RangeTransform = rangeTransform;
183
184      return model;
185    }
186  }
187}
Note: See TracBrowser for help on using the repository browser.