Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2288_HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks/3.3/VariableInteractionNetwork.cs @ 16864

Last change on this file since 16864 was 16864, checked in by gkronber, 5 years ago

#2288: updated to .NET 4.6.1 and new persistence backend for compatibility with current trunk

File size: 19.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Diagnostics.Eventing.Reader;
26using System.Linq;
27using System.Text;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HEAL.Attic;
33
34namespace HeuristicLab.VariableInteractionNetworks {
35  [Item("VariableInteractionNetwork", "A graph representation of variables and their relationships.")]
36  [StorableType("CD0A2622-6872-4E61-BE01-E6AA9E41A60D")]
37  public class VariableInteractionNetwork : DirectedGraph {
38
39    /// <summary>
40    /// Creates a simple network from a matrix of variable impacts (each row represents a target variable, each column represents an input variable)
41    /// For each target variable not more than one row can be defined (no junction nodes are build, cf. FromNmseAndVariableImpacts(..) for building more complex networks).
42    /// The network is acyclic. Values in the diagonal are ignored.
43    /// The algorithm starts with an empty network and incrementally adds next most relevant input variable for each target variable up to a given threshold
44    /// In each iteration cycles are broken by removing the weakest link.
45    /// </summary>
46    /// <param name="nmse">vector of NMSE values for each target variable</param>
47    /// <param name="variableImpacts">Variable impacts (smaller is lower impact). Row names and columns names should be set</param>
48    /// <param name="nmseThreshold">Threshold for NMSE values. Variables with a NMSE value larger than the threshold are considered as independent variables</param>
49    /// <param name="varImpactThreshold">Threshold for variable impact values. Impacts with a value smaller than the threshold are considered as independent</param>
50    /// <returns></returns>
51    public static VariableInteractionNetwork CreateSimpleNetwork(double[] nmse, DoubleMatrix variableImpacts, double nmseThreshold = 0.2, double varImpactThreshold = 0.0) {
52      if (variableImpacts.Rows != variableImpacts.Columns) throw new ArgumentException();
53      var network = new VariableInteractionNetwork();
54      var targets = new Dictionary<string, double>();
55      string[] varNames = variableImpacts.RowNames.ToArray();
56      if (nmse.Length != varNames.Length) throw new ArgumentException();
57
58      for (int i = 0; i < varNames.Length; i++) {
59        var name = varNames[i];
60        var varVertex = new VariableNetworkNode() {Label = name, Weight = nmse[i]};
61        network.AddVertex(varVertex);
62        if (nmse[i] < nmseThreshold) {
63          targets.Add(name, nmse[i]);
64        }
65      }
66
67      // rel is updated (impacts which are represented in the network are set to zero)
68      var rel = variableImpacts.CloneAsMatrix();
69      // make sure the diagonal is not considered
70      for (int i = 0; i < rel.GetLength(0); i++) rel[i, i] = double.NegativeInfinity;
71
72      var addedArcs = AddArcs(network, rel, varNames, targets, varImpactThreshold);
73      while (addedArcs.Any()) {
74        var cycles = network.FindShortestCycles().ToList();
75        while (cycles.Any()) {
76          // delete weakest link
77          var weakestArc = cycles.SelectMany(cycle => network.ArcsForCycle(cycle)).OrderBy(a => a.Weight).First();
78          network.RemoveArc(weakestArc);
79
80          cycles = network.FindShortestCycles().ToList();
81        }
82
83        addedArcs = AddArcs(network, rel, varNames, targets, varImpactThreshold);
84      }
85
86      return network;
87    }
88
89    private static List<IArc> AddArcs(VariableInteractionNetwork network, double[,] impacts, string[] varNames, Dictionary<string, double> targets, double threshold = 0.0) {
90      var newArcs = new List<IArc>();
91      for (int row = 0; row < impacts.GetLength(0); row++) {
92        if (!targets.ContainsKey(varNames[row])) continue;
93
94        var rowVector = Enumerable.Range(0, impacts.GetLength(0)).Select(col => impacts[row, col]).ToArray();
95        var max = rowVector.Max();
96        if (max > threshold) {
97          var idxOfMax = Array.IndexOf<double>(rowVector, max);
98          impacts[row, idxOfMax] = double.NegativeInfinity;
99          var srcName = varNames[idxOfMax];
100          var dstName = varNames[row];
101          var srcVertex = network.Vertices.Single(v => v.Label == srcName);
102          var dstVertex = network.Vertices.Single(v => v.Label == dstName);
103          var arc = network.AddArc(srcVertex, dstVertex);
104          arc.Weight = max;
105          newArcs.Add(arc);
106        }
107      }
108
109      return newArcs;
110    }
111
112    /// <summary>
113    /// Creates a network from a matrix of variable impacts (each row represents a target variable, each column represents an input variable)
114    /// The network is acyclic. Values in the diagonal are ignored.
115    /// The algorithm starts with an empty network and incrementally adds next most relevant input variable for each target variable up to a given threshold
116    /// In each iteration cycles are broken by removing the weakest link.
117    /// </summary>
118    /// <param name="nmse">vector of NMSE values for each target variable</param>
119    /// <param name="variableImpacts">Variable impacts (smaller is lower impact). Row names and columns names should be set</param>
120    /// <param name="nmseThreshold">Threshold for NMSE values. Variables with a NMSE value larger than the threshold are considered as independent variables</param>
121    /// <param name="varImpactThreshold">Threshold for variable impact values. Impacts with a value smaller than the threshold are considered as independent</param>
122    /// <returns></returns>
123    public static VariableInteractionNetwork FromNmseAndVariableImpacts(double[] nmse, DoubleMatrix variableImpacts, double nmseThreshold = 0.2, double varImpactThreshold = 0.0) {
124      if (variableImpacts.Rows != variableImpacts.Columns) throw new ArgumentException();
125      var network = new VariableInteractionNetwork();
126
127      Dictionary<string, IVertex> name2funVertex = new Dictionary<string, IVertex>(); // store vertexes representing the function for each target so we can easily add incoming arcs later on
128      string[] varNames = variableImpacts.RowNames.ToArray();
129      if (nmse.Length != varNames.Length) throw new ArgumentException();
130
131      for (int i = 0; i < varNames.Length; i++) {
132        var name = varNames[i];
133        var varVertex = new VariableNetworkNode() { Label = name };
134        network.AddVertex(varVertex);
135        if (nmse[i] < nmseThreshold) {
136          var functionVertex = new JunctionNetworkNode() { Label = "f_" + name };
137          name2funVertex.Add(name, functionVertex);
138          network.AddVertex(functionVertex);
139          var predArc = network.AddArc(functionVertex, varVertex);
140          predArc.Weight = double.PositiveInfinity; // never delete arcs from f_x -> x (representing output of a function)
141        }
142      }
143
144      // rel is updated (impacts which are represented in the network are set to zero)
145      var rel = variableImpacts.CloneAsMatrix();
146      // make sure the diagonal is not considered
147      for (int i = 0; i < rel.GetLength(0); i++) rel[i, i] = double.NegativeInfinity;
148
149      var addedArcs = AddArcs(network, rel, varNames, name2funVertex, varImpactThreshold);
150      while (addedArcs.Any()) {
151        var cycles = network.FindShortestCycles().ToList();
152        while (cycles.Any()) {
153          // delete weakest link
154          var weakestArc = cycles.SelectMany(cycle => network.ArcsForCycle(cycle)).OrderBy(a => a.Weight).First();
155          network.RemoveArc(weakestArc);
156
157          cycles = network.FindShortestCycles().ToList();
158        }
159
160        addedArcs = AddArcs(network, rel, varNames, name2funVertex, varImpactThreshold);
161      }
162
163      return network;
164    }
165
166    /// <summary>
167    /// Produces a combined network from two networks.
168    /// The set of nodes of the new network is the union of the node sets of the two input networks.
169    /// The set of edges of the new network is the union of the edge sets of the two input networks.
170    /// Added and removed nodes and edges are marked so that it is possible to visualize the network difference using graphviz
171    /// </summary>
172    /// <returns></returns>
173    public static VariableInteractionNetwork CalculateNetworkDiff(
174      VariableInteractionNetwork from,
175      VariableInteractionNetwork to) {
176      var g = new VariableInteractionNetwork();
177
178      // add nodes which are in both networks
179      foreach (var node in from.Vertices.Intersect(to.Vertices, new VertexLabelComparer())) {
180        g.AddVertex((IVertex)node.Clone());
181      }
182      // add nodes only in from network
183      foreach (var node in from.Vertices.Except(to.Vertices, new VertexLabelComparer())) {
184        var fromVertex = (IVertex)node.Clone();
185        fromVertex.Label += " (removed)";
186        g.AddVertex(fromVertex);
187      }
188      // add nodes only in to network
189      foreach (var node in to.Vertices.Except(from.Vertices, new VertexLabelComparer())) {
190        var fromVertex = (IVertex)node.Clone();
191        fromVertex.Label += " (added)";
192        g.AddVertex(fromVertex);
193      }
194
195      // add edges which are in both networks
196      foreach (var arc in from.Arcs.Intersect(to.Arcs, new ArcComparer())) {
197        g.AddArc(
198          g.Vertices.Single(v => arc.Source.Label == v.Label),
199          g.Vertices.Single(v => arc.Target.Label == v.Label)
200        );
201      }
202      // add edges only in from network
203      foreach (var arc in from.Arcs.Except(to.Arcs, new ArcComparer())) {
204        var fromVertex =
205          g.Vertices.Single(v => v.Label == arc.Source.Label || v.Label == arc.Source.Label + " (removed)");
206        var toVertex =
207          g.Vertices.Single(v => v.Label == arc.Target.Label || v.Label == arc.Target.Label + " (removed)");
208        var newArc = g.AddArc(
209          fromVertex,
210          toVertex
211        );
212        newArc.Label += " (removed)";
213      }
214      // add arcs only in to network
215      foreach (var arc in to.Arcs.Except(from.Arcs, new ArcComparer())) {
216        var fromVertex =
217          g.Vertices.Single(v => v.Label == arc.Source.Label || v.Label == arc.Source.Label + " (added)");
218        var toVertex =
219          g.Vertices.Single(v => v.Label == arc.Target.Label || v.Label == arc.Target.Label + " (added)");
220        var newArc = g.AddArc(
221          fromVertex,
222          toVertex
223        );
224        newArc.Label += " (added)";
225      }
226      return g;
227    }
228
229
230    private static List<IArc> AddArcs(VariableInteractionNetwork network, double[,] impacts, string[] varNames, Dictionary<string, IVertex> name2funVertex, double threshold = 0.0) {
231      var newArcs = new List<IArc>();
232      for (int row = 0; row < impacts.GetLength(0); row++) {
233        if (!name2funVertex.ContainsKey(varNames[row])) continue; // this variable does not have an associated function (considered as independent)
234
235        var rowVector = Enumerable.Range(0, impacts.GetLength(0)).Select(col => impacts[row, col]).ToArray();
236        var max = rowVector.Max();
237        if (max > threshold) {
238          var idxOfMax = Array.IndexOf<double>(rowVector, max);
239          impacts[row, idxOfMax] = double.NegativeInfinity; // edge is not considered anymore
240          var srcName = varNames[idxOfMax];
241          var dstName = varNames[row];
242          var vertex = network.Vertices.Single(v => v.Label == srcName);
243          var arc = network.AddArc(vertex, name2funVertex[dstName]);
244          arc.Weight = max;
245          newArcs.Add(arc);
246        }
247      }
248      return newArcs;
249    }
250
251    [StorableConstructor]
252    public VariableInteractionNetwork(StorableConstructorFlag _) : base(_) { }
253
254    public VariableInteractionNetwork() { }
255
256    protected VariableInteractionNetwork(VariableInteractionNetwork original, Cloner cloner) : base(original, cloner) { }
257
258    public override IDeepCloneable Clone(Cloner cloner) {
259      return new VariableInteractionNetwork(this, cloner);
260    }
261    private IList<IArc> ArcsForCycle(IList<IVertex> cycle) {
262      var res = new List<IArc>();
263      foreach (var t in cycle.Zip(cycle.Skip(1), Tuple.Create)) {
264        var src = t.Item1;
265        var dst = t.Item2;
266        var arc = Arcs.Single(a => a.Source == src && a.Target == dst);
267        res.Add(arc);
268      }
269      return res;
270    }
271
272
273    // finds the shortest cycles in the graph and returns all sub-graphs containing only the nodes / edges within the cycle
274    public IEnumerable<IList<IVertex>> FindShortestCycles() {
275      foreach (var startVariable in base.Vertices.OfType<VariableNetworkNode>()) {
276        foreach (var cycle in FindShortestCycles(startVariable))
277          yield return cycle;
278      }
279    }
280
281    private IEnumerable<IList<IVertex>> FindShortestCycles(VariableNetworkNode startVariable) {
282      var q = new Queue<List<IVertex>>(); // queue of paths
283      var path = new List<IVertex>();
284      var cycles = new List<List<IVertex>>();
285      var maxPathLength = base.Vertices.Count();
286
287      path.Add(startVariable);
288      q.Enqueue(new List<IVertex>(path));
289
290      FindShortestCycles(q, maxPathLength, cycles);
291      return cycles;
292    }
293
294    // TODO efficiency
295    private void FindShortestCycles(Queue<List<IVertex>> queue, int maxPathLength, List<List<IVertex>> cycles) {
296      while (queue.Any()) {
297        var path = queue.Dequeue();
298        if (path.Count > 1 && path.First() == path.Last()) {
299          cycles.Add(new List<IVertex>(path)); // found a cycle
300        } else if (path.Count >= maxPathLength) {
301          continue;
302        } else {
303          var lastVert = path.Last();
304          var neighbours = base.Arcs.Where(a => a.Source == lastVert).Select(a => a.Target);
305          foreach (var neighbour in neighbours) {
306            queue.Enqueue(new List<IVertex>(path.Concat(new IVertex[] { neighbour })));
307          }
308        }
309      }
310    }
311
312    public DoubleMatrix GetWeightsMatrix() {
313      var names = Vertices.OfType<VariableNetworkNode>()
314        .Select(v => v.Label)
315        .OrderBy(s => s, new NaturalStringComparer()).ToArray();
316      var w = new double[names.Length, names.Length];
317
318      var name2idx = new Dictionary<string, int>();
319      for (int i = 0; i < names.Length; i++) {
320        name2idx.Add(names[i], i);
321      }
322
323      foreach (var arc in Arcs) {
324        // only consider arcs going into a junction node
325        var target = arc.Target as JunctionNetworkNode;
326        if (target != null) {
327          var srcVarName = arc.Source.Label;
328          // each function node must have exactly one outgoing arc
329          var dstVarName = arc.Target.OutArcs.Single().Target.Label;
330
331          w[name2idx[dstVarName], name2idx[srcVarName]] = arc.Weight;
332        }
333      }
334
335
336      return new DoubleMatrix(w, names, names);
337    }
338
339    public DoubleMatrix GetSimpleWeightsMatrix() {
340      var names = Vertices.OfType<VariableNetworkNode>()
341        .Select(v => v.Label)
342        .OrderBy(s => s, new NaturalStringComparer()).ToArray();
343      var w = new double[names.Length, names.Length];
344
345      var name2idx = new Dictionary<string, int>();
346      for (int i = 0; i < names.Length; i++) {
347        name2idx.Add(names[i], i);
348      }
349
350      foreach (var arc in Arcs) {
351        if (arc.Target != null) {
352          var srcVarName = arc.Source.Label;
353          var dstVarName = arc.Target.Label;         
354          w[name2idx[dstVarName], name2idx[srcVarName]] = arc.Weight;
355        }
356      }
357
358      return new DoubleMatrix(w, names, names);
359    }
360
361    public string ToGraphVizString() {
362      Func<string, string> NodeAndEdgeColor = (str) =>
363      {
364        if (string.IsNullOrEmpty(str)) return "black";
365        else if (str.Contains("removed")) return "red";
366        else if (str.Contains("added")) return "green";
367        else return "black";
368      };
369
370      var sb = new StringBuilder();
371      sb.AppendLine("digraph {");
372      sb.AppendLine("rankdir=LR");
373      foreach (var v in Vertices.OfType<VariableNetworkNode>()) {
374        sb.AppendFormat("\"{0}\" [shape=oval, color={1}]", v.Label, NodeAndEdgeColor(v.Label)).AppendLine();
375      }
376      foreach (var v in Vertices.OfType<JunctionNetworkNode>()) {
377        sb.AppendFormat("\"{0}\" [shape=box, color={1}]", v.Label, NodeAndEdgeColor(v.Label)).AppendLine();
378      }
379      foreach (var arc in Arcs) {
380        sb.AppendFormat("\"{0}\"->\"{1}\" [color=\"{3}\"]", arc.Source.Label, arc.Target.Label, arc.Label, NodeAndEdgeColor(arc.Label)).AppendLine();
381      }
382      sb.AppendLine("}");
383      return sb.ToString();
384    }
385  }
386
387  public class VertexLabelComparer : IEqualityComparer<IVertex> {
388    public bool Equals(IVertex x, IVertex y) {
389      if (x == null && y == null) return true;
390      if (x != null && y != null) {
391        return x.Label == y.Label;
392      } else return false;
393    }
394
395    public int GetHashCode(IVertex obj) {
396      return obj.Label.GetHashCode();
397    }
398  }
399
400  public class ArcComparer : IEqualityComparer<IArc> {
401    public bool Equals(IArc x, IArc y) {
402      if (x == null && y == null) return true;
403      if (x != null && y != null) {
404        return x.Source.Label == y.Source.Label && x.Target.Label == y.Target.Label;
405      } else return false;
406    }
407
408    public int GetHashCode(IArc obj) {
409      return obj.Source.Label.GetHashCode() ^ obj.Target.Label.GetHashCode();
410    }
411  }
412
413  [Item("VariableNetworkNode", "A graph vertex which represents a symbolic regression variable.")]
414  [StorableType("95E27B45-DD4B-4C32-AC5E-40A4714EA6F7")]
415  public class VariableNetworkNode : Vertex<IDeepCloneable>, INetworkNode {
416    public VariableNetworkNode() {
417      Id = Guid.NewGuid().ToString();
418    }
419
420    public VariableNetworkNode(VariableNetworkNode original, Cloner cloner) : base(original, cloner) {
421      Id = original.Id;
422      Description = original.Description;
423    }
424
425    public override IDeepCloneable Clone(Cloner cloner) {
426      return new VariableNetworkNode(this, cloner);
427    }
428
429    public string Id { get; }
430    public string Description { get; set; }
431  }
432
433  [Item("FunctionNetworkNode", "A graph vertex representing a junction node.")]
434  [StorableType("8D4E55AC-EBF8-49BA-8361-FC51EF3BE990")]
435  public class JunctionNetworkNode : Vertex<IDeepCloneable>, INetworkNode {
436    public JunctionNetworkNode() {
437      Id = Guid.NewGuid().ToString();
438    }
439
440    public JunctionNetworkNode(JunctionNetworkNode original, Cloner cloner) : base(original, cloner) {
441      Id = original.Id;
442      Description = original.Description;
443    }
444
445    public override IDeepCloneable Clone(Cloner cloner) {
446      return new JunctionNetworkNode(this, cloner);
447    }
448
449    public string Id { get; }
450    public string Description { get; set; }
451  }
452}
Note: See TracBrowser for help on using the repository browser.