Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineCrossValidationEvaluator.cs @ 3884

Last change on this file since 3884 was 3884, checked in by gkronber, 14 years ago

Worked on support vector regression operators and views. #1009

File size: 10.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using System.Threading;
29using HeuristicLab.LibSVM;
30using HeuristicLab.Operators;
31using HeuristicLab.Parameters;
32using SVM;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
34using HeuristicLab.Optimization;
35
36namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
37  /// <summary>
38  /// Represents an operator that performs SVM cross validation with the given parameters.
39  /// </summary>
40  [StorableClass]
41  [Item("SupportVectorMachineCrossValidationEvaluator", "Represents an operator that performs SVM cross validation with the given parameters.")]
42  public class SupportVectorMachineCrossValidationEvaluator : SingleSuccessorOperator, ISingleObjectiveEvaluator {
43    private const string RandomParameterName = "Random";
44    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
45    private const string SvmTypeParameterName = "SvmType";
46    private const string KernelTypeParameterName = "KernelType";
47    private const string CostParameterName = "Cost";
48    private const string NuParameterName = "Nu";
49    private const string GammaParameterName = "Gamma";
50    private const string EpsilonParameterName = "Epsilon";
51    private const string SamplesStartParameterName = "SamplesStart";
52    private const string SamplesEndParameterName = "SamplesEnd";
53    private const string ActualSamplesParameterName = "ActualSamples";
54    private const string NumberOfFoldsParameterName = "NumberOfFolds";
55    private const string QualityParameterName = "Quality";
56
57    #region parameter properties
58    public ILookupParameter<IRandom> RandomParameter {
59      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
60    }
61    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
62      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
63    }
64    public IValueLookupParameter<StringValue> SvmTypeParameter {
65      get { return (IValueLookupParameter<StringValue>)Parameters[SvmTypeParameterName]; }
66    }
67    public IValueLookupParameter<StringValue> KernelTypeParameter {
68      get { return (IValueLookupParameter<StringValue>)Parameters[KernelTypeParameterName]; }
69    }
70    public IValueLookupParameter<DoubleValue> NuParameter {
71      get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; }
72    }
73    public IValueLookupParameter<DoubleValue> CostParameter {
74      get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; }
75    }
76    public IValueLookupParameter<DoubleValue> GammaParameter {
77      get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; }
78    }
79    public IValueLookupParameter<DoubleValue> EpsilonParameter {
80      get { return (IValueLookupParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
81    }
82    public IValueLookupParameter<IntValue> SamplesStartParameter {
83      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
84    }
85    public IValueLookupParameter<IntValue> SamplesEndParameter {
86      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
87    }
88    public IValueLookupParameter<PercentValue> ActualSamplesParameter {
89      get { return (IValueLookupParameter<PercentValue>)Parameters[ActualSamplesParameterName]; }
90    }
91    public IValueLookupParameter<IntValue> NumberOfFoldsParameter {
92      get { return (IValueLookupParameter<IntValue>)Parameters[NumberOfFoldsParameterName]; }
93    }
94    public ILookupParameter<DoubleValue> QualityParameter {
95      get { return (ILookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
96    }
97    #endregion
98    #region properties
99    public DataAnalysisProblemData DataAnalysisProblemData {
100      get { return DataAnalysisProblemDataParameter.ActualValue; }
101    }
102    public StringValue SvmType {
103      get { return SvmTypeParameter.ActualValue; }
104    }
105    public StringValue KernelType {
106      get { return KernelTypeParameter.ActualValue; }
107    }
108    public DoubleValue Nu {
109      get { return NuParameter.ActualValue; }
110    }
111    public DoubleValue Cost {
112      get { return CostParameter.ActualValue; }
113    }
114    public DoubleValue Gamma {
115      get { return GammaParameter.ActualValue; }
116    }
117    public DoubleValue Epsilon {
118      get { return EpsilonParameter.ActualValue; }
119    }
120    public IntValue SamplesStart {
121      get { return SamplesStartParameter.ActualValue; }
122    }
123    public IntValue SamplesEnd {
124      get { return SamplesEndParameter.ActualValue; }
125    }
126    public IntValue NumberOfFolds {
127      get { return NumberOfFoldsParameter.ActualValue; }
128    }
129    #endregion
130
131    public SupportVectorMachineCrossValidationEvaluator()
132      : base() {
133      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
134      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
135      Parameters.Add(new ValueLookupParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use."));
136      Parameters.Add(new ValueLookupParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM."));
137      Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR."));
138      Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR."));
139      Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function."));
140      Parameters.Add(new ValueLookupParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR."));
141      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first index of the data set partition the support vector machine should use for training."));
142      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last index of the data set partition the support vector machine should use for training."));
143      Parameters.Add(new ValueLookupParameter<PercentValue>(ActualSamplesParameterName, "The percentage of the training set that should be acutally used for cross-validation (samples are picked randomly from the training set)."));
144      Parameters.Add(new ValueLookupParameter<IntValue>(NumberOfFoldsParameterName, "The number of folds to use for cross-validation."));
145      Parameters.Add(new LookupParameter<DoubleValue>(QualityParameterName, "The cross validation quality reached with the given parameters."));
146    }
147
148    public override IOperation Apply() {
149      double reductionRatio = 1.0;
150      if (ActualSamplesParameter.ActualValue != null)
151        reductionRatio = ActualSamplesParameter.ActualValue.Value;
152
153      int reducedRows = (int)((SamplesEnd.Value - SamplesStart.Value) * reductionRatio);
154      var reducedProblemData = (DataAnalysisProblemData)DataAnalysisProblemData.Clone();
155      ShuffleRows(RandomParameter.ActualValue, reducedProblemData.Dataset, SamplesStart.Value, SamplesEnd.Value);
156
157      double quality = PerformCrossValidation(reducedProblemData,
158                             SamplesStart.Value, SamplesStart.Value + reducedRows,
159                             SvmType.Value, KernelType.Value,
160                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, NumberOfFolds.Value);
161
162      QualityParameter.ActualValue = new DoubleValue(quality);
163      return base.Apply();
164    }
165
166    private void ShuffleRows(IRandom random, Dataset dataset, int start, int end) {
167      for (int row = end - 1; row > start ; row--) {
168        int otherRow = random.Next(start, row);
169        for (int column = 0; column < dataset.Columns; column++) {
170          double tmp = dataset[otherRow, column];
171          dataset[otherRow, column] = dataset[row, column];
172          dataset[row, column] = tmp;
173        }
174      }
175    }
176
177    private static double PerformCrossValidation(
178      DataAnalysisProblemData problemData,
179      string svmType, string kernelType,
180      double cost, double nu, double gamma, double epsilon,
181      int nFolds) {
182      return PerformCrossValidation(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
183    }
184
185    public static double PerformCrossValidation(
186      DataAnalysisProblemData problemData,
187      int start, int end,
188      string svmType, string kernelType,
189      double cost, double nu, double gamma, double epsilon,
190      int nFolds) {
191      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
192
193      //extract SVM parameters from scope and set them
194      SVM.Parameter parameter = new SVM.Parameter();
195      parameter.SvmType = (SVM.SvmType)Enum.Parse(typeof(SVM.SvmType), svmType, true);
196      parameter.KernelType = (SVM.KernelType)Enum.Parse(typeof(SVM.KernelType), kernelType, true);
197      parameter.C = cost;
198      parameter.Nu = nu;
199      parameter.Gamma = gamma;
200      parameter.P = epsilon;
201      parameter.CacheSize = 500;
202      parameter.Probability = false;
203
204
205      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, start, end);
206      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
207      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
208
209      return SVM.Training.PerformCrossValidation(scaledProblem, parameter, nFolds, false);
210    }
211  }
212}
Note: See TracBrowser for help on using the repository browser.