Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Views/3.3/RunCollectionMonteCarloVariableImpactView.cs @ 4892

Last change on this file since 4892 was 4475, checked in by gkronber, 14 years ago

Fixed bugs in time series prognosis classes #1142.

File size: 9.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using System.Windows.Forms;
25using alglib;
26using HeuristicLab.Common;
27using HeuristicLab.Data;
28using HeuristicLab.MainForm;
29using HeuristicLab.MainForm.WindowsForms;
30using HeuristicLab.Optimization;
31using System;
32using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
34using HeuristicLab.Problems.DataAnalysis.Evaluators;
35
36namespace HeuristicLab.Problems.DataAnalysis.Views {
37  [Content(typeof(RunCollection), false)]
38  [View("RunCollection Monte-Carlo Variable Impact View")]
39  public partial class RunCollectionMonteCarloVariableImpactView : AsynchronousContentView {
40    private const string validationBestModelResultName = "Best solution (on validation set)";
41    public RunCollectionMonteCarloVariableImpactView() {
42      InitializeComponent();
43    }
44
45    public new RunCollection Content {
46      get { return (RunCollection)base.Content; }
47      set { base.Content = value; }
48    }
49
50    protected override void RegisterContentEvents() {
51      base.RegisterContentEvents();
52      this.Content.ItemsAdded += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsAdded);
53      this.Content.ItemsRemoved += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsRemoved);
54      this.Content.CollectionReset += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_CollectionReset);
55    }
56    protected override void DeregisterContentEvents() {
57      base.RegisterContentEvents();
58      this.Content.ItemsAdded -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsAdded);
59      this.Content.ItemsRemoved -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsRemoved);
60      this.Content.CollectionReset -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_CollectionReset);
61    }
62
63    protected override void OnContentChanged() {
64      base.OnContentChanged();
65      this.UpdateData();
66    }
67    private void Content_ItemsAdded(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
68      this.UpdateData();
69    }
70    private void Content_ItemsRemoved(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
71      this.UpdateData();
72    }
73    private void Content_CollectionReset(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
74      this.UpdateData();
75    }
76
77    private void UpdateData() {
78      matrixView.Content = CalculateVariableImpactMatrix();
79    }
80
81    private DoubleMatrix CalculateVariableImpactMatrix() {
82      DoubleMatrix matrix = null;
83      if (Content != null) {
84        List<IRun> runsWithSolutions = (from run in Content
85                                        where run.Results.ContainsKey(validationBestModelResultName)
86                                        select run)
87                                              .ToList();
88        IEnumerable<SymbolicRegressionSolution> allSolutions = (from run in Content
89                                                                where run.Results.ContainsKey(validationBestModelResultName)
90                                                                select run.Results[validationBestModelResultName]).Cast<SymbolicRegressionSolution>();
91
92        Dictionary<SymbolicRegressionSolution, IEnumerable<string>> variableReferences = new Dictionary<SymbolicRegressionSolution, IEnumerable<string>>();
93        foreach (var solution in allSolutions) {
94          variableReferences[solution] = GetVariableReferences(solution).Distinct();
95        }
96
97        List<string> variableNames = (from modelVarRefs in variableReferences.Values
98                                      from variableName in modelVarRefs
99                                      select variableName)
100                                     .Distinct()
101                                     .ToList();
102
103        List<string> statictics = new List<string> { "Median Impact", "Mean Impact", "StdDev", "pValue" };
104        List<string> columnNames = (from run in runsWithSolutions
105                                    select run.Name).ToList();
106        columnNames.AddRange(statictics);
107
108        matrix = new DoubleMatrix(variableNames.Count, columnNames.Count);
109        matrix.SortableView = true;
110        matrix.RowNames = variableNames;
111        matrix.ColumnNames = columnNames;
112        Random random = new Random();
113        int columnIndex = 0;
114        foreach (SymbolicRegressionSolution solution in variableReferences.Keys) {
115          foreach (string variableName in variableReferences[solution]) {
116            int rowIndex = variableNames.IndexOf(variableName);
117            if (rowIndex > -1) {
118              matrix[rowIndex, columnIndex] = ApproximatePermutationImpact(random, variableName, solution);
119            }
120          }
121          columnIndex++;
122        }
123        List<List<double>> variableImpactValues = (from row in Enumerable.Range(0, variableNames.Count())
124                                                   select GetRowValues(matrix, row).ToList())
125                                                         .ToList();
126        List<double> referenceValues = (from variableImpacts in variableImpactValues
127                                        orderby variableImpacts.Average()
128                                        select variableImpacts)
129                                       .First();
130        for (int row = 0; row < variableNames.Count; row++) {
131          List<double> rowValues = variableImpactValues[row];
132          matrix[row, columnIndex] = rowValues.Median();
133          matrix[row, columnIndex + 1] = rowValues.Average();
134          matrix[row, columnIndex + 2] = rowValues.StandardDeviation();
135
136          double bothTails, leftTail, rightTail;
137          bothTails = leftTail = rightTail = 0.0;
138          double[] z = new double[rowValues.Count()];
139          for (int i = 0; i < z.Length; i++) {
140            z[i] = rowValues[i] - referenceValues[i];
141          }
142          alglib.wsr.wilcoxonsignedranktest(z, z.Length, 0.0, ref bothTails, ref leftTail, ref rightTail);
143          matrix[row, columnIndex + 3] = bothTails;
144        }
145      }
146      return matrix;
147    }
148
149    private IEnumerable<double> GetRowValues(DoubleMatrix matrix, int row) {
150      return from col in Enumerable.Range(0, matrix.Columns)
151             select matrix[row, col];
152    }
153
154    private IEnumerable<string> GetVariableReferences(SymbolicRegressionSolution solution) {
155      return from node in solution.Model.SymbolicExpressionTree.IterateNodesPostfix().OfType<VariableTreeNode>()
156             select node.VariableName;
157    }
158
159    private double ApproximatePermutationImpact(Random random, string variableName, SymbolicRegressionSolution solution) {
160      int permutations = 10;
161      int variableIndex = solution.ProblemData.Dataset.GetVariableIndex(variableName);
162      List<double> originalOutput = new List<double>(solution.EstimatedValues);
163      Dataset originalDataset = solution.ProblemData.Dataset;
164
165      int rows = solution.ProblemData.Dataset.Rows;
166      int columns = solution.ProblemData.Dataset.Columns;
167      List<int> rowIndexPermutation = Enumerable.Range(0, rows).ToList();
168      double mseSum = 0.0;
169      for (int rep = 0; rep < permutations; rep++) {
170        double[,] manipulatedData = new double[rows, columns];
171        Shuffle(random, rowIndexPermutation);
172        for (int row = 0; row < rows; row++) {
173          for (int column = 0; column < columns; column++) {
174            if (column == variableIndex) {
175              manipulatedData[row, column] = solution.ProblemData.Dataset[row, column];
176            } else {
177              manipulatedData[row, column] = solution.ProblemData.Dataset[rowIndexPermutation[row], column];
178            }
179          }
180        }
181
182        Dataset manipulatedDataset = new Dataset(solution.ProblemData.Dataset.VariableNames, manipulatedData);
183        solution.ProblemData.Dataset = manipulatedDataset;
184        double mse = SimpleMSEEvaluator.Calculate(originalOutput, solution.EstimatedValues);
185        mseSum += mse;
186      }
187
188      solution.ProblemData.Dataset = originalDataset;
189      return mseSum / permutations;
190    }
191
192    private void Shuffle(Random random, List<int> xs) {
193      for (int i = xs.Count; i > 1; i--) {
194        int j = random.Next(i);
195        int tmp = xs[j];
196        xs[j] = xs[i - 1];
197        xs[i - 1] = tmp;
198      }
199    }
200  }
201}
Note: See TracBrowser for help on using the repository browser.