Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2288_HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks/3.3/VariableInteractionNetwork.cs @ 16752

Last change on this file since 16752 was 16497, checked in by jzenisek, 6 years ago

#2288:

  • added possibility to create simple networks (only one input var set per target var, i.e. without junction nodes)
  • fixed minor enumeration bug
  • enabled network view updates
File size: 19.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Diagnostics.Eventing.Reader;
26using System.Linq;
27using System.Text;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.VariableInteractionNetworks {
34  [Item("VariableInteractionNetwork", "A graph representation of variables and their relationships.")]
35  [StorableClass]
36  public class VariableInteractionNetwork : DirectedGraph {
37
38    /// <summary>
39    /// Creates a simple network from a matrix of variable impacts (each row represents a target variable, each column represents an input variable)
40    /// For each target variable not more than one row can be defined (no junction nodes are build, cf. FromNmseAndVariableImpacts(..) for building more complex networks).
41    /// The network is acyclic. Values in the diagonal are ignored.
42    /// The algorithm starts with an empty network and incrementally adds next most relevant input variable for each target variable up to a given threshold
43    /// In each iteration cycles are broken by removing the weakest link.
44    /// </summary>
45    /// <param name="nmse">vector of NMSE values for each target variable</param>
46    /// <param name="variableImpacts">Variable impacts (smaller is lower impact). Row names and columns names should be set</param>
47    /// <param name="nmseThreshold">Threshold for NMSE values. Variables with a NMSE value larger than the threshold are considered as independent variables</param>
48    /// <param name="varImpactThreshold">Threshold for variable impact values. Impacts with a value smaller than the threshold are considered as independent</param>
49    /// <returns></returns>
50    public static VariableInteractionNetwork CreateSimpleNetwork(double[] nmse, DoubleMatrix variableImpacts, double nmseThreshold = 0.2, double varImpactThreshold = 0.0) {
51      if (variableImpacts.Rows != variableImpacts.Columns) throw new ArgumentException();
52      var network = new VariableInteractionNetwork();
53      var targets = new Dictionary<string, double>();
54      string[] varNames = variableImpacts.RowNames.ToArray();
55      if (nmse.Length != varNames.Length) throw new ArgumentException();
56
57      for (int i = 0; i < varNames.Length; i++) {
58        var name = varNames[i];
59        var varVertex = new VariableNetworkNode() {Label = name, Weight = nmse[i]};
60        network.AddVertex(varVertex);
61        if (nmse[i] < nmseThreshold) {
62          targets.Add(name, nmse[i]);
63        }
64      }
65
66      // rel is updated (impacts which are represented in the network are set to zero)
67      var rel = variableImpacts.CloneAsMatrix();
68      // make sure the diagonal is not considered
69      for (int i = 0; i < rel.GetLength(0); i++) rel[i, i] = double.NegativeInfinity;
70
71      var addedArcs = AddArcs(network, rel, varNames, targets, varImpactThreshold);
72      while (addedArcs.Any()) {
73        var cycles = network.FindShortestCycles().ToList();
74        while (cycles.Any()) {
75          // delete weakest link
76          var weakestArc = cycles.SelectMany(cycle => network.ArcsForCycle(cycle)).OrderBy(a => a.Weight).First();
77          network.RemoveArc(weakestArc);
78
79          cycles = network.FindShortestCycles().ToList();
80        }
81
82        addedArcs = AddArcs(network, rel, varNames, targets, varImpactThreshold);
83      }
84
85      return network;
86    }
87
88    private static List<IArc> AddArcs(VariableInteractionNetwork network, double[,] impacts, string[] varNames, Dictionary<string, double> targets, double threshold = 0.0) {
89      var newArcs = new List<IArc>();
90      for (int row = 0; row < impacts.GetLength(0); row++) {
91        if (!targets.ContainsKey(varNames[row])) continue;
92
93        var rowVector = Enumerable.Range(0, impacts.GetLength(0)).Select(col => impacts[row, col]).ToArray();
94        var max = rowVector.Max();
95        if (max > threshold) {
96          var idxOfMax = Array.IndexOf<double>(rowVector, max);
97          impacts[row, idxOfMax] = double.NegativeInfinity;
98          var srcName = varNames[idxOfMax];
99          var dstName = varNames[row];
100          var srcVertex = network.Vertices.Single(v => v.Label == srcName);
101          var dstVertex = network.Vertices.Single(v => v.Label == dstName);
102          var arc = network.AddArc(srcVertex, dstVertex);
103          arc.Weight = max;
104          newArcs.Add(arc);
105        }
106      }
107
108      return newArcs;
109    }
110
111    /// <summary>
112    /// Creates a network from a matrix of variable impacts (each row represents a target variable, each column represents an input variable)
113    /// The network is acyclic. Values in the diagonal are ignored.
114    /// The algorithm starts with an empty network and incrementally adds next most relevant input variable for each target variable up to a given threshold
115    /// In each iteration cycles are broken by removing the weakest link.
116    /// </summary>
117    /// <param name="nmse">vector of NMSE values for each target variable</param>
118    /// <param name="variableImpacts">Variable impacts (smaller is lower impact). Row names and columns names should be set</param>
119    /// <param name="nmseThreshold">Threshold for NMSE values. Variables with a NMSE value larger than the threshold are considered as independent variables</param>
120    /// <param name="varImpactThreshold">Threshold for variable impact values. Impacts with a value smaller than the threshold are considered as independent</param>
121    /// <returns></returns>
122    public static VariableInteractionNetwork FromNmseAndVariableImpacts(double[] nmse, DoubleMatrix variableImpacts, double nmseThreshold = 0.2, double varImpactThreshold = 0.0) {
123      if (variableImpacts.Rows != variableImpacts.Columns) throw new ArgumentException();
124      var network = new VariableInteractionNetwork();
125
126      Dictionary<string, IVertex> name2funVertex = new Dictionary<string, IVertex>(); // store vertexes representing the function for each target so we can easily add incoming arcs later on
127      string[] varNames = variableImpacts.RowNames.ToArray();
128      if (nmse.Length != varNames.Length) throw new ArgumentException();
129
130      for (int i = 0; i < varNames.Length; i++) {
131        var name = varNames[i];
132        var varVertex = new VariableNetworkNode() { Label = name };
133        network.AddVertex(varVertex);
134        if (nmse[i] < nmseThreshold) {
135          var functionVertex = new JunctionNetworkNode() { Label = "f_" + name };
136          name2funVertex.Add(name, functionVertex);
137          network.AddVertex(functionVertex);
138          var predArc = network.AddArc(functionVertex, varVertex);
139          predArc.Weight = double.PositiveInfinity; // never delete arcs from f_x -> x (representing output of a function)
140        }
141      }
142
143      // rel is updated (impacts which are represented in the network are set to zero)
144      var rel = variableImpacts.CloneAsMatrix();
145      // make sure the diagonal is not considered
146      for (int i = 0; i < rel.GetLength(0); i++) rel[i, i] = double.NegativeInfinity;
147
148      var addedArcs = AddArcs(network, rel, varNames, name2funVertex, varImpactThreshold);
149      while (addedArcs.Any()) {
150        var cycles = network.FindShortestCycles().ToList();
151        while (cycles.Any()) {
152          // delete weakest link
153          var weakestArc = cycles.SelectMany(cycle => network.ArcsForCycle(cycle)).OrderBy(a => a.Weight).First();
154          network.RemoveArc(weakestArc);
155
156          cycles = network.FindShortestCycles().ToList();
157        }
158
159        addedArcs = AddArcs(network, rel, varNames, name2funVertex, varImpactThreshold);
160      }
161
162      return network;
163    }
164
165    /// <summary>
166    /// Produces a combined network from two networks.
167    /// The set of nodes of the new network is the union of the node sets of the two input networks.
168    /// The set of edges of the new network is the union of the edge sets of the two input networks.
169    /// Added and removed nodes and edges are marked so that it is possible to visualize the network difference using graphviz
170    /// </summary>
171    /// <returns></returns>
172    public static VariableInteractionNetwork CalculateNetworkDiff(
173      VariableInteractionNetwork from,
174      VariableInteractionNetwork to) {
175      var g = new VariableInteractionNetwork();
176
177      // add nodes which are in both networks
178      foreach (var node in from.Vertices.Intersect(to.Vertices, new VertexLabelComparer())) {
179        g.AddVertex((IVertex)node.Clone());
180      }
181      // add nodes only in from network
182      foreach (var node in from.Vertices.Except(to.Vertices, new VertexLabelComparer())) {
183        var fromVertex = (IVertex)node.Clone();
184        fromVertex.Label += " (removed)";
185        g.AddVertex(fromVertex);
186      }
187      // add nodes only in to network
188      foreach (var node in to.Vertices.Except(from.Vertices, new VertexLabelComparer())) {
189        var fromVertex = (IVertex)node.Clone();
190        fromVertex.Label += " (added)";
191        g.AddVertex(fromVertex);
192      }
193
194      // add edges which are in both networks
195      foreach (var arc in from.Arcs.Intersect(to.Arcs, new ArcComparer())) {
196        g.AddArc(
197          g.Vertices.Single(v => arc.Source.Label == v.Label),
198          g.Vertices.Single(v => arc.Target.Label == v.Label)
199        );
200      }
201      // add edges only in from network
202      foreach (var arc in from.Arcs.Except(to.Arcs, new ArcComparer())) {
203        var fromVertex =
204          g.Vertices.Single(v => v.Label == arc.Source.Label || v.Label == arc.Source.Label + " (removed)");
205        var toVertex =
206          g.Vertices.Single(v => v.Label == arc.Target.Label || v.Label == arc.Target.Label + " (removed)");
207        var newArc = g.AddArc(
208          fromVertex,
209          toVertex
210        );
211        newArc.Label += " (removed)";
212      }
213      // add arcs only in to network
214      foreach (var arc in to.Arcs.Except(from.Arcs, new ArcComparer())) {
215        var fromVertex =
216          g.Vertices.Single(v => v.Label == arc.Source.Label || v.Label == arc.Source.Label + " (added)");
217        var toVertex =
218          g.Vertices.Single(v => v.Label == arc.Target.Label || v.Label == arc.Target.Label + " (added)");
219        var newArc = g.AddArc(
220          fromVertex,
221          toVertex
222        );
223        newArc.Label += " (added)";
224      }
225      return g;
226    }
227
228
229    private static List<IArc> AddArcs(VariableInteractionNetwork network, double[,] impacts, string[] varNames, Dictionary<string, IVertex> name2funVertex, double threshold = 0.0) {
230      var newArcs = new List<IArc>();
231      for (int row = 0; row < impacts.GetLength(0); row++) {
232        if (!name2funVertex.ContainsKey(varNames[row])) continue; // this variable does not have an associated function (considered as independent)
233
234        var rowVector = Enumerable.Range(0, impacts.GetLength(0)).Select(col => impacts[row, col]).ToArray();
235        var max = rowVector.Max();
236        if (max > threshold) {
237          var idxOfMax = Array.IndexOf<double>(rowVector, max);
238          impacts[row, idxOfMax] = double.NegativeInfinity; // edge is not considered anymore
239          var srcName = varNames[idxOfMax];
240          var dstName = varNames[row];
241          var vertex = network.Vertices.Single(v => v.Label == srcName);
242          var arc = network.AddArc(vertex, name2funVertex[dstName]);
243          arc.Weight = max;
244          newArcs.Add(arc);
245        }
246      }
247      return newArcs;
248    }
249
250    [StorableConstructor]
251    public VariableInteractionNetwork(bool deserializing) : base(deserializing) { }
252
253    public VariableInteractionNetwork() { }
254
255    protected VariableInteractionNetwork(VariableInteractionNetwork original, Cloner cloner) : base(original, cloner) { }
256
257    public override IDeepCloneable Clone(Cloner cloner) {
258      return new VariableInteractionNetwork(this, cloner);
259    }
260    private IList<IArc> ArcsForCycle(IList<IVertex> cycle) {
261      var res = new List<IArc>();
262      foreach (var t in cycle.Zip(cycle.Skip(1), Tuple.Create)) {
263        var src = t.Item1;
264        var dst = t.Item2;
265        var arc = Arcs.Single(a => a.Source == src && a.Target == dst);
266        res.Add(arc);
267      }
268      return res;
269    }
270
271
272    // finds the shortest cycles in the graph and returns all sub-graphs containing only the nodes / edges within the cycle
273    public IEnumerable<IList<IVertex>> FindShortestCycles() {
274      foreach (var startVariable in base.Vertices.OfType<VariableNetworkNode>()) {
275        foreach (var cycle in FindShortestCycles(startVariable))
276          yield return cycle;
277      }
278    }
279
280    private IEnumerable<IList<IVertex>> FindShortestCycles(VariableNetworkNode startVariable) {
281      var q = new Queue<List<IVertex>>(); // queue of paths
282      var path = new List<IVertex>();
283      var cycles = new List<List<IVertex>>();
284      var maxPathLength = base.Vertices.Count();
285
286      path.Add(startVariable);
287      q.Enqueue(new List<IVertex>(path));
288
289      FindShortestCycles(q, maxPathLength, cycles);
290      return cycles;
291    }
292
293    // TODO efficiency
294    private void FindShortestCycles(Queue<List<IVertex>> queue, int maxPathLength, List<List<IVertex>> cycles) {
295      while (queue.Any()) {
296        var path = queue.Dequeue();
297        if (path.Count > 1 && path.First() == path.Last()) {
298          cycles.Add(new List<IVertex>(path)); // found a cycle
299        } else if (path.Count >= maxPathLength) {
300          continue;
301        } else {
302          var lastVert = path.Last();
303          var neighbours = base.Arcs.Where(a => a.Source == lastVert).Select(a => a.Target);
304          foreach (var neighbour in neighbours) {
305            queue.Enqueue(new List<IVertex>(path.Concat(new IVertex[] { neighbour })));
306          }
307        }
308      }
309    }
310
311    public DoubleMatrix GetWeightsMatrix() {
312      var names = Vertices.OfType<VariableNetworkNode>()
313        .Select(v => v.Label)
314        .OrderBy(s => s, new NaturalStringComparer()).ToArray();
315      var w = new double[names.Length, names.Length];
316
317      var name2idx = new Dictionary<string, int>();
318      for (int i = 0; i < names.Length; i++) {
319        name2idx.Add(names[i], i);
320      }
321
322      foreach (var arc in Arcs) {
323        // only consider arcs going into a junction node
324        var target = arc.Target as JunctionNetworkNode;
325        if (target != null) {
326          var srcVarName = arc.Source.Label;
327          // each function node must have exactly one outgoing arc
328          var dstVarName = arc.Target.OutArcs.Single().Target.Label;
329
330          w[name2idx[dstVarName], name2idx[srcVarName]] = arc.Weight;
331        }
332      }
333
334
335      return new DoubleMatrix(w, names, names);
336    }
337
338    public DoubleMatrix GetSimpleWeightsMatrix() {
339      var names = Vertices.OfType<VariableNetworkNode>()
340        .Select(v => v.Label)
341        .OrderBy(s => s, new NaturalStringComparer()).ToArray();
342      var w = new double[names.Length, names.Length];
343
344      var name2idx = new Dictionary<string, int>();
345      for (int i = 0; i < names.Length; i++) {
346        name2idx.Add(names[i], i);
347      }
348
349      foreach (var arc in Arcs) {
350        if (arc.Target != null) {
351          var srcVarName = arc.Source.Label;
352          var dstVarName = arc.Target.Label;         
353          w[name2idx[dstVarName], name2idx[srcVarName]] = arc.Weight;
354        }
355      }
356
357      return new DoubleMatrix(w, names, names);
358    }
359
360    public string ToGraphVizString() {
361      Func<string, string> NodeAndEdgeColor = (str) =>
362      {
363        if (string.IsNullOrEmpty(str)) return "black";
364        else if (str.Contains("removed")) return "red";
365        else if (str.Contains("added")) return "green";
366        else return "black";
367      };
368
369      var sb = new StringBuilder();
370      sb.AppendLine("digraph {");
371      sb.AppendLine("rankdir=LR");
372      foreach (var v in Vertices.OfType<VariableNetworkNode>()) {
373        sb.AppendFormat("\"{0}\" [shape=oval, color={1}]", v.Label, NodeAndEdgeColor(v.Label)).AppendLine();
374      }
375      foreach (var v in Vertices.OfType<JunctionNetworkNode>()) {
376        sb.AppendFormat("\"{0}\" [shape=box, color={1}]", v.Label, NodeAndEdgeColor(v.Label)).AppendLine();
377      }
378      foreach (var arc in Arcs) {
379        sb.AppendFormat("\"{0}\"->\"{1}\" [color=\"{3}\"]", arc.Source.Label, arc.Target.Label, arc.Label, NodeAndEdgeColor(arc.Label)).AppendLine();
380      }
381      sb.AppendLine("}");
382      return sb.ToString();
383    }
384  }
385
386  public class VertexLabelComparer : IEqualityComparer<IVertex> {
387    public bool Equals(IVertex x, IVertex y) {
388      if (x == null && y == null) return true;
389      if (x != null && y != null) {
390        return x.Label == y.Label;
391      } else return false;
392    }
393
394    public int GetHashCode(IVertex obj) {
395      return obj.Label.GetHashCode();
396    }
397  }
398
399  public class ArcComparer : IEqualityComparer<IArc> {
400    public bool Equals(IArc x, IArc y) {
401      if (x == null && y == null) return true;
402      if (x != null && y != null) {
403        return x.Source.Label == y.Source.Label && x.Target.Label == y.Target.Label;
404      } else return false;
405    }
406
407    public int GetHashCode(IArc obj) {
408      return obj.Source.Label.GetHashCode() ^ obj.Target.Label.GetHashCode();
409    }
410  }
411
412  [Item("VariableNetworkNode", "A graph vertex which represents a symbolic regression variable.")]
413  [StorableClass]
414  public class VariableNetworkNode : Vertex<IDeepCloneable>, INetworkNode {
415    public VariableNetworkNode() {
416      Id = Guid.NewGuid().ToString();
417    }
418
419    public VariableNetworkNode(VariableNetworkNode original, Cloner cloner) : base(original, cloner) {
420      Id = original.Id;
421      Description = original.Description;
422    }
423
424    public override IDeepCloneable Clone(Cloner cloner) {
425      return new VariableNetworkNode(this, cloner);
426    }
427
428    public string Id { get; }
429    public string Description { get; set; }
430  }
431
432  [Item("FunctionNetworkNode", "A graph vertex representing a junction node.")]
433  [StorableClass]
434  public class JunctionNetworkNode : Vertex<IDeepCloneable>, INetworkNode {
435    public JunctionNetworkNode() {
436      Id = Guid.NewGuid().ToString();
437    }
438
439    public JunctionNetworkNode(JunctionNetworkNode original, Cloner cloner) : base(original, cloner) {
440      Id = original.Id;
441      Description = original.Description;
442    }
443
444    public override IDeepCloneable Clone(Cloner cloner) {
445      return new JunctionNetworkNode(this, cloner);
446    }
447
448    public string Id { get; }
449    public string Description { get; set; }
450  }
451}
Note: See TracBrowser for help on using the repository browser.