#region License Information
/* HeuristicLab
* Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Drawing;
using System.Globalization;
using System.Linq;
using System.Text;
using System.Windows.Forms;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Core.Views;
using HeuristicLab.Data;
using HeuristicLab.MainForm;
using HeuristicLab.Optimization;
using HeuristicLab.Problems.DataAnalysis;
using HeuristicLab.Visualization;
using Ellipse = HeuristicLab.Visualization.Ellipse;
using Rectangle = HeuristicLab.Visualization.Rectangle;
namespace HeuristicLab.VariableInteractionNetworks.Views {
[View("Variable Interaction Network")]
[Content(typeof(RunCollection), IsDefaultView = false)]
public sealed partial class RunCollectionVariableInteractionNetworkView : ItemView {
public RunCollectionVariableInteractionNetworkView() {
InitializeComponent();
ConfigureNodeShapes();
}
public new RunCollection Content {
get { return (RunCollection)base.Content; }
set {
if (value != null && value != Content) {
base.Content = value;
}
}
}
private VariableInteractionNetwork variableInteractionNetwork;
private static void AssertSameProblemData(RunCollection runs) {
IDataset dataset = null;
IRegressionProblemData problemData = null;
foreach (var run in runs) {
var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
var ds = solution.ProblemData.Dataset;
if (solution.ProblemData == problemData) continue;
if (ds == dataset) continue;
if (problemData == null) {
problemData = solution.ProblemData;
continue;
}
if (dataset == null) {
dataset = ds;
continue;
}
if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End)
throw new InvalidOperationException("The runs must share the same data.");
if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables))
throw new InvalidOperationException("The runs must share the same data.");
foreach (var v in ds.DoubleVariables) {
var values1 = (IList)ds.GetReadOnlyDoubleValues(v);
var values2 = (IList)dataset.GetReadOnlyDoubleValues(v);
if (values1.Count != values2.Count)
throw new InvalidOperationException("The runs must share the same data.");
if (!values1.SequenceEqual(values2))
throw new InvalidOperationException("The runs must share the same data.");
}
}
}
public static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable runs) {
var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast();
return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData);
}
public static Dictionary, Dictionary>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) {
AssertSameProblemData(runs);
var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution);
var dataset = (Dataset)solution.ProblemData.Dataset;
var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList());
var md = dataset.ToModifiable();
var medians = new Dictionary>();
foreach (var v in dataset.DoubleVariables) {
var median = dataset.GetDoubleValues(v, solution.ProblemData.TrainingIndices).Median();
medians[v] = Enumerable.Repeat(median, originalValues[v].Count).ToList();
}
var targetImpacts = new Dictionary, Dictionary>>();
if (useBest) {
// build network using only the best run for each target
} else {
var groups = runs.GroupBy(run => {
var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable;
});
foreach (var group in groups) {
// calculate average impacts
var averageImpacts = new Dictionary();
solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution);
foreach (var run in group) {
var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution);
DoubleLimit estimationLimits = null;
if (run.Parameters.ContainsKey("EstimationLimits")) {
estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"];
}
var impacts = CalculateImpacts(sol, md, originalValues, medians, estimationLimits);
// var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol).ToDictionary(x => x.Item1, x => x.Item2);
foreach (var pair in impacts) {
if (averageImpacts.ContainsKey(pair.Key))
averageImpacts[pair.Key] += pair.Value;
else {
averageImpacts[pair.Key] = pair.Value;
}
}
}
var count = group.Count();
var keys = averageImpacts.Keys.ToList();
foreach (var v in keys) {
averageImpacts[v] /= count;
}
targetImpacts[solution.ProblemData.TargetVariable] = new Tuple, Dictionary>(group, averageImpacts);
}
}
return targetImpacts;
}
private static Dictionary CalculateImpacts(IRegressionSolution solution, ModifiableDataset dataset,
Dictionary> originalValues, Dictionary> medianValues, DoubleLimit estimationLimits = null) {
var impacts = new Dictionary();
var model = solution.Model;
var pd = solution.ProblemData;
var rows = pd.TrainingIndices.ToList();
var targetValues = pd.Dataset.GetDoubleValues(pd.TargetVariable, rows).ToList();
foreach (var v in pd.AllowedInputVariables) {
dataset.ReplaceVariable(v, medianValues[v]);
var estimatedValues = model.GetEstimatedValues(dataset, rows);
if (estimationLimits != null)
estimatedValues = estimatedValues.LimitToRange(estimationLimits.Lower, estimationLimits.Upper);
OnlineCalculatorError error;
var r = OnlinePearsonsRCalculator.Calculate(targetValues, estimatedValues, out error);
var newQuality = error == OnlineCalculatorError.None ? r * r : double.NaN;
var originalQuality = solution.TrainingRSquared;
impacts[v] = originalQuality - newQuality;
dataset.ReplaceVariable(v, originalValues[v]);
}
return impacts;
}
public static Dictionary, Dictionary>> CalculateVariableImpactsFromRunResults(RunCollection runs,
string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) {
var targets = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
var targetImpacts = new Dictionary, Dictionary>>();
if (useBestRunsPerTarget) {
var bestRunsPerTarget = maximization
? targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).Last())
: targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).First());
foreach (var run in bestRunsPerTarget) {
var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
var target = pd.TargetVariable;
var impacts = (DoubleMatrix)run.Results[impactsResultName];
targetImpacts[target] = new Tuple, Dictionary>(new[] { run }, impacts.RowNames.Select((x, i) => new { Name = x, Index = i }).ToDictionary(x => x.Name, x => impacts[x.Index, 0]));
}
} else {
foreach (var target in targets) {
var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName);
targetImpacts[target.Key] = new Tuple, Dictionary>(target, averageImpacts);
}
}
return targetImpacts;
}
public static VariableInteractionNetwork CreateNetwork(Dictionary, Dictionary>> targetImpacts) {
var nodes = new Dictionary();
var vn = new VariableInteractionNetwork();
foreach (var ti in targetImpacts) {
var target = ti.Key;
var variableImpacts = ti.Value.Item2;
var targetRuns = ti.Value.Item1;
IVertex targetNode;
var variables = variableImpacts.Keys.ToList();
if (variables.Count == 0) continue;
if (!nodes.TryGetValue(target, out targetNode)) {
targetNode = new VariableNetworkNode { Label = target };
vn.AddVertex(targetNode);
nodes[target] = targetNode;
}
IVertex variableNode;
if (variables.Count > 1) {
var variableList = new List(variables) { target };
var junctionLabel = Concatenate(variableList);
IVertex junctionNode;
var sb = new StringBuilder();
if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
var solutionsEnsemble = CreateEnsembleSolution(targetRuns);
junctionNode = new JunctionNetworkNode { Label = solutionsEnsemble.TrainingRSquared.ToString("N3", CultureInfo.CurrentCulture), Data = solutionsEnsemble };
vn.AddVertex(junctionNode);
nodes[junctionLabel] = junctionNode;
sb.AppendLine(junctionNode.Label);
}
IArc arc;
foreach (var v in variables) {
var impact = variableImpacts[v];
if (!nodes.TryGetValue(v, out variableNode)) {
variableNode = new VariableNetworkNode { Label = v };
vn.AddVertex(variableNode);
nodes[v] = variableNode;
}
arc = new Arc(variableNode, junctionNode) { Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture) };
sb.AppendLine(v + ": " + arc.Label);
vn.AddArc(arc);
}
var jcnNode = (JunctionNetworkNode)junctionNode;
var trainingR2 = ((IRegressionSolution)jcnNode.Data).TrainingRSquared;
arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight), Label = trainingR2.ToString("N3", CultureInfo.CurrentCulture) };
vn.AddArc(arc);
} else {
foreach (var v in variables) {
var impact = variableImpacts[v];
if (!nodes.TryGetValue(v, out variableNode)) {
variableNode = new VariableNetworkNode { Label = v };
vn.AddVertex(variableNode);
nodes[v] = variableNode;
}
var arc = new Arc(variableNode, targetNode) {
Weight = impact, Label = impact.ToString("N3", CultureInfo.CurrentCulture)
};
vn.AddArc(arc);
}
}
}
return vn;
}
public static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) {
var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList();
if (!arcs.Any()) return originalNetwork;
var filteredNetwork = new VariableInteractionNetwork();
var cloner = new Cloner();
var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone).Distinct(); // arcs are not cloned
filteredNetwork.AddVertices(vertices);
filteredNetwork.AddArcs(arcs.Select(x => (IArc)x.Clone(cloner)));
var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList();
filteredNetwork.RemoveVertices(unusedJunctions);
var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList();
filteredNetwork.RemoveVertices(orphanedNodes);
return filteredNetwork.Vertices.Any() ? filteredNetwork : originalNetwork;
}
private static double CalculateAverageQuality(RunCollection runs) {
var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
var target = pd.TargetVariable;
var inputs = pd.AllowedInputVariables;
if (!runs.All(x => {
var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
})) {
throw new ArgumentException("All runs must have the same target and inputs.");
}
return runs.Average(x => ((DoubleValue)x.Results["Best training solution quality"]).Value);
}
public static Dictionary CalculateAverageImpacts(RunCollection runs, string resultName) {
var pd = (IRegressionProblemData)runs.First().Parameters["ProblemData"];
var target = pd.TargetVariable;
var inputs = pd.AllowedInputVariables.ToList();
var impacts = inputs.ToDictionary(x => x, x => 0d);
// check if all the runs have the same target and same inputs
if (!runs.All(x => {
var problemData = (IRegressionProblemData)x.Parameters["ProblemData"];
return target == problemData.TargetVariable && inputs.SequenceEqual(problemData.AllowedInputVariables);
})) {
throw new ArgumentException("All runs must have the same target and inputs.");
}
foreach (var run in runs) {
var impactsMatrix = (DoubleMatrix)run.Results[resultName];
int i = 0;
foreach (var v in impactsMatrix.RowNames) {
impacts[v] += impactsMatrix[i, 0];
++i;
}
}
foreach (var v in inputs) {
impacts[v] /= runs.Count;
}
return impacts;
}
private static string Concatenate(IEnumerable strings) {
var sb = new StringBuilder();
foreach (var s in strings) {
sb.Append(s);
}
return sb.ToString();
}
private void ConfigureNodeShapes() {
graphChart.ClearShapes();
var font = new Font(FontFamily.GenericSansSerif, 12);
graphChart.AddShape(typeof(VariableNetworkNode), new LabeledPrimitive(new Ellipse(graphChart.Chart, new PointD(0, 0), new PointD(30, 30), Pens.Black, Brushes.White), "", font));
graphChart.AddShape(typeof(JunctionNetworkNode), new LabeledPrimitive(new Rectangle(graphChart.Chart, new PointD(0, 0), new PointD(15, 15), Pens.Black, Brushes.DarkGray), "", font));
}
#region events
protected override void OnContentChanged() {
base.OnContentChanged();
var run = Content.First();
var pd = (IRegressionProblemData)run.Parameters["ProblemData"];
var variables = new HashSet(new List(pd.Dataset.DoubleVariables));
impactResultNameComboBox.Items.Clear();
foreach (var result in run.Results.Where(x => x.Value is DoubleMatrix)) {
var m = (DoubleMatrix)result.Value;
if (m.RowNames.All(x => variables.Contains(x)))
impactResultNameComboBox.Items.Add(result.Key);
}
qualityResultNameComboBox.Items.Clear();
foreach (var result in run.Results.Where(x => x.Value is DoubleValue)) {
qualityResultNameComboBox.Items.Add(result.Key);
}
if (impactResultNameComboBox.Items.Count > 0) {
impactResultNameComboBox.Text = (string)impactResultNameComboBox.Items[0];
}
if (qualityResultNameComboBox.Items.Count > 0) {
qualityResultNameComboBox.Text = (string)qualityResultNameComboBox.Items[0];
}
if (impactResultNameComboBox.Items.Count > 0 && qualityResultNameComboBox.Items.Count > 0)
NetworkConfigurationChanged(this, EventArgs.Empty);
}
private void TextBoxValidating(object sender, CancelEventArgs e) {
double v;
string errorMsg = "Could not parse the entered value. Please input a real number.";
var tb = (TextBox)sender;
if (!double.TryParse(tb.Text, out v)) {
e.Cancel = true;
tb.Select(0, tb.Text.Length);
// Set the ErrorProvider error with the text to display.
this.errorProvider.SetError(tb, errorMsg);
errorProvider.BlinkStyle = ErrorBlinkStyle.NeverBlink;
errorProvider.SetIconPadding(tb, -20);
}
}
private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) {
var tb = (TextBox)sender;
errorProvider.SetError(tb, string.Empty);
double impact;
if (!double.TryParse(tb.Text, out impact)) {
impact = 0.2;
}
var network = ApplyThreshold(variableInteractionNetwork, impact);
graphChart.Graph = network;
}
private void LayoutConfigurationBoxValidated(object sender, EventArgs e) {
var tb = (TextBox)sender;
errorProvider.SetError(tb, string.Empty);
LayoutConfigurationChanged(sender, e);
}
private void NetworkConfigurationChanged(object sender, EventArgs e) {
var useBest = impactAggregationComboBox.SelectedIndex <= 0;
var threshold = impactThresholdTrackBar.Value / 100.0;
var qualityResultName = qualityResultNameComboBox.Text;
var impactsResultName = impactResultNameComboBox.Text;
if (string.IsNullOrEmpty(qualityResultName) || string.IsNullOrEmpty(impactsResultName))
return;
var maximization = maximizationCheckBox.Checked;
var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest);
variableInteractionNetwork = CreateNetwork(impacts);
var network = ApplyThreshold(variableInteractionNetwork, threshold);
graphChart.Graph = network;
}
private void LayoutConfigurationChanged(object sender, EventArgs e) {
ConstrainedForceDirectedLayout.EdgeRouting routingMode;
switch (edgeRoutingComboBox.SelectedIndex) {
case 0:
routingMode = ConstrainedForceDirectedLayout.EdgeRouting.None;
break;
case 1:
routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Polyline;
break;
case 2:
routingMode = ConstrainedForceDirectedLayout.EdgeRouting.Orthogonal;
break;
default:
throw new ArgumentException("Invalid edge routing mode.");
}
var idealEdgeLength = double.Parse(idealEdgeLengthTextBox.Text);
if (routingMode == graphChart.RoutingMode && idealEdgeLength.IsAlmost(graphChart.DefaultEdgeLength)) return;
graphChart.RoutingMode = routingMode;
graphChart.PerformEdgeRouting = routingMode != ConstrainedForceDirectedLayout.EdgeRouting.None;
graphChart.DefaultEdgeLength = idealEdgeLength;
graphChart.Draw();
}
private void ControlsEnable(bool enabled) {
qualityResultNameComboBox.Enabled
= impactResultNameComboBox.Enabled
= impactAggregationComboBox.Enabled
= impactThresholdTrackBar.Enabled
= onlineImpactCalculationButton.Enabled
= edgeRoutingComboBox.Enabled
= idealEdgeLengthTextBox.Enabled
= maximizationCheckBox.Enabled = enabled;
}
private void onlineImpactCalculationButton_Click(object sender, EventArgs args) {
var worker = new BackgroundWorker();
worker.DoWork += (o, e) => {
ControlsEnable(false);
var impacts = CalculateVariableImpactsOnline(Content, false);
variableInteractionNetwork = CreateNetwork(impacts);
var threshold = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
graphChart.Graph = ApplyThreshold(variableInteractionNetwork, threshold);
};
worker.RunWorkerCompleted += (o, e) => ControlsEnable(true);
worker.RunWorkerAsync();
}
private void relayoutGraphButton_Click(object sender, EventArgs e) {
graphChart.Draw();
}
#endregion
private void exportImpactsMatrixButton_Click(object sender, EventArgs e) {
var graph = graphChart.Graph;
var labels = graph.Vertices.OfType().Select(x => x.Label).ToList();
labels.Sort(); // sort variables alphabetically
var matrix = new DoubleMatrix(labels.Count, labels.Count) { RowNames = labels, ColumnNames = labels };
var indexes = labels.Select((x, i) => new { Label = x, Index = i }).ToDictionary(x => x.Label, x => x.Index);
var junctions = graph.Vertices.OfType().ToList();
foreach (var jn in junctions) {
var target = jn.OutArcs.First().Target.Label;
var targetIndex = indexes[target];
foreach (var input in jn.InArcs) {
var inputIndex = indexes[input.Source.Label];
var inputImpact = input.Weight;
matrix[targetIndex, inputIndex] = inputImpact;
}
}
for (int i = 0; i < labels.Count; ++i) matrix[i, i] = 1;
MainFormManager.MainForm.ShowContent(matrix);
}
private void impactThresholdTrackBar_ValueChanged(object sender, EventArgs e) {
var impact = impactThresholdTrackBar.Minimum + (double)impactThresholdTrackBar.Value / impactThresholdTrackBar.Maximum;
impactThresholdLabel.Text = impact.ToString("N3", CultureInfo.CurrentCulture);
var network = ApplyThreshold(variableInteractionNetwork, impact);
graphChart.Graph = network;
}
}
}