Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineCrossValidationEvaluator.cs @ 4068

Last change on this file since 4068 was 4068, checked in by swagner, 14 years ago

Sorted usings and removed unused usings in entire solution (#1094)

File size: 11.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Operators;
27using HeuristicLab.Optimization;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using SVM;
31
32namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
33  /// <summary>
34  /// Represents an operator that performs SVM cross validation with the given parameters.
35  /// </summary>
36  [StorableClass]
37  [Item("SupportVectorMachineCrossValidationEvaluator", "Represents an operator that performs SVM cross validation with the given parameters.")]
38  public class SupportVectorMachineCrossValidationEvaluator : SingleSuccessorOperator, ISingleObjectiveEvaluator {
39    private const string RandomParameterName = "Random";
40    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
41    private const string SvmTypeParameterName = "SvmType";
42    private const string KernelTypeParameterName = "KernelType";
43    private const string CostParameterName = "Cost";
44    private const string NuParameterName = "Nu";
45    private const string GammaParameterName = "Gamma";
46    private const string EpsilonParameterName = "Epsilon";
47    private const string SamplesStartParameterName = "SamplesStart";
48    private const string SamplesEndParameterName = "SamplesEnd";
49    private const string ActualSamplesParameterName = "ActualSamples";
50    private const string NumberOfFoldsParameterName = "NumberOfFolds";
51    private const string QualityParameterName = "Quality";
52
53    #region parameter properties
54    public ILookupParameter<IRandom> RandomParameter {
55      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
56    }
57    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
58      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
59    }
60    public IValueLookupParameter<StringValue> SvmTypeParameter {
61      get { return (IValueLookupParameter<StringValue>)Parameters[SvmTypeParameterName]; }
62    }
63    public IValueLookupParameter<StringValue> KernelTypeParameter {
64      get { return (IValueLookupParameter<StringValue>)Parameters[KernelTypeParameterName]; }
65    }
66    public IValueLookupParameter<DoubleValue> NuParameter {
67      get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; }
68    }
69    public IValueLookupParameter<DoubleValue> CostParameter {
70      get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; }
71    }
72    public IValueLookupParameter<DoubleValue> GammaParameter {
73      get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; }
74    }
75    public IValueLookupParameter<DoubleValue> EpsilonParameter {
76      get { return (IValueLookupParameter<DoubleValue>)Parameters[EpsilonParameterName]; }
77    }
78    public IValueLookupParameter<IntValue> SamplesStartParameter {
79      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
80    }
81    public IValueLookupParameter<IntValue> SamplesEndParameter {
82      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
83    }
84    public IValueLookupParameter<PercentValue> ActualSamplesParameter {
85      get { return (IValueLookupParameter<PercentValue>)Parameters[ActualSamplesParameterName]; }
86    }
87    public IValueLookupParameter<IntValue> NumberOfFoldsParameter {
88      get { return (IValueLookupParameter<IntValue>)Parameters[NumberOfFoldsParameterName]; }
89    }
90    public ILookupParameter<DoubleValue> QualityParameter {
91      get { return (ILookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
92    }
93    #endregion
94    #region properties
95    public DataAnalysisProblemData DataAnalysisProblemData {
96      get { return DataAnalysisProblemDataParameter.ActualValue; }
97    }
98    public StringValue SvmType {
99      get { return SvmTypeParameter.ActualValue; }
100    }
101    public StringValue KernelType {
102      get { return KernelTypeParameter.ActualValue; }
103    }
104    public DoubleValue Nu {
105      get { return NuParameter.ActualValue; }
106    }
107    public DoubleValue Cost {
108      get { return CostParameter.ActualValue; }
109    }
110    public DoubleValue Gamma {
111      get { return GammaParameter.ActualValue; }
112    }
113    public DoubleValue Epsilon {
114      get { return EpsilonParameter.ActualValue; }
115    }
116    public IntValue SamplesStart {
117      get { return SamplesStartParameter.ActualValue; }
118    }
119    public IntValue SamplesEnd {
120      get { return SamplesEndParameter.ActualValue; }
121    }
122    public IntValue NumberOfFolds {
123      get { return NumberOfFoldsParameter.ActualValue; }
124    }
125    #endregion
126
127    public SupportVectorMachineCrossValidationEvaluator()
128      : base() {
129      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
130      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
131      Parameters.Add(new ValueLookupParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use."));
132      Parameters.Add(new ValueLookupParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM."));
133      Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR."));
134      Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR."));
135      Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function."));
136      Parameters.Add(new ValueLookupParameter<DoubleValue>(EpsilonParameterName, "The value of the epsilon parameter for epsilon-SVR."));
137      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first index of the data set partition the support vector machine should use for training."));
138      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last index of the data set partition the support vector machine should use for training."));
139      Parameters.Add(new ValueLookupParameter<PercentValue>(ActualSamplesParameterName, "The percentage of the training set that should be acutally used for cross-validation (samples are picked randomly from the training set)."));
140      Parameters.Add(new ValueLookupParameter<IntValue>(NumberOfFoldsParameterName, "The number of folds to use for cross-validation."));
141      Parameters.Add(new LookupParameter<DoubleValue>(QualityParameterName, "The cross validation quality reached with the given parameters."));
142    }
143
144    public override IOperation Apply() {
145      double reductionRatio = 1.0;
146      if (ActualSamplesParameter.ActualValue != null)
147        reductionRatio = ActualSamplesParameter.ActualValue.Value;
148
149      int reducedRows = (int)((SamplesEnd.Value - SamplesStart.Value) * reductionRatio);
150      DataAnalysisProblemData reducedProblemData = (DataAnalysisProblemData)DataAnalysisProblemData.Clone();
151      reducedProblemData.Dataset = CreateReducedDataset(RandomParameter.ActualValue, reducedProblemData.Dataset, reductionRatio, SamplesStart.Value, SamplesEnd.Value);
152
153      double quality = PerformCrossValidation(reducedProblemData,
154                             SamplesStart.Value, SamplesStart.Value + reducedRows,
155                             SvmType.Value, KernelType.Value,
156                             Cost.Value, Nu.Value, Gamma.Value, Epsilon.Value, NumberOfFolds.Value);
157
158      QualityParameter.ActualValue = new DoubleValue(quality);
159      return base.Apply();
160    }
161
162    private Dataset CreateReducedDataset(IRandom random, Dataset dataset, double reductionRatio, int start, int end) {
163      int n = (int)((end - start) * reductionRatio);
164      // must not make a fink:
165      // => select n rows randomly from start..end
166      // => sort the selected rows by index
167      // => move rows to beginning of partition (start)
168
169      // all possible rowIndexes from start..end
170      int[] rowIndexes = Enumerable.Range(start, end - start).ToArray();
171
172      // knuth shuffle
173      for (int i = rowIndexes.Length - 1; i > 0; i--) {
174        int j = random.Next(0, i);
175        // swap
176        int tmp = rowIndexes[i];
177        rowIndexes[i] = rowIndexes[j];
178        rowIndexes[j] = tmp;
179      }
180
181      // take the first n indexes (selected n rowIndexes from start..end)
182      // now order by index
183      var orderedRandomIndexes = rowIndexes.Take(n).OrderBy(x => x).ToArray();
184
185      // now build a dataset collecting the rows from orderedRandomIndexes into the dataset starting at index start
186      double[,] reducedData = dataset.GetClonedData();
187      for (int i = 0; i < n; i++) {
188        for (int column = 0; column < dataset.Columns; column++) {
189          reducedData[start + i, column] = dataset[orderedRandomIndexes[i], column];
190        }
191      }
192      return new Dataset(dataset.VariableNames, reducedData);
193    }
194
195    private static double PerformCrossValidation(
196      DataAnalysisProblemData problemData,
197      string svmType, string kernelType,
198      double cost, double nu, double gamma, double epsilon,
199      int nFolds) {
200      return PerformCrossValidation(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value, svmType, kernelType, cost, nu, gamma, epsilon, nFolds);
201    }
202
203    public static double PerformCrossValidation(
204      DataAnalysisProblemData problemData,
205      int start, int end,
206      string svmType, string kernelType,
207      double cost, double nu, double gamma, double epsilon,
208      int nFolds) {
209      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
210
211      //extract SVM parameters from scope and set them
212      SVM.Parameter parameter = new SVM.Parameter();
213      parameter.SvmType = (SVM.SvmType)Enum.Parse(typeof(SVM.SvmType), svmType, true);
214      parameter.KernelType = (SVM.KernelType)Enum.Parse(typeof(SVM.KernelType), kernelType, true);
215      parameter.C = cost;
216      parameter.Nu = nu;
217      parameter.Gamma = gamma;
218      parameter.P = epsilon;
219      parameter.CacheSize = 500;
220      parameter.Probability = false;
221
222
223      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, start, end);
224      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
225      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
226
227      return SVM.Training.PerformCrossValidation(scaledProblem, parameter, nFolds, false);
228    }
229  }
230}
Note: See TracBrowser for help on using the repository browser.