#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.ComponentModel; using System.Drawing; using System.Linq; using System.Text; using System.Windows.Forms; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Core.Views; using HeuristicLab.Data; using HeuristicLab.MainForm; using HeuristicLab.Optimization; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Visualization; using Ellipse = HeuristicLab.Visualization.Ellipse; using Rectangle = HeuristicLab.Visualization.Rectangle; namespace HeuristicLab.VariableInteractionNetworks.Views { [View("Variable Interaction Network")] [Content(typeof(RunCollection), IsDefaultView = false)] public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView { public RunCollectionVariableInteractionNetworkView() { InitializeComponent(); ConfigureNodeShapes(); } public new RunCollection Content { get { return (RunCollection)base.Content; } set { if (value != null && value != Content) { base.Content = value; } } } private VariableInteractionNetwork variableInteractionNetwork; private static void AssertSameProblemData(RunCollection runs) { IDataset dataset = null; IRegressionProblemData problemData = null; foreach (var run in runs) { var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution); var ds = solution.ProblemData.Dataset; if (solution.ProblemData == problemData) continue; if (ds == dataset) continue; if (problemData == null) { problemData = solution.ProblemData; continue; } if (dataset == null) { dataset = ds; continue; } if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End) throw new InvalidOperationException("The runs must share the same data."); if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables)) throw new InvalidOperationException("The runs must share the same data."); foreach (var v in ds.DoubleVariables) { var values1 = (IList)ds.GetReadOnlyDoubleValues(v); var values2 = (IList)dataset.GetReadOnlyDoubleValues(v); if (values1.Count != values2.Count) throw new InvalidOperationException("The runs must share the same data."); if (!values1.SequenceEqual(values2)) throw new InvalidOperationException("The runs must share the same data."); } } } private static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable runs) { var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast(); return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData); } public static Dictionary, Dictionary>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) { AssertSameProblemData(runs); var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution); var dataset = (Dataset)solution.ProblemData.Dataset; var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList()); var md = dataset.ToModifiable(); var medians = new Dictionary>(); foreach (var v in dataset.DoubleVariables) { var median = dataset.GetDoubleValues(v, solution.ProblemData.TrainingIndices).Median(); medians[v] = Enumerable.Repeat(median, originalValues[v].Count).ToList(); } var targetImpacts = new Dictionary, Dictionary>>(); if (useBest) { // build network using only the best run for each target } else { var groups = runs.GroupBy(run => { var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution); return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable; }); foreach (var group in groups) { // calculate average impacts var averageImpacts = new Dictionary(); solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution); foreach (var run in group) { var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution); DoubleLimit estimationLimits = null; if (run.Parameters.ContainsKey("EstimationLimits")) { estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"]; } var impacts = CalculateImpacts(sol, md, originalValues, medians, estimationLimits); // var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol).ToDictionary(x => x.Item1, x => x.Item2); foreach (var pair in impacts) { if (averageImpacts.ContainsKey(pair.Key)) averageImpacts[pair.Key] += pair.Value; else { averageImpacts[pair.Key] = pair.Value; } } } var count = group.Count(); var keys = averageImpacts.Keys.ToList(); foreach (var v in keys) { averageImpacts[v] /= count; } targetImpacts[solution.ProblemData.TargetVariable] = new Tuple, Dictionary>(group, averageImpacts); } } return targetImpacts; } private static Dictionary CalculateImpacts(IRegressionSolution solution, ModifiableDataset dataset, Dictionary> originalValues, Dictionary> medianValues, DoubleLimit estimationLimits = null) { var impacts = new Dictionary(); var model = solution.Model; var pd = solution.ProblemData; var rows = pd.TrainingIndices.ToList(); var targetValues = pd.Dataset.GetDoubleValues(pd.TargetVariable, rows).ToList(); foreach (var v in pd.AllowedInputVariables) { dataset.ReplaceVariable(v, medianValues[v]); var estimatedValues = model.GetEstimatedValues(dataset, rows); if (estimationLimits != null) estimatedValues = estimatedValues.LimitToRange(estimationLimits.Lower, estimationLimits.Upper); OnlineCalculatorError error; var r = OnlinePearsonsRCalculator.Calculate(targetValues, estimatedValues, out error); var newQuality = error == OnlineCalculatorError.None ? r * r : double.NaN; var originalQuality = solution.TrainingRSquared; impacts[v] = originalQuality - newQuality; dataset.ReplaceVariable(v, originalValues[v]); } return impacts; } private static Dictionary, Dictionary>> CalculateVariableImpactsFromRunResults(RunCollection runs, string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) { var targets = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList(); var targetImpacts = new Dictionary, Dictionary>>(); if (useBestRunsPerTarget) { var bestRunsPerTarget = maximization ? targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).Last()) : targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).First()); foreach (var run in bestRunsPerTarget) { var pd = (IRegressionProblemData)run.Parameters["ProblemData"]; var target = pd.TargetVariable; var impacts = (DoubleMatrix)run.Results[impactsResultName]; targetImpacts[target] = new Tuple, Dictionary>(new[] { run }, impacts.RowNames.Select((x, i) => new { Name = x, Index = i }).ToDictionary(x => x.Name, x => impacts[x.Index, 0])); } } else { foreach (var target in targets) { var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName); targetImpacts[target.Key] = new Tuple, Dictionary>(target, averageImpacts); } } return targetImpacts; } private static VariableInteractionNetwork CreateNetwork(Dictionary, Dictionary>> targetImpacts) { var nodes = new Dictionary(); var vn = new VariableInteractionNetwork(); foreach (var ti in targetImpacts) { var target = ti.Key; var variableImpacts = ti.Value.Item2; var targetRuns = ti.Value.Item1; IVertex targetNode; var variables = variableImpacts.Keys.ToList(); if (variables.Count == 0) continue; if (!nodes.TryGetValue(target, out targetNode)) { targetNode = new VariableNetworkNode { Label = target }; vn.AddVertex(targetNode); nodes[target] = targetNode; } IVertex variableNode; if (variables.Count > 1) { var variableList = new List(variables) { target }; var junctionLabel = Concatenate(variableList); IVertex junctionNode; if (!nodes.TryGetValue(junctionLabel, out junctionNode)) { var solutionsEnsemble = CreateEnsembleSolution(targetRuns); junctionNode = new JunctionNetworkNode { Label = string.Empty, Data = solutionsEnsemble }; vn.AddVertex(junctionNode); nodes[junctionLabel] = junctionNode; junctionNode.Label = string.Format("Target quality: {0:0.000}", solutionsEnsemble.TrainingRSquared); } IArc arc; foreach (var v in variables) { var impact = variableImpacts[v]; if (!nodes.TryGetValue(v, out variableNode)) { variableNode = new VariableNetworkNode { Label = v }; vn.AddVertex(variableNode); nodes[v] = variableNode; } arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = string.Format("Impact: {0:0.000}", impact) }; vn.AddArc(arc); } var trainingR2 = ((IRegressionSolution)((JunctionNetworkNode)junctionNode).Data).TrainingRSquared; arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = string.Format("Quality: {0:0.000}", trainingR2) }; vn.AddArc(arc); } else { foreach (var v in variables) { var impact = variableImpacts[v]; if (!nodes.TryGetValue(v, out variableNode)) { variableNode = new VariableNetworkNode { Label = v }; vn.AddVertex(variableNode); nodes[v] = variableNode; } var arc = new Arc(variableNode, targetNode) { Weight = impact, Label = string.Format("Impact: {0:0.000}", impact) }; vn.AddArc(arc); } } } return vn; } private static double CalculateAverageQuality(RunCollection runs) { var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"]; var target = pd.TargetVariable; var inputs = pd.AllowedInputVariables; if (!runs.All(x => { var problemData = (IRegressionProblemData)x.Parameters["ProblemData"]; return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables); })) { throw new ArgumentException("All runs must have the same target and inputs."); } return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value); } private static Dictionary CalculateAverageImpacts(RunCollection runs, string resultName) { var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"]; var target = pd.TargetVariable; var inputs = pd.AllowedInputVariables.ToList(); var impacts = inputs.ToDictionary(x => x, x => 0d); // check if all the runs have the same target and same inputs if (!runs.All(x => { var problemData = (IRegressionProblemData)x.Parameters["ProblemData"]; return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables); })) { throw new ArgumentException("All runs must have the same target and inputs."); } foreach (var run in runs) { var impactsMatrix = (DoubleMatrix)run.Results[resultName]; int i = 0; foreach (var v in impactsMatrix.RowNames) { impacts[v] += impactsMatrix[i, 0]; ++i; } } foreach (var v in inputs) { impacts[v] /= runs.Count; } return impacts; } private static string Concatenate(IEnumerable strings) { var sb = new StringBuilder(); foreach (var s in strings) { sb.Append(s); } return sb.ToString(); } private void ConfigureNodeShapes() { graphChart.ClearShapes(); var font = new Font(FontFamily.GenericSansSerif, 12); graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font)); graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font)); } #region events protected override void OnContentChanged() { base.OnContentChanged(); var run = Content.First(); var pd = (IRegressionProblemData)run.Parameters["ProblemData"]; var variables = new HashSet(new List(pd.Dataset.DoubleVariables)); impactResultNameComboBox.Items.Clear(); foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) { var m = (DoubleMatrix)result.Value; if (m.RowNames.All(x => variables.Contains(x))) impactResultNameComboBox.Items.Add(result.Key); } qualityResultNameComboBox.Items.Clear(); foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) { qualityResultNameComboBox.Items.Add(result.Key); } if (impactResultNameComboBox.Items.Count > 0) { impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0]; } if (qualityResultNameComboBox.Items.Count > 0) { qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0]; } if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0) NetworkConfigurationChanged(this, EventArgs.Empty); } private void TextBoxValidating(object sender, CancelEventArgs e) { double v; string errorMsg = "Could not parse the entered value. Please input a real number."; var tb = (TextBox)sender; if (!double.TryParse(tb.Text, out v)) { e.Cancel = true; tb.Select(0, tb.Text.Length); // Set the ErrorProvider error with the text to display. this.errorProvider.SetError(tb, errorMsg); errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink; errorProvider.SetIconPadding(tb, -20); } } private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) { var tb = (TextBox)sender; errorProvider.SetError(tb, string.Empty); var network = ApplyThreshold(variableInteractionNetwork, double.Parse(tb.Text)); graphChart.Graph = network; } private static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) { var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList(); if (!arcs.Any()) return originalNetwork; var filteredNetwork = new VariableInteractionNetwork(); var cloner = new Cloner(); var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned filteredNetwork.AddVertices(vertices); filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner))); var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList(); filteredNetwork.RemoveVertices(unusedJunctions); var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList(); filteredNetwork.RemoveVertices(orphanedNodes); return filteredNetwork; } private void LayoutConfigurationBoxValidated(object sender, EventArgs e) { var tb = (TextBox)sender; errorProvider.SetError(tb, string.Empty); LayoutConfigurationChanged(sender, e); } private void NetworkConfigurationChanged(object sender, EventArgs e) { var useBest = impactAggregationComboBox.SelectedIndex <= 0; var threshold = double.Parse(impactThresholdTextBox.Text); var qualityResultName = qualityResultNameComboBox.Text; var impactsResultName = impactResultNameComboBox.Text; if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName)) return; var maximization = maximizationCheckBox.Checked; var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest); variableInteractionNetwork = CreateNetwork(impacts); var network = ApplyThreshold(variableInteractionNetwork, threshold); graphChart.Graph = network; } private void LayoutConfigurationChanged(object sender, EventArgs e) { ConstrainedForceDirectedLayout.EdgeRouting routingMode; switch (edgeRoutingComboBox.SelectedIndex) { case 0: routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None; break; case 1: routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline; break; case 2: routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal; break; default: throw new ArgumentException("Invalid edge routing mode."); } var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text); if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.IdealEdgeLength)) return; graphChart.RoutingMode = routingMode; graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None; graphChart.IdealEdgeLength = idealEdgeLength; graphChart.Draw(); } private void onlineImpactCalculationButton_Click(object sender, EventArgs args) { var button = (Button)sender; var worker = new BackgroundWorker(); worker.DoWork += (o, e) => { button.Enabled = false; var impacts = CalculateVariableImpactsOnline(Content, false); variableInteractionNetwork = CreateNetwork(impacts); var threshold = double.Parse(impactThresholdTextBox.Text); graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold); }; worker.RunWorkerCompleted += (o, e) => button.Enabled = true; worker.RunWorkerAsync(); } #endregion } }