Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Views/3.3/RunCollectionMonteCarloVariableImpactView.cs @ 11155

Last change on this file since 11155 was 5275, checked in by gkronber, 14 years ago

Merged changes from trunk to data analysis exploration branch and added fractional distance metric evaluator. #1142

File size: 9.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using System.Windows.Forms;
25using HeuristicLab.Common;
26using HeuristicLab.Data;
27using HeuristicLab.MainForm;
28using HeuristicLab.MainForm.WindowsForms;
29using HeuristicLab.Optimization;
30using System;
31using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
33using HeuristicLab.Problems.DataAnalysis.Evaluators;
34
35namespace HeuristicLab.Problems.DataAnalysis.Views {
36  [Content(typeof(RunCollection), false)]
37  [View("RunCollection Monte-Carlo Variable Impact View")]
38  public partial class RunCollectionMonteCarloVariableImpactView : AsynchronousContentView {
39    private const string validationBestModelResultName = "Best solution (on validation set)";
40    public RunCollectionMonteCarloVariableImpactView() {
41      InitializeComponent();
42    }
43
44    public new RunCollection Content {
45      get { return (RunCollection)base.Content; }
46      set { base.Content = value; }
47    }
48
49    protected override void RegisterContentEvents() {
50      base.RegisterContentEvents();
51      this.Content.ItemsAdded += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsAdded);
52      this.Content.ItemsRemoved += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsRemoved);
53      this.Content.CollectionReset += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_CollectionReset);
54    }
55    protected override void DeregisterContentEvents() {
56      base.RegisterContentEvents();
57      this.Content.ItemsAdded -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsAdded);
58      this.Content.ItemsRemoved -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsRemoved);
59      this.Content.CollectionReset -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_CollectionReset);
60    }
61
62    protected override void OnContentChanged() {
63      base.OnContentChanged();
64      this.UpdateData();
65    }
66    private void Content_ItemsAdded(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
67      this.UpdateData();
68    }
69    private void Content_ItemsRemoved(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
70      this.UpdateData();
71    }
72    private void Content_CollectionReset(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
73      this.UpdateData();
74    }
75
76    private void UpdateData() {
77      matrixView.Content = CalculateVariableImpactMatrix();
78    }
79
80    public DoubleMatrix CalculateVariableImpactMatrix() {
81      DoubleMatrix matrix = null;
82      if (Content != null) {
83        List<IRun> runsWithSolutions = (from run in Content
84                                        where run.Results.ContainsKey(validationBestModelResultName)
85                                        select run)
86                                              .ToList();
87        IEnumerable<SymbolicRegressionSolution> allSolutions = (from run in Content
88                                                                where run.Results.ContainsKey(validationBestModelResultName)
89                                                                select run.Results[validationBestModelResultName]).Cast<SymbolicRegressionSolution>();
90
91        Dictionary<SymbolicRegressionSolution, IEnumerable<string>> variableReferences = new Dictionary<SymbolicRegressionSolution, IEnumerable<string>>();
92        foreach (var solution in allSolutions) {
93          variableReferences[solution] = GetVariableReferences(solution).Distinct();
94        }
95
96        List<string> variableNames = (from modelVarRefs in variableReferences.Values
97                                      from variableName in modelVarRefs
98                                      select variableName)
99                                     .Distinct()
100                                     .ToList();
101
102        List<string> statictics = new List<string> { "Median Impact", "Mean Impact", "StdDev", "pValue" };
103        List<string> columnNames = (from run in runsWithSolutions
104                                    select run.Name).ToList();
105        columnNames.AddRange(statictics);
106
107        matrix = new DoubleMatrix(variableNames.Count, columnNames.Count);
108        matrix.SortableView = true;
109        matrix.RowNames = variableNames;
110        matrix.ColumnNames = columnNames;
111        Random random = new Random();
112        int columnIndex = 0;
113        foreach (SymbolicRegressionSolution solution in variableReferences.Keys) {
114          foreach (string variableName in variableReferences[solution]) {
115            int rowIndex = variableNames.IndexOf(variableName);
116            if (rowIndex > -1) {
117              matrix[rowIndex, columnIndex] = ApproximatePermutationImpact(random, variableName, solution);
118            }
119          }
120          columnIndex++;
121        }
122        List<List<double>> variableImpactValues = (from row in Enumerable.Range(0, variableNames.Count())
123                                                   select GetRowValues(matrix, row).ToList())
124                                                         .ToList();
125        List<double> referenceValues = (from variableImpacts in variableImpactValues
126                                        orderby variableImpacts.Average()
127                                        select variableImpacts)
128                                       .First();
129        for (int row = 0; row < variableNames.Count; row++) {
130          List<double> rowValues = variableImpactValues[row];
131          matrix[row, columnIndex] = rowValues.Median();
132          matrix[row, columnIndex + 1] = rowValues.Average();
133          matrix[row, columnIndex + 2] = rowValues.StandardDeviation();
134
135          double bothTails, leftTail, rightTail;
136          bothTails = leftTail = rightTail = 0.0;
137          double[] z = new double[rowValues.Count()];
138          for (int i = 0; i < z.Length; i++) {
139            z[i] = rowValues[i] - referenceValues[i];
140          }
141          alglib.wsr.wilcoxonsignedranktest(z, z.Length, 0.0, ref bothTails, ref leftTail, ref rightTail);
142          matrix[row, columnIndex + 3] = bothTails;
143        }
144      }
145      return matrix;
146    }
147
148    private IEnumerable<double> GetRowValues(DoubleMatrix matrix, int row) {
149      return from col in Enumerable.Range(0, matrix.Columns)
150             select matrix[row, col];
151    }
152
153    private IEnumerable<string> GetVariableReferences(SymbolicRegressionSolution solution) {
154      return from node in solution.Model.SymbolicExpressionTree.IterateNodesPostfix().OfType<VariableTreeNode>()
155             select node.VariableName;
156    }
157
158    private double ApproximatePermutationImpact(Random random, string variableName, SymbolicRegressionSolution solution) {
159      int permutations = 10;
160      int variableIndex = solution.ProblemData.Dataset.GetVariableIndex(variableName);
161      List<double> originalOutput = new List<double>(solution.EstimatedValues);
162      Dataset originalDataset = solution.ProblemData.Dataset;
163
164      int rows = solution.ProblemData.Dataset.Rows;
165      int columns = solution.ProblemData.Dataset.Columns;
166      List<int> rowIndexPermutation = Enumerable.Range(0, rows).ToList();
167      double mseSum = 0.0;
168      for (int rep = 0; rep < permutations; rep++) {
169        double[,] manipulatedData = new double[rows, columns];
170        Shuffle(random, rowIndexPermutation);
171        for (int row = 0; row < rows; row++) {
172          for (int column = 0; column < columns; column++) {
173            if (column == variableIndex) {
174              manipulatedData[row, column] = solution.ProblemData.Dataset[row, column];
175            } else {
176              manipulatedData[row, column] = solution.ProblemData.Dataset[rowIndexPermutation[row], column];
177            }
178          }
179        }
180
181        Dataset manipulatedDataset = new Dataset(solution.ProblemData.Dataset.VariableNames, manipulatedData);
182        solution.ProblemData.Dataset = manipulatedDataset;
183        double mse = SimpleMSEEvaluator.Calculate(originalOutput, solution.EstimatedValues);
184        mseSum += mse;
185      }
186
187      solution.ProblemData.Dataset = originalDataset;
188      return mseSum / permutations;
189    }
190
191    private void Shuffle(Random random, List<int> xs) {
192      for (int i = xs.Count; i > 1; i--) {
193        int j = random.Next(i);
194        int tmp = xs[j];
195        xs[j] = xs[i - 1];
196        xs[i - 1] = tmp;
197      }
198    }
199  }
200}
Note: See TracBrowser for help on using the repository browser.