source: branches/HeuristicLab.EvolutionTracking/HeuristicLab.Problems.DataAnalysis.Symbolic.Views/3.4/RunCollectionVariableImpactView.cs @ 11208

Last change on this file since 11208 was 11208, checked in by bburlacu, 6 years ago

#1772: Merged trunk changes.

File size: 16.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using HeuristicLab.Common;
28using HeuristicLab.Data;
29using HeuristicLab.MainForm;
30using HeuristicLab.MainForm.WindowsForms;
31using HeuristicLab.Optimization;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Views {
34  [Content(typeof(RunCollection), false)]
35  [View("Variable Impacts")]
36  public sealed partial class RunCollectionVariableImpactView : AsynchronousContentView {
37    private const string variableImpactResultName = "Variable impacts";
38    private const string crossValidationFoldsResultName = "CrossValidation Folds";
39    private const string numberOfFoldsParameterName = "Folds";
40    public RunCollectionVariableImpactView() {
41      InitializeComponent();
42    }
43
44    public new RunCollection Content {
45      get { return (RunCollection)base.Content; }
46      set { base.Content = value; }
47    }
48
49    #region events
50    protected override void RegisterContentEvents() {
51      base.RegisterContentEvents();
52      Content.UpdateOfRunsInProgressChanged += Content_UpdateOfRunsInProgressChanged;
53      Content.ItemsAdded += Content_ItemsAdded;
54      Content.ItemsRemoved += Content_ItemsRemoved;
55      Content.CollectionReset += Content_CollectionReset;
56      RegisterRunEvents(Content);
57    }
58    protected override void DeregisterContentEvents() {
59      base.RegisterContentEvents();
60      Content.UpdateOfRunsInProgressChanged -= Content_UpdateOfRunsInProgressChanged;
61      Content.ItemsAdded -= Content_ItemsAdded;
62      Content.ItemsRemoved -= Content_ItemsRemoved;
63      Content.CollectionReset -= Content_CollectionReset;
64      DeregisterRunEvents(Content);
65    }
66    private void RegisterRunEvents(IEnumerable<IRun> runs) {
67      foreach (IRun run in runs)
68        run.Changed += Run_Changed;
69    }
70    private void DeregisterRunEvents(IEnumerable<IRun> runs) {
71      foreach (IRun run in runs)
72        run.Changed -= Run_Changed;
73    }
74    private void Content_ItemsAdded(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
75      RegisterRunEvents(e.Items);
76      UpdateData();
77    }
78    private void Content_ItemsRemoved(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
79      DeregisterRunEvents(e.Items);
80      UpdateData();
81    }
82    private void Content_CollectionReset(object sender, HeuristicLab.Collections.CollectionItemsChangedEventArgs<IRun> e) {
83      DeregisterRunEvents(e.OldItems);
84      RegisterRunEvents(e.Items);
85      UpdateData();
86    }
87    private void Content_UpdateOfRunsInProgressChanged(object sender, EventArgs e) {
88      if (!Content.UpdateOfRunsInProgress) UpdateData();
89    }
90    private void Run_Changed(object sender, EventArgs e) {
91      if (!Content.UpdateOfRunsInProgress) UpdateData();
92    }
93    #endregion
94
95    protected override void OnContentChanged() {
96      base.OnContentChanged();
97      this.UpdateData();
98    }
99
100    private void comboBox_SelectedValueChanged(object sender, EventArgs e) {
101      if (comboBox.SelectedItem != null) {
102        var visibleRuns = from run in Content where run.Visible select run;
103        if (comboBox.SelectedIndex == 0) {
104          var selectedFolds = from r in visibleRuns
105                              let foldCollection = (RunCollection)r.Results[crossValidationFoldsResultName]
106                              from run in foldCollection
107                              let name = (r.Name + " " + run.Name)
108                              select new { run, name };
109          matrixView.Content = CalculateVariableImpactMatrix(selectedFolds.Select(x => x.run).ToArray(), selectedFolds.Select(x => x.name).ToArray());
110        } else {
111          var selectedFolds = from r in visibleRuns
112                              let foldCollection = (RunCollection)r.Results[crossValidationFoldsResultName]
113                              let run = foldCollection.ElementAt(comboBox.SelectedIndex - 1)
114                              let name = (r.Name + " " + run.Name)
115                              select new { run, name };
116          matrixView.Content = CalculateVariableImpactMatrix(selectedFolds.Select(x => x.run).ToArray(), selectedFolds.Select(x => x.name).ToArray());
117        }
118      }
119    }
120
121
122    private void UpdateData() {
123      if (InvokeRequired) {
124        Invoke((Action)UpdateData);
125      } else {
126        if (Content != null) {
127          comboBox.Items.Clear();
128          comboBox.Enabled = false;
129          comboBox.Visible = false;
130          foldsLabel.Visible = false;
131          variableImpactsGroupBox.Dock = DockStyle.Fill;
132          var visibleRuns = Content.Where(r => r.Visible).ToArray();
133          if (visibleRuns.Length == 0) {
134            DisplayMessage("Run collection is empty.");
135          } else if (visibleRuns.All(r => r.Parameters.ContainsKey(numberOfFoldsParameterName))) {
136            // check if all runs are comparable (CV or normal runs)
137            CheckAndUpdateCvRuns();
138          } else if (visibleRuns.All(r => !r.Parameters.ContainsKey(numberOfFoldsParameterName))) {
139            CheckAndUpdateNormalRuns();
140          } else {
141            // there is a mix of CV and normal runs => show an error message
142            DisplayMessage("The run collection contains a mixture of normal runs and cross-validation runs. Variable impact calculation does not work in this case.");
143          }
144        }
145      }
146    }
147
148    private void CheckAndUpdateCvRuns() {
149      var visibleRuns = from run in Content where run.Visible select run;
150      var representativeRun = visibleRuns.First();
151      // make sure all runs have the same number of folds
152      int nFolds = ((IntValue)representativeRun.Parameters[numberOfFoldsParameterName]).Value;
153      if (visibleRuns.All(r => ((IntValue)r.Parameters[numberOfFoldsParameterName]).Value == nFolds)) {
154        var allFoldResults = visibleRuns.SelectMany(run => (RunCollection)run.Results[crossValidationFoldsResultName]);
155
156        // make sure each fold contains variable impacts
157        if (!allFoldResults.All(r => r.Results.ContainsKey(variableImpactResultName))) {
158          DisplayMessage("At least one of the runs does not contain a variable impact result.");
159        } else {
160          // make sure each of the runs has the same input variables
161          var allVariableNames = from run in allFoldResults
162                                 let varImpacts = (DoubleMatrix)run.Results[variableImpactResultName]
163                                 select varImpacts.RowNames;
164          var groupedVariableNames = allVariableNames
165            .SelectMany(x => x)
166            .GroupBy(x => x);
167
168          if (groupedVariableNames.Any(g => g.Count() != allFoldResults.Count())) {
169            DisplayMessage("At least one of the runs has a different input variable set than the rest.");
170          } else {
171            // populate combobox
172            comboBox.Items.Add("Overall");
173            for (int foldIndex = 0; foldIndex < nFolds; foldIndex++) {
174              comboBox.Items.Add("Fold " + foldIndex);
175            }
176            comboBox.SelectedIndex = 0;
177            comboBox.Enabled = true;
178            comboBox.Visible = true;
179            foldsLabel.Visible = true;
180            variableImpactsGroupBox.Controls.Clear();
181            variableImpactsGroupBox.Dock = DockStyle.None;
182            variableImpactsGroupBox.Anchor = AnchorStyles.Left | AnchorStyles.Top | AnchorStyles.Right |
183                                             AnchorStyles.Bottom;
184            variableImpactsGroupBox.Height = this.Height - comboBox.Height - 12;
185            variableImpactsGroupBox.Width = this.Width;
186            matrixView.Dock = DockStyle.Fill;
187            variableImpactsGroupBox.Controls.Add(matrixView);
188          }
189        }
190      } else {
191        DisplayMessage("At least on of the cross-validation runs has a different number of folds than the rest.");
192      }
193    }
194
195    private void CheckAndUpdateNormalRuns() {
196      // make sure all runs contain variable impact results
197      var visibleRuns = from run in Content where run.Visible select run;
198
199      if (!visibleRuns.All(r => r.Results.ContainsKey(variableImpactResultName))) {
200        DisplayMessage("At least one of the runs does not contain a variable impact result.");
201      } else {
202        // make sure each of the runs has the same input variables
203        var allVariableNames = from run in visibleRuns
204                               let varImpacts = (DoubleMatrix)run.Results[variableImpactResultName]
205                               select varImpacts.RowNames;
206        var groupedVariableNames = allVariableNames
207          .SelectMany(x => x)
208          .GroupBy(x => x);
209
210        if (groupedVariableNames.Any(g => g.Count() != visibleRuns.Count())) {
211          DisplayMessage("At least one of the runs has a different input variable set than the rest.");
212        } else {
213          if (!variableImpactsGroupBox.Controls.Contains(matrixView)) {
214            variableImpactsGroupBox.Controls.Clear();
215            matrixView.Dock = DockStyle.Fill;
216            variableImpactsGroupBox.Controls.Add(matrixView);
217          }
218          matrixView.Content = CalculateVariableImpactMatrix(visibleRuns.ToArray(), visibleRuns.Select(r => r.Name).ToArray());
219        }
220      }
221    }
222
223    private DoubleMatrix CalculateVariableImpactMatrix(IRun[] runs, string[] runNames) {
224      DoubleMatrix matrix = null;
225      IEnumerable<DoubleMatrix> allVariableImpacts = (from run in runs
226                                                      select run.Results[variableImpactResultName]).Cast<DoubleMatrix>();
227      IEnumerable<string> variableNames = (from variableImpact in allVariableImpacts
228                                           from variableName in variableImpact.RowNames
229                                           select variableName)
230                                          .Distinct();
231      // filter variableNames: only include names that have at least one non-zero value in a run
232      List<string> variableNamesList = (from variableName in variableNames
233                                        where GetVariableImpacts(variableName, allVariableImpacts).Any(x => !x.IsAlmost(0.0))
234                                        select variableName)
235                                       .ToList();
236
237      List<string> statictics = new List<string> { "Median Rank", "Mean", "StdDev", "pValue" };
238      List<string> columnNames = new List<string>(runNames);
239      columnNames.AddRange(statictics);
240      int numberOfRuns = runs.Length;
241
242      matrix = new DoubleMatrix(variableNamesList.Count, numberOfRuns + statictics.Count);
243      matrix.SortableView = true;
244      matrix.ColumnNames = columnNames;
245
246      // calculate statistics
247      List<List<double>> variableImpactsOverRuns = (from variableName in variableNamesList
248                                                    select GetVariableImpacts(variableName, allVariableImpacts).ToList())
249                                             .ToList();
250      List<List<double>> variableRanks = (from variableName in variableNamesList
251                                          select GetVariableImpactRanks(variableName, allVariableImpacts).ToList())
252                                      .ToList();
253      if (variableImpactsOverRuns.Count() > 0) {
254        // the variable with the worst median impact value is chosen as the reference variable
255        // this is problematic if all variables are relevant, however works often in practice
256        List<double> referenceImpacts = (from impacts in variableImpactsOverRuns
257                                         let avg = impacts.Median()
258                                         orderby avg
259                                         select impacts)
260                                         .First();
261        // for all variables
262        for (int row = 0; row < variableImpactsOverRuns.Count; row++) {
263          // median rank
264          matrix[row, numberOfRuns] = variableRanks[row].Median();
265          // also show mean and std.dev. of relative variable impacts to indicate the relative difference in impacts of variables
266          matrix[row, numberOfRuns + 1] = Math.Round(variableImpactsOverRuns[row].Average(), 3);
267          matrix[row, numberOfRuns + 2] = Math.Round(variableImpactsOverRuns[row].StandardDeviation(), 3);
268
269          double leftTail = 0; double rightTail = 0; double bothTails = 0;
270          // calc differences of impacts for current variable and reference variable
271          double[] z = new double[referenceImpacts.Count];
272          for (int i = 0; i < z.Length; i++) {
273            z[i] = variableImpactsOverRuns[row][i] - referenceImpacts[i];
274          }
275          // wilcoxon signed rank test is used because the impact values of two variables in a single run are not independent
276          alglib.wsr.wilcoxonsignedranktest(z, z.Length, 0, ref bothTails, ref leftTail, ref rightTail);
277          matrix[row, numberOfRuns + 3] = Math.Round(bothTails, 4);
278        }
279      }
280
281      // fill matrix with impacts from runs
282      for (int i = 0; i < runs.Length; i++) {
283        IRun run = runs[i];
284        DoubleMatrix runVariableImpacts = (DoubleMatrix)run.Results[variableImpactResultName];
285        for (int j = 0; j < runVariableImpacts.Rows; j++) {
286          int rowIndex = variableNamesList.FindIndex(s => s == runVariableImpacts.RowNames.ElementAt(j));
287          if (rowIndex > -1) {
288            matrix[rowIndex, i] = Math.Round(runVariableImpacts[j, 0], 3);
289          }
290        }
291      }
292      // sort by median
293      var sortedMatrix = (DoubleMatrix)matrix.Clone();
294      var sortedIndexes = from i in Enumerable.Range(0, sortedMatrix.Rows)
295                          orderby matrix[i, numberOfRuns]
296                          select i;
297
298      int targetIndex = 0;
299      foreach (var sourceIndex in sortedIndexes) {
300        for (int c = 0; c < matrix.Columns; c++)
301          sortedMatrix[targetIndex, c] = matrix[sourceIndex, c];
302        targetIndex++;
303      }
304      sortedMatrix.RowNames = sortedIndexes.Select(i => variableNamesList[i]);
305
306      return sortedMatrix;
307    }
308
309    private IEnumerable<double> GetVariableImpactRanks(string variableName, IEnumerable<DoubleMatrix> allVariableImpacts) {
310      foreach (DoubleMatrix runVariableImpacts in allVariableImpacts) {
311        // certainly not yet very efficient because ranks are computed multiple times for the same run
312        string[] variableNames = runVariableImpacts.RowNames.ToArray();
313        double[] values = (from row in Enumerable.Range(0, runVariableImpacts.Rows)
314                           select runVariableImpacts[row, 0] * -1)
315                          .ToArray();
316        Array.Sort(values, variableNames);
317        // calculate ranks
318        double[] ranks = new double[values.Length];
319        // check for tied ranks
320        int i = 0;
321        while (i < values.Length) {
322          ranks[i] = i + 1;
323          int j = i + 1;
324          while (j < values.Length && values[i].IsAlmost(values[j])) {
325            ranks[j] = ranks[i];
326            j++;
327          }
328          i = j;
329        }
330        int rankIndex = 0;
331        foreach (string rowVariableName in variableNames) {
332          if (rowVariableName == variableName)
333            yield return ranks[rankIndex];
334          rankIndex++;
335        }
336      }
337    }
338
339    private IEnumerable<double> GetVariableImpacts(string variableName, IEnumerable<DoubleMatrix> allVariableImpacts) {
340      foreach (DoubleMatrix runVariableImpacts in allVariableImpacts) {
341        int row = 0;
342        foreach (string rowName in runVariableImpacts.RowNames) {
343          if (rowName == variableName)
344            yield return runVariableImpacts[row, 0];
345          row++;
346        }
347      }
348    }
349
350    private void DisplayMessage(string message) {
351      variableImpactsGroupBox.Controls.Remove(matrixView);
352      var label = new Label { TextAlign = ContentAlignment.MiddleCenter, Text = message, Dock = DockStyle.Fill };
353      variableImpactsGroupBox.Controls.Add(label);
354    }
355  }
356}
Note: See TracBrowser for help on using the repository browser.