Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Persistence Test/HeuristicLab.Modeling/3.2/ProblemInjector.cs @ 3928

Last change on this file since 3928 was 2440, checked in by gkronber, 15 years ago

Fixed #784 (ProblemInjector should be changed to read variable names instead of indexes for input and target variables)

File size: 8.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Text;
25using System.Xml;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.DataAnalysis;
29using System.Linq;
30
31namespace HeuristicLab.Modeling {
32  public class ProblemInjector : OperatorBase {
33    public override string Description {
34      get { return @"Injects the necessary variables for a data-based modeling problem."; }
35    }
36
37    public ProblemInjector()
38      : base() {
39      AddVariableInfo(new VariableInfo("Dataset", "Dataset", typeof(Dataset), VariableKind.New));
40      GetVariableInfo("Dataset").Local = true;
41      AddVariable(new Variable("Dataset", new Dataset()));
42
43      AddVariableInfo(new VariableInfo("TargetVariable", "TargetVariable", typeof(StringData), VariableKind.New));
44      GetVariableInfo("TargetVariable").Local = true;
45      AddVariable(new Variable("TargetVariable", new StringData()));
46
47      AddVariableInfo(new VariableInfo("AllowedFeatures", "Indexes of allowed input variables", typeof(ItemList<StringData>), VariableKind.In));
48      GetVariableInfo("AllowedFeatures").Local = true;
49      AddVariable(new Variable("AllowedFeatures", new ItemList<StringData>()));
50
51      AddVariableInfo(new VariableInfo("TrainingSamplesStart", "TrainingSamplesStart", typeof(IntData), VariableKind.New));
52      GetVariableInfo("TrainingSamplesStart").Local = true;
53      AddVariable(new Variable("TrainingSamplesStart", new IntData()));
54
55      AddVariableInfo(new VariableInfo("TrainingSamplesEnd", "TrainingSamplesEnd", typeof(IntData), VariableKind.New));
56      GetVariableInfo("TrainingSamplesEnd").Local = true;
57      AddVariable(new Variable("TrainingSamplesEnd", new IntData()));
58
59      AddVariableInfo(new VariableInfo("ActualTrainingSamplesStart", "ActualTrainingSamplesStart", typeof(IntData), VariableKind.New));
60      AddVariableInfo(new VariableInfo("ActualTrainingSamplesEnd", "ActualTrainingSamplesEnd", typeof(IntData), VariableKind.New));
61
62      AddVariableInfo(new VariableInfo("ValidationSamplesStart", "ValidationSamplesStart", typeof(IntData), VariableKind.New));
63      GetVariableInfo("ValidationSamplesStart").Local = true;
64      AddVariable(new Variable("ValidationSamplesStart", new IntData()));
65
66      AddVariableInfo(new VariableInfo("ValidationSamplesEnd", "ValidationSamplesEnd", typeof(IntData), VariableKind.New));
67      GetVariableInfo("ValidationSamplesEnd").Local = true;
68      AddVariable(new Variable("ValidationSamplesEnd", new IntData()));
69
70      AddVariableInfo(new VariableInfo("TestSamplesStart", "TestSamplesStart", typeof(IntData), VariableKind.New));
71      GetVariableInfo("TestSamplesStart").Local = true;
72      AddVariable(new Variable("TestSamplesStart", new IntData()));
73
74      AddVariableInfo(new VariableInfo("TestSamplesEnd", "TestSamplesEnd", typeof(IntData), VariableKind.New));
75      GetVariableInfo("TestSamplesEnd").Local = true;
76      AddVariable(new Variable("TestSamplesEnd", new IntData()));
77
78      AddVariableInfo(new VariableInfo("MaxNumberOfTrainingSamples", "Maximal number of training samples to use (optional)", typeof(IntData), VariableKind.In));
79      AddVariableInfo(new VariableInfo("NumberOfInputVariables", "The number of available input variables", typeof(IntData), VariableKind.New));
80      AddVariableInfo(new VariableInfo("InputVariables", "List of input variable names", typeof(ItemList), VariableKind.New));
81    }
82
83    public override IView CreateView() {
84      return new ProblemInjectorView(this);
85    }
86
87    public override IOperation Apply(IScope scope) {
88      AddVariableToScope("TrainingSamplesStart", scope);
89      AddVariableToScope("TrainingSamplesEnd", scope);
90      AddVariableToScope("ValidationSamplesStart", scope);
91      AddVariableToScope("ValidationSamplesEnd", scope);
92      AddVariableToScope("TestSamplesStart", scope);
93      AddVariableToScope("TestSamplesEnd", scope);
94
95      Dataset operatorDataset = (Dataset)GetVariable("Dataset").Value;
96      string targetVariable = ((StringData)GetVariable("TargetVariable").Value).Data;
97      ItemList<StringData> operatorAllowedFeatures = (ItemList<StringData>)GetVariable("AllowedFeatures").Value;
98
99      Dataset scopeDataset = CreateNewDataset(operatorDataset, targetVariable, operatorAllowedFeatures);
100      ItemList inputVariables = new ItemList();
101      for (int i = 1; i < scopeDataset.Columns; i++) {
102        inputVariables.Add(new StringData(scopeDataset.GetVariableName(i)));
103      }
104
105      scope.AddVariable(new Variable(scope.TranslateName("Dataset"), scopeDataset));
106      scope.AddVariable(new Variable(scope.TranslateName("TargetVariable"), new StringData(targetVariable)));
107      scope.AddVariable(new Variable(scope.TranslateName("NumberOfInputVariables"), new IntData(scopeDataset.Columns - 1)));
108      scope.AddVariable(new Variable(scope.TranslateName("InputVariables"), inputVariables));
109
110      int trainingStart = GetVariableValue<IntData>("TrainingSamplesStart", scope, true).Data;
111      int trainingEnd = GetVariableValue<IntData>("TrainingSamplesEnd", scope, true).Data;
112
113      var maxTraining = GetVariableValue<IntData>("MaxNumberOfTrainingSamples", scope, true, false);
114      int nTrainingSamples;
115      if (maxTraining != null) {
116        nTrainingSamples = Math.Min(maxTraining.Data, trainingEnd - trainingStart);
117        if (nTrainingSamples <= 0)
118          throw new ArgumentException("Maximal number of training samples must be larger than 0", "MaxNumberOfTrainingSamples");
119      } else {
120        nTrainingSamples = trainingEnd - trainingStart;
121      }
122      scope.AddVariable(new Variable(scope.TranslateName("ActualTrainingSamplesStart"), new IntData(trainingStart)));
123      scope.AddVariable(new Variable(scope.TranslateName("ActualTrainingSamplesEnd"), new IntData(trainingStart + nTrainingSamples)));
124
125
126      return null;
127    }
128
129    private Dataset CreateNewDataset(Dataset operatorDataset, string targetVariable, ItemList<StringData> operatorAllowedVariables) {
130      int columns = (operatorAllowedVariables.Count() + 1);
131      int rows = operatorDataset.Rows;
132      double[] values = new double[rows * columns];
133      int targetVariableIndex = operatorDataset.GetVariableIndex(targetVariable);
134      for (int row = 0; row < rows; row++) {
135        int column = 0;
136        values[row*columns + column] = operatorDataset.GetValue(row, targetVariableIndex); // set target variable value to column index 0
137        column++; // start input variables at column index 1
138        foreach (var inputVariable in operatorAllowedVariables) {
139          int variableColumnIndex = operatorDataset.GetVariableIndex(inputVariable.Data);
140          values[row * columns + column] = operatorDataset.GetValue(row, variableColumnIndex);
141          column++;
142        }
143      }
144
145      Dataset ds = new Dataset();
146      ds.Columns = columns;
147      ds.Rows = operatorDataset.Rows;
148      ds.Name = operatorDataset.Name;
149      ds.Samples = values;
150      double[] scalingFactor = new double[columns];
151      double[] scalingOffset = new double[columns];
152      ds.SetVariableName(0, targetVariable);
153      scalingFactor[0] = operatorDataset.ScalingFactor[targetVariableIndex];
154      scalingOffset[0] = operatorDataset.ScalingOffset[targetVariableIndex];
155      for (int column = 1; column < columns; column++) {
156        int variableColumnIndex = operatorDataset.GetVariableIndex(operatorAllowedVariables[column - 1].Data);
157        ds.SetVariableName(column, operatorAllowedVariables[column - 1].Data);
158        scalingFactor[column] = operatorDataset.ScalingFactor[variableColumnIndex];
159        scalingOffset[column] = operatorDataset.ScalingOffset[variableColumnIndex];
160      }
161      ds.ScalingOffset = scalingOffset;
162      ds.ScalingFactor = scalingFactor;
163      return ds;
164    }
165
166    private void AddVariableToScope(string variableName, IScope scope) {
167      scope.AddVariable(new Variable(scope.TranslateName(variableName), (IItem)GetVariable(variableName).Value.Clone()));     
168    }
169  }
170}
Note: See TracBrowser for help on using the repository browser.