Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Persistence Test/HeuristicLab.GP.StructureIdentification.Classification/3.3/CrossValidation.cs @ 3703

Last change on this file since 3703 was 2222, checked in by gkronber, 15 years ago

Merged changes from GP-refactoring branch back into the trunk #713.

File size: 6.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Core;
24using HeuristicLab.Data;
25using HeuristicLab.DataAnalysis;
26
27namespace HeuristicLab.GP.StructureIdentification.Classification {
28  public class CrossValidation : OperatorBase {
29
30    private const string DATASET = "Dataset";
31    private const string NFOLD = "n-Fold";
32    private const string TRAININGSAMPLESSTART = "TrainingSamplesStart";
33    private const string TRAININGSAMPLESEND = "TrainingSamplesEnd";
34    private const string VALIDATIONSAMPLESSTART = "ValidationSamplesStart";
35    private const string VALIDATIONSAMPLESEND = "ValidationSamplesEnd";
36    private const string TESTSAMPLESSTART = "TestSamplesStart";
37    private const string TESTSAMPLESEND = "TestSamplesEnd";
38
39    public override string Description {
40      get { return @"TASK"; }
41    }
42
43    public CrossValidation()
44      : base() {
45      AddVariableInfo(new VariableInfo(DATASET, "The original dataset and the new datasets in the newly created subscopes", typeof(Dataset), VariableKind.In));
46      AddVariableInfo(new VariableInfo(NFOLD, "Number of folds for the cross-validation", typeof(IntData), VariableKind.In));
47      AddVariableInfo(new VariableInfo(TRAININGSAMPLESSTART, "The start of training samples in the original dataset and starts of training samples in the new datasets", typeof(IntData), VariableKind.In | VariableKind.New));
48      AddVariableInfo(new VariableInfo(TRAININGSAMPLESEND, "The end of training samples in the original dataset and ends of training samples in the new datasets", typeof(IntData), VariableKind.In | VariableKind.New));
49      AddVariableInfo(new VariableInfo(VALIDATIONSAMPLESSTART, "The start of validation samples in the original dataset and starts of validation samples in the new datasets", typeof(IntData), VariableKind.In | VariableKind.New));
50      AddVariableInfo(new VariableInfo(VALIDATIONSAMPLESEND, "The end of validation samples in the original dataset and ends of validation samples in the new datasets", typeof(IntData), VariableKind.In | VariableKind.New));
51      AddVariableInfo(new VariableInfo(TESTSAMPLESSTART, "The start of the test samples in the new datasets", typeof(IntData), VariableKind.New));
52      AddVariableInfo(new VariableInfo(TESTSAMPLESEND, "The end of the test samples in the new datasets", typeof(IntData), VariableKind.New));
53    }
54
55    public override IOperation Apply(IScope scope) {
56      Dataset origDataset = GetVariableValue<Dataset>(DATASET, scope, true);
57      int nFolds = GetVariableValue<IntData>(NFOLD, scope, true).Data;
58      if (nFolds < 2) throw new ArgumentException("The number of folds (nFolds) has to be >=2 for cross validation");
59      int origTrainingSamplesStart = GetVariableValue<IntData>(TRAININGSAMPLESSTART, scope, true).Data;
60      int origTrainingSamplesEnd = GetVariableValue<IntData>(TRAININGSAMPLESEND, scope, true).Data;
61      int origValidationSamplesStart = GetVariableValue<IntData>(VALIDATIONSAMPLESSTART, scope, true).Data;
62      int origValidationSamplesEnd = GetVariableValue<IntData>(VALIDATIONSAMPLESEND, scope, true).Data;
63      int n = origDataset.Rows;
64      int origTrainingSamples = (origTrainingSamplesEnd - origTrainingSamplesStart);
65      int origValidationSamples = (origValidationSamplesEnd - origValidationSamplesStart);
66
67      double percentTrainingSamples = origTrainingSamples / (double)(origValidationSamples + origTrainingSamples);
68      int nTestSamples = n / nFolds;
69
70      int newTrainingSamplesStart = 0;
71      int newTrainingSamplesEnd = (int)((n - nTestSamples) * percentTrainingSamples);
72      int newValidationSamplesStart = newTrainingSamplesEnd;
73      int newValidationSamplesEnd = n - nTestSamples;
74      int newTestSamplesStart = n - nTestSamples;
75      int newTestSamplesEnd = n;
76
77      for (int i = 0; i < nFolds; i++) {
78        Scope childScope = new Scope(i.ToString());
79        Dataset rotatedSet = new Dataset();
80
81        double[] samples = new double[origDataset.Samples.Length];
82        Array.Copy(origDataset.Samples, samples, samples.Length);
83        RotateArray(samples, i * nTestSamples * origDataset.Columns);
84
85        rotatedSet.Rows = origDataset.Rows;
86        rotatedSet.Columns = origDataset.Columns;
87        rotatedSet.Samples = samples;
88        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(DATASET), rotatedSet));
89        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(TRAININGSAMPLESSTART), new IntData(newTrainingSamplesStart)));
90        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(TRAININGSAMPLESEND), new IntData(newTrainingSamplesEnd)));
91        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(VALIDATIONSAMPLESSTART), new IntData(newValidationSamplesStart)));
92        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(VALIDATIONSAMPLESEND), new IntData(newValidationSamplesEnd)));
93        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(TESTSAMPLESSTART), new IntData(newTestSamplesStart)));
94        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(TESTSAMPLESEND), new IntData(newTestSamplesEnd)));
95
96        scope.AddSubScope(childScope);
97      }
98      return null;
99    }
100
101    private void RotateArray(double[] samples, int p) {
102      Array.Reverse(samples, 0, p);
103      Array.Reverse(samples, p, samples.Length - p);
104      Array.Reverse(samples);
105    }
106  }
107}
Note: See TracBrowser for help on using the repository browser.