Free cookie consent management tool by TermsFeed Policy Generator

source: branches/OptimizationNetworks/HeuristicLab.Networks/3.3/FeatureSelection Network/FeatureSelectionConnector.cs @ 15689

Last change on this file since 15689 was 12327, checked in by gkronber, 10 years ago

#2205: removed hard coded string for the retrieval of the regression solution to allow using other regression algorithms (e.g. RF)

File size: 4.4 KB
Line 
1using System;
2using System.Collections;
3using System.Collections.Generic;
4using System.Linq;
5using System.Threading;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Core.Networks;
9using HeuristicLab.Data;
10using HeuristicLab.Encodings.BinaryVectorEncoding;
11using HeuristicLab.Problems.DataAnalysis;
12
13namespace HeuristicLab.Networks.FeatureSelection_Network {
14  [Item("FeatureSelectionConnector", "")]
15  public sealed class FeatureSelectionConnector : Node {
16    private FeatureSelectionConnector(FeatureSelectionConnector original, Cloner cloner) : base(original, cloner) { }
17    public FeatureSelectionConnector()
18      : base() {
19      if (Ports.Count == 0)
20        Initialize();
21    }
22
23    public override IDeepCloneable Clone(Cloner cloner) {
24      return new FeatureSelectionConnector(this, cloner);
25    }
26
27    public void Initialize() {
28      var parameters = new MessagePort("Parameters");
29      Ports.Add(parameters);
30      parameters.Parameters.Add(new PortParameter<IRegressionProblemData>("ProblemData") { Type = PortParameterType.Input });
31
32      var selectionPort = new MessagePort("Selection Connector");
33      Ports.Add(selectionPort);
34      selectionPort.Parameters.Add(new PortParameter<BinaryVector>("Selection") { Type = PortParameterType.Input });
35      selectionPort.Parameters.Add(new PortParameter<DoubleValue>("Quality") { Type = PortParameterType.Input | PortParameterType.Output });
36
37      var regressionPort = new MessagePort("Regression Connector");
38      Ports.Add(regressionPort);
39      regressionPort.Parameters.Add(new PortParameter<IRegressionProblemData>("ProblemData") { Type = PortParameterType.Output });
40      regressionPort.Parameters.Add(new PortParameter<IRegressionSolution>("Linear regression solution") { Type = PortParameterType.Input });
41      RegisterEvents();
42    }
43
44    public void RegisterEvents() {
45      var selection = (IMessagePort)Ports["Selection Connector"];
46      selection.MessageReceived += Selection_MessageReceived;
47    }
48    public void DeregisterEvents() {
49      var selection = (IMessagePort)Ports["Selection Connector"];
50      selection.MessageReceived -= Selection_MessageReceived;
51    }
52
53    private Dictionary<BinaryVector, double> solutionCache = new Dictionary<BinaryVector, double>(new BinaryVectorComparer());
54
55    private void Selection_MessageReceived(object sender, EventArgs<IMessage, CancellationToken> e) {
56      // get parameters
57      var parametersPort = (IMessagePort)Ports["Parameters"];
58      var parameters = parametersPort.PrepareMessage();
59      parametersPort.SendMessage(parameters, e.Value2);
60      var problemData = (IRegressionProblemData)parameters["ProblemData"];
61
62      // filter allowed variables
63      var selectionMsg = e.Value;
64      var selection = (BinaryVector)selectionMsg["Selection"];
65
66      // if possible return the cached answer
67      double solutionQuality;
68      if (solutionCache.TryGetValue(selection, out solutionQuality)) {
69        selectionMsg["Quality"] = new DoubleValue(solutionQuality);
70      } else {
71        var allowedVariables = problemData.AllowedInputVariables;
72
73        var selectedInputVariables = from t in selection.Zip(allowedVariables, Tuple.Create)
74                                     where t.Item1
75                                     select t.Item2;
76
77        var clonedProblemData = new RegressionProblemData(problemData.Dataset, selectedInputVariables, problemData.TargetVariable);
78
79        // solve Regression
80        var regressionConPort = (IMessagePort)Ports["Regression Connector"];
81        var regressionMsg = regressionConPort.PrepareMessage();
82        regressionMsg["ProblemData"] = clonedProblemData;
83        regressionConPort.SendMessage(regressionMsg, e.Value2);
84        var solution = regressionMsg.Values.Select(v => v.Value).OfType<IRegressionSolution>().Single();
85
86        selectionMsg["Quality"] = new DoubleValue(solution.TestNormalizedMeanSquaredError);
87        // cache solution quality
88        solutionCache.Add(selection, solution.TestNormalizedMeanSquaredError);
89      }
90    }
91  }
92
93  internal class BinaryVectorComparer : IEqualityComparer<BinaryVector> {
94    public bool Equals(BinaryVector x, BinaryVector y) {
95      return x.Length == y.Length &&
96        x.Zip(y, Tuple.Create).All(t => t.Item1 == t.Item2);
97    }
98
99    public int GetHashCode(BinaryVector obj) {
100      // return number of set bits
101      return obj.Count(t => t);
102    }
103  }
104}
Note: See TracBrowser for help on using the repository browser.