Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/18/16 13:35:50 (8 years ago)
Author:
bburlacu
Message:

#2288: Refactor the RunCollectionVariableInteractionNetworkView and add online impact calculation (optimized method inside the view performs faster than the equivalent method provided by the RegressionSolutionVariableImpactsCalculator).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks.Views/3.3/RunCollectionVariableInteractionNetworkView.cs

    r13727 r13773  
    3535using HeuristicLab.Problems.DataAnalysis;
    3636using HeuristicLab.Visualization;
     37using Ellipse = HeuristicLab.Visualization.Ellipse;
    3738using Rectangle = HeuristicLab.Visualization.Rectangle;
    3839
     
    5657    }
    5758
    58     private static VariableInteractionNetwork BuildNetworkFromSolutionQualities(RunCollection runs, double threshold, bool useBestRunsPerTarget = false) {
    59       var nodes = new Dictionary<string, IVertex>();
    60       var vn = new VariableInteractionNetwork();
     59    private VariableInteractionNetwork variableInteractionNetwork;
     60
     61    private static void AssertSameProblemData(RunCollection runs) {
     62      IDataset dataset = null;
     63      IRegressionProblemData problemData = null;
     64      foreach (var run in runs) {
     65        var solution = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
     66        var ds = solution.ProblemData.Dataset;
     67
     68        if (solution.ProblemData == problemData) continue;
     69        if (ds == dataset) continue;
     70        if (problemData == null) {
     71          problemData = solution.ProblemData;
     72          continue;
     73        }
     74        if (dataset == null) {
     75          dataset = ds;
     76          continue;
     77        }
     78
     79        if (problemData.TrainingPartition.Start != solution.ProblemData.TrainingPartition.Start || problemData.TrainingPartition.End != solution.ProblemData.TrainingPartition.End)
     80          throw new InvalidOperationException("The runs must share the same data.");
     81
     82        if (!ds.DoubleVariables.SequenceEqual(dataset.DoubleVariables))
     83          throw new InvalidOperationException("The runs must share the same data.");
     84
     85        foreach (var v in ds.DoubleVariables) {
     86          var values1 = (IList<double>)ds.GetReadOnlyDoubleValues(v);
     87          var values2 = (IList<double>)dataset.GetReadOnlyDoubleValues(v);
     88
     89          if (values1.Count != values2.Count)
     90            throw new InvalidOperationException("The runs must share the same data.");
     91
     92          if (!values1.SequenceEqual(values2))
     93            throw new InvalidOperationException("The runs must share the same data.");
     94        }
     95      }
     96    }
     97
     98    private static RegressionEnsembleSolution CreateEnsembleSolution(IEnumerable<IRun> runs) {
     99      var solutions = runs.Select(x => x.Results.Values.Single(v => v is IRegressionSolution)).Cast<IRegressionSolution>();
     100      return new RegressionEnsembleSolution(new RegressionEnsembleModel(solutions.Select(x => x.Model)), solutions.First().ProblemData);
     101    }
     102
     103    public static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsOnline(RunCollection runs, bool useBest) {
     104      AssertSameProblemData(runs);
     105      var solution = (IRegressionSolution)runs.First().Results.Values.Single(x => x is IRegressionSolution);
     106      var dataset = (Dataset)solution.ProblemData.Dataset;
     107      var originalValues = dataset.DoubleVariables.ToDictionary(x => x, x => dataset.GetReadOnlyDoubleValues(x).ToList());
     108      var md = dataset.ToModifiable();
     109      var medians = new Dictionary<string, List<double>>();
     110      foreach (var v in dataset.DoubleVariables) {
     111        var median = dataset.GetDoubleValues(v, solution.ProblemData.TrainingIndices).Median();
     112        medians[v] = Enumerable.Repeat(median, originalValues[v].Count).ToList();
     113      }
     114
     115      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
     116
     117      if (useBest) {
     118        // build network using only the best run for each target
     119      } else {
     120        var groups = runs.GroupBy(run => {
     121          var sol = (IRegressionSolution)run.Results.Values.Single(x => x is IRegressionSolution);
     122          return Concatenate(sol.ProblemData.AllowedInputVariables) + sol.ProblemData.TargetVariable;
     123        });
     124
     125        foreach (var group in groups) {
     126          // calculate average impacts
     127          var averageImpacts = new Dictionary<string, double>();
     128          solution = (IRegressionSolution)group.First().Results.Values.Single(x => x is IRegressionSolution);
     129          foreach (var run in group) {
     130            var sol = (IRegressionSolution)run.Results.Values.Single(v => v is IRegressionSolution);
     131
     132            DoubleLimit estimationLimits = null;
     133            if (run.Parameters.ContainsKey("EstimationLimits")) {
     134              estimationLimits = (DoubleLimit)run.Parameters["EstimationLimits"];
     135            }
     136            var impacts = CalculateImpacts(sol, md, originalValues, medians, estimationLimits);
     137            //            var impacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(sol).ToDictionary(x => x.Item1, x => x.Item2);
     138            foreach (var pair in impacts) {
     139              if (averageImpacts.ContainsKey(pair.Key))
     140                averageImpacts[pair.Key] += pair.Value;
     141              else {
     142                averageImpacts[pair.Key] = pair.Value;
     143              }
     144            }
     145          }
     146          var count = group.Count();
     147          var keys = averageImpacts.Keys.ToList();
     148          foreach (var v in keys) {
     149            averageImpacts[v] /= count;
     150          }
     151
     152          targetImpacts[solution.ProblemData.TargetVariable] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(group, averageImpacts);
     153        }
     154      }
     155      return targetImpacts;
     156    }
     157
     158    private static Dictionary<string, double> CalculateImpacts(IRegressionSolution solution, ModifiableDataset dataset,
     159      Dictionary<string, List<double>> originalValues, Dictionary<string, List<double>> medianValues, DoubleLimit estimationLimits = null) {
     160      var impacts = new Dictionary<string, double>();
     161
     162      var model = solution.Model;
     163      var pd = solution.ProblemData;
     164
     165      var rows = pd.TrainingIndices.ToList();
     166      var targetValues = pd.Dataset.GetDoubleValues(pd.TargetVariable, rows).ToList();
     167
     168
     169      foreach (var v in pd.AllowedInputVariables) {
     170        dataset.ReplaceVariable(v, medianValues[v]);
     171
     172        var estimatedValues = model.GetEstimatedValues(dataset, rows);
     173        if (estimationLimits != null)
     174          estimatedValues = estimatedValues.LimitToRange(estimationLimits.Lower, estimationLimits.Upper);
     175
     176        OnlineCalculatorError error;
     177        var r = OnlinePearsonsRCalculator.Calculate(targetValues, estimatedValues, out error);
     178        var newQuality = error == OnlineCalculatorError.None ? r * r : double.NaN;
     179        var originalQuality = solution.TrainingRSquared;
     180        impacts[v] = originalQuality - newQuality;
     181
     182        dataset.ReplaceVariable(v, originalValues[v]);
     183      }
     184      return impacts;
     185    }
     186
     187    private static Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> CalculateVariableImpactsFromRunResults(RunCollection runs,
     188      string qualityResultName, bool maximization, string impactsResultName, bool useBestRunsPerTarget = false) {
    61189      var targets = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
    62       var targetQualities = new Dictionary<string, double>();
    63       var targetInputs = new Dictionary<string, List<string>>();
    64 
     190      var targetImpacts = new Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>>();
    65191      if (useBestRunsPerTarget) {
    66         foreach (var target in targets) {
    67           var bestRun = target.OrderBy(x => ((DoubleValue)x.Results["Best training solution quality"]).Value).First();
    68           var bestQuality = ((DoubleValue)bestRun.Results["Best training solution quality"]).Value;
    69           var pd = (IRegressionProblemData)bestRun.Parameters["ProblemData"];
    70           if (threshold > bestQuality) continue; // skip if quality is below the treshold
    71           targetQualities[target.Key] = bestQuality;
    72           targetInputs[target.Key] = pd.AllowedInputVariables.ToList();
    73         }
    74       } else {
    75         foreach (var target in targets) {
    76           var avgQuality = CalculateAverageQuality(new RunCollection(target));
    77           if (threshold > avgQuality) continue;
    78           targetQualities[target.Key] = avgQuality;
    79           var pd = (IRegressionProblemData)target.First().Parameters["ProblemData"];
    80           targetInputs[target.Key] = pd.AllowedInputVariables.ToList();
    81         }
    82       }
    83 
    84       foreach (var ti in targetQualities) {
    85         var target = ti.Key;
    86         var variables = targetInputs[ti.Key];
    87         var quality = ti.Value;
    88         IVertex targetNode;
    89 
    90         if (!nodes.TryGetValue(target, out targetNode)) {
    91           targetNode = new VariableNetworkNode { Label = target };
    92           vn.AddVertex(targetNode);
    93           nodes[target] = targetNode;
    94         }
    95 
    96         IVertex variableNode;
    97         if (variables.Count > 0) {
    98           var variableList = new List<string>(variables);
    99           variableList.Add(target);
    100           var junctionLabel = Concatenate(variableList);
    101           IVertex junctionNode;
    102           if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
    103             junctionNode = new JunctionNetworkNode { Label = string.Empty };
    104             vn.AddVertex(junctionNode);
    105             nodes[junctionLabel] = junctionNode;
    106           }
    107           IArc arc;
    108           foreach (var v in variables) {
    109             var impact = quality;
    110             if (!nodes.TryGetValue(v, out variableNode)) {
    111               variableNode = new VariableNetworkNode { Label = v };
    112               vn.AddVertex(variableNode);
    113               nodes[v] = variableNode;
    114             }
    115             arc = new Arc(variableNode, junctionNode) { Weight = impact };
    116             vn.AddArc(arc);
    117           }
    118           arc = new Arc(junctionNode, targetNode) { Weight = junctionNode.InArcs.Sum(x => x.Weight) };
    119           vn.AddArc(arc);
    120         } else {
    121           foreach (var v in variables) {
    122             var impact = quality;
    123             if (!nodes.TryGetValue(v, out variableNode)) {
    124               variableNode = new VariableNetworkNode { Label = v };
    125               vn.AddVertex(variableNode);
    126               nodes[v] = variableNode;
    127             }
    128             var arc = new Arc(variableNode, targetNode) { Weight = impact };
    129             vn.AddArc(arc);
    130           }
    131         }
    132       }
    133 
    134       return vn;
    135     }
    136 
    137     private static VariableInteractionNetwork BuildNetworkFromVariableImpacts(RunCollection runs, string qualityResultName, bool maximization, string impactsResultName, double threshold, bool useBestRunsPerTarget = false) {
    138       var nodes = new Dictionary<string, IVertex>();
    139       var vn = new VariableInteractionNetwork();
    140       var targets = runs.GroupBy(x => ((IRegressionProblemData)x.Parameters["ProblemData"]).TargetVariable).ToList();
    141 
    142       var targetImpacts = new Dictionary<string, Dictionary<string, double>>();
    143 
    144       if (useBestRunsPerTarget) {
    145         var bestRunsPerTarget = maximization ?
    146           targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).Last()) :
    147           targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).First());
     192        var bestRunsPerTarget = maximization
     193          ? targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).Last())
     194          : targets.Select(x => x.OrderBy(y => ((DoubleValue)y.Results[qualityResultName]).Value).First());
    148195
    149196        foreach (var run in bestRunsPerTarget) {
     
    151198          var target = pd.TargetVariable;
    152199          var impacts = (DoubleMatrix)run.Results[impactsResultName];
    153           targetImpacts[target] = impacts.RowNames.Select((x, i) => new { Name = x, Index = i }).ToDictionary(x => x.Name, x => impacts[x.Index, 0]);
     200          targetImpacts[target] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(new[] { run }, impacts.RowNames.Select((x, i) => new { Name = x, Index = i }).ToDictionary(x => x.Name, x => impacts[x.Index, 0]));
    154201        }
    155202      } else {
    156203        foreach (var target in targets) {
    157204          var averageImpacts = CalculateAverageImpacts(new RunCollection(target), impactsResultName);
    158           targetImpacts[target.Key] = averageImpacts;
    159         }
    160       }
    161 
     205          targetImpacts[target.Key] = new Tuple<IEnumerable<IRun>, Dictionary<string, double>>(target, averageImpacts);
     206        }
     207      }
     208      return targetImpacts;
     209    }
     210
     211    private static VariableInteractionNetwork CreateNetwork(Dictionary<string, Tuple<IEnumerable<IRun>, Dictionary<string, double>>> targetImpacts) {
     212      var nodes = new Dictionary<string, IVertex>();
     213      var vn = new VariableInteractionNetwork();
    162214      foreach (var ti in targetImpacts) {
    163215        var target = ti.Key;
    164         var variableImpacts = ti.Value;
     216        var variableImpacts = ti.Value.Item2;
     217        var targetRuns = ti.Value.Item1;
    165218        IVertex targetNode;
    166219
    167         var variables = variableImpacts.Keys.Where(x => variableImpacts[x] >= threshold).ToList();
     220        var variables = variableImpacts.Keys.ToList();
    168221        if (variables.Count == 0) continue;
    169222
     
    180233          IVertex junctionNode;
    181234          if (!nodes.TryGetValue(junctionLabel, out junctionNode)) {
    182             junctionNode = new JunctionNetworkNode { Label = string.Empty };
     235            junctionNode = new JunctionNetworkNode { Label = string.Empty, Data = CreateEnsembleSolution(targetRuns) };
    183236            vn.AddVertex(junctionNode);
    184237            nodes[junctionLabel] = junctionNode;
     
    315368    }
    316369
    317     private void NetworkConfigurationBoxValidated(object sender, EventArgs e) {
     370    private void ImpactThresholdTextBoxValidated(object sender, EventArgs e) {
    318371      var tb = (TextBox)sender;
    319372      errorProvider.SetError(tb, string.Empty);
    320       NetworkConfigurationChanged(sender, e);
     373      var network = ApplyThreshold(variableInteractionNetwork, double.Parse(tb.Text));
     374      graphChart.Graph = network;
     375    }
     376
     377    private static VariableInteractionNetwork ApplyThreshold(VariableInteractionNetwork originalNetwork, double threshold) {
     378      var arcs = originalNetwork.Arcs.Where(x => x.Weight >= threshold).ToList();
     379      if (!arcs.Any()) return originalNetwork;
     380      var filteredNetwork = new VariableInteractionNetwork();
     381      var cloner = new Cloner();
     382      var vertices = arcs.SelectMany(x => new[] { x.Source, x.Target }).Select(cloner.Clone); // arcs are not cloned
     383      filteredNetwork.AddVertices(vertices);
     384      foreach (var arc in arcs) {
     385        var source = cloner.Clone(arc.Source);
     386        var target = cloner.Clone(arc.Target);
     387        filteredNetwork.AddArc(source, target);
     388      }
     389      var unusedJunctions = filteredNetwork.Vertices.Where(x => x.InDegree == 0 && x is JunctionNetworkNode).ToList();
     390      filteredNetwork.RemoveVertices(unusedJunctions);
     391      var orphanedNodes = filteredNetwork.Vertices.Where(x => x.Degree == 0).ToList();
     392      filteredNetwork.RemoveVertices(orphanedNodes);
     393      return filteredNetwork;
    321394    }
    322395
     
    335408        return;
    336409      var maximization = maximizationCheckBox.Checked;
    337       var network = BuildNetworkFromVariableImpacts(Content, qualityResultName, maximization, impactsResultName, threshold, useBest);
    338       if (network.Vertices.Any())
    339         graphChart.Graph = network;
     410      var impacts = CalculateVariableImpactsFromRunResults(Content, qualityResultName, maximization, impactsResultName, useBest);
     411      variableInteractionNetwork = CreateNetwork(impacts);
     412      var network = ApplyThreshold(variableInteractionNetwork, threshold);
     413      graphChart.Graph = network;
    340414    }
    341415
     
    362436      graphChart.Draw();
    363437    }
     438
     439    private void onlineImpactCalculationButton_Click(object sender, EventArgs args) {
     440      var button = (Button)sender;
     441      var worker = new BackgroundWorker();
     442      worker.DoWork += (o, e) => {
     443        button.Enabled = false;
     444        var impacts = CalculateVariableImpactsOnline(Content, false);
     445        var network = CreateNetwork(impacts);
     446        var threshold = double.Parse(impactThresholdTextBox.Text);
     447        graphChart.Graph = ApplyThreshold(network, threshold);
     448      };
     449      worker.RunWorkerCompleted += (o, e) => button.Enabled = true;
     450      worker.RunWorkerAsync();
     451    }
    364452    #endregion
    365453  }
Note: See TracChangeset for help on using the changeset viewer.