Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorRegression.cs @ 5649

Last change on this file since 5649 was 5649, checked in by gkronber, 13 years ago

#1418 Implemented classes for classification based on a discriminant function and thresholds and implemented interfaces and base classes for clustering.

File size: 7.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Optimization;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using System.Collections.Generic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  /// <summary>
37  /// Support vector machine regression data analysis algorithm.
38  /// </summary>
39  [Item("Support Vector Regression", "Support vector machine regression data analysis algorithm.")]
40  [Creatable("Data Analysis")]
41  [StorableClass]
42  public sealed class SupportVectorRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
43    private const string SvmTypeParameterName = "SvmType";
44    private const string KernelTypeParameterName = "KernelType";
45    private const string CostParameterName = "Cost";
46    private const string NuParameterName = "Nu";
47    private const string GammaParameterName = "Gamma";
48    private const string EpsilonParameterName = "Epsilon";
49
50    #region parameter properties
51    public IValueParameter<StringValue> SvmTypeParameter {
52      get { return (IValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
53    }
54    public IValueParameter<StringValue> KernelTypeParameter {
55      get { return (IValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
56    }
57    public IValueParameter<DoubleValue> NuParameter {
58      get { return (IValueParameter<DoubleValue>)Parameters[NuParameterName]; }
59    }
60    public IValueParameter<DoubleValue> CostParameter {
61      get { return (IValueParameter<DoubleValue>)Parameters[CostParameterName]; }
62    }
63    public IValueParameter<DoubleValue> GammaParameter {
64      get { return (IValueParameter<DoubleValue>)Parameters[GammaParameterName]; }
65    }
66    public IValueParameter<DoubleValue> EpsilonParameter {
67      get { return (IValueParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
68    }
69    #endregion
70    #region properties
71    public StringValue SvmType {
72      get { return SvmTypeParameter.Value; }
73    }
74    public StringValue KernelType {
75      get { return KernelTypeParameter.Value; }
76    }
77    public DoubleValue Nu {
78      get { return NuParameter.Value; }
79    }
80    public DoubleValue Cost {
81      get { return CostParameter.Value; }
82    }
83    public DoubleValue Gamma {
84      get { return GammaParameter.Value; }
85    }
86    public DoubleValue Epsilon {
87      get { return EpsilonParameter.Value; }
88    }
89    #endregion
90    [StorableConstructor]
91    private SupportVectorRegression(bool deserializing) : base(deserializing) { }
92    private SupportVectorRegression(SupportVectorRegression original, Cloner cloner)
93      : base(original, cloner) {
94    }
95    public SupportVectorRegression()
96      : base() {
97      Problem = new RegressionProblem();
98
99      List<StringValue> svrTypes = (from type in new List<string> { "NU_SVR", "EPSILON_SVR" }
100                                    select new StringValue(type).AsReadOnly())
101                                   .ToList();
102      ItemSet<StringValue> svrTypeSet = new ItemSet<StringValue>(svrTypes);
103      List<StringValue> kernelTypes = (from type in new List<string> { "LINEAR", "POLY", "SIGMOID", "RBF" }
104                                       select new StringValue(type).AsReadOnly())
105                                   .ToList();
106      ItemSet<StringValue> kernelTypeSet = new ItemSet<StringValue>(kernelTypes);
107      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", svrTypeSet, svrTypes[0]));
108      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", kernelTypeSet, kernelTypes[3]));
109      Parameters.Add(new ValueParameter<DoubleValue>(NuParameterName, "The value of the nu parameter of the nu-SVR.", new DoubleValue(0.5)));
110      Parameters.Add(new ValueParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of epsilon-SVR and nu-SVR.", new DoubleValue(1.0)));
111      Parameters.Add(new ValueParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.", new DoubleValue(1.0)));
112      Parameters.Add(new ValueParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR.", new DoubleValue(0.1)));
113    }
114    [StorableHook(HookType.AfterDeserialization)]
115    private void AfterDeserialization() { }
116
117    public override IDeepCloneable Clone(Cloner cloner) {
118      return new SupportVectorRegression(this, cloner);
119    }
120
121    #region support vector regression
122    protected override void Run() {
123      IRegressionProblemData problemData = Problem.ProblemData;
124      IEnumerable<string> selectedInputVariables = problemData.AllowedInputVariables;
125      var solution = CreateSupportVectorRegressionSolution(problemData, selectedInputVariables, SvmType.Value, KernelType.Value, Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value);
126
127      Results.Add(new Result("Support vector regression solution", "The support vector regression solution.", solution));
128    }
129
130    public static SupportVectorRegressionSolution CreateSupportVectorRegressionSolution(IRegressionProblemData problemData, IEnumerable<string> allowedInputVariables,
131      string svmType, string kernelType, double cost, double nu, double gamma, double epsilon) {
132      Dataset dataset = problemData.Dataset;
133      string targetVariable = problemData.TargetVariable;
134      int start = problemData.TrainingPartitionStart.Value;
135      int end = problemData.TrainingPartitionEnd.Value;
136      IEnumerable<int> rows = Enumerable.Range(start, end - start);
137
138      //extract SVM parameters from scope and set them
139      SVM.Parameter parameter = new SVM.Parameter();
140      parameter.SvmType = (SVM.SvmType)Enum.Parse(typeof(SVM.SvmType), svmType, true);
141      parameter.KernelType = (SVM.KernelType)Enum.Parse(typeof(SVM.KernelType), kernelType, true);
142      parameter.C = cost;
143      parameter.Nu = nu;
144      parameter.Gamma = gamma;
145      parameter.P = epsilon;
146      parameter.CacheSize = 500;
147      parameter.Probability = false;
148
149
150      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows);
151      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
152      SVM.Problem scaledProblem = SVM.Scaling.Scale(rangeTransform, problem);
153      var model = new SupportVectorMachineModel(SVM.Training.Train(scaledProblem, parameter), rangeTransform, targetVariable, allowedInputVariables);
154      return new SupportVectorRegressionSolution(model, problemData);
155    }
156    #endregion
157  }
158}
Note: See TracBrowser for help on using the repository browser.