Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineCrossValidationEvaluator.cs @ 4543

Last change on this file since 4543 was 4543, checked in by gkronber, 14 years ago

Adapted SVM classes to work correctly for overlapping training / test partitions. #1226

File size: 11.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Operators;
27using HeuristicLab.Optimization;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using SVM;
31using System.Collections.Generic;
32
33namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
34  /// <summary>
35  /// Represents an operator that performs SVM cross validation with the given parameters.
36  /// </summary>
37  [StorableClass]
38  [Item("SupportVectorMachineCrossValidationEvaluator", "Represents an operator that performs SVM cross validation with the given parameters.")]
39  public class SupportVectorMachineCrossValidationEvaluator : SingleSuccessorOperator, ISingleObjectiveEvaluator {
40    private const string RandomParameterName = "Random";
41    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
42    private const string SvmTypeParameterName = "SvmType";
43    private const string KernelTypeParameterName = "KernelType";
44    private const string CostParameterName = "Cost";
45    private const string NuParameterName = "Nu";
46    private const string GammaParameterName = "Gamma";
47    private const string EpsilonParameterName = "Epsilon";
48    private const string SamplesStartParameterName = "SamplesStart";
49    private const string SamplesEndParameterName = "SamplesEnd";
50    private const string ActualSamplesParameterName = "ActualSamples";
51    private const string NumberOfFoldsParameterName = "NumberOfFolds";
52    private const string QualityParameterName = "Quality";
53
54    #region parameter properties
55    public ILookupParameter<IRandom> RandomParameter {
56      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
57    }
58    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
59      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
60    }
61    public IValueLookupParameter<StringValue> SvmTypeParameter {
62      get { return (IValueLookupParameter<StringValue>)Parameters[SvmTypeParameterName]; }
63    }
64    public IValueLookupParameter<StringValue> KernelTypeParameter {
65      get { return (IValueLookupParameter<StringValue>)Parameters[KernelTypeParameterName]; }
66    }
67    public IValueLookupParameter<DoubleValue> NuParameter {
68      get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; }
69    }
70    public IValueLookupParameter<DoubleValue> CostParameter {
71      get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; }
72    }
73    public IValueLookupParameter<DoubleValue> GammaParameter {
74      get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; }
75    }
76    public IValueLookupParameter<DoubleValue> EpsilonParameter {
77      get { return (IValueLookupParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
78    }
79    public IValueLookupParameter<IntValue> SamplesStartParameter {
80      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
81    }
82    public IValueLookupParameter<IntValue> SamplesEndParameter {
83      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
84    }
85    public IValueLookupParameter<PercentValue> ActualSamplesParameter {
86      get { return (IValueLookupParameter<PercentValue>)Parameters[ActualSamplesParameterName]; }
87    }
88    public IValueLookupParameter<IntValue> NumberOfFoldsParameter {
89      get { return (IValueLookupParameter<IntValue>)Parameters[NumberOfFoldsParameterName]; }
90    }
91    public ILookupParameter<DoubleValue> QualityParameter {
92      get { return (ILookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
93    }
94    #endregion
95    #region properties
96    public DataAnalysisProblemData DataAnalysisProblemData {
97      get { return DataAnalysisProblemDataParameter.ActualValue; }
98    }
99    public StringValue SvmType {
100      get { return SvmTypeParameter.ActualValue; }
101    }
102    public StringValue KernelType {
103      get { return KernelTypeParameter.ActualValue; }
104    }
105    public DoubleValue Nu {
106      get { return NuParameter.ActualValue; }
107    }
108    public DoubleValue Cost {
109      get { return CostParameter.ActualValue; }
110    }
111    public DoubleValue Gamma {
112      get { return GammaParameter.ActualValue; }
113    }
114    public DoubleValue Epsilon {
115      get { return EpsilonParameter.ActualValue; }
116    }
117    public IntValue SamplesStart {
118      get { return SamplesStartParameter.ActualValue; }
119    }
120    public IntValue SamplesEnd {
121      get { return SamplesEndParameter.ActualValue; }
122    }
123    public IntValue NumberOfFolds {
124      get { return NumberOfFoldsParameter.ActualValue; }
125    }
126    #endregion
127
128    public SupportVectorMachineCrossValidationEvaluator()
129      : base() {
130      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
131      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
132      Parameters.Add(new ValueLookupParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use."));
133      Parameters.Add(new ValueLookupParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM."));
134      Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR."));
135      Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR."));
136      Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function."));
137      Parameters.Add(new ValueLookupParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR."));
138      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first index of the data set partition the support vector machine should use for training."));
139      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last index of the data set partition the support vector machine should use for training."));
140      Parameters.Add(new ValueLookupParameter<PercentValue>(ActualSamplesParameterName, "The percentage of the training set that should be acutally used for cross-validation (samples are picked randomly from the training set)."));
141      Parameters.Add(new ValueLookupParameter<IntValue>(NumberOfFoldsParameterName, "The number of folds to use for cross-validation."));
142      Parameters.Add(new LookupParameter<DoubleValue>(QualityParameterName, "The cross validation quality reached with the given parameters."));
143    }
144
145    public override IOperation Apply() {
146      double reductionRatio = 1.0; // TODO: make parameter
147      if (ActualSamplesParameter.ActualValue != null)
148        reductionRatio = ActualSamplesParameter.ActualValue.Value;
149      IEnumerable<int> rows =
150        Enumerable.Range(SamplesStart.Value, SamplesEnd.Value - SamplesStart.Value)
151        .Where(i => i < DataAnalysisProblemData.TestSamplesStart.Value || DataAnalysisProblemData.TestSamplesEnd.Value <= i);
152
153      // create a new DataAnalysisProblemData instance
154      DataAnalysisProblemData reducedProblemData = (DataAnalysisProblemData)DataAnalysisProblemData.Clone();
155      reducedProblemData.Dataset =
156        CreateReducedDataset(RandomParameter.ActualValue, reducedProblemData.Dataset, rows, reductionRatio);
157      reducedProblemData.TrainingSamplesStart.Value = 0;
158      reducedProblemData.TrainingSamplesEnd.Value = reducedProblemData.Dataset.Rows;
159      reducedProblemData.TestSamplesStart.Value = reducedProblemData.Dataset.Rows;
160      reducedProblemData.TestSamplesEnd.Value = reducedProblemData.Dataset.Rows;
161      reducedProblemData.ValidationPercentage.Value = 0;
162
163      double quality = PerformCrossValidation(reducedProblemData,
164                             SvmType.Value, KernelType.Value,
165                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, NumberOfFolds.Value);
166
167      QualityParameter.ActualValue = new DoubleValue(quality);
168      return base.Apply();
169    }
170
171    private Dataset CreateReducedDataset(IRandom random, Dataset dataset, IEnumerable<int> rowIndices, double reductionRatio) {
172     
173      // must not make a fink:
174      // => select n rows randomly from start..end
175      // => sort the selected rows by index
176      // => move rows to beginning of partition (start)
177
178      // all possible rowIndexes from start..end
179      int[] rowIndexArr = rowIndices.ToArray();
180      int n = (int)Math.Max(1.0, rowIndexArr.Length * reductionRatio);
181
182      // knuth shuffle
183      for (int i = rowIndexArr.Length - 1; i > 0; i--) {
184        int j = random.Next(0, i);
185        // swap
186        int tmp = rowIndexArr[i];
187        rowIndexArr[i] = rowIndexArr[j];
188        rowIndexArr[j] = tmp;
189      }
190
191      // take the first n indexes (selected n rowIndexes from start..end)
192      // now order by index
193      int[] orderedRandomIndexes =
194        rowIndexArr.Take(n)
195        .OrderBy(x => x)
196        .ToArray();
197
198      // now build a dataset containing only rows from orderedRandomIndexes
199      double[,] reducedData = new double[n, dataset.Columns];
200      for (int i = 0; i < n; i++) {
201        for (int column = 0; column < dataset.Columns; column++) {
202          reducedData[i, column] = dataset[orderedRandomIndexes[i], column];
203        }
204      }
205      return new Dataset(dataset.VariableNames, reducedData);
206    }
207
208    private static double PerformCrossValidation(
209      DataAnalysisProblemData problemData,
210      string svmType, string kernelType,
211      double cost, double nu, double gamma, double epsilon,
212      int nFolds) {
213      return PerformCrossValidation(problemData, problemData.TrainingIndizes, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
214    }
215
216    public static double PerformCrossValidation(
217      DataAnalysisProblemData problemData,
218      IEnumerable<int> rowIndices,
219      string svmType, string kernelType,
220      double cost, double nu, double gamma, double epsilon,
221      int nFolds) {
222      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
223
224      //extract SVM parameters from scope and set them
225      SVM.Parameter parameter = new SVM.Parameter();
226      parameter.SvmType = (SVM.SvmType)Enum.Parse(typeof(SVM.SvmType), svmType, true);
227      parameter.KernelType = (SVM.KernelType)Enum.Parse(typeof(SVM.KernelType), kernelType, true);
228      parameter.C = cost;
229      parameter.Nu = nu;
230      parameter.Gamma = gamma;
231      parameter.P = epsilon;
232      parameter.CacheSize = 500;
233      parameter.Probability = false;
234
235
236      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, rowIndices);
237      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
238      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
239
240      return SVM.Training.PerformCrossValidation(scaledProblem, parameter, nFolds, false);
241    }
242  }
243}
Note: See TracBrowser for help on using the repository browser.