Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2288_HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks.Views/3.3/RunCollectionVariableInteractionNetworkView.cs @ 16497

Last change on this file since 16497 was 16497, checked in by jzenisek, 5 years ago

#2288:

  • added possibility to create simple networks (only one input var set per target var, i.e. without junction nodes)
  • fixed minor enumeration bug
  • enabled network view updates
File size: 22.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.ComponentModel;
25using System.Drawing;
26using System.Globalization;
27using System.Linq;
28using System.Text;
29using System.Windows.Forms;
30using HeuristicLab.Common;
31using HeuristicLab.Core;
32using HeuristicLab.Core.Views;
33using HeuristicLab.Data;
34using HeuristicLab.MainForm;
35using HeuristicLab.Optimization;
36using HeuristicLab.Problems.DataAnalysis;
37using HeuristicLab.Visualization;
38using Ellipse = HeuristicLab.Visualization.Ellipse;
39using Rectangle = HeuristicLab.Visualization.Rectangle;
40
41namespace HeuristicLab.VariableInteractionNetworks.Views {
42  [View("Variable Interaction Network")]
43  [Content(typeof(RunCollection), IsDefaultView = false)]
44
45  public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView {
46    public RunCollectionVariableInteractionNetworkView() {
47      InitializeComponent();
48      ConfigureNodeShapes();
49    }
50
51    public new RunCollection Content {
52      get { return (RunCollection)base.Content; }
53      set {
54        if (value != null && value != Content) {
55          base.Content = value;
56        }
57      }
58    }
59
60    private VariableInteractionNetwork variableInteractionNetwork;
61
62    private static void AssertSameProblemData(RunCollection runs) {
63      IDataset dataset = null;
64      IRegressionProblemData problemData = null;
65      foreach (var run in runs) {
66        var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
67        var ds = solution.ProblemData.Dataset;
68
69        if (solution.ProblemData == problemData) continue;
70        if (ds == dataset) continue;
71        if (problemData == null) {
72          problemData = solution.ProblemData;
73          continue;
74        }
75        if (dataset == null) {
76          dataset = ds;
77          continue;
78        }
79
80        if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End)
81          throw new InvalidOperationException("The runs must share the same data.");
82
83        if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables))
84          throw new InvalidOperationException("The runs must share the same data.");
85
86        foreach (var v in ds.DoubleVariables) {
87          var values1 = (IList<double>)ds.GetReadOnlyDoubleValues(v);
88          var values2 = (IList<double>)dataset.GetReadOnlyDoubleValues(v);
89
90          if (values1.Count != values2.Count)
91            throw new InvalidOperationException("The runs must share the same data.");
92
93          if (!values1.SequenceEqual(values2))
94            throw new InvalidOperationException("The runs must share the same data.");
95        }
96      }
97    }
98
99    public static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable<IRun> runs) {
100      var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast<IRegressionSolution>();
101      return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData);
102    }
103
104    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) {
105      AssertSameProblemData(runs);
106      var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution);
107      var dataset = (Dataset)solution.ProblemData.Dataset;
108      var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList());
109      var medians = dataset.DoubleVariables.ToDictionary(x => x, x => Enumerable.Repeat(originalValues[x].Median(), originalValues[x].Count).ToList());
110
111      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
112
113      var groups = runs.GroupBy(run => {
114        var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
115        return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable;
116      });
117
118      if (useBest) {
119        // build network using only the best run for each target
120        foreach (var group in groups) {
121          var solutions = group.Select(run => Tuple.Create(run, (IRegressionSolution)run.Results.Values.Single(sol => sol is IRegressionSolution)));
122          var best = solutions.OrderBy(x => x.Item2.TrainingRSquared).Last();
123          var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(best.Item2, RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.All, RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle).ToDictionary(x => x.Item1, x => x.Item2);
124
125          targetImpacts[best.Item2.ProblemData.TargetVariable] = Tuple.Create(new[] { best.Item1 }.AsEnumerable(), impacts);
126        }
127      } else {
128        foreach (var group in groups) {
129          // calculate average impacts
130          var averageImpacts = new Dictionary<string, double>();
131          solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution);
132          foreach (var run in group) {
133            var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution);
134
135            DoubleLimit estimationLimits = null;
136            if (run.Parameters.ContainsKey("EstimationLimits")) {
137              estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"];
138            }
139            var md = dataset.ToModifiable();
140
141            var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol, RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.All, RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle);
142            foreach (var t in impacts) {
143              if (averageImpacts.ContainsKey(t.Item1))
144                averageImpacts[t.Item1] += t.Item2;
145              else {
146                averageImpacts[t.Item1] = t.Item2;
147              }
148            }
149          }
150
151          var count = group.Count();
152          foreach (var v in averageImpacts.Keys.ToList()) {
153            averageImpacts[v] /= count;
154          }
155
156          targetImpacts[solution.ProblemData.TargetVariable] = Tuple.Create(group.AsEnumerable(), averageImpacts);
157        }
158      }
159      return targetImpacts;
160    }
161
162    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsFromRunResults(RunCollection runs,
163      string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) {
164
165      Func<IRun, double> getQuality = run => ((DoubleValue)run.Results[qualityResultName]).Value;
166      var targetGroups = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
167      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
168
169      if (useBestRunsPerTarget) {
170        foreach (var group in targetGroups) {
171          var ordered = group.OrderBy(getQuality);
172          var best = maximization ? ordered.Last() : ordered.First();
173          var pd = (IRegressionProblemData)best.Parameters["ProblemData"];
174          var target = group.Key;
175          var impacts = (DoubleMatrix)best.Results[impactsResultName];
176          targetImpacts[target] = Tuple.Create((IEnumerable<IRun>)new[] { best }, impacts.RowNames.Select((x, i) => new { x, i }).ToDictionary(x => x.x, x => impacts[x.i, 0]));
177        }
178      } else {
179        foreach (var target in targetGroups) {
180          var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName);
181          targetImpacts[target.Key] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(target, averageImpacts);
182        }
183      }
184      return targetImpacts;
185    }
186
187    public static VariableInteractionNetwork CreateNetwork(Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> targetImpacts) {
188      var nodes = new Dictionary<string, IVertex>();
189      var vn = new VariableInteractionNetwork();
190      foreach (var ti in targetImpacts) {
191        var target = ti.Key;
192        var variableImpacts = ti.Value.Item2;
193        var targetRuns = ti.Value.Item1;
194        IVertex targetNode;
195
196        var variables = variableImpacts.Keys.ToList();
197        if (variables.Count == 0) continue;
198
199        if (!nodes.TryGetValue(target, out targetNode)) {
200          targetNode = new VariableNetworkNode { Label = target };
201          vn.AddVertex(targetNode);
202          nodes[target] = targetNode;
203        }
204
205        IVertex variableNode;
206        if (variables.Count > 1) {
207          var variableList = new List<string>(variables) { target };
208          var junctionLabel = Concatenate(variableList);
209          IVertex junctionNode;
210          var sb = new StringBuilder();
211          if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
212            var solutionsEnsemble = CreateEnsembleSolution(targetRuns);
213            junctionNode = new JunctionNetworkNode { Label = solutionsEnsemble.TrainingRSquared.ToString("N3", CultureInfo.CurrentCulture), Data = solutionsEnsemble };
214            vn.AddVertex(junctionNode);
215            nodes[junctionLabel] = junctionNode;
216            sb.AppendLine(junctionNode.Label);
217          }
218          IArc arc;
219          foreach (var v in variables) {
220            var impact = variableImpacts[v];
221            if (!nodes.TryGetValue(v, out variableNode)) {
222              variableNode = new VariableNetworkNode { Label = v };
223              vn.AddVertex(variableNode);
224              nodes[v] = variableNode;
225            }
226            arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture) };
227            sb.AppendLine(v + ": " + arc.Label);
228            vn.AddArc(arc);
229          }
230          var jcnNode = (JunctionNetworkNode)junctionNode;
231          var trainingR2 = ((IRegressionSolution)jcnNode.Data).TrainingRSquared;
232          arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = trainingR2.ToString("N3", CultureInfo.CurrentCulture) };
233          vn.AddArc(arc);
234        } else {
235          foreach (var v in variables) {
236            var impact = variableImpacts[v];
237            if (!nodes.TryGetValue(v, out variableNode)) {
238              variableNode = new VariableNetworkNode { Label = v };
239              vn.AddVertex(variableNode);
240              nodes[v] = variableNode;
241            }
242            var arc = new Arc(variableNode, targetNode) {
243              Weight = impact,
244              Label = impact.ToString("N3", CultureInfo.CurrentCulture)
245            };
246            vn.AddArc(arc);
247          }
248        }
249      }
250      return vn;
251    }
252
253    public static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) {
254      var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList();
255      if (!arcs.Any()) return originalNetwork;
256      var filteredNetwork = new VariableInteractionNetwork();
257      var cloner = new Cloner();
258      var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned
259      filteredNetwork.AddVertices(vertices);
260      filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner)));
261
262      var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList();
263      filteredNetwork.RemoveVertices(unusedJunctions);
264      var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList();
265      filteredNetwork.RemoveVertices(orphanedNodes);
266      return filteredNetwork.Vertices.Any() ? filteredNetwork : originalNetwork;
267    }
268
269    private static double CalculateAverageQuality(RunCollection runs) {
270      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
271      var target = pd.TargetVariable;
272      var inputs = pd.AllowedInputVariables;
273
274      if (!runs.All(x => {
275        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
276        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
277      })) {
278        throw new ArgumentException("All runs must have the same target and inputs.");
279      }
280      return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value);
281    }
282
283    public static Dictionary<string, double> CalculateAverageImpacts(RunCollection runs, string resultName) {
284      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
285      var target = pd.TargetVariable;
286      var inputs = pd.AllowedInputVariables.ToList();
287
288      var impacts = inputs.ToDictionary(x => x, x => 0d);
289
290      // check if all the runs have the same target and same inputs
291      if (!runs.All(x => {
292        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
293        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
294      })) {
295        throw new ArgumentException("All runs must have the same target and inputs.");
296      }
297
298      foreach (var run in runs) {
299        var impactsMatrix = (DoubleMatrix)run.Results[resultName];
300        int i = 0;
301        foreach (var v in impactsMatrix.RowNames) {
302          impacts[v] += impactsMatrix[i, 0];
303          ++i;
304        }
305      }
306
307      foreach (var v in inputs) {
308        impacts[v] /= runs.Count;
309      }
310
311      return impacts;
312    }
313
314    private static string Concatenate(IEnumerable<string> strings) {
315      var sb = new StringBuilder();
316      foreach (var s in strings) {
317        sb.Append(s);
318      }
319      return sb.ToString();
320    }
321
322    private void ConfigureNodeShapes() {
323      graphChart.ClearShapes();
324      var font = new Font(FontFamily.GenericSansSerif, 12);
325      graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font));
326      graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font));
327    }
328
329    public void UpdateNetwork(VariableInteractionNetwork network) {
330      if (InvokeRequired) {
331        Invoke((Action<VariableInteractionNetwork>)UpdateNetwork, network);
332      } else {
333        graphChart.Graph = network;
334      }     
335    }
336
337    #region events
338    protected override void OnContentChanged() {
339      base.OnContentChanged();
340      var run = Content.First();
341      var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
342      var variables = new HashSet<string>(new List<string>(pd.Dataset.DoubleVariables));
343      impactResultNameComboBox.Items.Clear();
344      foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) {
345        var m = (DoubleMatrix)result.Value;
346        if (m.RowNames.All(x => variables.Contains(x)))
347          impactResultNameComboBox.Items.Add(result.Key);
348      }
349      qualityResultNameComboBox.Items.Clear();
350      foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) {
351        qualityResultNameComboBox.Items.Add(result.Key);
352      }
353      if (impactResultNameComboBox.Items.Count > 0) {
354        impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0];
355      }
356      if (qualityResultNameComboBox.Items.Count > 0) {
357        qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0];
358      }
359      if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0)
360        NetworkConfigurationChanged(this, EventArgs.Empty);
361    }
362
363    private void TextBoxValidating(object sender, CancelEventArgs e) {
364      double v;
365      string errorMsg = "Could not parse the entered value. Please input a real number.";
366      var tb = (TextBox)sender;
367      if (!double.TryParse(tb.Text, out v)) {
368        e.Cancel = true;
369        tb.Select(0, tb.Text.Length);
370
371        // Set the ErrorProvider error with the text to display.
372        this.errorProvider.SetError(tb, errorMsg);
373        errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink;
374        errorProvider.SetIconPadding(tb, -20);
375      }
376    }
377
378    private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) {
379      var tb = (TextBox)sender;
380      errorProvider.SetError(tb, string.Empty);
381      double impact;
382      if (!double.TryParse(tb.Text, out impact)) {
383        impact = 0.2;
384      }
385      var network = ApplyThreshold(variableInteractionNetwork, impact);
386      graphChart.Graph = network;
387    }
388
389    private void LayoutConfigurationBoxValidated(object sender, EventArgs e) {
390      var tb = (TextBox)sender;
391      errorProvider.SetError(tb, string.Empty);
392      LayoutConfigurationChanged(sender, e);
393    }
394
395    private void NetworkConfigurationChanged(object sender, EventArgs e) {
396      var useBest = impactAggregationComboBox.SelectedIndex <= 0;
397      var threshold = impactThresholdTrackBar.Value / 100.0;
398      var qualityResultName = qualityResultNameComboBox.Text;
399      var impactsResultName = impactResultNameComboBox.Text;
400      if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName))
401        return;
402      var maximization = maximizationCheckBox.Checked;
403      var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest);
404      variableInteractionNetwork = CreateNetwork(impacts);
405      var network = ApplyThreshold(variableInteractionNetwork, threshold);
406      graphChart.Graph = network;
407    }
408
409    private void LayoutConfigurationChanged(object sender, EventArgs e) {
410      ConstrainedForceDirectedLayout.EdgeRouting routingMode;
411      switch (edgeRoutingComboBox.SelectedIndex) {
412        case 0:
413          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None;
414          break;
415        case 1:
416          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline;
417          break;
418        case 2:
419          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal;
420          break;
421        default:
422          throw new ArgumentException("Invalid edge routing mode.");
423      }
424      var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text);
425      if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.DefaultEdgeLength)) return;
426      graphChart.RoutingMode = routingMode;
427      graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None;
428      graphChart.DefaultEdgeLength = idealEdgeLength;
429      graphChart.Draw();
430    }
431
432    private void ControlsEnable(bool enabled) {
433      qualityResultNameComboBox.Enabled
434        = impactResultNameComboBox.Enabled
435        = impactAggregationComboBox.Enabled
436        = impactThresholdTrackBar.Enabled
437        = onlineImpactCalculationButton.Enabled
438        = edgeRoutingComboBox.Enabled
439        = idealEdgeLengthTextBox.Enabled
440        = maximizationCheckBox.Enabled = enabled;
441    }
442
443    private void onlineImpactCalculationButton_Click(object sender, EventArgs args) {
444      var worker = new BackgroundWorker();
445      worker.DoWork += (o, e) => {
446        ControlsEnable(false);
447        var impacts = CalculateVariableImpactsOnline(Content, impactAggregationComboBox.SelectedIndex == 0);
448        variableInteractionNetwork = CreateNetwork(impacts);
449        var threshold = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
450        graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold);
451      };
452      worker.RunWorkerCompleted += (o, e) => ControlsEnable(true);
453      worker.RunWorkerAsync();
454    }
455
456    private void relayoutGraphButton_Click(object sender, EventArgs e) {
457      graphChart.Draw();
458    }
459    #endregion
460
461    private void exportImpactsMatrixButton_Click(object sender, EventArgs e) {
462      var graph = graphChart.Graph;
463      var labels = graph.Vertices.OfType<VariableNetworkNode>().Select(x => x.Label).ToList();
464      labels.Sort(); // sort variables alphabetically
465      var matrix = new DoubleMatrix(labels.Count, labels.Count) { RowNames = labels, ColumnNames = labels };
466      var indexes = labels.Select((x, i) => new { Label = x, Index = i }).ToDictionary(x => x.Label, x => x.Index);
467      var junctions = graph.Vertices.OfType<JunctionNetworkNode>().ToList();
468      foreach (var jn in junctions) {
469        var target = jn.OutArcs.First().Target.Label;
470        var targetIndex = indexes[target];
471        foreach (var input in jn.InArcs) {
472          var inputIndex = indexes[input.Source.Label];
473          var inputImpact = input.Weight;
474          matrix[targetIndex, inputIndex] = inputImpact;
475        }
476      }
477      for (int i = 0; i < labels.Count; ++i) matrix[i, i] = 1;
478      MainFormManager.MainForm.ShowContent(matrix);
479    }
480
481    private void impactThresholdTrackBar_ValueChanged(object sender, EventArgs e) {
482      var impact = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
483      impactThresholdLabel.Text = impact.ToString("N3", CultureInfo.CurrentCulture);
484      var network = ApplyThreshold(variableInteractionNetwork, impact);
485      graphChart.Graph = network;
486    }
487
488
489  }
490}
Note: See TracBrowser for help on using the repository browser.