Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Views/3.4/RunCollectionVariableImpactView.cs @ 6766

Last change on this file since 6766 was 6766, checked in by gkronber, 13 years ago

#1635 added possibility to calculate impacts over all folds for cross validation runs.

File size: 13.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using HeuristicLab.Common;
27using HeuristicLab.Data;
28using HeuristicLab.MainForm;
29using HeuristicLab.MainForm.WindowsForms;
30using HeuristicLab.Optimization;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Views {
33  [Content(typeof(RunCollection), false)]
34  [View("Variable Impacts")]
35  public sealed partial class RunCollectionVariableImpactView : AsynchronousContentView {
36    private const string variableImpactResultName = "Variable impacts";
37    private const string crossValidationFoldsResultName = "CrossValidation Folds";
38    private const string numberOfFoldsParameterName = "Folds";
39    public RunCollectionVariableImpactView() {
40      InitializeComponent();
41    }
42
43    public new RunCollection Content {
44      get { return (RunCollection)base.Content; }
45      set { base.Content = value; }
46    }
47
48    #region events
49    protected override void RegisterContentEvents() {
50      base.RegisterContentEvents();
51      Content.UpdateOfRunsInProgressChanged += new EventHandler(Content_UpdateOfRunsInProgressChanged);
52      Content.ItemsAdded += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsAdded);
53      Content.ItemsRemoved += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsRemoved);
54      Content.CollectionReset += new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_CollectionReset);
55      RegisterRunEvents(Content);
56    }
57    protected override void DeregisterContentEvents() {
58      base.RegisterContentEvents();
59      Content.UpdateOfRunsInProgressChanged -= new EventHandler(Content_UpdateOfRunsInProgressChanged);
60      Content.ItemsAdded -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsAdded);
61      Content.ItemsRemoved -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_ItemsRemoved);
62      Content.CollectionReset -= new HeuristicLab.Collections.CollectionItemsChangedEventHandler<IRun>(Content_CollectionReset);
63      DeregisterRunEvents(Content);
64    }
65    private void RegisterRunEvents(IEnumerable<IRun> runs) {
66      foreach (IRun run in runs)
67        run.Changed += new EventHandler(Run_Changed);
68    }
69    private void DeregisterRunEvents(IEnumerable<IRun> runs) {
70      foreach (IRun run in runs)
71        run.Changed -= new EventHandler(Run_Changed);
72    }
73    private void Content_ItemsAdded(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
74      RegisterRunEvents(e.Items);
75      UpdateData();
76    }
77    private void Content_ItemsRemoved(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
78      DeregisterRunEvents(e.Items);
79      UpdateData();
80    }
81    private void Content_CollectionReset(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
82      DeregisterRunEvents(e.OldItems);
83      RegisterRunEvents(e.Items);
84      UpdateData();
85    }
86    private void Content_UpdateOfRunsInProgressChanged(object sender, EventArgs e) {
87      if (!Content.UpdateOfRunsInProgress) UpdateData();
88    }
89    private void Run_Changed(object sender, EventArgs e) {
90      if (!Content.UpdateOfRunsInProgress) UpdateData();
91    }
92    #endregion
93
94    protected override void OnContentChanged() {
95      base.OnContentChanged();
96      this.UpdateData();
97    }
98
99    private void comboBox_SelectedValueChanged(object sender, EventArgs e) {
100      if (comboBox.SelectedItem != null) {
101        var cvRuns = from r in Content
102                     where r.Visible
103                     where r.Parameters.ContainsKey(numberOfFoldsParameterName)
104                     select r;
105        if (comboBox.SelectedIndex == 0) {
106          var selectedFolds = cvRuns.SelectMany(r => (RunCollection)r.Results[crossValidationFoldsResultName]);
107          matrixView.Content = CalculateVariableImpactMatrix(selectedFolds.ToArray());
108        } else {
109          var selectedFolds = from r in cvRuns
110                              let foldCollection = (RunCollection)r.Results[crossValidationFoldsResultName]
111                              select (IRun)foldCollection.ElementAt(comboBox.SelectedIndex - 1).Clone();
112          matrixView.Content = CalculateVariableImpactMatrix(selectedFolds.ToArray(), cvRuns.Select(r => r.Name).ToArray());
113        }
114      }
115    }
116
117
118    private void UpdateData() {
119      if (Content != null) {
120        comboBox.Items.Clear();
121        comboBox.Enabled = false;
122        var visibleRuns = Content.Where(r => r.Visible).ToArray();
123        var representativeCvRun =
124          visibleRuns.Where(r => r.Parameters.ContainsKey(numberOfFoldsParameterName)).FirstOrDefault();
125        if (representativeCvRun != null) {
126          // make sure all runs have the same number of folds
127          int nFolds = ((IntValue)representativeCvRun.Parameters[numberOfFoldsParameterName]).Value;
128          var cvRuns = visibleRuns.Where(r => r.Parameters.ContainsKey(numberOfFoldsParameterName));
129          if (cvRuns.All(r => ((IntValue)r.Parameters[numberOfFoldsParameterName]).Value == nFolds)) {
130            // populate combobox
131            comboBox.Items.Add("Overall");
132            for (int foldIndex = 0; foldIndex < nFolds; foldIndex++) {
133              comboBox.Items.Add("Fold " + foldIndex);
134            }
135            comboBox.SelectedIndex = 0;
136            comboBox.Enabled = true;
137            var selectedFolds = cvRuns.SelectMany(r => (RunCollection)r.Results[crossValidationFoldsResultName]);
138            matrixView.Content = CalculateVariableImpactMatrix(selectedFolds.ToArray());
139          } else {
140            matrixView.Content = null;
141          }
142        } else {
143          var runsWithVariables = visibleRuns.Where(r => r.Results.ContainsKey(variableImpactResultName)).ToArray();
144          matrixView.Content = CalculateVariableImpactMatrix(runsWithVariables);
145        }
146      }
147    }
148
149    private IStringConvertibleMatrix CalculateVariableImpactMatrix(IRun[] runs) {
150      return CalculateVariableImpactMatrix(runs, runs.Select(r => r.Name).ToArray());
151    }
152
153    private DoubleMatrix CalculateVariableImpactMatrix(IRun[] runs, string[] runNames) {
154      DoubleMatrix matrix = null;
155      IEnumerable<DoubleMatrix> allVariableImpacts = (from run in runs
156                                                      select run.Results[variableImpactResultName]).Cast<DoubleMatrix>();
157      IEnumerable<string> variableNames = (from variableImpact in allVariableImpacts
158                                           from variableName in variableImpact.RowNames
159                                           select variableName)
160                                          .Distinct();
161      // filter variableNames: only include names that have at least one non-zero value in a run
162      List<string> variableNamesList = (from variableName in variableNames
163                                        where GetVariableImpacts(variableName, allVariableImpacts).Any(x => !x.IsAlmost(0.0))
164                                        select variableName)
165                                       .ToList();
166
167      List<string> statictics = new List<string> { "Median Rank", "Mean", "StdDev", "pValue" };
168      List<string> columnNames = new List<string>(runNames);
169      columnNames.AddRange(statictics);
170      int numberOfRuns = runs.Length;
171
172      matrix = new DoubleMatrix(variableNamesList.Count, numberOfRuns + statictics.Count);
173      matrix.SortableView = true;
174      matrix.RowNames = variableNamesList;
175      matrix.ColumnNames = columnNames;
176
177      // calculate statistics
178      List<List<double>> variableImpactsOverRuns = (from variableName in variableNamesList
179                                                    select GetVariableImpacts(variableName, allVariableImpacts).ToList())
180                                             .ToList();
181      List<List<double>> variableRanks = (from variableName in variableNamesList
182                                          select GetVariableImpactRanks(variableName, allVariableImpacts).ToList())
183                                      .ToList();
184      if (variableImpactsOverRuns.Count() > 0) {
185        // the variable with the worst median impact value is chosen as the reference variable
186        // this is problematic if all variables are relevant, however works often in practice
187        List<double> referenceImpacts = (from impacts in variableImpactsOverRuns
188                                         let avg = impacts.Median()
189                                         orderby avg
190                                         select impacts)
191                                         .First();
192        // for all variables
193        for (int row = 0; row < variableImpactsOverRuns.Count; row++) {
194          // median rank
195          matrix[row, numberOfRuns] = variableRanks[row].Median();
196          // also show mean and std.dev. of relative variable impacts to indicate the relative difference in impacts of variables
197          matrix[row, numberOfRuns + 1] = Math.Round(variableImpactsOverRuns[row].Average(), 3);
198          matrix[row, numberOfRuns + 2] = Math.Round(variableImpactsOverRuns[row].StandardDeviation(), 3);
199
200          double leftTail = 0; double rightTail = 0; double bothTails = 0;
201          // calc differences of impacts for current variable and reference variable
202          double[] z = new double[referenceImpacts.Count];
203          for (int i = 0; i < z.Length; i++) {
204            z[i] = variableImpactsOverRuns[row][i] - referenceImpacts[i];
205          }
206          // wilcoxon signed rank test is used because the impact values of two variables in a single run are not independent
207          alglib.wsr.wilcoxonsignedranktest(z, z.Length, 0, ref bothTails, ref leftTail, ref rightTail);
208          matrix[row, numberOfRuns + 3] = Math.Round(bothTails, 4);
209        }
210      }
211
212      // fill matrix with impacts from runs
213      for (int i = 0; i < runs.Length; i++) {
214        IRun run = runs[i];
215        DoubleMatrix runVariableImpacts = (DoubleMatrix)run.Results[variableImpactResultName];
216        for (int j = 0; j < runVariableImpacts.Rows; j++) {
217          int rowIndex = variableNamesList.FindIndex(s => s == runVariableImpacts.RowNames.ElementAt(j));
218          if (rowIndex > -1) {
219            matrix[rowIndex, i] = Math.Round(runVariableImpacts[j, 0], 3);
220          }
221        }
222      }
223      // sort by median
224      var sortedMatrix = (DoubleMatrix)matrix.Clone();
225      var sortedIndexes = from i in Enumerable.Range(0, sortedMatrix.Rows)
226                          orderby matrix[i, numberOfRuns]
227                          select i;
228
229      int targetIndex = 0;
230      foreach (var sourceIndex in sortedIndexes) {
231        for (int c = 0; c < matrix.Columns; c++)
232          sortedMatrix[targetIndex, c] = matrix[sourceIndex, c];
233        targetIndex++;
234      }
235      return sortedMatrix;
236    }
237
238    private IEnumerable<double> GetVariableImpactRanks(string variableName, IEnumerable<DoubleMatrix> allVariableImpacts) {
239      foreach (DoubleMatrix runVariableImpacts in allVariableImpacts) {
240        // certainly not yet very efficient because ranks are computed multiple times for the same run
241        string[] variableNames = runVariableImpacts.RowNames.ToArray();
242        double[] values = (from row in Enumerable.Range(0, runVariableImpacts.Rows)
243                           select runVariableImpacts[row, 0] * -1)
244                          .ToArray();
245        Array.Sort(values, variableNames);
246        // calculate ranks
247        double[] ranks = new double[values.Length];
248        // check for tied ranks
249        int i = 0;
250        while (i < values.Length) {
251          ranks[i] = i + 1;
252          int j = i + 1;
253          while (j < values.Length && values[i].IsAlmost(values[j])) {
254            ranks[j] = ranks[i];
255            j++;
256          }
257          i = j;
258        }
259        int rankIndex = 0;
260        foreach (string rowVariableName in variableNames) {
261          if (rowVariableName == variableName)
262            yield return ranks[rankIndex];
263          rankIndex++;
264        }
265      }
266    }
267
268    private IEnumerable<double> GetVariableImpacts(string variableName, IEnumerable<DoubleMatrix> allVariableImpacts) {
269      foreach (DoubleMatrix runVariableImpacts in allVariableImpacts) {
270        int row = 0;
271        foreach (string rowName in runVariableImpacts.RowNames) {
272          if (rowName == variableName)
273            yield return runVariableImpacts[row, 0];
274          row++;
275        }
276      }
277    }
278
279  }
280}
Note: See TracBrowser for help on using the repository browser.