Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks.Views/3.3/RunCollectionVariableInteractionNetworkView.cs @ 13893

Last change on this file since 13893 was 13893, checked in by bburlacu, 8 years ago

#2288: Simplify and optimize code for cluster identification in ConstrainedForceDirectedLayout.cs. Introduce a TrackBar for adjusting network threshold in the RunCollectionVariableInteractionNetworkView. Minor improvements to the DirectedGraphChart (work in progress).

File size: 22.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.ComponentModel;
25using System.Drawing;
26using System.Globalization;
27using System.Linq;
28using System.Text;
29using System.Windows.Forms;
30using HeuristicLab.Common;
31using HeuristicLab.Core;
32using HeuristicLab.Core.Views;
33using HeuristicLab.Data;
34using HeuristicLab.MainForm;
35using HeuristicLab.Optimization;
36using HeuristicLab.Problems.DataAnalysis;
37using HeuristicLab.Visualization;
38using Ellipse = HeuristicLab.Visualization.Ellipse;
39using Rectangle = HeuristicLab.Visualization.Rectangle;
40
41namespace HeuristicLab.VariableInteractionNetworks.Views {
42  [View("Variable Interaction Network")]
43  [Content(typeof(RunCollection), IsDefaultView = false)]
44
45  public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView {
46    public RunCollectionVariableInteractionNetworkView() {
47      InitializeComponent();
48      ConfigureNodeShapes();
49    }
50
51    public new RunCollection Content {
52      get { return (RunCollection)base.Content; }
53      set {
54        if (value != null && value != Content) {
55          base.Content = value;
56        }
57      }
58    }
59
60    private VariableInteractionNetwork variableInteractionNetwork;
61
62    private static void AssertSameProblemData(RunCollection runs) {
63      IDataset dataset = null;
64      IRegressionProblemData problemData = null;
65      foreach (var run in runs) {
66        var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
67        var ds = solution.ProblemData.Dataset;
68
69        if (solution.ProblemData == problemData) continue;
70        if (ds == dataset) continue;
71        if (problemData == null) {
72          problemData = solution.ProblemData;
73          continue;
74        }
75        if (dataset == null) {
76          dataset = ds;
77          continue;
78        }
79
80        if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End)
81          throw new InvalidOperationException("The runs must share the same data.");
82
83        if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables))
84          throw new InvalidOperationException("The runs must share the same data.");
85
86        foreach (var v in ds.DoubleVariables) {
87          var values1 = (IList<double>)ds.GetReadOnlyDoubleValues(v);
88          var values2 = (IList<double>)dataset.GetReadOnlyDoubleValues(v);
89
90          if (values1.Count != values2.Count)
91            throw new InvalidOperationException("The runs must share the same data.");
92
93          if (!values1.SequenceEqual(values2))
94            throw new InvalidOperationException("The runs must share the same data.");
95        }
96      }
97    }
98
99    private static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable<IRun> runs) {
100      var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast<IRegressionSolution>();
101      return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData);
102    }
103
104    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) {
105      AssertSameProblemData(runs);
106      var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution);
107      var dataset = (Dataset)solution.ProblemData.Dataset;
108      var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList());
109      var md = dataset.ToModifiable();
110      var medians = new Dictionary<string, List<double>>();
111      foreach (var v in dataset.DoubleVariables) {
112        var median = dataset.GetDoubleValues(v, solution.ProblemData.TrainingIndices).Median();
113        medians[v] = Enumerable.Repeat(median, originalValues[v].Count).ToList();
114      }
115
116      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
117
118      if (useBest) {
119        // build network using only the best run for each target
120      } else {
121        var groups = runs.GroupBy(run => {
122          var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
123          return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable;
124        });
125
126        foreach (var group in groups) {
127          // calculate average impacts
128          var averageImpacts = new Dictionary<string, double>();
129          solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution);
130          foreach (var run in group) {
131            var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution);
132
133            DoubleLimit estimationLimits = null;
134            if (run.Parameters.ContainsKey("EstimationLimits")) {
135              estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"];
136            }
137            var impacts = CalculateImpacts(sol, md, originalValues, medians, estimationLimits);
138            //            var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol).ToDictionary(x => x.Item1, x => x.Item2);
139            foreach (var pair in impacts) {
140              if (averageImpacts.ContainsKey(pair.Key))
141                averageImpacts[pair.Key] += pair.Value;
142              else {
143                averageImpacts[pair.Key] = pair.Value;
144              }
145            }
146          }
147          var count = group.Count();
148          var keys = averageImpacts.Keys.ToList();
149          foreach (var v in keys) {
150            averageImpacts[v] /= count;
151          }
152
153          targetImpacts[solution.ProblemData.TargetVariable] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(group, averageImpacts);
154        }
155      }
156      return targetImpacts;
157    }
158
159    private static Dictionary<string, double> CalculateImpacts(IRegressionSolution solution, ModifiableDataset dataset,
160      Dictionary<string, List<double>> originalValues, Dictionary<string, List<double>> medianValues, DoubleLimit estimationLimits = null) {
161      var impacts = new Dictionary<string, double>();
162
163      var model = solution.Model;
164      var pd = solution.ProblemData;
165
166      var rows = pd.TrainingIndices.ToList();
167      var targetValues = pd.Dataset.GetDoubleValues(pd.TargetVariable, rows).ToList();
168
169
170      foreach (var v in pd.AllowedInputVariables) {
171        dataset.ReplaceVariable(v, medianValues[v]);
172
173        var estimatedValues = model.GetEstimatedValues(dataset, rows);
174        if (estimationLimits != null)
175          estimatedValues = estimatedValues.LimitToRange(estimationLimits.Lower, estimationLimits.Upper);
176
177        OnlineCalculatorError error;
178        var r = OnlinePearsonsRCalculator.Calculate(targetValues, estimatedValues, out error);
179        var newQuality = error == OnlineCalculatorError.None ? r * r : double.NaN;
180        var originalQuality = solution.TrainingRSquared;
181        impacts[v] = originalQuality - newQuality;
182
183        dataset.ReplaceVariable(v, originalValues[v]);
184      }
185      return impacts;
186    }
187
188    private static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsFromRunResults(RunCollection runs,
189      string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) {
190      var targets = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
191      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
192      if (useBestRunsPerTarget) {
193        var bestRunsPerTarget = maximization
194          ? targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).Last())
195          : targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).First());
196
197        foreach (var run in bestRunsPerTarget) {
198          var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
199          var target = pd.TargetVariable;
200          var impacts = (DoubleMatrix)run.Results[impactsResultName];
201          targetImpacts[target] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(new[] { run }, impacts.RowNames.Select((x, i) => new { Name = x, Index = i }).ToDictionary(x => x.Name, x => impacts[x.Index, 0]));
202        }
203      } else {
204        foreach (var target in targets) {
205          var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName);
206          targetImpacts[target.Key] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(target, averageImpacts);
207        }
208      }
209      return targetImpacts;
210    }
211
212    private static VariableInteractionNetwork CreateNetwork(Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> targetImpacts) {
213      var nodes = new Dictionary<string, IVertex>();
214      var vn = new VariableInteractionNetwork();
215      foreach (var ti in targetImpacts) {
216        var target = ti.Key;
217        var variableImpacts = ti.Value.Item2;
218        var targetRuns = ti.Value.Item1;
219        IVertex targetNode;
220
221        var variables = variableImpacts.Keys.ToList();
222        if (variables.Count == 0) continue;
223
224        if (!nodes.TryGetValue(target, out targetNode)) {
225          targetNode = new VariableNetworkNode { Label = target };
226          vn.AddVertex(targetNode);
227          nodes[target] = targetNode;
228        }
229
230        IVertex variableNode;
231        if (variables.Count > 1) {
232          var variableList = new List<string>(variables) { target };
233          var junctionLabel = Concatenate(variableList);
234          IVertex junctionNode;
235          var sb = new StringBuilder();
236          if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
237            var solutionsEnsemble = CreateEnsembleSolution(targetRuns);
238            junctionNode = new JunctionNetworkNode { Label = solutionsEnsemble.TrainingRSquared.ToString("N3", CultureInfo.CurrentCulture), Data = solutionsEnsemble };
239            vn.AddVertex(junctionNode);
240            nodes[junctionLabel] = junctionNode;
241            sb.AppendLine(junctionNode.Label);
242          }
243          IArc arc;
244          foreach (var v in variables) {
245            var impact = variableImpacts[v];
246            if (!nodes.TryGetValue(v, out variableNode)) {
247              variableNode = new VariableNetworkNode { Label = v };
248              vn.AddVertex(variableNode);
249              nodes[v] = variableNode;
250            }
251            arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture) };
252            sb.AppendLine(v + ": " + arc.Label);
253            vn.AddArc(arc);
254          }
255          var jcnNode = (JunctionNetworkNode)junctionNode;
256          var trainingR2 = ((IRegressionSolution)jcnNode.Data).TrainingRSquared;
257          arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = trainingR2.ToString("N3", CultureInfo.CurrentCulture) };
258          vn.AddArc(arc);
259        } else {
260          foreach (var v in variables) {
261            var impact = variableImpacts[v];
262            if (!nodes.TryGetValue(v, out variableNode)) {
263              variableNode = new VariableNetworkNode { Label = v };
264              vn.AddVertex(variableNode);
265              nodes[v] = variableNode;
266            }
267            var arc = new Arc(variableNode, targetNode) {
268              Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture)
269            };
270            vn.AddArc(arc);
271          }
272        }
273      }
274      return vn;
275    }
276
277    private static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) {
278      var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList();
279      if (!arcs.Any()) return originalNetwork;
280      var filteredNetwork = new VariableInteractionNetwork();
281      var cloner = new Cloner();
282      var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned
283      filteredNetwork.AddVertices(vertices);
284      filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner)));
285
286      var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList();
287      filteredNetwork.RemoveVertices(unusedJunctions);
288      var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList();
289      filteredNetwork.RemoveVertices(orphanedNodes);
290      return filteredNetwork.Vertices.Any() ? filteredNetwork : originalNetwork;
291    }
292
293    private static double CalculateAverageQuality(RunCollection runs) {
294      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
295      var target = pd.TargetVariable;
296      var inputs = pd.AllowedInputVariables;
297
298      if (!runs.All(x => {
299        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
300        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
301      })) {
302        throw new ArgumentException("All runs must have the same target and inputs.");
303      }
304      return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value);
305    }
306
307    private static Dictionary<string, double> CalculateAverageImpacts(RunCollection runs, string resultName) {
308      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
309      var target = pd.TargetVariable;
310      var inputs = pd.AllowedInputVariables.ToList();
311
312      var impacts = inputs.ToDictionary(x => x, x => 0d);
313
314      // check if all the runs have the same target and same inputs
315      if (!runs.All(x => {
316        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
317        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
318      })) {
319        throw new ArgumentException("All runs must have the same target and inputs.");
320      }
321
322      foreach (var run in runs) {
323        var impactsMatrix = (DoubleMatrix)run.Results[resultName];
324        int i = 0;
325        foreach (var v in impactsMatrix.RowNames) {
326          impacts[v] += impactsMatrix[i, 0];
327          ++i;
328        }
329      }
330
331      foreach (var v in inputs) {
332        impacts[v] /= runs.Count;
333      }
334
335      return impacts;
336    }
337
338    private static string Concatenate(IEnumerable<string> strings) {
339      var sb = new StringBuilder();
340      foreach (var s in strings) {
341        sb.Append(s);
342      }
343      return sb.ToString();
344    }
345
346    private void ConfigureNodeShapes() {
347      graphChart.ClearShapes();
348      var font = new Font(FontFamily.GenericSansSerif, 12);
349      graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font));
350      graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font));
351    }
352
353    #region events
354    protected override void OnContentChanged() {
355      base.OnContentChanged();
356      var run = Content.First();
357      var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
358      var variables = new HashSet<string>(new List<string>(pd.Dataset.DoubleVariables));
359      impactResultNameComboBox.Items.Clear();
360      foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) {
361        var m = (DoubleMatrix)result.Value;
362        if (m.RowNames.All(x => variables.Contains(x)))
363          impactResultNameComboBox.Items.Add(result.Key);
364      }
365      qualityResultNameComboBox.Items.Clear();
366      foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) {
367        qualityResultNameComboBox.Items.Add(result.Key);
368      }
369      if (impactResultNameComboBox.Items.Count > 0) {
370        impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0];
371      }
372      if (qualityResultNameComboBox.Items.Count > 0) {
373        qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0];
374      }
375      if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0)
376        NetworkConfigurationChanged(this, EventArgs.Empty);
377    }
378
379    private void TextBoxValidating(object sender, CancelEventArgs e) {
380      double v;
381      string errorMsg = "Could not parse the entered value. Please input a real number.";
382      var tb = (TextBox)sender;
383      if (!double.TryParse(tb.Text, out v)) {
384        e.Cancel = true;
385        tb.Select(0, tb.Text.Length);
386
387        // Set the ErrorProvider error with the text to display.
388        this.errorProvider.SetError(tb, errorMsg);
389        errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink;
390        errorProvider.SetIconPadding(tb, -20);
391      }
392    }
393
394    private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) {
395      var tb = (TextBox)sender;
396      errorProvider.SetError(tb, string.Empty);
397      double impact;
398      if (!double.TryParse(tb.Text, out impact)) {
399        impact = 0.2;
400      }
401      var network = ApplyThreshold(variableInteractionNetwork, impact);
402      graphChart.Graph = network;
403    }
404
405    private void LayoutConfigurationBoxValidated(object sender, EventArgs e) {
406      var tb = (TextBox)sender;
407      errorProvider.SetError(tb, string.Empty);
408      LayoutConfigurationChanged(sender, e);
409    }
410
411    private void NetworkConfigurationChanged(object sender, EventArgs e) {
412      var useBest = impactAggregationComboBox.SelectedIndex <= 0;
413      var threshold = impactThresholdTrackBar.Value / 100.0;
414      var qualityResultName = qualityResultNameComboBox.Text;
415      var impactsResultName = impactResultNameComboBox.Text;
416      if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName))
417        return;
418      var maximization = maximizationCheckBox.Checked;
419      var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest);
420      variableInteractionNetwork = CreateNetwork(impacts);
421      var network = ApplyThreshold(variableInteractionNetwork, threshold);
422      graphChart.Graph = network;
423    }
424
425    private void LayoutConfigurationChanged(object sender, EventArgs e) {
426      ConstrainedForceDirectedLayout.EdgeRouting routingMode;
427      switch (edgeRoutingComboBox.SelectedIndex) {
428        case 0:
429          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None;
430          break;
431        case 1:
432          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline;
433          break;
434        case 2:
435          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal;
436          break;
437        default:
438          throw new ArgumentException("Invalid edge routing mode.");
439      }
440      var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text);
441      if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.IdealEdgeLength)) return;
442      graphChart.RoutingMode = routingMode;
443      graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None;
444      graphChart.IdealEdgeLength = idealEdgeLength;
445      graphChart.Draw();
446    }
447
448    private void ControlsEnable(bool enabled) {
449      qualityResultNameComboBox.Enabled
450        = impactResultNameComboBox.Enabled
451        = impactAggregationComboBox.Enabled
452        = impactThresholdTrackBar.Enabled
453        = onlineImpactCalculationButton.Enabled
454        = edgeRoutingComboBox.Enabled
455        = idealEdgeLengthTextBox.Enabled
456        = maximizationCheckBox.Enabled = enabled;
457    }
458
459    private void onlineImpactCalculationButton_Click(object sender, EventArgs args) {
460      var worker = new BackgroundWorker();
461      worker.DoWork += (o, e) => {
462        ControlsEnable(false);
463        var impacts = CalculateVariableImpactsOnline(Content, false);
464        variableInteractionNetwork = CreateNetwork(impacts);
465        var threshold = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
466        graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold);
467      };
468      worker.RunWorkerCompleted += (o, e) => ControlsEnable(true);
469      worker.RunWorkerAsync();
470    }
471
472    private void relayoutGraphButton_Click(object sender, EventArgs e) {
473      graphChart.Draw();
474    }
475    #endregion
476
477    private void exportImpactsMatrixButton_Click(object sender, EventArgs e) {
478      var graph = graphChart.Graph;
479      var labels = graph.Vertices.OfType<VariableNetworkNode>().Select(x => x.Label).ToList();
480      labels.Sort(); // sort variables alphabetically
481      var matrix = new DoubleMatrix(labels.Count, labels.Count) { RowNames = labels, ColumnNames = labels };
482      var indexes = labels.Select((x, i) => new { Label = x, Index = i }).ToDictionary(x => x.Label, x => x.Index);
483      var junctions = graph.Vertices.OfType<JunctionNetworkNode>().ToList();
484      foreach (var jn in junctions) {
485        var target = jn.OutArcs.First().Target.Label;
486        var targetIndex = indexes[target];
487        foreach (var input in jn.InArcs) {
488          var inputIndex = indexes[input.Source.Label];
489          var inputImpact = input.Weight;
490          matrix[targetIndex, inputIndex] = inputImpact;
491        }
492      }
493      for (int i = 0; i < labels.Count; ++i) matrix[i, i] = 1;
494      MainFormManager.MainForm.ShowContent(matrix);
495    }
496
497    private void impactThresholdTrackBar_ValueChanged(object sender, EventArgs e) {
498      var impact = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
499      impactThresholdLabel.Text = impact.ToString("N3", CultureInfo.CurrentCulture);
500      var network = ApplyThreshold(variableInteractionNetwork, impact);
501      graphChart.Graph = network;
502    }
503
504
505  }
506}
Note: See TracBrowser for help on using the repository browser.