Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineCrossValidationEvaluator.cs @ 3933

Last change on this file since 3933 was 3933, checked in by mkommend, 14 years ago

removed cloning of dataset and made it readonly (ticket #938)

File size: 10.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using System.Threading;
29using HeuristicLab.LibSVM;
30using HeuristicLab.Operators;
31using HeuristicLab.Parameters;
32using SVM;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
34using HeuristicLab.Optimization;
35
36namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
37  /// <summary>
38  /// Represents an operator that performs SVM cross validation with the given parameters.
39  /// </summary>
40  [StorableClass]
41  [Item("SupportVectorMachineCrossValidationEvaluator", "Represents an operator that performs SVM cross validation with the given parameters.")]
42  public class SupportVectorMachineCrossValidationEvaluator : SingleSuccessorOperator, ISingleObjectiveEvaluator {
43    private const string RandomParameterName = "Random";
44    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
45    private const string SvmTypeParameterName = "SvmType";
46    private const string KernelTypeParameterName = "KernelType";
47    private const string CostParameterName = "Cost";
48    private const string NuParameterName = "Nu";
49    private const string GammaParameterName = "Gamma";
50    private const string EpsilonParameterName = "Epsilon";
51    private const string SamplesStartParameterName = "SamplesStart";
52    private const string SamplesEndParameterName = "SamplesEnd";
53    private const string ActualSamplesParameterName = "ActualSamples";
54    private const string NumberOfFoldsParameterName = "NumberOfFolds";
55    private const string QualityParameterName = "Quality";
56
57    #region parameter properties
58    public ILookupParameter<IRandom> RandomParameter {
59      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
60    }
61    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
62      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
63    }
64    public IValueLookupParameter<StringValue> SvmTypeParameter {
65      get { return (IValueLookupParameter<StringValue>)Parameters[SvmTypeParameterName]; }
66    }
67    public IValueLookupParameter<StringValue> KernelTypeParameter {
68      get { return (IValueLookupParameter<StringValue>)Parameters[KernelTypeParameterName]; }
69    }
70    public IValueLookupParameter<DoubleValue> NuParameter {
71      get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; }
72    }
73    public IValueLookupParameter<DoubleValue> CostParameter {
74      get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; }
75    }
76    public IValueLookupParameter<DoubleValue> GammaParameter {
77      get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; }
78    }
79    public IValueLookupParameter<DoubleValue> EpsilonParameter {
80      get { return (IValueLookupParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
81    }
82    public IValueLookupParameter<IntValue> SamplesStartParameter {
83      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
84    }
85    public IValueLookupParameter<IntValue> SamplesEndParameter {
86      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
87    }
88    public IValueLookupParameter<PercentValue> ActualSamplesParameter {
89      get { return (IValueLookupParameter<PercentValue>)Parameters[ActualSamplesParameterName]; }
90    }
91    public IValueLookupParameter<IntValue> NumberOfFoldsParameter {
92      get { return (IValueLookupParameter<IntValue>)Parameters[NumberOfFoldsParameterName]; }
93    }
94    public ILookupParameter<DoubleValue> QualityParameter {
95      get { return (ILookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
96    }
97    #endregion
98    #region properties
99    public DataAnalysisProblemData DataAnalysisProblemData {
100      get { return DataAnalysisProblemDataParameter.ActualValue; }
101    }
102    public StringValue SvmType {
103      get { return SvmTypeParameter.ActualValue; }
104    }
105    public StringValue KernelType {
106      get { return KernelTypeParameter.ActualValue; }
107    }
108    public DoubleValue Nu {
109      get { return NuParameter.ActualValue; }
110    }
111    public DoubleValue Cost {
112      get { return CostParameter.ActualValue; }
113    }
114    public DoubleValue Gamma {
115      get { return GammaParameter.ActualValue; }
116    }
117    public DoubleValue Epsilon {
118      get { return EpsilonParameter.ActualValue; }
119    }
120    public IntValue SamplesStart {
121      get { return SamplesStartParameter.ActualValue; }
122    }
123    public IntValue SamplesEnd {
124      get { return SamplesEndParameter.ActualValue; }
125    }
126    public IntValue NumberOfFolds {
127      get { return NumberOfFoldsParameter.ActualValue; }
128    }
129    #endregion
130
131    public SupportVectorMachineCrossValidationEvaluator()
132      : base() {
133      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
134      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
135      Parameters.Add(new ValueLookupParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use."));
136      Parameters.Add(new ValueLookupParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM."));
137      Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR."));
138      Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR."));
139      Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function."));
140      Parameters.Add(new ValueLookupParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR."));
141      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first index of the data set partition the support vector machine should use for training."));
142      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last index of the data set partition the support vector machine should use for training."));
143      Parameters.Add(new ValueLookupParameter<PercentValue>(ActualSamplesParameterName, "The percentage of the training set that should be acutally used for cross-validation (samples are picked randomly from the training set)."));
144      Parameters.Add(new ValueLookupParameter<IntValue>(NumberOfFoldsParameterName, "The number of folds to use for cross-validation."));
145      Parameters.Add(new LookupParameter<DoubleValue>(QualityParameterName, "The cross validation quality reached with the given parameters."));
146    }
147
148    public override IOperation Apply() {
149      double reductionRatio = 1.0;
150      if (ActualSamplesParameter.ActualValue != null)
151        reductionRatio = ActualSamplesParameter.ActualValue.Value;
152
153      int reducedRows = (int)((SamplesEnd.Value - SamplesStart.Value) * reductionRatio);
154      DataAnalysisProblemData reducedProblemData = (DataAnalysisProblemData)DataAnalysisProblemData.Clone();
155      reducedProblemData.Dataset = CreateReducedDataset(RandomParameter.ActualValue, reducedProblemData.Dataset, reductionRatio, SamplesStart.Value, SamplesEnd.Value);
156
157      double quality = PerformCrossValidation(reducedProblemData,
158                             SamplesStart.Value, SamplesStart.Value + reducedRows,
159                             SvmType.Value, KernelType.Value,
160                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, NumberOfFolds.Value);
161
162      QualityParameter.ActualValue = new DoubleValue(quality);
163      return base.Apply();
164    }
165
166    private Dataset CreateReducedDataset(IRandom random, Dataset dataset, double reductionRatio, int start, int end) {
167      int reducedRows = (int)((end - start) * reductionRatio);
168      double[,] reducedData = dataset.GetClonedData();
169      HashSet<int> leftRows = new HashSet<int>(Enumerable.Range(0, end - start));
170      for (int row = 0; row < reducedRows; row++) {
171        int rowIndex = random.Next(0, leftRows.Count);
172        leftRows.Remove(rowIndex);
173        for (int column = 0; column < dataset.Columns; column++)
174          reducedData[row, column] = dataset[rowIndex, column];
175      }
176      return new Dataset(dataset.VariableNames, reducedData);
177    }
178
179    private static double PerformCrossValidation(
180      DataAnalysisProblemData problemData,
181      string svmType, string kernelType,
182      double cost, double nu, double gamma, double epsilon,
183      int nFolds) {
184      return PerformCrossValidation(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
185    }
186
187    public static double PerformCrossValidation(
188      DataAnalysisProblemData problemData,
189      int start, int end,
190      string svmType, string kernelType,
191      double cost, double nu, double gamma, double epsilon,
192      int nFolds) {
193      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
194
195      //extract SVM parameters from scope and set them
196      SVM.Parameter parameter = new SVM.Parameter();
197      parameter.SvmType = (SVM.SvmType)Enum.Parse(typeof(SVM.SvmType), svmType, true);
198      parameter.KernelType = (SVM.KernelType)Enum.Parse(typeof(SVM.KernelType), kernelType, true);
199      parameter.C = cost;
200      parameter.Nu = nu;
201      parameter.Gamma = gamma;
202      parameter.P = epsilon;
203      parameter.CacheSize = 500;
204      parameter.Probability = false;
205
206
207      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, start, end);
208      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
209      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
210
211      return SVM.Training.PerformCrossValidation(scaledProblem, parameter, nFolds, false);
212    }
213  }
214}
Note: See TracBrowser for help on using the repository browser.