Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineCrossValidationEvaluator.cs @ 10552

Last change on this file since 10552 was 5275, checked in by gkronber, 14 years ago

Merged changes from trunk to data analysis exploration branch and added fractional distance metric evaluator. #1142

File size: 12.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Operators;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using SVM;
33
34namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
35  /// <summary>
36  /// Represents an operator that performs SVM cross validation with the given parameters.
37  /// </summary>
38  [StorableClass]
39  [Item("SupportVectorMachineCrossValidationEvaluator", "Represents an operator that performs SVM cross validation with the given parameters.")]
40  public class SupportVectorMachineCrossValidationEvaluator : SingleSuccessorOperator, ISingleObjectiveEvaluator {
41    private const string RandomParameterName = "Random";
42    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
43    private const string SvmTypeParameterName = "SvmType";
44    private const string KernelTypeParameterName = "KernelType";
45    private const string CostParameterName = "Cost";
46    private const string NuParameterName = "Nu";
47    private const string GammaParameterName = "Gamma";
48    private const string EpsilonParameterName = "Epsilon";
49    private const string SamplesStartParameterName = "SamplesStart";
50    private const string SamplesEndParameterName = "SamplesEnd";
51    private const string ActualSamplesParameterName = "ActualSamples";
52    private const string NumberOfFoldsParameterName = "NumberOfFolds";
53    private const string QualityParameterName = "Quality";
54
55    #region parameter properties
56    public ILookupParameter<IRandom> RandomParameter {
57      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
58    }
59    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
60      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
61    }
62    public IValueLookupParameter<StringValue> SvmTypeParameter {
63      get { return (IValueLookupParameter<StringValue>)Parameters[SvmTypeParameterName]; }
64    }
65    public IValueLookupParameter<StringValue> KernelTypeParameter {
66      get { return (IValueLookupParameter<StringValue>)Parameters[KernelTypeParameterName]; }
67    }
68    public IValueLookupParameter<DoubleValue> NuParameter {
69      get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; }
70    }
71    public IValueLookupParameter<DoubleValue> CostParameter {
72      get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; }
73    }
74    public IValueLookupParameter<DoubleValue> GammaParameter {
75      get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; }
76    }
77    public IValueLookupParameter<DoubleValue> EpsilonParameter {
78      get { return (IValueLookupParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
79    }
80    public IValueLookupParameter<IntValue> SamplesStartParameter {
81      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
82    }
83    public IValueLookupParameter<IntValue> SamplesEndParameter {
84      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
85    }
86    public IValueLookupParameter<PercentValue> ActualSamplesParameter {
87      get { return (IValueLookupParameter<PercentValue>)Parameters[ActualSamplesParameterName]; }
88    }
89    public IValueLookupParameter<IntValue> NumberOfFoldsParameter {
90      get { return (IValueLookupParameter<IntValue>)Parameters[NumberOfFoldsParameterName]; }
91    }
92    public ILookupParameter<DoubleValue> QualityParameter {
93      get { return (ILookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
94    }
95    #endregion
96    #region properties
97    public DataAnalysisProblemData DataAnalysisProblemData {
98      get { return DataAnalysisProblemDataParameter.ActualValue; }
99    }
100    public StringValue SvmType {
101      get { return SvmTypeParameter.ActualValue; }
102    }
103    public StringValue KernelType {
104      get { return KernelTypeParameter.ActualValue; }
105    }
106    public DoubleValue Nu {
107      get { return NuParameter.ActualValue; }
108    }
109    public DoubleValue Cost {
110      get { return CostParameter.ActualValue; }
111    }
112    public DoubleValue Gamma {
113      get { return GammaParameter.ActualValue; }
114    }
115    public DoubleValue Epsilon {
116      get { return EpsilonParameter.ActualValue; }
117    }
118    public IntValue SamplesStart {
119      get { return SamplesStartParameter.ActualValue; }
120    }
121    public IntValue SamplesEnd {
122      get { return SamplesEndParameter.ActualValue; }
123    }
124    public IntValue NumberOfFolds {
125      get { return NumberOfFoldsParameter.ActualValue; }
126    }
127    #endregion
128
129    [StorableConstructor]
130    protected SupportVectorMachineCrossValidationEvaluator(bool deserializing) : base(deserializing) { }
131
132    protected SupportVectorMachineCrossValidationEvaluator(SupportVectorMachineCrossValidationEvaluator original,
133      Cloner cloner)
134      : base(original, cloner) { }
135    public SupportVectorMachineCrossValidationEvaluator()
136      : base() {
137      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
138      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
139      Parameters.Add(new ValueLookupParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use."));
140      Parameters.Add(new ValueLookupParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM."));
141      Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR."));
142      Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR."));
143      Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function."));
144      Parameters.Add(new ValueLookupParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR."));
145      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first index of the data set partition the support vector machine should use for training."));
146      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last index of the data set partition the support vector machine should use for training."));
147      Parameters.Add(new ValueLookupParameter<PercentValue>(ActualSamplesParameterName, "The percentage of the training set that should be acutally used for cross-validation (samples are picked randomly from the training set)."));
148      Parameters.Add(new ValueLookupParameter<IntValue>(NumberOfFoldsParameterName, "The number of folds to use for cross-validation."));
149      Parameters.Add(new LookupParameter<DoubleValue>(QualityParameterName, "The cross validation quality reached with the given parameters."));
150    }
151
152    public override IDeepCloneable Clone(Cloner cloner) {
153      return new SupportVectorMachineCrossValidationEvaluator(this, cloner);
154    }
155
156    public override IOperation Apply() {
157      double reductionRatio = 1.0; // TODO: make parameter
158      if (ActualSamplesParameter.ActualValue != null)
159        reductionRatio = ActualSamplesParameter.ActualValue.Value;
160      IEnumerable<int> rows =
161        Enumerable.Range(SamplesStart.Value, SamplesEnd.Value - SamplesStart.Value)
162        .Where(i => i < DataAnalysisProblemData.TestSamplesStart.Value || DataAnalysisProblemData.TestSamplesEnd.Value <= i);
163
164      // create a new DataAnalysisProblemData instance
165      DataAnalysisProblemData reducedProblemData = (DataAnalysisProblemData)DataAnalysisProblemData.Clone();
166      reducedProblemData.Dataset =
167        CreateReducedDataset(RandomParameter.ActualValue, reducedProblemData.Dataset, rows, reductionRatio);
168      reducedProblemData.TrainingSamplesStart.Value = 0;
169      reducedProblemData.TrainingSamplesEnd.Value = reducedProblemData.Dataset.Rows;
170      reducedProblemData.TestSamplesStart.Value = reducedProblemData.Dataset.Rows;
171      reducedProblemData.TestSamplesEnd.Value = reducedProblemData.Dataset.Rows;
172      reducedProblemData.ValidationPercentage.Value = 0;
173
174      double quality = PerformCrossValidation(reducedProblemData,
175                             SvmType.Value, KernelType.Value,
176                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, NumberOfFolds.Value);
177
178      QualityParameter.ActualValue = new DoubleValue(quality);
179      return base.Apply();
180    }
181
182    private Dataset CreateReducedDataset(IRandom random, Dataset dataset, IEnumerable<int> rowIndices, double reductionRatio) {
183
184      // must not make a fink:
185      // => select n rows randomly from start..end
186      // => sort the selected rows by index
187      // => move rows to beginning of partition (start)
188
189      // all possible rowIndexes from start..end
190      int[] rowIndexArr = rowIndices.ToArray();
191      int n = (int)Math.Max(1.0, rowIndexArr.Length * reductionRatio);
192
193      // knuth shuffle
194      for (int i = rowIndexArr.Length - 1; i > 0; i--) {
195        int j = random.Next(0, i);
196        // swap
197        int tmp = rowIndexArr[i];
198        rowIndexArr[i] = rowIndexArr[j];
199        rowIndexArr[j] = tmp;
200      }
201
202      // take the first n indexes (selected n rowIndexes from start..end)
203      // now order by index
204      int[] orderedRandomIndexes =
205        rowIndexArr.Take(n)
206        .OrderBy(x => x)
207        .ToArray();
208
209      // now build a dataset containing only rows from orderedRandomIndexes
210      double[,] reducedData = new double[n, dataset.Columns];
211      for (int i = 0; i < n; i++) {
212        for (int column = 0; column < dataset.Columns; column++) {
213          reducedData[i, column] = dataset[orderedRandomIndexes[i], column];
214        }
215      }
216      return new Dataset(dataset.VariableNames, reducedData);
217    }
218
219    private static double PerformCrossValidation(
220      DataAnalysisProblemData problemData,
221      string svmType, string kernelType,
222      double cost, double nu, double gamma, double epsilon,
223      int nFolds) {
224      return PerformCrossValidation(problemData, problemData.TrainingIndizes, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
225    }
226
227    public static double PerformCrossValidation(
228      DataAnalysisProblemData problemData,
229      IEnumerable<int> rowIndices,
230      string svmType, string kernelType,
231      double cost, double nu, double gamma, double epsilon,
232      int nFolds) {
233      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
234
235      //extract SVM parameters from scope and set them
236      SVM.Parameter parameter = new SVM.Parameter();
237      parameter.SvmType = (SVM.SvmType)Enum.Parse(typeof(SVM.SvmType), svmType, true);
238      parameter.KernelType = (SVM.KernelType)Enum.Parse(typeof(SVM.KernelType), kernelType, true);
239      parameter.C = cost;
240      parameter.Nu = nu;
241      parameter.Gamma = gamma;
242      parameter.P = epsilon;
243      parameter.CacheSize = 500;
244      parameter.Probability = false;
245
246
247      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, rowIndices);
248      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
249      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
250
251      return SVM.Training.PerformCrossValidation(scaledProblem, parameter, nFolds, false);
252    }
253  }
254}
Note: See TracBrowser for help on using the repository browser.