source: branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks.Views/3.3/RunCollectionVariableInteractionNetworkView.cs @ 13806

Last change on this file since 13806 was 13806, checked in by bburlacu, 5 years ago

#2288: Remove TinySet.cs in favor of a more general method for generating k-combinations. Improve target variation experiment generation. Refactored code and avoided some corner case exceptions.

File size: 20.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.ComponentModel;
25using System.Drawing;
26using System.Linq;
27using System.Text;
28using System.Windows.Forms;
29using HeuristicLab.Common;
30using HeuristicLab.Core;
31using HeuristicLab.Core.Views;
32using HeuristicLab.Data;
33using HeuristicLab.MainForm;
34using HeuristicLab.Optimization;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Visualization;
37using Ellipse = HeuristicLab.Visualization.Ellipse;
38using Rectangle = HeuristicLab.Visualization.Rectangle;
39
40namespace HeuristicLab.VariableInteractionNetworks.Views {
41  [View("Variable Interaction Network")]
42  [Content(typeof(RunCollection), IsDefaultView = false)]
43
44  public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView {
45    public RunCollectionVariableInteractionNetworkView() {
46      InitializeComponent();
47      ConfigureNodeShapes();
48    }
49
50    public new RunCollection Content {
51      get { return (RunCollection)base.Content; }
52      set {
53        if (value != null && value != Content) {
54          base.Content = value;
55        }
56      }
57    }
58
59    private VariableInteractionNetwork variableInteractionNetwork;
60
61    private static void AssertSameProblemData(RunCollection runs) {
62      IDataset dataset = null;
63      IRegressionProblemData problemData = null;
64      foreach (var run in runs) {
65        var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
66        var ds = solution.ProblemData.Dataset;
67
68        if (solution.ProblemData == problemData) continue;
69        if (ds == dataset) continue;
70        if (problemData == null) {
71          problemData = solution.ProblemData;
72          continue;
73        }
74        if (dataset == null) {
75          dataset = ds;
76          continue;
77        }
78
79        if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End)
80          throw new InvalidOperationException("The runs must share the same data.");
81
82        if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables))
83          throw new InvalidOperationException("The runs must share the same data.");
84
85        foreach (var v in ds.DoubleVariables) {
86          var values1 = (IList<double>)ds.GetReadOnlyDoubleValues(v);
87          var values2 = (IList<double>)dataset.GetReadOnlyDoubleValues(v);
88
89          if (values1.Count != values2.Count)
90            throw new InvalidOperationException("The runs must share the same data.");
91
92          if (!values1.SequenceEqual(values2))
93            throw new InvalidOperationException("The runs must share the same data.");
94        }
95      }
96    }
97
98    private static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable<IRun> runs) {
99      var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast<IRegressionSolution>();
100      return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData);
101    }
102
103    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) {
104      AssertSameProblemData(runs);
105      var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution);
106      var dataset = (Dataset)solution.ProblemData.Dataset;
107      var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList());
108      var md = dataset.ToModifiable();
109      var medians = new Dictionary<string, List<double>>();
110      foreach (var v in dataset.DoubleVariables) {
111        var median = dataset.GetDoubleValues(v, solution.ProblemData.TrainingIndices).Median();
112        medians[v] = Enumerable.Repeat(median, originalValues[v].Count).ToList();
113      }
114
115      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
116
117      if (useBest) {
118        // build network using only the best run for each target
119      } else {
120        var groups = runs.GroupBy(run => {
121          var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
122          return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable;
123        });
124
125        foreach (var group in groups) {
126          // calculate average impacts
127          var averageImpacts = new Dictionary<string, double>();
128          solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution);
129          foreach (var run in group) {
130            var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution);
131
132            DoubleLimit estimationLimits = null;
133            if (run.Parameters.ContainsKey("EstimationLimits")) {
134              estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"];
135            }
136            var impacts = CalculateImpacts(sol, md, originalValues, medians, estimationLimits);
137            //            var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol).ToDictionary(x => x.Item1, x => x.Item2);
138            foreach (var pair in impacts) {
139              if (averageImpacts.ContainsKey(pair.Key))
140                averageImpacts[pair.Key] += pair.Value;
141              else {
142                averageImpacts[pair.Key] = pair.Value;
143              }
144            }
145          }
146          var count = group.Count();
147          var keys = averageImpacts.Keys.ToList();
148          foreach (var v in keys) {
149            averageImpacts[v] /= count;
150          }
151
152          targetImpacts[solution.ProblemData.TargetVariable] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(group, averageImpacts);
153        }
154      }
155      return targetImpacts;
156    }
157
158    private static Dictionary<string, double> CalculateImpacts(IRegressionSolution solution, ModifiableDataset dataset,
159      Dictionary<string, List<double>> originalValues, Dictionary<string, List<double>> medianValues, DoubleLimit estimationLimits = null) {
160      var impacts = new Dictionary<string, double>();
161
162      var model = solution.Model;
163      var pd = solution.ProblemData;
164
165      var rows = pd.TrainingIndices.ToList();
166      var targetValues = pd.Dataset.GetDoubleValues(pd.TargetVariable, rows).ToList();
167
168
169      foreach (var v in pd.AllowedInputVariables) {
170        dataset.ReplaceVariable(v, medianValues[v]);
171
172        var estimatedValues = model.GetEstimatedValues(dataset, rows);
173        if (estimationLimits != null)
174          estimatedValues = estimatedValues.LimitToRange(estimationLimits.Lower, estimationLimits.Upper);
175
176        OnlineCalculatorError error;
177        var r = OnlinePearsonsRCalculator.Calculate(targetValues, estimatedValues, out error);
178        var newQuality = error == OnlineCalculatorError.None ? r * r : double.NaN;
179        var originalQuality = solution.TrainingRSquared;
180        impacts[v] = originalQuality - newQuality;
181
182        dataset.ReplaceVariable(v, originalValues[v]);
183      }
184      return impacts;
185    }
186
187    private static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsFromRunResults(RunCollection runs,
188      string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) {
189      var targets = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
190      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
191      if (useBestRunsPerTarget) {
192        var bestRunsPerTarget = maximization
193          ? targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).Last())
194          : targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).First());
195
196        foreach (var run in bestRunsPerTarget) {
197          var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
198          var target = pd.TargetVariable;
199          var impacts = (DoubleMatrix)run.Results[impactsResultName];
200          targetImpacts[target] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(new[] { run }, impacts.RowNames.Select((x, i) => new { Name = x, Index = i }).ToDictionary(x => x.Name, x => impacts[x.Index, 0]));
201        }
202      } else {
203        foreach (var target in targets) {
204          var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName);
205          targetImpacts[target.Key] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(target, averageImpacts);
206        }
207      }
208      return targetImpacts;
209    }
210
211    private static VariableInteractionNetwork CreateNetwork(Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> targetImpacts) {
212      var nodes = new Dictionary<string, IVertex>();
213      var vn = new VariableInteractionNetwork();
214      foreach (var ti in targetImpacts) {
215        var target = ti.Key;
216        var variableImpacts = ti.Value.Item2;
217        var targetRuns = ti.Value.Item1;
218        IVertex targetNode;
219
220        var variables = variableImpacts.Keys.ToList();
221        if (variables.Count == 0) continue;
222
223        if (!nodes.TryGetValue(target, out targetNode)) {
224          targetNode = new VariableNetworkNode { Label = target };
225          vn.AddVertex(targetNode);
226          nodes[target] = targetNode;
227        }
228
229        IVertex variableNode;
230        if (variables.Count > 1) {
231          var variableList = new List<string>(variables) { target };
232          var junctionLabel = Concatenate(variableList);
233          IVertex junctionNode;
234          if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
235            var solutionsEnsemble = CreateEnsembleSolution(targetRuns);
236            junctionNode = new JunctionNetworkNode { Label = string.Empty, Data = solutionsEnsemble };
237            vn.AddVertex(junctionNode);
238            nodes[junctionLabel] = junctionNode;
239            junctionNode.Label = string.Format("Target quality: {0:0.000}", solutionsEnsemble.TrainingRSquared);
240          }
241          IArc arc;
242          foreach (var v in variables) {
243            var impact = variableImpacts[v];
244            if (!nodes.TryGetValue(v, out variableNode)) {
245              variableNode = new VariableNetworkNode { Label = v };
246              vn.AddVertex(variableNode);
247              nodes[v] = variableNode;
248            }
249            arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = string.Format("Impact: {0:0.000}", impact) };
250            vn.AddArc(arc);
251          }
252          var trainingR2 = ((IRegressionSolution)((JunctionNetworkNode)junctionNode).Data).TrainingRSquared;
253          arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = string.Format("Quality: {0:0.000}", trainingR2) };
254          vn.AddArc(arc);
255        } else {
256          foreach (var v in variables) {
257            var impact = variableImpacts[v];
258            if (!nodes.TryGetValue(v, out variableNode)) {
259              variableNode = new VariableNetworkNode { Label = v };
260              vn.AddVertex(variableNode);
261              nodes[v] = variableNode;
262            }
263            var arc = new Arc(variableNode, targetNode) { Weight = impact, Label = string.Format("Impact: {0:0.000}", impact) };
264            vn.AddArc(arc);
265          }
266        }
267      }
268      return vn;
269    }
270
271    private static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) {
272      var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList();
273      if (!arcs.Any()) return originalNetwork;
274      var filteredNetwork = new VariableInteractionNetwork();
275      var cloner = new Cloner();
276      var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned
277      filteredNetwork.AddVertices(vertices);
278      filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner)));
279
280      var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList();
281      filteredNetwork.RemoveVertices(unusedJunctions);
282      var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList();
283      filteredNetwork.RemoveVertices(orphanedNodes);
284      return filteredNetwork.Vertices.Any() ? filteredNetwork : originalNetwork;
285    }
286
287    private static double CalculateAverageQuality(RunCollection runs) {
288      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
289      var target = pd.TargetVariable;
290      var inputs = pd.AllowedInputVariables;
291
292      if (!runs.All(x => {
293        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
294        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
295      })) {
296        throw new ArgumentException("All runs must have the same target and inputs.");
297      }
298      return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value);
299    }
300
301    private static Dictionary<string, double> CalculateAverageImpacts(RunCollection runs, string resultName) {
302      var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
303      var target = pd.TargetVariable;
304      var inputs = pd.AllowedInputVariables.ToList();
305
306      var impacts = inputs.ToDictionary(x => x, x => 0d);
307
308      // check if all the runs have the same target and same inputs
309      if (!runs.All(x => {
310        var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
311        return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
312      })) {
313        throw new ArgumentException("All runs must have the same target and inputs.");
314      }
315
316      foreach (var run in runs) {
317        var impactsMatrix = (DoubleMatrix)run.Results[resultName];
318
319        int i = 0;
320        foreach (var v in impactsMatrix.RowNames) {
321          impacts[v] += impactsMatrix[i, 0];
322          ++i;
323        }
324      }
325
326      foreach (var v in inputs) {
327        impacts[v] /= runs.Count;
328      }
329
330      return impacts;
331    }
332
333    private static string Concatenate(IEnumerable<string> strings) {
334      var sb = new StringBuilder();
335      foreach (var s in strings) {
336        sb.Append(s);
337      }
338      return sb.ToString();
339    }
340
341    private void ConfigureNodeShapes() {
342      graphChart.ClearShapes();
343      var font = new Font(FontFamily.GenericSansSerif, 12);
344      graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font));
345      graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font));
346    }
347
348    #region events
349    protected override void OnContentChanged() {
350      base.OnContentChanged();
351      var run = Content.First();
352      var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
353      var variables = new HashSet<string>(new List<string>(pd.Dataset.DoubleVariables));
354      impactResultNameComboBox.Items.Clear();
355      foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) {
356        var m = (DoubleMatrix)result.Value;
357        if (m.RowNames.All(x => variables.Contains(x)))
358          impactResultNameComboBox.Items.Add(result.Key);
359      }
360      qualityResultNameComboBox.Items.Clear();
361      foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) {
362        qualityResultNameComboBox.Items.Add(result.Key);
363      }
364      if (impactResultNameComboBox.Items.Count > 0) {
365        impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0];
366      }
367      if (qualityResultNameComboBox.Items.Count > 0) {
368        qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0];
369      }
370      if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0)
371        NetworkConfigurationChanged(this, EventArgs.Empty);
372    }
373
374    private void TextBoxValidating(object sender, CancelEventArgs e) {
375      double v;
376      string errorMsg = "Could not parse the entered value. Please input a real number.";
377      var tb = (TextBox)sender;
378      if (!double.TryParse(tb.Text, out v)) {
379        e.Cancel = true;
380        tb.Select(0, tb.Text.Length);
381
382        // Set the ErrorProvider error with the text to display.
383        this.errorProvider.SetError(tb, errorMsg);
384        errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink;
385        errorProvider.SetIconPadding(tb, -20);
386      }
387    }
388
389    private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) {
390      var tb = (TextBox)sender;
391      errorProvider.SetError(tb, string.Empty);
392      double impact;
393      if (!double.TryParse(tb.Text, out impact))
394        impact = 0.1;
395      var network = ApplyThreshold(variableInteractionNetwork, impact);
396      graphChart.Graph = network;
397    }
398
399    private void LayoutConfigurationBoxValidated(object sender, EventArgs e) {
400      var tb = (TextBox)sender;
401      errorProvider.SetError(tb, string.Empty);
402      LayoutConfigurationChanged(sender, e);
403    }
404
405    private void NetworkConfigurationChanged(object sender, EventArgs e) {
406      var useBest = impactAggregationComboBox.SelectedIndex <= 0;
407      var threshold = double.Parse(impactThresholdTextBox.Text);
408      var qualityResultName = qualityResultNameComboBox.Text;
409      var impactsResultName = impactResultNameComboBox.Text;
410      if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName))
411        return;
412      var maximization = maximizationCheckBox.Checked;
413      var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest);
414      variableInteractionNetwork = CreateNetwork(impacts);
415      var network = ApplyThreshold(variableInteractionNetwork, threshold);
416      graphChart.Graph = network;
417    }
418
419    private void LayoutConfigurationChanged(object sender, EventArgs e) {
420      ConstrainedForceDirectedLayout.EdgeRouting routingMode;
421      switch (edgeRoutingComboBox.SelectedIndex) {
422        case 0:
423          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None;
424          break;
425        case 1:
426          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline;
427          break;
428        case 2:
429          routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal;
430          break;
431        default:
432          throw new ArgumentException("Invalid edge routing mode.");
433      }
434      var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text);
435      if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.IdealEdgeLength)) return;
436      graphChart.RoutingMode = routingMode;
437      graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None;
438      graphChart.IdealEdgeLength = idealEdgeLength;
439      graphChart.Draw();
440    }
441
442    private void onlineImpactCalculationButton_Click(object sender, EventArgs args) {
443      var button = (Button)sender;
444      var worker = new BackgroundWorker();
445      worker.DoWork += (o, e) => {
446        button.Enabled = false;
447        var impacts = CalculateVariableImpactsOnline(Content, false);
448        variableInteractionNetwork = CreateNetwork(impacts);
449        var threshold = double.Parse(impactThresholdTextBox.Text);
450        graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold);
451      };
452      worker.RunWorkerCompleted += (o, e) => button.Enabled = true;
453      worker.RunWorkerAsync();
454    }
455    #endregion
456  }
457}
Note: See TracBrowser for help on using the repository browser.