Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.VariableInteractionNetworks/HeuristicLab.VariableInteractionNetworks/3.3/VariableInteractionNetwork.cs @ 14622

Last change on this file since 14622 was 14622, checked in by gkronber, 7 years ago

#2288: added a static method to create a network (as a DAG) from NMSE vector and variable impacts matrix as well as code for cycle detection and conversion of a network to graphviz and its adjacency matrix

File size: 10.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Linq;
26using System.Text;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.VariableInteractionNetworks {
33  [Item("VariableInteractionNetwork", "A graph representation of variables and their relationships.")]
34  [StorableClass]
35  public class VariableInteractionNetwork : DirectedGraph {
36    /// <summary>
37    /// Creates a network from a matrix of variable impacts (each row represents a target variable, each column represents an input variable)
38    /// The network is acyclic. Values in the diagonal are ignored.
39    /// The algorithm starts with an empty network and incrementally adds next most relevant input variable for each target variable up to a given threshold
40    /// In each iteration cycles are broken by removing the weakest link.
41    /// </summary>
42    /// <param name="nmse">vector of NMSE values for each target variable</param>
43    /// <param name="variableImpacts">Variable impacts (smaller is lower impact). Row names and columns names should be set</param>
44    /// <param name="nmseThreshold">Threshold for NMSE values. Variables with a NMSE value larger than the threshold are considered as independent variables</param>
45    /// <param name="varImpactThreshold">Threshold for variable impact values. Impacts with a value smaller than the threshold are considered as independent</param>
46    /// <returns></returns>
47    public static VariableInteractionNetwork FromNmseAndVariableImpacts(double[] nmse, DoubleMatrix variableImpacts, double nmseThreshold = 0.2, double varImpactThreshold = 0.0) {
48      if(variableImpacts.Rows != variableImpacts.Columns) throw new ArgumentException();
49      var network = new VariableInteractionNetwork();
50
51      Dictionary<string, IVertex> name2funVertex = new Dictionary<string, IVertex>(); // store vertexes representing the function for each target so we can easily add incoming arcs later on
52      string[] varNames = variableImpacts.RowNames.ToArray();
53      if(nmse.Length != varNames.Length) throw new ArgumentException();
54
55      for(int i = 0; i < varNames.Length; i++) {
56        var name = varNames[i];
57        var varVertex = new VariableNetworkNode() { Label = name };
58        network.AddVertex(varVertex);
59        if(nmse[i] < nmseThreshold) {
60          var functionVertex = new JunctionNetworkNode() { Label = "f_" + name };
61          name2funVertex.Add(name, functionVertex);
62          network.AddVertex(functionVertex);
63          var predArc = network.AddArc(functionVertex, varVertex);
64          predArc.Weight = double.PositiveInfinity; // never delete arcs from f_x -> x (representing output of a function)
65        }
66      }
67
68      // rel is updated (impacts which are represented in the network are set to zero)
69      var rel = variableImpacts.CloneAsMatrix();
70      // make sure the diagonal is not considered
71      for(int i = 0; i < rel.GetLength(0); i++) rel[i, i] = double.NegativeInfinity;
72
73      var addedArcs = AddArcs(network, rel, varNames, name2funVertex, varImpactThreshold);
74      while(addedArcs.Any()) {
75        var cycles = network.FindShortestCycles().ToList();
76        while(cycles.Any()) {
77          // delete weakest link
78          var weakestArc = cycles.SelectMany(cycle => network.ArcsForCycle(cycle)).OrderBy(a => a.Weight).First();
79          network.RemoveArc(weakestArc);
80
81          cycles = network.FindShortestCycles().ToList();
82        }
83
84        addedArcs = AddArcs(network, rel, varNames, name2funVertex, varImpactThreshold);
85      }
86
87      return network;
88    }
89
90    private static List<IArc> AddArcs(VariableInteractionNetwork network, double[,] impacts, string[] varNames, Dictionary<string, IVertex> name2funVertex, double threshold = 0.0) {
91      var newArcs = new List<IArc>();
92      for(int row = 0; row < impacts.GetLength(0); row++) {
93        if(!name2funVertex.ContainsKey(varNames[row])) continue; // this variable does not have an associated function (considered as independent)
94
95        var rowVector = Enumerable.Range(0, impacts.GetLength(0)).Select(col => impacts[row, col]).ToArray();
96        var max = rowVector.Max();
97        if(max > threshold) {
98          var idxOfMax = Array.IndexOf<double>(rowVector, max);
99          impacts[row, idxOfMax] = double.NegativeInfinity; // edge is not considered anymore
100          var srcName = varNames[idxOfMax];
101          var dstName = varNames[row];
102          var vertex = network.Vertices.Single(v => v.Label == srcName);
103          var arc = network.AddArc(vertex, name2funVertex[dstName]);
104          arc.Weight = max;
105          newArcs.Add(arc);
106        }
107      }
108      return newArcs;
109    }
110
111    [StorableConstructor]
112    public VariableInteractionNetwork(bool deserializing) : base(deserializing) { }
113
114    public VariableInteractionNetwork() { }
115
116    protected VariableInteractionNetwork(VariableInteractionNetwork original, Cloner cloner) : base(original, cloner) { }
117
118    public override IDeepCloneable Clone(Cloner cloner) {
119      return new VariableInteractionNetwork(this, cloner);
120    }
121    private IList<IArc> ArcsForCycle(IList<IVertex> cycle) {
122      var res = new List<IArc>();
123      foreach(var t in cycle.Zip(cycle.Skip(1), Tuple.Create)) {
124        var src = t.Item1;
125        var dst = t.Item2;
126        var arc = Arcs.Single(a => a.Source == src && a.Target == dst);
127        res.Add(arc);
128      }
129      return res;
130    }
131
132
133    // finds the shortest cycles in the graph and returns all sub-graphs containing only the nodes / edges within the cycle
134    public IEnumerable<IList<IVertex>> FindShortestCycles() {
135      foreach(var startVariable in base.Vertices.OfType<VariableNetworkNode>()) {
136        foreach(var cycle in FindShortestCycles(startVariable))
137          yield return cycle;
138      }
139    }
140
141    private IEnumerable<IList<IVertex>> FindShortestCycles(VariableNetworkNode startVariable) {
142      var q = new Queue<List<IVertex>>(); // queue of paths
143      var path = new List<IVertex>();
144      var cycles = new List<List<IVertex>>();
145      var maxPathLength = base.Vertices.Count();
146
147      path.Add(startVariable);
148      q.Enqueue(new List<IVertex>(path));
149
150      FindShortestCycles(q, maxPathLength, cycles);
151      return cycles;
152    }
153
154    // TODO efficiency
155    private void FindShortestCycles(Queue<List<IVertex>> queue, int maxPathLength, List<List<IVertex>> cycles) {
156      while(queue.Any()) {
157        var path = queue.Dequeue();
158        if(path.Count > 1 && path.First() == path.Last()) {
159          cycles.Add(new List<IVertex>(path)); // found a cycle
160        } else if(path.Count >= maxPathLength) {
161          continue;
162        } else {
163          var lastVert = path.Last();
164          var neighbours = base.Arcs.Where(a => a.Source == lastVert).Select(a => a.Target);
165          foreach(var neighbour in neighbours) {
166            queue.Enqueue(new List<IVertex>(path.Concat(new IVertex[] { neighbour })));
167          }
168        }
169      }
170    }
171
172    public DoubleMatrix GetWeightsMatrix() {
173      var names = Vertices.OfType<VariableNetworkNode>()
174        .Select(v => v.Label)
175        .OrderBy(s => s, new NaturalStringComparer()).ToArray();
176      var w = new double[names.Length, names.Length];
177
178      var name2idx = new Dictionary<string, int>();
179      for(int i = 0; i < names.Length; i++) {
180        name2idx.Add(names[i], i);
181      }
182
183      foreach(var arc in Arcs) {
184        // only consider arcs going into a junction node
185        var target = arc.Target as JunctionNetworkNode;
186        if(target != null)
187        {
188          var srcVarName = arc.Source.Label;
189          // each function node must have exactly one outgoing arc
190          var dstVarName = arc.Target.OutArcs.Single().Target.Label;
191
192          w[name2idx[dstVarName], name2idx[srcVarName]] = arc.Weight;
193        }
194      }
195
196
197      return new DoubleMatrix(w, names, names);
198    }
199
200    public string ToGraphVizString() {
201      var sb = new StringBuilder();
202      sb.AppendLine("digraph {");
203      sb.AppendLine("rankdir=LR");
204      foreach(var v in Vertices.OfType<VariableNetworkNode>()) {
205        sb.AppendFormat("\"{0}\" [shape=oval]", v.Label).AppendLine();
206      }
207      foreach(var v in Vertices.OfType<JunctionNetworkNode>()) {
208        sb.AppendFormat("\"{0}\" [shape=box]", v.Label).AppendLine();
209      }
210      foreach(var arc in Arcs) {
211        sb.AppendFormat("\"{0}\"->\"{1}\"", arc.Source.Label, arc.Target.Label).AppendLine();
212      }
213      sb.AppendLine("}");
214      return sb.ToString();
215    }
216  }
217
218  [Item("VariableNetworkNode", "A graph vertex which represents a symbolic regression variable.")]
219  [StorableClass]
220  public class VariableNetworkNode : Vertex<IDeepCloneable>, INetworkNode {
221    public VariableNetworkNode() {
222      Id = Guid.NewGuid().ToString();
223    }
224
225    public VariableNetworkNode(VariableNetworkNode original, Cloner cloner) : base(original, cloner) {
226      Id = original.Id;
227      Description = original.Description;
228    }
229
230    public override IDeepCloneable Clone(Cloner cloner) {
231      return new VariableNetworkNode(this, cloner);
232    }
233
234    public string Id { get; }
235    public string Description { get; set; }
236  }
237
238  [Item("FunctionNetworkNode", "A graph vertex representing a junction node.")]
239  [StorableClass]
240  public class JunctionNetworkNode : Vertex<IDeepCloneable>, INetworkNode {
241    public JunctionNetworkNode() {
242      Id = Guid.NewGuid().ToString();
243    }
244
245    public JunctionNetworkNode(JunctionNetworkNode original, Cloner cloner) : base(original, cloner) {
246      Id = original.Id;
247      Description = original.Description;
248    }
249
250    public override IDeepCloneable Clone(Cloner cloner) {
251      return new JunctionNetworkNode(this, cloner);
252    }
253
254    public string Id { get; }
255    public string Description { get; set; }
256  }
257}
Note: See TracBrowser for help on using the repository browser.