Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Modeling/3.2/ProblemInjector.cs @ 3494

Last change on this file since 3494 was 2855, checked in by gkronber, 15 years ago

Fixed a minor bug in the problem injector. #886 (ProblemInjector doesn't work correctly if actual names of local variables are changed)

File size: 8.7 KB
RevLine 
[645]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Text;
25using System.Xml;
[1856]26using HeuristicLab.Core;
[645]27using HeuristicLab.Data;
28using HeuristicLab.DataAnalysis;
[2162]29using System.Linq;
[645]30
[1856]31namespace HeuristicLab.Modeling {
32  public class ProblemInjector : OperatorBase {
[645]33    public override string Description {
[1856]34      get { return @"Injects the necessary variables for a data-based modeling problem."; }
[645]35    }
36
[1252]37    public ProblemInjector()
[645]38      : base() {
[1856]39      AddVariableInfo(new VariableInfo("Dataset", "Dataset", typeof(Dataset), VariableKind.New));
[645]40      GetVariableInfo("Dataset").Local = true;
[1856]41      AddVariable(new Variable("Dataset", new Dataset()));
[645]42
[2440]43      AddVariableInfo(new VariableInfo("TargetVariable", "TargetVariable", typeof(StringData), VariableKind.New));
[645]44      GetVariableInfo("TargetVariable").Local = true;
[2440]45      AddVariable(new Variable("TargetVariable", new StringData()));
[645]46
[2440]47      AddVariableInfo(new VariableInfo("AllowedFeatures", "Indexes of allowed input variables", typeof(ItemList<StringData>), VariableKind.In));
[645]48      GetVariableInfo("AllowedFeatures").Local = true;
[2440]49      AddVariable(new Variable("AllowedFeatures", new ItemList<StringData>()));
[645]50
[1856]51      AddVariableInfo(new VariableInfo("TrainingSamplesStart", "TrainingSamplesStart", typeof(IntData), VariableKind.New));
[645]52      GetVariableInfo("TrainingSamplesStart").Local = true;
[1856]53      AddVariable(new Variable("TrainingSamplesStart", new IntData()));
[645]54
[1856]55      AddVariableInfo(new VariableInfo("TrainingSamplesEnd", "TrainingSamplesEnd", typeof(IntData), VariableKind.New));
[645]56      GetVariableInfo("TrainingSamplesEnd").Local = true;
[1856]57      AddVariable(new Variable("TrainingSamplesEnd", new IntData()));
[645]58
[2161]59      AddVariableInfo(new VariableInfo("ActualTrainingSamplesStart", "ActualTrainingSamplesStart", typeof(IntData), VariableKind.New));
60      AddVariableInfo(new VariableInfo("ActualTrainingSamplesEnd", "ActualTrainingSamplesEnd", typeof(IntData), VariableKind.New));
61
[1856]62      AddVariableInfo(new VariableInfo("ValidationSamplesStart", "ValidationSamplesStart", typeof(IntData), VariableKind.New));
[645]63      GetVariableInfo("ValidationSamplesStart").Local = true;
[1856]64      AddVariable(new Variable("ValidationSamplesStart", new IntData()));
[645]65
[1856]66      AddVariableInfo(new VariableInfo("ValidationSamplesEnd", "ValidationSamplesEnd", typeof(IntData), VariableKind.New));
[645]67      GetVariableInfo("ValidationSamplesEnd").Local = true;
[1856]68      AddVariable(new Variable("ValidationSamplesEnd", new IntData()));
[645]69
[1856]70      AddVariableInfo(new VariableInfo("TestSamplesStart", "TestSamplesStart", typeof(IntData), VariableKind.New));
[645]71      GetVariableInfo("TestSamplesStart").Local = true;
[1856]72      AddVariable(new Variable("TestSamplesStart", new IntData()));
[645]73
[1856]74      AddVariableInfo(new VariableInfo("TestSamplesEnd", "TestSamplesEnd", typeof(IntData), VariableKind.New));
[645]75      GetVariableInfo("TestSamplesEnd").Local = true;
[1856]76      AddVariable(new Variable("TestSamplesEnd", new IntData()));
[2161]77
78      AddVariableInfo(new VariableInfo("MaxNumberOfTrainingSamples", "Maximal number of training samples to use (optional)", typeof(IntData), VariableKind.In));
[2165]79      AddVariableInfo(new VariableInfo("NumberOfInputVariables", "The number of available input variables", typeof(IntData), VariableKind.New));
[2174]80      AddVariableInfo(new VariableInfo("InputVariables", "List of input variable names", typeof(ItemList), VariableKind.New));
[645]81    }
82
[1856]83    public override IView CreateView() {
[1252]84      return new ProblemInjectorView(this);
[645]85    }
86
[1856]87    public override IOperation Apply(IScope scope) {
[2161]88      AddVariableToScope("TrainingSamplesStart", scope);
89      AddVariableToScope("TrainingSamplesEnd", scope);
90      AddVariableToScope("ValidationSamplesStart", scope);
91      AddVariableToScope("ValidationSamplesEnd", scope);
92      AddVariableToScope("TestSamplesStart", scope);
93      AddVariableToScope("TestSamplesEnd", scope);
94
[2162]95      Dataset operatorDataset = (Dataset)GetVariable("Dataset").Value;
[2440]96      string targetVariable = ((StringData)GetVariable("TargetVariable").Value).Data;
97      ItemList<StringData> operatorAllowedFeatures = (ItemList<StringData>)GetVariable("AllowedFeatures").Value;
[2162]98
99      Dataset scopeDataset = CreateNewDataset(operatorDataset, targetVariable, operatorAllowedFeatures);
[2174]100      ItemList inputVariables = new ItemList();
101      for (int i = 1; i < scopeDataset.Columns; i++) {
102        inputVariables.Add(new StringData(scopeDataset.GetVariableName(i)));
103      }
[2162]104
[2174]105      scope.AddVariable(new Variable(scope.TranslateName("Dataset"), scopeDataset));
[2440]106      scope.AddVariable(new Variable(scope.TranslateName("TargetVariable"), new StringData(targetVariable)));
[2174]107      scope.AddVariable(new Variable(scope.TranslateName("NumberOfInputVariables"), new IntData(scopeDataset.Columns - 1)));
108      scope.AddVariable(new Variable(scope.TranslateName("InputVariables"), inputVariables));
[2162]109
[2855]110      int trainingStart = ((IntData)GetVariable("TrainingSamplesStart").Value).Data;
111      int trainingEnd = ((IntData)GetVariable("TrainingSamplesEnd").Value).Data;
[2161]112
113      var maxTraining = GetVariableValue<IntData>("MaxNumberOfTrainingSamples", scope, true, false);
114      int nTrainingSamples;
115      if (maxTraining != null) {
116        nTrainingSamples = Math.Min(maxTraining.Data, trainingEnd - trainingStart);
117        if (nTrainingSamples <= 0)
118          throw new ArgumentException("Maximal number of training samples must be larger than 0", "MaxNumberOfTrainingSamples");
119      } else {
120        nTrainingSamples = trainingEnd - trainingStart;
[645]121      }
[2161]122      scope.AddVariable(new Variable(scope.TranslateName("ActualTrainingSamplesStart"), new IntData(trainingStart)));
123      scope.AddVariable(new Variable(scope.TranslateName("ActualTrainingSamplesEnd"), new IntData(trainingStart + nTrainingSamples)));
[2174]124
125
[645]126      return null;
127    }
[2161]128
[2440]129    private Dataset CreateNewDataset(Dataset operatorDataset, string targetVariable, ItemList<StringData> operatorAllowedVariables) {
130      int columns = (operatorAllowedVariables.Count() + 1);
131      int rows = operatorDataset.Rows;
132      double[] values = new double[rows * columns];
133      int targetVariableIndex = operatorDataset.GetVariableIndex(targetVariable);
134      for (int row = 0; row < rows; row++) {
135        int column = 0;
[2855]136        values[row * columns + column] = operatorDataset.GetValue(row, targetVariableIndex); // set target variable value to column index 0
[2440]137        column++; // start input variables at column index 1
138        foreach (var inputVariable in operatorAllowedVariables) {
139          int variableColumnIndex = operatorDataset.GetVariableIndex(inputVariable.Data);
140          values[row * columns + column] = operatorDataset.GetValue(row, variableColumnIndex);
141          column++;
[2162]142        }
143      }
144
145      Dataset ds = new Dataset();
146      ds.Columns = columns;
147      ds.Rows = operatorDataset.Rows;
148      ds.Name = operatorDataset.Name;
149      ds.Samples = values;
150      double[] scalingFactor = new double[columns];
151      double[] scalingOffset = new double[columns];
[2440]152      ds.SetVariableName(0, targetVariable);
153      scalingFactor[0] = operatorDataset.ScalingFactor[targetVariableIndex];
154      scalingOffset[0] = operatorDataset.ScalingOffset[targetVariableIndex];
[2162]155      for (int column = 1; column < columns; column++) {
[2440]156        int variableColumnIndex = operatorDataset.GetVariableIndex(operatorAllowedVariables[column - 1].Data);
157        ds.SetVariableName(column, operatorAllowedVariables[column - 1].Data);
158        scalingFactor[column] = operatorDataset.ScalingFactor[variableColumnIndex];
159        scalingOffset[column] = operatorDataset.ScalingOffset[variableColumnIndex];
[2162]160      }
161      ds.ScalingOffset = scalingOffset;
162      ds.ScalingFactor = scalingFactor;
163      return ds;
164    }
165
[2161]166    private void AddVariableToScope(string variableName, IScope scope) {
[2855]167      scope.AddVariable(new Variable(scope.TranslateName(variableName), (IItem)GetVariable(variableName).Value.Clone()));
[2161]168    }
[645]169  }
170}
Note: See TracBrowser for help on using the repository browser.