Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2288_HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks.Views/3.3/RunCollectionVariableInteractionNetworkView.cs @ 16674

Last change on this file since 16674 was 16498, checked in by jzenisek, 6 years ago

#2288: adapted online calculation of variable impacts within VIN-view according to new interface of RegressionSolutionVariableImpactsCalculator.CalculateImpacts(..)

File size: 22.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.ComponentModel;
25using System.Drawing;
26using System.Globalization;
27using System.Linq;
28using System.Text;
29using System.Windows.Forms;
30using HeuristicLab.Common;
31using HeuristicLab.Core;
32using HeuristicLab.Core.Views;
33using HeuristicLab.Data;
34using HeuristicLab.MainForm;
35using HeuristicLab.Optimization;
36using HeuristicLab.Problems.DataAnalysis;
37using HeuristicLab.Visualization;
38using Ellipse = HeuristicLab.Visualization.Ellipse;
39using Rectangle = HeuristicLab.Visualization.Rectangle;
40
41namespace HeuristicLab.VariableInteractionNetworks.Views {
42  [View("Variable Interaction Network")]
43  [Content(typeof(RunCollection), IsDefaultView = false)]
44
45  public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView {
46    public RunCollectionVariableInteractionNetworkView() {
47      InitializeComponent();
48      ConfigureNodeShapes();
49    }
50
51    public new RunCollection Content {
52      get { return (RunCollection)base.Content; }
53      set {
54        if (value != null && value != Content) {
55          base.Content = value;
56        }
57      }
58    }
59
60    private VariableInteractionNetwork variableInteractionNetwork;
61
62    private static void AssertSameProblemData(RunCollection runs) {
63      IDataset dataset = null;
64      IRegressionProblemData problemData = null;
65      foreach (var run in runs) {
66        var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
67        var ds = solution.ProblemData.Dataset;
68
69        if (solution.ProblemData == problemData) continue;
70        if (ds == dataset) continue;
71        if (problemData == null) {
72          problemData = solution.ProblemData;
73          continue;
74        }
75        if (dataset == null) {
76          dataset = ds;
77          continue;
78        }
79
80        if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End)
81          throw new InvalidOperationException("The runs must share the same data.");
82
83        if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables))
84          throw new InvalidOperationException("The runs must share the same data.");
85
86        foreach (var v in ds.DoubleVariables) {
87          var values1 = (IList<double>)ds.GetReadOnlyDoubleValues(v);
88          var values2 = (IList<double>)dataset.GetReadOnlyDoubleValues(v);
89
90          if (values1.Count != values2.Count)
91            throw new InvalidOperationException("The runs must share the same data.");
92
93          if (!values1.SequenceEqual(values2))
94            throw new InvalidOperationException("The runs must share the same data.");
95        }
96      }
97    }
98
99    public static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable<IRun> runs) {
100      var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast<IRegressionSolution>();
101      return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData);
102    }
103
104    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) {
105      AssertSameProblemData(runs);
106      var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution);
107      var dataset = (Dataset)solution.ProblemData.Dataset;
108      var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList());
109      var medians = dataset.DoubleVariables.ToDictionary(x => x, x => Enumerable.Repeat(originalValues[x].Median(), originalValues[x].Count).ToList());
110
111      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
112
113      var groups = runs.GroupBy(run => {
114        var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
115        return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable;
116      });
117
118      if (useBest) {
119        // build network using only the best run for each target
120        foreach (var group in groups) {
121          var solutions = group.Select(run => Tuple.Create(run, (IRegressionSolution)run.Results.Values.Single(sol => sol is IRegressionSolution)));
122          var best = solutions.OrderBy(x => x.Item2.TrainingRSquared).Last();
123          var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(best.Item2,
124            RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle,
125            RegressionSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Best,
126            RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.All).ToDictionary(x => x.Item1, x => x.Item2);
127
128          targetImpacts[best.Item2.ProblemData.TargetVariable] = Tuple.Create(new[] { best.Item1 }.AsEnumerable(), impacts);
129        }
130      } else {
131        foreach (var group in groups) {
132          // calculate average impacts
133          var averageImpacts = new Dictionary<string, double>();
134          solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution);
135          foreach (var run in group) {
136            var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution);
137
138            DoubleLimit estimationLimits = null;
139            if (run.Parameters.ContainsKey("EstimationLimits")) {
140              estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"];
141            }
142            var md = dataset.ToModifiable();
143
144            var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol,
145              RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle,
146              RegressionSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Best,
147              RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.All);           
148            foreach (var t in impacts) {
149              if (averageImpacts.ContainsKey(t.Item1))
150                averageImpacts[t.Item1] += t.Item2;
151              else {
152                averageImpacts[t.Item1] = t.Item2;
153              }
154            }
155          }
156
157          var count = group.Count();
158          foreach (var v in averageImpacts.Keys.ToList()) {
159            averageImpacts[v] /= count;
160          }
161
162          targetImpacts[solution.ProblemData.TargetVariable] = Tuple.Create(group.AsEnumerable(), averageImpacts);
163        }
164      }
165      return targetImpacts;
166    }
167
168    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsFromRunResults(RunCollection runs,
169      string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) {
170
171      Func<IRun, double> getQuality = run => ((DoubleValue)run.Results[qualityResultName]).Value;
172      var targetGroups = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
173      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
174
175      if (useBestRunsPerTarget) {
176        foreach (var group in targetGroups) {
177          var ordered = group.OrderBy(getQuality);
178          var best = maximization ? ordered.Last() : ordered.First();
179          var pd = (IRegressionProblemData)best.Parameters["ProblemData"];
180          var target = group.Key;
181          var impacts = (DoubleMatrix)best.Results[impactsResultName];
182          targetImpacts[target] = Tuple.Create((IEnumerable<IRun>)new[] { best }, impacts.RowNames.Select((x, i) => new { x, i }).ToDictionary(x => x.x, x => impacts[x.i, 0]));
183        }
184      } else {
185        foreach (var target in targetGroups) {
186          var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName);
187          targetImpacts[target.Key] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(target, averageImpacts);
188        }
189      }
190      return targetImpacts;
191    }
192
193    public static VariableInteractionNetwork CreateNetwork(Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> targetImpacts) {
194      var nodes = new Dictionary<string, IVertex>();
195      var vn = new VariableInteractionNetwork();
196      foreach (var ti in targetImpacts) {
197        var target = ti.Key;
198        var variableImpacts = ti.Value.Item2;
199        var targetRuns = ti.Value.Item1;
200        IVertex targetNode;
201
202        var variables = variableImpacts.Keys.ToList();
203        if (variables.Count == 0) continue;
204
205        if (!nodes.TryGetValue(target, out targetNode)) {
206          targetNode = new VariableNetworkNode { Label = target };
207          vn.AddVertex(targetNode);
208          nodes[target] = targetNode;
209        }
210
211        IVertex variableNode;
212        if (variables.Count > 1) {
213          var variableList = new List<string>(variables) { target };
214          var junctionLabel = Concatenate(variableList);
215          IVertex junctionNode;
216          var sb = new StringBuilder();
217          if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
218            var solutionsEnsemble = CreateEnsembleSolution(targetRuns);
219            junctionNode = new JunctionNetworkNode { Label = solutionsEnsemble.TrainingRSquared.ToString("N3", CultureInfo.CurrentCulture), Data = solutionsEnsemble };
220            vn.AddVertex(junctionNode);
221            nodes[junctionLabel] = junctionNode;
222            sb.AppendLine(junctionNode.Label);
223          }
224          IArc arc;
225          foreach (var v in variables) {
226            var impact = variableImpacts[v];
227            if (!nodes.TryGetValue(v, out variableNode)) {
228              variableNode = new VariableNetworkNode { Label = v };
229              vn.AddVertex(variableNode);
230              nodes[v] = variableNode;
231            }
232            arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture) };
233            sb.AppendLine(v + ": " + arc.Label);
234            vn.AddArc(arc);
235          }
236          var jcnNode = (JunctionNetworkNode)junctionNode;
237          var trainingR2 = ((IRegressionSolution)jcnNode.Data).TrainingRSquared;
238          arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = trainingR2.ToString("N3", CultureInfo.CurrentCulture) };
239          vn.AddArc(arc);
240        } else {
241          foreach (var v in variables) {
242            var impact = variableImpacts[v];
243            if (!nodes.TryGetValue(v, out variableNode)) {
244              variableNode = new VariableNetworkNode { Label = v };
245              vn.AddVertex(variableNode);
246              nodes[v] = variableNode;
247            }
248            var arc = new Arc(variableNode, targetNode) {
249              Weight = impact,
250              Label = impact.ToString("N3", CultureInfo.CurrentCulture)
251            };
252            vn.AddArc(arc);
253          }
254        }
255      }
256      return vn;
257    }
258
259    public static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) {
260      var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList();
261      if (!arcs.Any()) return originalNetwork;
262      var filteredNetwork = new VariableInteractionNetwork();
263      var cloner = new Cloner();
264      var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned
265      filteredNetwork.AddVertices(vertices);
266      filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner)));
267
268      var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList();
269      filteredNetwork.RemoveVertices(unusedJunctions);
270      var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList();
271      filteredNetwork.RemoveVertices(orphanedNodes);
272      return filteredNetwork.Vertices.Any() ? filteredNetwork : originalNetwork;
273    }
274
275    private static double CalculateAverageQuality(RunCollection runs) {
276      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
277      var target = pd.TargetVariable;
278      var inputs = pd.AllowedInputVariables;
279
280      if (!runs.All(x => {
281        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
282        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
283      })) {
284        throw new ArgumentException("All runs must have the same target and inputs.");
285      }
286      return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value);
287    }
288
289    public static Dictionary<string, double> CalculateAverageImpacts(RunCollection runs, string resultName) {
290      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
291      var target = pd.TargetVariable;
292      var inputs = pd.AllowedInputVariables.ToList();
293
294      var impacts = inputs.ToDictionary(x => x, x => 0d);
295
296      // check if all the runs have the same target and same inputs
297      if (!runs.All(x => {
298        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
299        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
300      })) {
301        throw new ArgumentException("All runs must have the same target and inputs.");
302      }
303
304      foreach (var run in runs) {
305        var impactsMatrix = (DoubleMatrix)run.Results[resultName];
306        int i = 0;
307        foreach (var v in impactsMatrix.RowNames) {
308          impacts[v] += impactsMatrix[i, 0];
309          ++i;
310        }
311      }
312
313      foreach (var v in inputs) {
314        impacts[v] /= runs.Count;
315      }
316
317      return impacts;
318    }
319
320    private static string Concatenate(IEnumerable<string> strings) {
321      var sb = new StringBuilder();
322      foreach (var s in strings) {
323        sb.Append(s);
324      }
325      return sb.ToString();
326    }
327
328    private void ConfigureNodeShapes() {
329      graphChart.ClearShapes();
330      var font = new Font(FontFamily.GenericSansSerif, 12);
331      graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font));
332      graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font));
333    }
334
335    public void UpdateNetwork(VariableInteractionNetwork network) {
336      if (InvokeRequired) {
337        Invoke((Action<VariableInteractionNetwork>)UpdateNetwork, network);
338      } else {
339        graphChart.Graph = network;
340      }     
341    }
342
343    #region events
344    protected override void OnContentChanged() {
345      base.OnContentChanged();
346      var run = Content.First();
347      var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
348      var variables = new HashSet<string>(new List<string>(pd.Dataset.DoubleVariables));
349      impactResultNameComboBox.Items.Clear();
350      foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) {
351        var m = (DoubleMatrix)result.Value;
352        if (m.RowNames.All(x => variables.Contains(x)))
353          impactResultNameComboBox.Items.Add(result.Key);
354      }
355      qualityResultNameComboBox.Items.Clear();
356      foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) {
357        qualityResultNameComboBox.Items.Add(result.Key);
358      }
359      if (impactResultNameComboBox.Items.Count > 0) {
360        impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0];
361      }
362      if (qualityResultNameComboBox.Items.Count > 0) {
363        qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0];
364      }
365      if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0)
366        NetworkConfigurationChanged(this, EventArgs.Empty);
367    }
368
369    private void TextBoxValidating(object sender, CancelEventArgs e) {
370      double v;
371      string errorMsg = "Could not parse the entered value. Please input a real number.";
372      var tb = (TextBox)sender;
373      if (!double.TryParse(tb.Text, out v)) {
374        e.Cancel = true;
375        tb.Select(0, tb.Text.Length);
376
377        // Set the ErrorProvider error with the text to display.
378        this.errorProvider.SetError(tb, errorMsg);
379        errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink;
380        errorProvider.SetIconPadding(tb, -20);
381      }
382    }
383
384    private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) {
385      var tb = (TextBox)sender;
386      errorProvider.SetError(tb, string.Empty);
387      double impact;
388      if (!double.TryParse(tb.Text, out impact)) {
389        impact = 0.2;
390      }
391      var network = ApplyThreshold(variableInteractionNetwork, impact);
392      graphChart.Graph = network;
393    }
394
395    private void LayoutConfigurationBoxValidated(object sender, EventArgs e) {
396      var tb = (TextBox)sender;
397      errorProvider.SetError(tb, string.Empty);
398      LayoutConfigurationChanged(sender, e);
399    }
400
401    private void NetworkConfigurationChanged(object sender, EventArgs e) {
402      var useBest = impactAggregationComboBox.SelectedIndex <= 0;
403      var threshold = impactThresholdTrackBar.Value / 100.0;
404      var qualityResultName = qualityResultNameComboBox.Text;
405      var impactsResultName = impactResultNameComboBox.Text;
406      if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName))
407        return;
408      var maximization = maximizationCheckBox.Checked;
409      var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest);
410      variableInteractionNetwork = CreateNetwork(impacts);
411      var network = ApplyThreshold(variableInteractionNetwork, threshold);
412      graphChart.Graph = network;
413    }
414
415    private void LayoutConfigurationChanged(object sender, EventArgs e) {
416      ConstrainedForceDirectedLayout.EdgeRouting routingMode;
417      switch (edgeRoutingComboBox.SelectedIndex) {
418        case 0:
419          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None;
420          break;
421        case 1:
422          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline;
423          break;
424        case 2:
425          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal;
426          break;
427        default:
428          throw new ArgumentException("Invalid edge routing mode.");
429      }
430      var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text);
431      if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.DefaultEdgeLength)) return;
432      graphChart.RoutingMode = routingMode;
433      graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None;
434      graphChart.DefaultEdgeLength = idealEdgeLength;
435      graphChart.Draw();
436    }
437
438    private void ControlsEnable(bool enabled) {
439      qualityResultNameComboBox.Enabled
440        = impactResultNameComboBox.Enabled
441        = impactAggregationComboBox.Enabled
442        = impactThresholdTrackBar.Enabled
443        = onlineImpactCalculationButton.Enabled
444        = edgeRoutingComboBox.Enabled
445        = idealEdgeLengthTextBox.Enabled
446        = maximizationCheckBox.Enabled = enabled;
447    }
448
449    private void onlineImpactCalculationButton_Click(object sender, EventArgs args) {
450      var worker = new BackgroundWorker();
451      worker.DoWork += (o, e) => {
452        ControlsEnable(false);
453        var impacts = CalculateVariableImpactsOnline(Content, impactAggregationComboBox.SelectedIndex == 0);
454        variableInteractionNetwork = CreateNetwork(impacts);
455        var threshold = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
456        graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold);
457      };
458      worker.RunWorkerCompleted += (o, e) => ControlsEnable(true);
459      worker.RunWorkerAsync();
460    }
461
462    private void relayoutGraphButton_Click(object sender, EventArgs e) {
463      graphChart.Draw();
464    }
465    #endregion
466
467    private void exportImpactsMatrixButton_Click(object sender, EventArgs e) {
468      var graph = graphChart.Graph;
469      var labels = graph.Vertices.OfType<VariableNetworkNode>().Select(x => x.Label).ToList();
470      labels.Sort(); // sort variables alphabetically
471      var matrix = new DoubleMatrix(labels.Count, labels.Count) { RowNames = labels, ColumnNames = labels };
472      var indexes = labels.Select((x, i) => new { Label = x, Index = i }).ToDictionary(x => x.Label, x => x.Index);
473      var junctions = graph.Vertices.OfType<JunctionNetworkNode>().ToList();
474      foreach (var jn in junctions) {
475        var target = jn.OutArcs.First().Target.Label;
476        var targetIndex = indexes[target];
477        foreach (var input in jn.InArcs) {
478          var inputIndex = indexes[input.Source.Label];
479          var inputImpact = input.Weight;
480          matrix[targetIndex, inputIndex] = inputImpact;
481        }
482      }
483      for (int i = 0; i < labels.Count; ++i) matrix[i, i] = 1;
484      MainFormManager.MainForm.ShowContent(matrix);
485    }
486
487    private void impactThresholdTrackBar_ValueChanged(object sender, EventArgs e) {
488      var impact = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
489      impactThresholdLabel.Text = impact.ToString("N3", CultureInfo.CurrentCulture);
490      var network = ApplyThreshold(variableInteractionNetwork, impact);
491      graphChart.Graph = network;
492    }
493
494
495  }
496}
Note: See TracBrowser for help on using the repository browser.