source: branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks.Views/3.3/RunCollectionVariableInteractionNetworkView.cs @ 14275

Last change on this file since 14275 was 14275, checked in by bburlacu, 5 years ago

#2288: Clean up code and add comments in the ConstrainedForceDirectedLayout class. Minor changes to view and directed graph chart. Introduced an INetworkNode interface for more flexibility. Updated cola and adaptagrams dlls with latest changes from upstream.

File size: 22.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.ComponentModel;
25using System.Drawing;
26using System.Globalization;
27using System.Linq;
28using System.Text;
29using System.Windows.Forms;
30using HeuristicLab.Common;
31using HeuristicLab.Core;
32using HeuristicLab.Core.Views;
33using HeuristicLab.Data;
34using HeuristicLab.MainForm;
35using HeuristicLab.Optimization;
36using HeuristicLab.Problems.DataAnalysis;
37using HeuristicLab.Visualization;
38using Ellipse = HeuristicLab.Visualization.Ellipse;
39using Rectangle = HeuristicLab.Visualization.Rectangle;
40
41namespace HeuristicLab.VariableInteractionNetworks.Views {
42  [View("Variable Interaction Network")]
43  [Content(typeof(RunCollection), IsDefaultView = false)]
44
45  public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView {
46    public RunCollectionVariableInteractionNetworkView() {
47      InitializeComponent();
48      ConfigureNodeShapes();
49    }
50
51    public new RunCollection Content {
52      get { return (RunCollection)base.Content; }
53      set {
54        if (value != null && value != Content) {
55          base.Content = value;
56        }
57      }
58    }
59
60    private VariableInteractionNetwork variableInteractionNetwork;
61
62    private static void AssertSameProblemData(RunCollection runs) {
63      IDataset dataset = null;
64      IRegressionProblemData problemData = null;
65      foreach (var run in runs) {
66        var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
67        var ds = solution.ProblemData.Dataset;
68
69        if (solution.ProblemData == problemData) continue;
70        if (ds == dataset) continue;
71        if (problemData == null) {
72          problemData = solution.ProblemData;
73          continue;
74        }
75        if (dataset == null) {
76          dataset = ds;
77          continue;
78        }
79
80        if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End)
81          throw new InvalidOperationException("The runs must share the same data.");
82
83        if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables))
84          throw new InvalidOperationException("The runs must share the same data.");
85
86        foreach (var v in ds.DoubleVariables) {
87          var values1 = (IList<double>)ds.GetReadOnlyDoubleValues(v);
88          var values2 = (IList<double>)dataset.GetReadOnlyDoubleValues(v);
89
90          if (values1.Count != values2.Count)
91            throw new InvalidOperationException("The runs must share the same data.");
92
93          if (!values1.SequenceEqual(values2))
94            throw new InvalidOperationException("The runs must share the same data.");
95        }
96      }
97    }
98
99    public static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable<IRun> runs) {
100      var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast<IRegressionSolution>();
101      return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData);
102    }
103
104    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) {
105      AssertSameProblemData(runs);
106      var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution);
107      var dataset = (Dataset)solution.ProblemData.Dataset;
108      var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList());
109      var md = dataset.ToModifiable();
110      var medians = new Dictionary<string, List<double>>();
111      foreach (var v in dataset.DoubleVariables) {
112        var median = dataset.GetDoubleValues(v, solution.ProblemData.TrainingIndices).Median();
113        medians[v] = Enumerable.Repeat(median, originalValues[v].Count).ToList();
114      }
115
116      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
117
118      if (useBest) {
119        // build network using only the best run for each target
120      } else {
121        var groups = runs.GroupBy(run => {
122          var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
123          return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable;
124        });
125
126        foreach (var group in groups) {
127          // calculate average impacts
128          var averageImpacts = new Dictionary<string, double>();
129          solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution);
130          foreach (var run in group) {
131            var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution);
132
133            DoubleLimit estimationLimits = null;
134            if (run.Parameters.ContainsKey("EstimationLimits")) {
135              estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"];
136            }
137            var impacts = CalculateImpacts(sol, md, originalValues, medians, estimationLimits);
138            //            var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol).ToDictionary(x => x.Item1, x => x.Item2);
139            foreach (var pair in impacts) {
140              if (averageImpacts.ContainsKey(pair.Key))
141                averageImpacts[pair.Key] += pair.Value;
142              else {
143                averageImpacts[pair.Key] = pair.Value;
144              }
145            }
146          }
147          var count = group.Count();
148          var keys = averageImpacts.Keys.ToList();
149          foreach (var v in keys) {
150            averageImpacts[v] /= count;
151          }
152
153          targetImpacts[solution.ProblemData.TargetVariable] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(group, averageImpacts);
154        }
155      }
156      return targetImpacts;
157    }
158
159    private static Dictionary<string, double> CalculateImpacts(IRegressionSolution solution, ModifiableDataset dataset,
160      Dictionary<string, List<double>> originalValues, Dictionary<string, List<double>> medianValues, DoubleLimit estimationLimits = null) {
161      var impacts = new Dictionary<string, double>();
162
163      var model = solution.Model;
164      var pd = solution.ProblemData;
165
166      var rows = pd.TrainingIndices.ToList();
167      var targetValues = pd.Dataset.GetDoubleValues(pd.TargetVariable, rows).ToList();
168
169
170      foreach (var v in pd.AllowedInputVariables) {
171        dataset.ReplaceVariable(v, medianValues[v]);
172
173        var estimatedValues = model.GetEstimatedValues(dataset, rows);
174        if (estimationLimits != null)
175          estimatedValues = estimatedValues.LimitToRange(estimationLimits.Lower, estimationLimits.Upper);
176
177        OnlineCalculatorError error;
178        var r = OnlinePearsonsRCalculator.Calculate(targetValues, estimatedValues, out error);
179        var newQuality = error == OnlineCalculatorError.None ? r * r : double.NaN;
180        var originalQuality = solution.TrainingRSquared;
181        impacts[v] = originalQuality - newQuality;
182
183        dataset.ReplaceVariable(v, originalValues[v]);
184      }
185      return impacts;
186    }
187
188    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsFromRunResults(RunCollection runs,
189      string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) {
190      var targets = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
191      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
192      if (useBestRunsPerTarget) {
193        var bestRunsPerTarget = maximization
194          ? targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).Last())
195          : targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).First());
196
197        foreach (var run in bestRunsPerTarget) {
198          var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
199          var target = pd.TargetVariable;
200          var impacts = (DoubleMatrix)run.Results[impactsResultName];
201          targetImpacts[target] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(new[] { run }, impacts.RowNames.Select((x, i) => new { Name = x, Index = i }).ToDictionary(x => x.Name, x => impacts[x.Index, 0]));
202        }
203      } else {
204        foreach (var target in targets) {
205          var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName);
206          targetImpacts[target.Key] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(target, averageImpacts);
207        }
208      }
209      return targetImpacts;
210    }
211
212    public static VariableInteractionNetwork CreateNetwork(Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> targetImpacts) {
213      var nodes = new Dictionary<string, IVertex>();
214      var vn = new VariableInteractionNetwork();
215      foreach (var ti in targetImpacts) {
216        var target = ti.Key;
217        var variableImpacts = ti.Value.Item2;
218        var targetRuns = ti.Value.Item1;
219        IVertex targetNode;
220
221        var variables = variableImpacts.Keys.ToList();
222        if (variables.Count == 0) continue;
223
224        if (!nodes.TryGetValue(target, out targetNode)) {
225          targetNode = new VariableNetworkNode { Label = target };
226          vn.AddVertex(targetNode);
227          nodes[target] = targetNode;
228        }
229
230        IVertex variableNode;
231        if (variables.Count > 1) {
232          var variableList = new List<string>(variables) { target };
233          var junctionLabel = Concatenate(variableList);
234          IVertex junctionNode;
235          var sb = new StringBuilder();
236          if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
237            var solutionsEnsemble = CreateEnsembleSolution(targetRuns);
238            junctionNode = new JunctionNetworkNode { Label = solutionsEnsemble.TrainingRSquared.ToString("N3", CultureInfo.CurrentCulture), Data = solutionsEnsemble };
239            vn.AddVertex(junctionNode);
240            nodes[junctionLabel] = junctionNode;
241            sb.AppendLine(junctionNode.Label);
242          }
243          IArc arc;
244          foreach (var v in variables) {
245            var impact = variableImpacts[v];
246            if (!nodes.TryGetValue(v, out variableNode)) {
247              variableNode = new VariableNetworkNode { Label = v };
248              vn.AddVertex(variableNode);
249              nodes[v] = variableNode;
250            }
251            arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture) };
252            sb.AppendLine(v + ": " + arc.Label);
253            vn.AddArc(arc);
254          }
255          var jcnNode = (JunctionNetworkNode)junctionNode;
256          var trainingR2 = ((IRegressionSolution)jcnNode.Data).TrainingRSquared;
257          arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = trainingR2.ToString("N3", CultureInfo.CurrentCulture) };
258          vn.AddArc(arc);
259        } else {
260          foreach (var v in variables) {
261            var impact = variableImpacts[v];
262            if (!nodes.TryGetValue(v, out variableNode)) {
263              variableNode = new VariableNetworkNode { Label = v };
264              vn.AddVertex(variableNode);
265              nodes[v] = variableNode;
266            }
267            var arc = new Arc(variableNode, targetNode) {
268              Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture)
269            };
270            vn.AddArc(arc);
271          }
272        }
273      }
274      return vn;
275    }
276
277    public static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) {
278      var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList();
279      if (!arcs.Any()) return originalNetwork;
280      var filteredNetwork = new VariableInteractionNetwork();
281      var cloner = new Cloner();
282      var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned
283      filteredNetwork.AddVertices(vertices);
284      filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner)));
285
286      var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList();
287      filteredNetwork.RemoveVertices(unusedJunctions);
288      var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList();
289      filteredNetwork.RemoveVertices(orphanedNodes);
290      return filteredNetwork.Vertices.Any() ? filteredNetwork : originalNetwork;
291    }
292
293    private static double CalculateAverageQuality(RunCollection runs) {
294      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
295      var target = pd.TargetVariable;
296      var inputs = pd.AllowedInputVariables;
297
298      if (!runs.All(x => {
299        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
300        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
301      })) {
302        throw new ArgumentException("All runs must have the same target and inputs.");
303      }
304      return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value);
305    }
306
307    public static Dictionary<string, double> CalculateAverageImpacts(RunCollection runs, string resultName) {
308      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
309      var target = pd.TargetVariable;
310      var inputs = pd.AllowedInputVariables.ToList();
311
312      var impacts = inputs.ToDictionary(x => x, x => 0d);
313
314      // check if all the runs have the same target and same inputs
315      if (!runs.All(x => {
316        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
317        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
318      })) {
319        throw new ArgumentException("All runs must have the same target and inputs.");
320      }
321
322      foreach (var run in runs) {
323        var impactsMatrix = (DoubleMatrix)run.Results[resultName];
324        int i = 0;
325        foreach (var v in impactsMatrix.RowNames) {
326          impacts[v] += impactsMatrix[i, 0];
327          ++i;
328        }
329      }
330
331      foreach (var v in inputs) {
332        impacts[v] /= runs.Count;
333      }
334
335      return impacts;
336    }
337
338    private static string Concatenate(IEnumerable<string> strings) {
339      var sb = new StringBuilder();
340      foreach (var s in strings) {
341        sb.Append(s);
342      }
343      return sb.ToString();
344    }
345
346    private void ConfigureNodeShapes() {
347      graphChart.ClearShapes();
348      var font = new Font(FontFamily.GenericSansSerif, 12);
349      graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font));
350      graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font));
351    }
352
353    #region events
354    protected override void OnContentChanged() {
355      base.OnContentChanged();
356      var run = Content.First();
357      var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
358      var variables = new HashSet<string>(new List<string>(pd.Dataset.DoubleVariables));
359      impactResultNameComboBox.Items.Clear();
360      foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) {
361        var m = (DoubleMatrix)result.Value;
362        if (m.RowNames.All(x => variables.Contains(x)))
363          impactResultNameComboBox.Items.Add(result.Key);
364      }
365      qualityResultNameComboBox.Items.Clear();
366      foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) {
367        qualityResultNameComboBox.Items.Add(result.Key);
368      }
369      if (impactResultNameComboBox.Items.Count > 0) {
370        impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0];
371      }
372      if (qualityResultNameComboBox.Items.Count > 0) {
373        qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0];
374      }
375      if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0)
376        NetworkConfigurationChanged(this, EventArgs.Empty);
377    }
378
379    private void TextBoxValidating(object sender, CancelEventArgs e) {
380      double v;
381      string errorMsg = "Could not parse the entered value. Please input a real number.";
382      var tb = (TextBox)sender;
383      if (!double.TryParse(tb.Text, out v)) {
384        e.Cancel = true;
385        tb.Select(0, tb.Text.Length);
386
387        // Set the ErrorProvider error with the text to display.
388        this.errorProvider.SetError(tb, errorMsg);
389        errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink;
390        errorProvider.SetIconPadding(tb, -20);
391      }
392    }
393
394    private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) {
395      var tb = (TextBox)sender;
396      errorProvider.SetError(tb, string.Empty);
397      double impact;
398      if (!double.TryParse(tb.Text, out impact)) {
399        impact = 0.2;
400      }
401      var network = ApplyThreshold(variableInteractionNetwork, impact);
402      graphChart.Graph = network;
403    }
404
405    private void LayoutConfigurationBoxValidated(object sender, EventArgs e) {
406      var tb = (TextBox)sender;
407      errorProvider.SetError(tb, string.Empty);
408      LayoutConfigurationChanged(sender, e);
409    }
410
411    private void NetworkConfigurationChanged(object sender, EventArgs e) {
412      var useBest = impactAggregationComboBox.SelectedIndex <= 0;
413      var threshold = impactThresholdTrackBar.Value / 100.0;
414      var qualityResultName = qualityResultNameComboBox.Text;
415      var impactsResultName = impactResultNameComboBox.Text;
416      if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName))
417        return;
418      var maximization = maximizationCheckBox.Checked;
419      var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest);
420      variableInteractionNetwork = CreateNetwork(impacts);
421      var network = ApplyThreshold(variableInteractionNetwork, threshold);
422      graphChart.Graph = network;
423    }
424
425    private void LayoutConfigurationChanged(object sender, EventArgs e) {
426      ConstrainedForceDirectedLayout.EdgeRouting routingMode;
427      switch (edgeRoutingComboBox.SelectedIndex) {
428        case 0:
429          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None;
430          break;
431        case 1:
432          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline;
433          break;
434        case 2:
435          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal;
436          break;
437        default:
438          throw new ArgumentException("Invalid edge routing mode.");
439      }
440      var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text);
441      if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.DefaultEdgeLength)) return;
442      graphChart.RoutingMode = routingMode;
443      graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None;
444      graphChart.DefaultEdgeLength = idealEdgeLength;
445      graphChart.Draw();
446    }
447
448    private void ControlsEnable(bool enabled) {
449      qualityResultNameComboBox.Enabled
450        = impactResultNameComboBox.Enabled
451        = impactAggregationComboBox.Enabled
452        = impactThresholdTrackBar.Enabled
453        = onlineImpactCalculationButton.Enabled
454        = edgeRoutingComboBox.Enabled
455        = idealEdgeLengthTextBox.Enabled
456        = maximizationCheckBox.Enabled = enabled;
457    }
458
459    private void onlineImpactCalculationButton_Click(object sender, EventArgs args) {
460      var worker = new BackgroundWorker();
461      worker.DoWork += (o, e) => {
462        ControlsEnable(false);
463        var impacts = CalculateVariableImpactsOnline(Content, false);
464        variableInteractionNetwork = CreateNetwork(impacts);
465        var threshold = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
466        graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold);
467      };
468      worker.RunWorkerCompleted += (o, e) => ControlsEnable(true);
469      worker.RunWorkerAsync();
470    }
471
472    private void relayoutGraphButton_Click(object sender, EventArgs e) {
473      graphChart.Draw();
474    }
475    #endregion
476
477    private void exportImpactsMatrixButton_Click(object sender, EventArgs e) {
478      var graph = graphChart.Graph;
479      var labels = graph.Vertices.OfType<VariableNetworkNode>().Select(x => x.Label).ToList();
480      labels.Sort(); // sort variables alphabetically
481      var matrix = new DoubleMatrix(labels.Count, labels.Count) { RowNames = labels, ColumnNames = labels };
482      var indexes = labels.Select((x, i) => new { Label = x, Index = i }).ToDictionary(x => x.Label, x => x.Index);
483      var junctions = graph.Vertices.OfType<JunctionNetworkNode>().ToList();
484      foreach (var jn in junctions) {
485        var target = jn.OutArcs.First().Target.Label;
486        var targetIndex = indexes[target];
487        foreach (var input in jn.InArcs) {
488          var inputIndex = indexes[input.Source.Label];
489          var inputImpact = input.Weight;
490          matrix[targetIndex, inputIndex] = inputImpact;
491        }
492      }
493      for (int i = 0; i < labels.Count; ++i) matrix[i, i] = 1;
494      MainFormManager.MainForm.ShowContent(matrix);
495    }
496
497    private void impactThresholdTrackBar_ValueChanged(object sender, EventArgs e) {
498      var impact = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
499      impactThresholdLabel.Text = impact.ToString("N3", CultureInfo.CurrentCulture);
500      var network = ApplyThreshold(variableInteractionNetwork, impact);
501      graphChart.Graph = network;
502    }
503
504
505  }
506}
Note: See TracBrowser for help on using the repository browser.