1 | using System;
|
---|
2 | using System.Collections;
|
---|
3 | using System.Collections.Generic;
|
---|
4 | using System.Linq;
|
---|
5 | using System.Threading;
|
---|
6 | using HeuristicLab.Common;
|
---|
7 | using HeuristicLab.Core;
|
---|
8 | using HeuristicLab.Core.Networks;
|
---|
9 | using HeuristicLab.Data;
|
---|
10 | using HeuristicLab.Encodings.BinaryVectorEncoding;
|
---|
11 | using HeuristicLab.Problems.DataAnalysis;
|
---|
12 |
|
---|
13 | namespace HeuristicLab.Networks.FeatureSelection_Network {
|
---|
14 | [Item("FeatureSelectionConnector", "")]
|
---|
15 | public sealed class FeatureSelectionConnector : Node {
|
---|
16 | private FeatureSelectionConnector(FeatureSelectionConnector original, Cloner cloner) : base(original, cloner) { }
|
---|
17 | public FeatureSelectionConnector()
|
---|
18 | : base() {
|
---|
19 | if (Ports.Count == 0)
|
---|
20 | Initialize();
|
---|
21 | }
|
---|
22 |
|
---|
23 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
24 | return new FeatureSelectionConnector(this, cloner);
|
---|
25 | }
|
---|
26 |
|
---|
27 | public void Initialize() {
|
---|
28 | var parameters = new MessagePort("Parameters");
|
---|
29 | Ports.Add(parameters);
|
---|
30 | parameters.Parameters.Add(new PortParameter<IRegressionProblemData>("ProblemData") { Type = PortParameterType.Input });
|
---|
31 |
|
---|
32 | var selectionPort = new MessagePort("Selection Connector");
|
---|
33 | Ports.Add(selectionPort);
|
---|
34 | selectionPort.Parameters.Add(new PortParameter<BinaryVector>("Selection") { Type = PortParameterType.Input });
|
---|
35 | selectionPort.Parameters.Add(new PortParameter<DoubleValue>("Quality") { Type = PortParameterType.Input | PortParameterType.Output });
|
---|
36 |
|
---|
37 | var regressionPort = new MessagePort("Regression Connector");
|
---|
38 | Ports.Add(regressionPort);
|
---|
39 | regressionPort.Parameters.Add(new PortParameter<IRegressionProblemData>("ProblemData") { Type = PortParameterType.Output });
|
---|
40 | regressionPort.Parameters.Add(new PortParameter<IRegressionSolution>("Linear regression solution") { Type = PortParameterType.Input });
|
---|
41 | RegisterEvents();
|
---|
42 | }
|
---|
43 |
|
---|
44 | public void RegisterEvents() {
|
---|
45 | var selection = (IMessagePort)Ports["Selection Connector"];
|
---|
46 | selection.MessageReceived += Selection_MessageReceived;
|
---|
47 | }
|
---|
48 | public void DeregisterEvents() {
|
---|
49 | var selection = (IMessagePort)Ports["Selection Connector"];
|
---|
50 | selection.MessageReceived -= Selection_MessageReceived;
|
---|
51 | }
|
---|
52 |
|
---|
53 | private Dictionary<BinaryVector, double> solutionCache = new Dictionary<BinaryVector, double>(new BinaryVectorComparer());
|
---|
54 |
|
---|
55 | private void Selection_MessageReceived(object sender, EventArgs<IMessage, CancellationToken> e) {
|
---|
56 | // get parameters
|
---|
57 | var parametersPort = (IMessagePort)Ports["Parameters"];
|
---|
58 | var parameters = parametersPort.PrepareMessage();
|
---|
59 | parametersPort.SendMessage(parameters, e.Value2);
|
---|
60 | var problemData = (IRegressionProblemData)parameters["ProblemData"];
|
---|
61 |
|
---|
62 | // filter allowed variables
|
---|
63 | var selectionMsg = e.Value;
|
---|
64 | var selection = (BinaryVector)selectionMsg["Selection"];
|
---|
65 |
|
---|
66 | // if possible return the cached answer
|
---|
67 | double solutionQuality;
|
---|
68 | if (solutionCache.TryGetValue(selection, out solutionQuality)) {
|
---|
69 | selectionMsg["Quality"] = new DoubleValue(solutionQuality);
|
---|
70 | } else {
|
---|
71 | var allowedVariables = problemData.AllowedInputVariables;
|
---|
72 |
|
---|
73 | var selectedInputVariables = from t in selection.Zip(allowedVariables, Tuple.Create)
|
---|
74 | where t.Item1
|
---|
75 | select t.Item2;
|
---|
76 |
|
---|
77 | var clonedProblemData = new RegressionProblemData(problemData.Dataset, selectedInputVariables, problemData.TargetVariable);
|
---|
78 |
|
---|
79 | // solve Regression
|
---|
80 | var regressionConPort = (IMessagePort)Ports["Regression Connector"];
|
---|
81 | var regressionMsg = regressionConPort.PrepareMessage();
|
---|
82 | regressionMsg["ProblemData"] = clonedProblemData;
|
---|
83 | regressionConPort.SendMessage(regressionMsg, e.Value2);
|
---|
84 | var solution = regressionMsg.Values.Select(v => v.Value).OfType<IRegressionSolution>().Single();
|
---|
85 |
|
---|
86 | selectionMsg["Quality"] = new DoubleValue(solution.TestNormalizedMeanSquaredError);
|
---|
87 | // cache solution quality
|
---|
88 | solutionCache.Add(selection, solution.TestNormalizedMeanSquaredError);
|
---|
89 | }
|
---|
90 | }
|
---|
91 | }
|
---|
92 |
|
---|
93 | internal class BinaryVectorComparer : IEqualityComparer<BinaryVector> {
|
---|
94 | public bool Equals(BinaryVector x, BinaryVector y) {
|
---|
95 | return x.Length == y.Length &&
|
---|
96 | x.Zip(y, Tuple.Create).All(t => t.Item1 == t.Item2);
|
---|
97 | }
|
---|
98 |
|
---|
99 | public int GetHashCode(BinaryVector obj) {
|
---|
100 | // return number of set bits
|
---|
101 | return obj.Count(t => t);
|
---|
102 | }
|
---|
103 | }
|
---|
104 | }
|
---|