source: branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks.Views/3.3/RunCollectionVariableInteractionNetworkView.cs @ 13789

Last change on this file since 13789 was 13789, checked in by bburlacu, 5 years ago

#2288:

  • Refactor RunCollectionVariableInteractionNetworkView improving functionality, modularity and code organisation.
  • Small tweaks to the DirectedGraphChart and DirectedGraphChartMode
File size: 20.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.ComponentModel;
25using System.Drawing;
26using System.Linq;
27using System.Text;
28using System.Windows.Forms;
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Core.Views;
32using HeuristicLab.Data;
33using HeuristicLab.MainForm;
34using HeuristicLab.Optimization;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Visualization;
37using Ellipse = HeuristicLab.Visualization.Ellipse;
38using Rectangle = HeuristicLab.Visualization.Rectangle;
39
40namespace HeuristicLab.VariableInteractionNetworks.Views {
41  [View("Variable Interaction Network")]
42  [Content(typeof(RunCollection), IsDefaultView = false)]
43
44  public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView {
45    public RunCollectionVariableInteractionNetworkView() {
46      InitializeComponent();
47      ConfigureNodeShapes();
48    }
49
50    public new RunCollection Content {
51      get { return (RunCollection)base.Content; }
52      set {
53        if (value != null && value != Content) {
54          base.Content = value;
55        }
56      }
57    }
58
59    private VariableInteractionNetwork variableInteractionNetwork;
60
61    private static void AssertSameProblemData(RunCollection runs) {
62      IDataset dataset = null;
63      IRegressionProblemData problemData = null;
64      foreach (var run in runs) {
65        var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
66        var ds = solution.ProblemData.Dataset;
67
68        if (solution.ProblemData == problemData) continue;
69        if (ds == dataset) continue;
70        if (problemData == null) {
71          problemData = solution.ProblemData;
72          continue;
73        }
74        if (dataset == null) {
75          dataset = ds;
76          continue;
77        }
78
79        if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End)
80          throw new InvalidOperationException("The runs must share the same data.");
81
82        if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables))
83          throw new InvalidOperationException("The runs must share the same data.");
84
85        foreach (var v in ds.DoubleVariables) {
86          var values1 = (IList<double>)ds.GetReadOnlyDoubleValues(v);
87          var values2 = (IList<double>)dataset.GetReadOnlyDoubleValues(v);
88
89          if (values1.Count != values2.Count)
90            throw new InvalidOperationException("The runs must share the same data.");
91
92          if (!values1.SequenceEqual(values2))
93            throw new InvalidOperationException("The runs must share the same data.");
94        }
95      }
96    }
97
98    private static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable<IRun> runs) {
99      var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast<IRegressionSolution>();
100      return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData);
101    }
102
103    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) {
104      AssertSameProblemData(runs);
105      var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution);
106      var dataset = (Dataset)solution.ProblemData.Dataset;
107      var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList());
108      var md = dataset.ToModifiable();
109      var medians = new Dictionary<string, List<double>>();
110      foreach (var v in dataset.DoubleVariables) {
111        var median = dataset.GetDoubleValues(v, solution.ProblemData.TrainingIndices).Median();
112        medians[v] = Enumerable.Repeat(median, originalValues[v].Count).ToList();
113      }
114
115      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
116
117      if (useBest) {
118        // build network using only the best run for each target
119      } else {
120        var groups = runs.GroupBy(run => {
121          var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
122          return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable;
123        });
124
125        foreach (var group in groups) {
126          // calculate average impacts
127          var averageImpacts = new Dictionary<string, double>();
128          solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution);
129          foreach (var run in group) {
130            var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution);
131
132            DoubleLimit estimationLimits = null;
133            if (run.Parameters.ContainsKey("EstimationLimits")) {
134              estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"];
135            }
136            var impacts = CalculateImpacts(sol, md, originalValues, medians, estimationLimits);
137            //            var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol).ToDictionary(x => x.Item1, x => x.Item2);
138            foreach (var pair in impacts) {
139              if (averageImpacts.ContainsKey(pair.Key))
140                averageImpacts[pair.Key] += pair.Value;
141              else {
142                averageImpacts[pair.Key] = pair.Value;
143              }
144            }
145          }
146          var count = group.Count();
147          var keys = averageImpacts.Keys.ToList();
148          foreach (var v in keys) {
149            averageImpacts[v] /= count;
150          }
151
152          targetImpacts[solution.ProblemData.TargetVariable] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(group, averageImpacts);
153        }
154      }
155      return targetImpacts;
156    }
157
158    private static Dictionary<string, double> CalculateImpacts(IRegressionSolution solution, ModifiableDataset dataset,
159      Dictionary<string, List<double>> originalValues, Dictionary<string, List<double>> medianValues, DoubleLimit estimationLimits = null) {
160      var impacts = new Dictionary<string, double>();
161
162      var model = solution.Model;
163      var pd = solution.ProblemData;
164
165      var rows = pd.TrainingIndices.ToList();
166      var targetValues = pd.Dataset.GetDoubleValues(pd.TargetVariable, rows).ToList();
167
168
169      foreach (var v in pd.AllowedInputVariables) {
170        dataset.ReplaceVariable(v, medianValues[v]);
171
172        var estimatedValues = model.GetEstimatedValues(dataset, rows);
173        if (estimationLimits != null)
174          estimatedValues = estimatedValues.LimitToRange(estimationLimits.Lower, estimationLimits.Upper);
175
176        OnlineCalculatorError error;
177        var r = OnlinePearsonsRCalculator.Calculate(targetValues, estimatedValues, out error);
178        var newQuality = error == OnlineCalculatorError.None ? r * r : double.NaN;
179        var originalQuality = solution.TrainingRSquared;
180        impacts[v] = originalQuality - newQuality;
181
182        dataset.ReplaceVariable(v, originalValues[v]);
183      }
184      return impacts;
185    }
186
187    private static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsFromRunResults(RunCollection runs,
188      string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) {
189      var targets = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
190      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
191      if (useBestRunsPerTarget) {
192        var bestRunsPerTarget = maximization
193          ? targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).Last())
194          : targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).First());
195
196        foreach (var run in bestRunsPerTarget) {
197          var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
198          var target = pd.TargetVariable;
199          var impacts = (DoubleMatrix)run.Results[impactsResultName];
200          targetImpacts[target] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(new[] { run }, impacts.RowNames.Select((x, i) => new { Name = x, Index = i }).ToDictionary(x => x.Name, x => impacts[x.Index, 0]));
201        }
202      } else {
203        foreach (var target in targets) {
204          var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName);
205          targetImpacts[target.Key] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(target, averageImpacts);
206        }
207      }
208      return targetImpacts;
209    }
210
211    private static VariableInteractionNetwork CreateNetwork(Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> targetImpacts) {
212      var nodes = new Dictionary<string, IVertex>();
213      var vn = new VariableInteractionNetwork();
214      foreach (var ti in targetImpacts) {
215        var target = ti.Key;
216        var variableImpacts = ti.Value.Item2;
217        var targetRuns = ti.Value.Item1;
218        IVertex targetNode;
219
220        var variables = variableImpacts.Keys.ToList();
221        if (variables.Count == 0) continue;
222
223        if (!nodes.TryGetValue(target, out targetNode)) {
224          targetNode = new VariableNetworkNode { Label = target };
225          vn.AddVertex(targetNode);
226          nodes[target] = targetNode;
227        }
228
229        IVertex variableNode;
230        if (variables.Count > 1) {
231          var variableList = new List<string>(variables) { target };
232          var junctionLabel = Concatenate(variableList);
233          IVertex junctionNode;
234          if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
235            var solutionsEnsemble = CreateEnsembleSolution(targetRuns);
236            junctionNode = new JunctionNetworkNode { Label = string.Empty, Data = solutionsEnsemble };
237            vn.AddVertex(junctionNode);
238            nodes[junctionLabel] = junctionNode;
239            junctionNode.Label = string.Format("Target quality: {0:0.000}", solutionsEnsemble.TrainingRSquared);
240          }
241          IArc arc;
242          foreach (var v in variables) {
243            var impact = variableImpacts[v];
244            if (!nodes.TryGetValue(v, out variableNode)) {
245              variableNode = new VariableNetworkNode { Label = v };
246              vn.AddVertex(variableNode);
247              nodes[v] = variableNode;
248            }
249            arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = string.Format("Impact: {0:0.000}", impact) };
250            vn.AddArc(arc);
251          }
252          var trainingR2 = ((IRegressionSolution)((JunctionNetworkNode)junctionNode).Data).TrainingRSquared;
253          arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = string.Format("Quality: {0:0.000}", trainingR2) };
254          vn.AddArc(arc);
255        } else {
256          foreach (var v in variables) {
257            var impact = variableImpacts[v];
258            if (!nodes.TryGetValue(v, out variableNode)) {
259              variableNode = new VariableNetworkNode { Label = v };
260              vn.AddVertex(variableNode);
261              nodes[v] = variableNode;
262            }
263            var arc = new Arc(variableNode, targetNode) { Weight = impact, Label = string.Format("Impact: {0:0.000}", impact) };
264            vn.AddArc(arc);
265          }
266        }
267      }
268      return vn;
269    }
270
271    private static double CalculateAverageQuality(RunCollection runs) {
272      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
273      var target = pd.TargetVariable;
274      var inputs = pd.AllowedInputVariables;
275
276      if (!runs.All(x => {
277        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
278        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
279      })) {
280        throw new ArgumentException("All runs must have the same target and inputs.");
281      }
282      return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value);
283    }
284
285    private static Dictionary<string, double> CalculateAverageImpacts(RunCollection runs, string resultName) {
286      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
287      var target = pd.TargetVariable;
288      var inputs = pd.AllowedInputVariables.ToList();
289
290      var impacts = inputs.ToDictionary(x => x, x => 0d);
291
292      // check if all the runs have the same target and same inputs
293      if (!runs.All(x => {
294        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
295        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
296      })) {
297        throw new ArgumentException("All runs must have the same target and inputs.");
298      }
299
300      foreach (var run in runs) {
301        var impactsMatrix = (DoubleMatrix)run.Results[resultName];
302
303        int i = 0;
304        foreach (var v in impactsMatrix.RowNames) {
305          impacts[v] += impactsMatrix[i, 0];
306          ++i;
307        }
308      }
309
310      foreach (var v in inputs) {
311        impacts[v] /= runs.Count;
312      }
313
314      return impacts;
315    }
316
317    private static string Concatenate(IEnumerable<string> strings) {
318      var sb = new StringBuilder();
319      foreach (var s in strings) {
320        sb.Append(s);
321      }
322      return sb.ToString();
323    }
324
325    private void ConfigureNodeShapes() {
326      graphChart.ClearShapes();
327      var font = new Font(FontFamily.GenericSansSerif, 12);
328      graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font));
329      graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font));
330    }
331
332    #region events
333    protected override void OnContentChanged() {
334      base.OnContentChanged();
335      var run = Content.First();
336      var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
337      var variables = new HashSet<string>(new List<string>(pd.Dataset.DoubleVariables));
338      impactResultNameComboBox.Items.Clear();
339      foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) {
340        var m = (DoubleMatrix)result.Value;
341        if (m.RowNames.All(x => variables.Contains(x)))
342          impactResultNameComboBox.Items.Add(result.Key);
343      }
344      qualityResultNameComboBox.Items.Clear();
345      foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) {
346        qualityResultNameComboBox.Items.Add(result.Key);
347      }
348      if (impactResultNameComboBox.Items.Count > 0) {
349        impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0];
350      }
351      if (qualityResultNameComboBox.Items.Count > 0) {
352        qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0];
353      }
354      if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0)
355        NetworkConfigurationChanged(this, EventArgs.Empty);
356    }
357
358    private void TextBoxValidating(object sender, CancelEventArgs e) {
359      double v;
360      string errorMsg = "Could not parse the entered value. Please input a real number.";
361      var tb = (TextBox)sender;
362      if (!double.TryParse(tb.Text, out v)) {
363        e.Cancel = true;
364        tb.Select(0, tb.Text.Length);
365
366        // Set the ErrorProvider error with the text to display.
367        this.errorProvider.SetError(tb, errorMsg);
368        errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink;
369        errorProvider.SetIconPadding(tb, -20);
370      }
371    }
372
373    private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) {
374      var tb = (TextBox)sender;
375      errorProvider.SetError(tb, string.Empty);
376      var network = ApplyThreshold(variableInteractionNetwork, double.Parse(tb.Text));
377      graphChart.Graph = network;
378    }
379
380    private static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) {
381      var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList();
382      if (!arcs.Any()) return originalNetwork;
383      var filteredNetwork = new VariableInteractionNetwork();
384      var cloner = new Cloner();
385      var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned
386      filteredNetwork.AddVertices(vertices);
387      filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner)));
388
389      var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList();
390      filteredNetwork.RemoveVertices(unusedJunctions);
391      var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList();
392      filteredNetwork.RemoveVertices(orphanedNodes);
393      return filteredNetwork;
394    }
395
396    private void LayoutConfigurationBoxValidated(object sender, EventArgs e) {
397      var tb = (TextBox)sender;
398      errorProvider.SetError(tb, string.Empty);
399      LayoutConfigurationChanged(sender, e);
400    }
401
402    private void NetworkConfigurationChanged(object sender, EventArgs e) {
403      var useBest = impactAggregationComboBox.SelectedIndex <= 0;
404      var threshold = double.Parse(impactThresholdTextBox.Text);
405      var qualityResultName = qualityResultNameComboBox.Text;
406      var impactsResultName = impactResultNameComboBox.Text;
407      if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName))
408        return;
409      var maximization = maximizationCheckBox.Checked;
410      var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest);
411      variableInteractionNetwork = CreateNetwork(impacts);
412      var network = ApplyThreshold(variableInteractionNetwork, threshold);
413      graphChart.Graph = network;
414    }
415
416    private void LayoutConfigurationChanged(object sender, EventArgs e) {
417      ConstrainedForceDirectedLayout.EdgeRouting routingMode;
418      switch (edgeRoutingComboBox.SelectedIndex) {
419        case 0:
420          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None;
421          break;
422        case 1:
423          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline;
424          break;
425        case 2:
426          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal;
427          break;
428        default:
429          throw new ArgumentException("Invalid edge routing mode.");
430      }
431      var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text);
432      if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.IdealEdgeLength)) return;
433      graphChart.RoutingMode = routingMode;
434      graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None;
435      graphChart.IdealEdgeLength = idealEdgeLength;
436      graphChart.Draw();
437    }
438
439    private void onlineImpactCalculationButton_Click(object sender, EventArgs args) {
440      var button = (Button)sender;
441      var worker = new BackgroundWorker();
442      worker.DoWork += (o, e) => {
443        button.Enabled = false;
444        var impacts = CalculateVariableImpactsOnline(Content, false);
445        variableInteractionNetwork = CreateNetwork(impacts);
446        var threshold = double.Parse(impactThresholdTextBox.Text);
447        graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold);
448      };
449      worker.RunWorkerCompleted += (o, e) => button.Enabled = true;
450      worker.RunWorkerAsync();
451    }
452    #endregion
453  }
454}
Note: See TracBrowser for help on using the repository browser.