Free cookie consent management tool by TermsFeed Policy Generator

source: branches/CEDMA-Exporter-715/sources/HeuristicLab.Modeling/3.2/VariableImpactCalculatorBase.cs @ 4021

Last change on this file since 4021 was 2043, checked in by gkronber, 16 years ago

Added variable impact calculation operators for support vector machines. #644 (Variable impact of CEDMA models should be calculated and stored in the result DB)

File size: 5.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Text;
25using System.Xml;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.DataAnalysis;
29using System.Linq;
30
31namespace HeuristicLab.Modeling {
32  public abstract class VariableImpactCalculatorBase<T> : OperatorBase {
33    private bool abortRequested = false;
34
35    public override string Description {
36      get { return @"Calculates the impact of all allowed input variables on the model."; }
37    }
38
39    public abstract string OutputVariableName { get; }
40
41    public override void Abort() {
42      abortRequested = true;
43    }
44
45    public VariableImpactCalculatorBase()
46      : base() {
47      AddVariableInfo(new VariableInfo("Dataset", "Dataset", typeof(Dataset), VariableKind.In));
48      AddVariableInfo(new VariableInfo("TargetVariable", "TargetVariable", typeof(IntData), VariableKind.In));
49      AddVariableInfo(new VariableInfo("AllowedFeatures", "Indexes of allowed input variables", typeof(ItemList<IntData>), VariableKind.In));
50      AddVariableInfo(new VariableInfo("TrainingSamplesStart", "TrainingSamplesStart", typeof(IntData), VariableKind.In));
51      AddVariableInfo(new VariableInfo("TrainingSamplesEnd", "TrainingSamplesEnd", typeof(IntData), VariableKind.In));
52      AddVariableInfo(new VariableInfo(OutputVariableName, OutputVariableName, typeof(ItemList), VariableKind.New));
53    }
54
55    public override IOperation Apply(IScope scope) {
56      ItemList<IntData> allowedFeatures = GetVariableValue<ItemList<IntData>>("AllowedFeatures", scope, true);
57      int targetVariable = GetVariableValue<IntData>("TargetVariable", scope, true).Data;
58      Dataset dataset = GetVariableValue<Dataset>("Dataset", scope, true);
59      Dataset dirtyDataset = (Dataset)dataset.Clone();
60      int start = GetVariableValue<IntData>("TrainingSamplesStart", scope, true).Data;
61      int end = GetVariableValue<IntData>("TrainingSamplesEnd", scope, true).Data;
62
63      T referenceValue = CalculateValue(scope, dataset, targetVariable, allowedFeatures, start, end);
64      double[] impacts = new double[allowedFeatures.Count];
65
66      for (int i = 0; i < allowedFeatures.Count && !abortRequested; i++) {
67        int currentVariable = allowedFeatures[i].Data;
68        var oldValues = ReplaceVariableValues(dirtyDataset, currentVariable, CalculateNewValues(dirtyDataset, currentVariable, start, end), start, end);
69        T newValue = CalculateValue(scope, dirtyDataset, targetVariable, allowedFeatures, start, end);
70        impacts[i] = CalculateImpact(referenceValue, newValue);
71        ReplaceVariableValues(dirtyDataset, currentVariable, oldValues, start, end);
72      }
73
74      if (!abortRequested) {
75        impacts = PostProcessImpacts(impacts);
76
77        ItemList variableImpacts = new ItemList();
78        for (int i = 0; i < allowedFeatures.Count; i++) {
79          int currentVariable = allowedFeatures[i].Data;
80          ItemList row = new ItemList();
81          row.Add(new StringData(dataset.GetVariableName(currentVariable)));
82          row.Add(new DoubleData(impacts[i]));
83          variableImpacts.Add(row);
84        }
85
86        scope.AddVariable(new Variable(scope.TranslateName(OutputVariableName), variableImpacts));
87        return null;
88      } else {
89        return new AtomicOperation(this, scope);
90      }
91    }
92
93    protected abstract T CalculateValue(IScope scope, Dataset dataset, int targetVariable, ItemList<IntData> allowedFeatures, int start, int end);
94
95    protected abstract double CalculateImpact(T referenceValue, T newValue);
96
97    protected virtual double[] PostProcessImpacts(double[] impacts) {
98      return impacts;
99    }
100
101    private IEnumerable<double> ReplaceVariableValues(Dataset ds, int variableIndex, IEnumerable<double> newValues, int start, int end) {
102      double[] oldValues = new double[end - start];
103      for (int i = 0; i < end - start; i++) oldValues[i] = ds.GetValue(i + start, variableIndex);
104      if (newValues.Count() != end - start) throw new ArgumentException("The length of the new values sequence doesn't match the required length (number of replaced values)");
105
106      int index = start;
107      ds.FireChangeEvents = false;
108      foreach (double v in newValues) {
109        ds.SetValue(index++, variableIndex, v);
110      }
111      ds.FireChangeEvents = true;
112      ds.FireChanged();
113      return oldValues;
114    }
115
116    private IEnumerable<double> CalculateNewValues(Dataset ds, int variableIndex, int start, int end) {
117      double mean = ds.GetMean(variableIndex, start, end);
118      return Enumerable.Repeat(mean, end - start);
119    }
120  }
121}
Note: See TracBrowser for help on using the repository browser.