Free cookie consent management tool by TermsFeed Policy Generator

source: branches/plugins/HeuristicLab.GP.StructureIdentification.Classification/3.2/CrossValidation.cs @ 3494

Last change on this file since 3494 was 708, checked in by gkronber, 16 years ago

added a check for the value of nFolds to fix #313 (CrossValidation operator doesn't work with fold=1)

File size: 6.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Text;
25using System.Xml;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.DataAnalysis;
29
30namespace HeuristicLab.GP.StructureIdentification.Classification
31{
32    public class CrossValidation : OperatorBase
33    {
34
35        private const string DATASET = "Dataset";
36        private const string NFOLD = "n-Fold";
37        private const string TRAININGSAMPLESSTART = "TrainingSamplesStart";
38        private const string TRAININGSAMPLESEND = "TrainingSamplesEnd";
39        private const string VALIDATIONSAMPLESSTART = "ValidationSamplesStart";
40        private const string VALIDATIONSAMPLESEND = "ValidationSamplesEnd";
41        private const string TESTSAMPLESSTART = "TestSamplesStart";
42        private const string TESTSAMPLESEND = "TestSamplesEnd";
43
44        public override string Description
45        {
46            get { return @"TASK"; }
47        }
48
49        public CrossValidation()
50            : base()
51        {
52            AddVariableInfo(new VariableInfo(DATASET, "The original dataset and the new datasets in the newly created subscopes", typeof(Dataset), VariableKind.In));
53            AddVariableInfo(new VariableInfo(NFOLD, "Number of folds for the cross-validation", typeof(IntData), VariableKind.In));
54            AddVariableInfo(new VariableInfo(TRAININGSAMPLESSTART, "The start of training samples in the original dataset and starts of training samples in the new datasets", typeof(IntData), VariableKind.In | VariableKind.New));
55            AddVariableInfo(new VariableInfo(TRAININGSAMPLESEND, "The end of training samples in the original dataset and ends of training samples in the new datasets", typeof(IntData), VariableKind.In | VariableKind.New));
56            AddVariableInfo(new VariableInfo(VALIDATIONSAMPLESSTART, "The start of validation samples in the original dataset and starts of validation samples in the new datasets", typeof(IntData), VariableKind.In | VariableKind.New));
57            AddVariableInfo(new VariableInfo(VALIDATIONSAMPLESEND, "The end of validation samples in the original dataset and ends of validation samples in the new datasets", typeof(IntData), VariableKind.In | VariableKind.New));
58            AddVariableInfo(new VariableInfo(TESTSAMPLESSTART, "The start of the test samples in the new datasets", typeof(IntData), VariableKind.New));
59            AddVariableInfo(new VariableInfo(TESTSAMPLESEND, "The end of the test samples in the new datasets", typeof(IntData), VariableKind.New));
60        }
61
62        public override IOperation Apply(IScope scope) {
63      Dataset origDataset = GetVariableValue<Dataset>(DATASET, scope, true);
64      int nFolds = GetVariableValue<IntData>(NFOLD, scope, true).Data;
65      if (nFolds < 2) throw new ArgumentException("The number of folds (nFolds) has to be >=2 for cross validation");
66      int origTrainingSamplesStart = GetVariableValue<IntData>(TRAININGSAMPLESSTART, scope, true).Data;
67      int origTrainingSamplesEnd = GetVariableValue<IntData>(TRAININGSAMPLESEND, scope, true).Data;
68      int origValidationSamplesStart = GetVariableValue<IntData>(VALIDATIONSAMPLESSTART, scope, true).Data;
69      int origValidationSamplesEnd = GetVariableValue<IntData>(VALIDATIONSAMPLESEND, scope, true).Data;
70      int n=origDataset.Rows;
71      int origTrainingSamples = (origTrainingSamplesEnd-origTrainingSamplesStart);
72      int origValidationSamples = (origValidationSamplesEnd-origValidationSamplesStart);
73
74      double percentTrainingSamples = origTrainingSamples / (double)(origValidationSamples + origTrainingSamples);
75      int nTestSamples = n / nFolds;
76
77      int newTrainingSamplesStart = 0;
78      int newTrainingSamplesEnd = (int)((n - nTestSamples) * percentTrainingSamples);
79      int newValidationSamplesStart = newTrainingSamplesEnd;
80      int newValidationSamplesEnd = n - nTestSamples;
81      int newTestSamplesStart = n - nTestSamples;
82      int newTestSamplesEnd = n;
83
84      for(int i = 0; i < nFolds; i++) {
85        Scope childScope = new Scope(i.ToString());
86        Dataset rotatedSet = new Dataset();
87
88        double[] samples = new double[origDataset.Samples.Length];
89        Array.Copy(origDataset.Samples, samples, samples.Length);
90        RotateArray(samples, i * nTestSamples * origDataset.Columns);
91
92        rotatedSet.Rows = origDataset.Rows;
93        rotatedSet.Columns = origDataset.Columns;
94        rotatedSet.Samples = samples;
95        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(DATASET), rotatedSet));
96        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(TRAININGSAMPLESSTART), new IntData(newTrainingSamplesStart)));
97        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(TRAININGSAMPLESEND), new IntData(newTrainingSamplesEnd)));
98        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(VALIDATIONSAMPLESSTART), new IntData(newValidationSamplesStart)));
99        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(VALIDATIONSAMPLESEND), new IntData(newValidationSamplesEnd)));
100        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(TESTSAMPLESSTART), new IntData(newTestSamplesStart)));
101        childScope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(TESTSAMPLESEND), new IntData(newTestSamplesEnd)));
102
103        scope.AddSubScope(childScope);
104      }
105      return null;
106    }
107
108        private void RotateArray(double[] samples, int p)
109        {
110            Array.Reverse(samples, 0, p);
111            Array.Reverse(samples, p, samples.Length - p);
112            Array.Reverse(samples);
113        }
114    }
115}
Note: See TracBrowser for help on using the repository browser.