#region License Information /* HeuristicLab * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System.Linq; using System.Threading; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Core.Networks; using HeuristicLab.Data; using HeuristicLab.Encodings.BinaryVectorEncoding; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace HeuristicLab.Networks.IntegratedOptimization.MachineLearning { [StorableClass] public sealed class FeatureSelectionOrchestrator : OrchestratorNode { private const string REGRESSION_ORCHESTRATION_PORT_NAME = "Regression algorithm orchestration port"; private const string FEATURE_SELECTION_EVALUATION_PORT_NAME = "Feature selection evaluation port"; private const string REGRESSION_PROBLEM_PARAMETER_NAME = "Regression Problem Data"; public IMessagePort RegressionOrchestrationPort { get { return (IMessagePort)Ports[REGRESSION_ORCHESTRATION_PORT_NAME]; } } public IMessagePort FeatureSelectionEvaluationPort { get { return (IMessagePort)Ports[FEATURE_SELECTION_EVALUATION_PORT_NAME]; } } public IValueParameter ProblemDataParameter { get { return (IValueParameter)Parameters[REGRESSION_PROBLEM_PARAMETER_NAME]; } } public IRegressionProblemData RegressionProblemData { get { return ProblemDataParameter.Value; } set { ProblemDataParameter.Value = value; } } [StorableConstructor] private FeatureSelectionOrchestrator(bool deserializing) : base(deserializing) { } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { RegisterPortEvents(); } private FeatureSelectionOrchestrator(FeatureSelectionOrchestrator original, Cloner cloner) : base(original, cloner) { network = cloner.Clone(original.network); RegisterPortEvents(); } public override IDeepCloneable Clone(Cloner cloner) { return new FeatureSelectionOrchestrator(this, cloner); } //TODO remove network reference; //TODO move regression problem to network [Storable] private readonly FeatureSelectionNetwork network; public FeatureSelectionOrchestrator(FeatureSelectionNetwork network) : base() { var featureSelectionPort = CreateEvaluationPort(FEATURE_SELECTION_EVALUATION_PORT_NAME, "BinaryVector", "Quality"); Ports.Add(featureSelectionPort); var regressionPort = CreateOrchestrationPort(REGRESSION_ORCHESTRATION_PORT_NAME); Ports.Add(regressionPort); this.network = network; Parameters.Add(new ValueParameter(REGRESSION_PROBLEM_PARAMETER_NAME, "", new RegressionProblemData())); RegisterPortEvents(); } private void RegisterPortEvents() { FeatureSelectionEvaluationPort.MessageReceived += (s, e) => FeatureSelectionEvaluationPort_MessageReceived(e.Value, e.Value2); } private void FeatureSelectionEvaluationPort_MessageReceived(IMessage evaluationMessage, CancellationToken token) { var problemData = (IRegressionProblemData)RegressionProblemData.Clone(); var binaryVector = (BinaryVector)evaluationMessage["BinaryVector"]; binaryVector.ElementNames = problemData.InputVariables.CheckedItems.Select(variable => variable.Value.Value); var allowedVariables = problemData.InputVariables.CheckedItems.Zip(binaryVector, (variable, allowed) => new { VariableName = variable.Value, Allowed = allowed }); foreach (var allowedVariable in allowedVariables) problemData.InputVariables.SetItemCheckedState(allowedVariable.VariableName, allowedVariable.Allowed); var orchestrationMessage = RegressionOrchestrationPort.PrepareMessage(); orchestrationMessage["Problem"] = new RegressionProblem() { ProblemData = problemData }; orchestrationMessage["OrchestrationMessage"] = new EnumValue(OrchestrationMessage.Prepare); RegressionOrchestrationPort.SendMessage(orchestrationMessage, token); var startMessage = RegressionOrchestrationPort.PrepareMessage(); startMessage["OrchestrationMessage"] = new EnumValue(OrchestrationMessage.Start); RegressionOrchestrationPort.SendMessage(startMessage, token); var results = (ResultCollection)startMessage["Results"]; var regressionSolution = results.Select(r => r.Value).OfType().First(); UpdatedResults(binaryVector, regressionSolution); double quality = regressionSolution.TestMeanAbsoluteError; evaluationMessage["Quality"] = new DoubleValue(quality); } private void UpdatedResults(BinaryVector binaryVector, IRegressionSolution solution) { if (!Results.ContainsKey("Best Solution Vector")) { Results.Add(new Result("Best Solution Vector", typeof(BinaryVector))); //Results.Add(new Result("Best Solution Variables", typeof(DoubleArray))); Results.Add(new Result("Best Symbolic Solution", typeof(IRegressionSolution))); } var previousBestVector = (BinaryVector)Results["Best Solution Vector"].Value; //check if better vector has been found if (previousBestVector != null && binaryVector.SequenceEqual(previousBestVector)) return; Results["Best Solution Vector"].Value = binaryVector; Results["Best Symbolic Solution"].Value = solution; //var variableNames = solution.ProblemData.AllowedInputVariables; //alglib.linearmodel lm = BuildModel(best, problemData); //double[] coefficients = null; //int nFeatures = -1; //alglib.lrunpack(lm, out coefficients, out nFeatures); //var doubleArray = new DoubleArray(coefficients.ToArray()); //doubleArray.ElementNames = variableNames; //Results["Best Solution Variables"].Value = doubleArray; } //TODO Remove methods public override void Pause() { network.Pause(); } public override void Prepare(bool clearRuns = false) { network.Prepare(clearRuns); } public override void Start() { network.Start(); } public override void Stop() { network.Stop(); } } }