#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.ComponentModel; using System.Drawing; using System.Globalization; using System.Linq; using System.Text; using System.Windows.Forms; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Core.Views; using HeuristicLab.Data; using HeuristicLab.MainForm; using HeuristicLab.Optimization; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Visualization; using Ellipse = HeuristicLab.Visualization.Ellipse; using Rectangle = HeuristicLab.Visualization.Rectangle; namespace HeuristicLab.VariableInteractionNetworks.Views { [View("Variable Interaction Network")] [Content(typeof(RunCollection), IsDefaultView = false)] public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView { public RunCollectionVariableInteractionNetworkView() { InitializeComponent(); ConfigureNodeShapes(); } public new RunCollection Content { get { return (RunCollection)base.Content; } set { if (value != null && value != Content) { base.Content = value; } } } private VariableInteractionNetwork variableInteractionNetwork; private static void AssertSameProblemData(RunCollection runs) { IDataset dataset = null; IRegressionProblemData problemData = null; foreach (var run in runs) { var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution); var ds = solution.ProblemData.Dataset; if (solution.ProblemData == problemData) continue; if (ds == dataset) continue; if (problemData == null) { problemData = solution.ProblemData; continue; } if (dataset == null) { dataset = ds; continue; } if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End) throw new InvalidOperationException("The runs must share the same data."); if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables)) throw new InvalidOperationException("The runs must share the same data."); foreach (var v in ds.DoubleVariables) { var values1 = (IList)ds.GetReadOnlyDoubleValues(v); var values2 = (IList)dataset.GetReadOnlyDoubleValues(v); if (values1.Count != values2.Count) throw new InvalidOperationException("The runs must share the same data."); if (!values1.SequenceEqual(values2)) throw new InvalidOperationException("The runs must share the same data."); } } } public static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable runs) { var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast(); return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData); } public static Dictionary, Dictionary>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) { AssertSameProblemData(runs); var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution); var dataset = (Dataset)solution.ProblemData.Dataset; var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList()); var medians = dataset.DoubleVariables.ToDictionary(x => x, x => Enumerable.Repeat(originalValues[x].Median(), originalValues[x].Count).ToList()); var targetImpacts = new Dictionary, Dictionary>>(); var groups = runs.GroupBy(run => { var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution); return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable; }); if (useBest) { // build network using only the best run for each target foreach (var group in groups) { var solutions = group.Select(run => Tuple.Create(run, (IRegressionSolution)run.Results.Values.Single(sol => sol is IRegressionSolution))); var best = solutions.OrderBy(x => x.Item2.TrainingRSquared).Last(); var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(best.Item2, RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle, RegressionSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Best, RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.All).ToDictionary(x => x.Item1, x => x.Item2); targetImpacts[best.Item2.ProblemData.TargetVariable] = Tuple.Create(new[] { best.Item1 }.AsEnumerable(), impacts); } } else { foreach (var group in groups) { // calculate average impacts var averageImpacts = new Dictionary(); solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution); foreach (var run in group) { var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution); DoubleLimit estimationLimits = null; if (run.Parameters.ContainsKey("EstimationLimits")) { estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"]; } var md = dataset.ToModifiable(); var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol, RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle, RegressionSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Best, RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.All); foreach (var t in impacts) { if (averageImpacts.ContainsKey(t.Item1)) averageImpacts[t.Item1] += t.Item2; else { averageImpacts[t.Item1] = t.Item2; } } } var count = group.Count(); foreach (var v in averageImpacts.Keys.ToList()) { averageImpacts[v] /= count; } targetImpacts[solution.ProblemData.TargetVariable] = Tuple.Create(group.AsEnumerable(), averageImpacts); } } return targetImpacts; } public static Dictionary, Dictionary>> CalculateVariableImpactsFromRunResults(RunCollection runs, string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) { Func getQuality = run => ((DoubleValue)run.Results[qualityResultName]).Value; var targetGroups = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList(); var targetImpacts = new Dictionary, Dictionary>>(); if (useBestRunsPerTarget) { foreach (var group in targetGroups) { var ordered = group.OrderBy(getQuality); var best = maximization ? ordered.Last() : ordered.First(); var pd = (IRegressionProblemData)best.Parameters["ProblemData"]; var target = group.Key; var impacts = (DoubleMatrix)best.Results[impactsResultName]; targetImpacts[target] = Tuple.Create((IEnumerable)new[] { best }, impacts.RowNames.Select((x, i) => new { x, i }).ToDictionary(x => x.x, x => impacts[x.i, 0])); } } else { foreach (var target in targetGroups) { var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName); targetImpacts[target.Key] = new Tuple, Dictionary>(target, averageImpacts); } } return targetImpacts; } public static VariableInteractionNetwork CreateNetwork(Dictionary, Dictionary>> targetImpacts) { var nodes = new Dictionary(); var vn = new VariableInteractionNetwork(); foreach (var ti in targetImpacts) { var target = ti.Key; var variableImpacts = ti.Value.Item2; var targetRuns = ti.Value.Item1; IVertex targetNode; var variables = variableImpacts.Keys.ToList(); if (variables.Count == 0) continue; if (!nodes.TryGetValue(target, out targetNode)) { targetNode = new VariableNetworkNode { Label = target }; vn.AddVertex(targetNode); nodes[target] = targetNode; } IVertex variableNode; if (variables.Count > 1) { var variableList = new List(variables) { target }; var junctionLabel = Concatenate(variableList); IVertex junctionNode; var sb = new StringBuilder(); if (!nodes.TryGetValue(junctionLabel, out junctionNode)) { var solutionsEnsemble = CreateEnsembleSolution(targetRuns); junctionNode = new JunctionNetworkNode { Label = solutionsEnsemble.TrainingRSquared.ToString("N3", CultureInfo.CurrentCulture), Data = solutionsEnsemble }; vn.AddVertex(junctionNode); nodes[junctionLabel] = junctionNode; sb.AppendLine(junctionNode.Label); } IArc arc; foreach (var v in variables) { var impact = variableImpacts[v]; if (!nodes.TryGetValue(v, out variableNode)) { variableNode = new VariableNetworkNode { Label = v }; vn.AddVertex(variableNode); nodes[v] = variableNode; } arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture) }; sb.AppendLine(v + ": " + arc.Label); vn.AddArc(arc); } var jcnNode = (JunctionNetworkNode)junctionNode; var trainingR2 = ((IRegressionSolution)jcnNode.Data).TrainingRSquared; arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = trainingR2.ToString("N3", CultureInfo.CurrentCulture) }; vn.AddArc(arc); } else { foreach (var v in variables) { var impact = variableImpacts[v]; if (!nodes.TryGetValue(v, out variableNode)) { variableNode = new VariableNetworkNode { Label = v }; vn.AddVertex(variableNode); nodes[v] = variableNode; } var arc = new Arc(variableNode, targetNode) { Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture) }; vn.AddArc(arc); } } } return vn; } public static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) { var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList(); if (!arcs.Any()) return originalNetwork; var filteredNetwork = new VariableInteractionNetwork(); var cloner = new Cloner(); var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned filteredNetwork.AddVertices(vertices); filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner))); var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList(); filteredNetwork.RemoveVertices(unusedJunctions); var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList(); filteredNetwork.RemoveVertices(orphanedNodes); return filteredNetwork.Vertices.Any() ? filteredNetwork : originalNetwork; } private static double CalculateAverageQuality(RunCollection runs) { var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"]; var target = pd.TargetVariable; var inputs = pd.AllowedInputVariables; if (!runs.All(x => { var problemData = (IRegressionProblemData)x.Parameters["ProblemData"]; return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables); })) { throw new ArgumentException("All runs must have the same target and inputs."); } return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value); } public static Dictionary CalculateAverageImpacts(RunCollection runs, string resultName) { var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"]; var target = pd.TargetVariable; var inputs = pd.AllowedInputVariables.ToList(); var impacts = inputs.ToDictionary(x => x, x => 0d); // check if all the runs have the same target and same inputs if (!runs.All(x => { var problemData = (IRegressionProblemData)x.Parameters["ProblemData"]; return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables); })) { throw new ArgumentException("All runs must have the same target and inputs."); } foreach (var run in runs) { var impactsMatrix = (DoubleMatrix)run.Results[resultName]; int i = 0; foreach (var v in impactsMatrix.RowNames) { impacts[v] += impactsMatrix[i, 0]; ++i; } } foreach (var v in inputs) { impacts[v] /= runs.Count; } return impacts; } private static string Concatenate(IEnumerable strings) { var sb = new StringBuilder(); foreach (var s in strings) { sb.Append(s); } return sb.ToString(); } private void ConfigureNodeShapes() { graphChart.ClearShapes(); var font = new Font(FontFamily.GenericSansSerif, 12); graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font)); graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font)); } public void UpdateNetwork(VariableInteractionNetwork network) { if (InvokeRequired) { Invoke((Action)UpdateNetwork, network); } else { graphChart.Graph = network; } } #region events protected override void OnContentChanged() { base.OnContentChanged(); var run = Content.First(); var pd = (IRegressionProblemData)run.Parameters["ProblemData"]; var variables = new HashSet(new List(pd.Dataset.DoubleVariables)); impactResultNameComboBox.Items.Clear(); foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) { var m = (DoubleMatrix)result.Value; if (m.RowNames.All(x => variables.Contains(x))) impactResultNameComboBox.Items.Add(result.Key); } qualityResultNameComboBox.Items.Clear(); foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) { qualityResultNameComboBox.Items.Add(result.Key); } if (impactResultNameComboBox.Items.Count > 0) { impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0]; } if (qualityResultNameComboBox.Items.Count > 0) { qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0]; } if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0) NetworkConfigurationChanged(this, EventArgs.Empty); } private void TextBoxValidating(object sender, CancelEventArgs e) { double v; string errorMsg = "Could not parse the entered value. Please input a real number."; var tb = (TextBox)sender; if (!double.TryParse(tb.Text, out v)) { e.Cancel = true; tb.Select(0, tb.Text.Length); // Set the ErrorProvider error with the text to display. this.errorProvider.SetError(tb, errorMsg); errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink; errorProvider.SetIconPadding(tb, -20); } } private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) { var tb = (TextBox)sender; errorProvider.SetError(tb, string.Empty); double impact; if (!double.TryParse(tb.Text, out impact)) { impact = 0.2; } var network = ApplyThreshold(variableInteractionNetwork, impact); graphChart.Graph = network; } private void LayoutConfigurationBoxValidated(object sender, EventArgs e) { var tb = (TextBox)sender; errorProvider.SetError(tb, string.Empty); LayoutConfigurationChanged(sender, e); } private void NetworkConfigurationChanged(object sender, EventArgs e) { var useBest = impactAggregationComboBox.SelectedIndex <= 0; var threshold = impactThresholdTrackBar.Value / 100.0; var qualityResultName = qualityResultNameComboBox.Text; var impactsResultName = impactResultNameComboBox.Text; if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName)) return; var maximization = maximizationCheckBox.Checked; var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest); variableInteractionNetwork = CreateNetwork(impacts); var network = ApplyThreshold(variableInteractionNetwork, threshold); graphChart.Graph = network; } private void LayoutConfigurationChanged(object sender, EventArgs e) { ConstrainedForceDirectedLayout.EdgeRouting routingMode; switch (edgeRoutingComboBox.SelectedIndex) { case 0: routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None; break; case 1: routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline; break; case 2: routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal; break; default: throw new ArgumentException("Invalid edge routing mode."); } var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text); if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.DefaultEdgeLength)) return; graphChart.RoutingMode = routingMode; graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None; graphChart.DefaultEdgeLength = idealEdgeLength; graphChart.Draw(); } private void ControlsEnable(bool enabled) { qualityResultNameComboBox.Enabled = impactResultNameComboBox.Enabled = impactAggregationComboBox.Enabled = impactThresholdTrackBar.Enabled = onlineImpactCalculationButton.Enabled = edgeRoutingComboBox.Enabled = idealEdgeLengthTextBox.Enabled = maximizationCheckBox.Enabled = enabled; } private void onlineImpactCalculationButton_Click(object sender, EventArgs args) { var worker = new BackgroundWorker(); worker.DoWork += (o, e) => { ControlsEnable(false); var impacts = CalculateVariableImpactsOnline(Content, impactAggregationComboBox.SelectedIndex == 0); variableInteractionNetwork = CreateNetwork(impacts); var threshold = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum; graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold); }; worker.RunWorkerCompleted += (o, e) => ControlsEnable(true); worker.RunWorkerAsync(); } private void relayoutGraphButton_Click(object sender, EventArgs e) { graphChart.Draw(); } #endregion private void exportImpactsMatrixButton_Click(object sender, EventArgs e) { var graph = graphChart.Graph; var labels = graph.Vertices.OfType().Select(x => x.Label).ToList(); labels.Sort(); // sort variables alphabetically var matrix = new DoubleMatrix(labels.Count, labels.Count) { RowNames = labels, ColumnNames = labels }; var indexes = labels.Select((x, i) => new { Label = x, Index = i }).ToDictionary(x => x.Label, x => x.Index); var junctions = graph.Vertices.OfType().ToList(); foreach (var jn in junctions) { var target = jn.OutArcs.First().Target.Label; var targetIndex = indexes[target]; foreach (var input in jn.InArcs) { var inputIndex = indexes[input.Source.Label]; var inputImpact = input.Weight; matrix[targetIndex, inputIndex] = inputImpact; } } for (int i = 0; i < labels.Count; ++i) matrix[i, i] = 1; MainFormManager.MainForm.ShowContent(matrix); } private void impactThresholdTrackBar_ValueChanged(object sender, EventArgs e) { var impact = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum; impactThresholdLabel.Text = impact.ToString("N3", CultureInfo.CurrentCulture); var network = ApplyThreshold(variableInteractionNetwork, impact); graphChart.Graph = network; } } }