1 | using System;
2 | using System.Collections;
3 | using System.Collections.Generic;
4 | using System.Linq;
5 | using System.Threading;
6 | using HeuristicLab.Common;
7 | using HeuristicLab.Core;
8 | using HeuristicLab.Core.Networks;
9 | using HeuristicLab.Data;
10 | using HeuristicLab.Encodings.BinaryVectorEncoding;
11 | using HeuristicLab.Problems.DataAnalysis;
12 |
13 | namespace HeuristicLab.Networks.FeatureSelection_Network {
14 | [Item("FeatureSelectionConnector", "")]
15 | public sealed class FeatureSelectionConnector : Node {
16 | private FeatureSelectionConnector(FeatureSelectionConnector original, Cloner cloner) : base(original, cloner) { }
17 | public FeatureSelectionConnector()
18 | : base() {
19 | if (Ports.Count == 0)
20 | Initialize();
21 | }
22 |
23 | public override IDeepCloneable Clone(Cloner cloner) {
24 | return new FeatureSelectionConnector(this, cloner);
25 | }
26 |
27 | public void Initialize() {
28 | var parameters = new MessagePort("Parameters");
29 | Ports.Add(parameters);
30 | parameters.Parameters.Add(new PortParameter<IRegressionProblemData>("ProblemData") { Type = PortParameterType.Input });
31 |
32 | var selectionPort = new MessagePort("Selection Connector");
33 | Ports.Add(selectionPort);
34 | selectionPort.Parameters.Add(new PortParameter<BinaryVector>("Selection") { Type = PortParameterType.Input });
35 | selectionPort.Parameters.Add(new PortParameter<DoubleValue>("Quality") { Type = PortParameterType.Input | PortParameterType.Output });
36 |
37 | var regressionPort = new MessagePort("Regression Connector");
38 | Ports.Add(regressionPort);
39 | regressionPort.Parameters.Add(new PortParameter<IRegressionProblemData>("ProblemData") { Type = PortParameterType.Output });
40 | regressionPort.Parameters.Add(new PortParameter<IRegressionSolution>("Linear regression solution") { Type = PortParameterType.Input });
41 | RegisterEvents();
42 | }
43 |
44 | public void RegisterEvents() {
45 | var selection = (IMessagePort)Ports["Selection Connector"];
46 | selection.MessageReceived += Selection_MessageReceived;
47 | }
48 | public void DeregisterEvents() {
49 | var selection = (IMessagePort)Ports["Selection Connector"];
50 | selection.MessageReceived -= Selection_MessageReceived;
51 | }
52 |
53 | private Dictionary<BinaryVector, double> solutionCache = new Dictionary<BinaryVector, double>(new BinaryVectorComparer());
54 |
55 | private void Selection_MessageReceived(object sender, EventArgs<IMessage, CancellationToken> e) {
56 | // get parameters
57 | var parametersPort = (IMessagePort)Ports["Parameters"];
58 | var parameters = parametersPort.PrepareMessage();
59 | parametersPort.SendMessage(parameters, e.Value2);
60 | var problemData = (IRegressionProblemData)parameters["ProblemData"];
61 |
62 | // filter allowed variables
63 | var selectionMsg = e.Value;
64 | var selection = (BinaryVector)selectionMsg["Selection"];
65 |
66 | // if possible return the cached answer
67 | double solutionQuality;
68 | if (solutionCache.TryGetValue(selection, out solutionQuality)) {
69 | selectionMsg["Quality"] = new DoubleValue(solutionQuality);
70 | } else {
71 | var allowedVariables = problemData.AllowedInputVariables;
72 |
73 | var selectedInputVariables = from t in selection.Zip(allowedVariables, Tuple.Create)
74 | where t.Item1
75 | select t.Item2;
76 |
77 | var clonedProblemData = new RegressionProblemData(problemData.Dataset, selectedInputVariables, problemData.TargetVariable);
78 |
79 | // solve Regression
80 | var regressionConPort = (IMessagePort)Ports["Regression Connector"];
81 | var regressionMsg = regressionConPort.PrepareMessage();
82 | regressionMsg["ProblemData"] = clonedProblemData;
83 | regressionConPort.SendMessage(regressionMsg, e.Value2);
84 | var solution = regressionMsg.Values.Select(v => v.Value).OfType<IRegressionSolution>().Single();
85 |
86 | selectionMsg["Quality"] = new DoubleValue(solution.TestNormalizedMeanSquaredError);
87 | // cache solution quality
88 | solutionCache.Add(selection, solution.TestNormalizedMeanSquaredError);
89 | }
90 | }
91 | }
92 |
93 | internal class BinaryVectorComparer : IEqualityComparer<BinaryVector> {
94 | public bool Equals(BinaryVector x, BinaryVector y) {
95 | return x.Length == y.Length &&
96 | x.Zip(y, Tuple.Create).All(t => t.Item1 == t.Item2);
97 | }
98 |
99 | public int GetHashCode(BinaryVector obj) {
100 | // return number of set bits
101 | return obj.Count(t => t);
102 | }
103 | }
104 | }