#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Analysis; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace HeuristicLab.Algorithms.DataAnalysis.Experimental { /// /// Forward selection meta-algorithm. /// [Item("Forward Selection", "Meta-algorithm that performs feature selection for a given base algorithm using greedy forward selection.")] [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)] [StorableClass] public sealed class ForwardsSelectionAlgorithm : FixedDataAnalysisAlgorithm { public IFixedValueParameter MaximumInputsParameter { get { return (IFixedValueParameter)Parameters["Maximum Inputs"]; } } public int MaximumInputs { get { return MaximumInputsParameter.Value.Value; } set { MaximumInputsParameter.Value.Value = value; } } public IValueParameter> AlgorithmParameter { get { return (IValueParameter>)Parameters["Algorithm"]; } } public FixedDataAnalysisAlgorithm Algorithm { get { return AlgorithmParameter.Value; } set { AlgorithmParameter.Value = value; } } [StorableConstructor] private ForwardsSelectionAlgorithm(bool deserializing) : base(deserializing) { } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { RegisterEventHandlers(); } private ForwardsSelectionAlgorithm(ForwardsSelectionAlgorithm original, Cloner cloner) : base(original, cloner) { RegisterEventHandlers(); } public override IDeepCloneable Clone(Cloner cloner) { return new ForwardsSelectionAlgorithm(this, cloner); } public ForwardsSelectionAlgorithm() : base() { Parameters.Add(new FixedValueParameter("Maximum Inputs", "The maximum number of input variables used in the models.", new IntValue(1))); Parameters.Add(new ValueParameter>("Algorithm", "The base algorithm for modeling", new LinearRegression())); Problem = new RegressionProblem(); RegisterEventHandlers(); } private void RegisterEventHandlers() { Problem.ProblemDataChanged += (o, e) => { MaximumInputs = Problem.ProblemData.InputVariables.CheckedItems .Select(t => t.Value) .Where(v => Problem.ProblemData.Dataset.VariableHasType(v.Value)) .Count(); }; } protected override void OnProblemChanged() { base.OnProblemChanged(); MaximumInputs = Problem.ProblemData.InputVariables.CheckedItems .Select(t => t.Value) .Where(v => Problem.ProblemData.Dataset.VariableHasType(v.Value)) .Count(); } protected override void Run(CancellationToken cancellationToken) { InitResults(); var problemClone = (IRegressionProblem)Problem.Clone(); var problemDataClone = (IRegressionProblemData)problemClone.ProblemData; var allowedInputVariables = problemDataClone.InputVariables.CheckedItems.Select(t=>t.Value) .Where(v => problemDataClone.Dataset.VariableHasType(v.Value)) .ToList(); foreach (var variable in problemDataClone.InputVariables) problemDataClone.InputVariables.SetItemCheckedState(variable, false); var alg = Algorithm; alg.Problem = problemClone; for (int inputs = 1; inputs <= MaximumInputs; inputs++) { var bestRMSE = double.MaxValue; IRegressionSolution bestSolution = null; StringValue bestInput = null; foreach (var inputVar in allowedInputVariables) { if (cancellationToken.IsCancellationRequested) { cancellationToken.ThrowIfCancellationRequested(); } problemDataClone.InputVariables.SetItemCheckedState(inputVar, true); var solution = RunAlg(alg); if (solution != null && solution.TrainingRootMeanSquaredError < bestRMSE) { bestRMSE = solution.TrainingRootMeanSquaredError; bestSolution = solution; bestInput = inputVar; } problemDataClone.InputVariables.SetItemCheckedState(inputVar, false); } allowedInputVariables.Remove(bestInput); problemDataClone.InputVariables.SetItemCheckedState(bestInput, true); bestSolution.Name = inputs.ToString() + " " + bestSolution.Name; Results["Current solution"].Value = bestSolution; ((ItemList)Results["All Solutions"].Value).Add(bestSolution); ((IntValue)Results["Number of variables"].Value).Value = inputs; ((DataTable)Results["RMSE table"].Value).Rows["RMSE (train)"].Values.Add(bestSolution.TrainingRootMeanSquaredError); ((DataTable)Results["RMSE table"].Value).Rows["RMSE (test)"].Values.Add(bestSolution.TestRootMeanSquaredError); if (cancellationToken.IsCancellationRequested) { cancellationToken.ThrowIfCancellationRequested(); } } } private void InitResults() { Results.Add(new Result("Current solution", typeof(IRegressionSolution))); Results.Add(new Result("All Solutions", new ItemList())); Results.Add(new Result("Number of variables", new IntValue(0))); var rmseTable = new DataTable("RMSE table"); var trainingRmseRow = new DataRow("RMSE (train)"); var testRmseRow = new DataRow("RMSE (test)"); rmseTable.Rows.Add(trainingRmseRow); rmseTable.Rows.Add(testRmseRow); Results.Add(new Result("RMSE table", rmseTable)); } private IRegressionSolution RunAlg(FixedDataAnalysisAlgorithm alg) { using (var wh = new AutoResetEvent(false)) { EventHandler> setWhForException = (sender, args) => { wh.Set(); }; EventHandler setWh = (sender, args) => { wh.Set(); }; try { alg.ExceptionOccurred += setWhForException; alg.Stopped += setWh; alg.Prepare(true); alg.Start(); wh.WaitOne(); return alg.Results.Select(r => r.Value).OfType().FirstOrDefault(); } finally { alg.ExceptionOccurred -= setWhForException; alg.Stopped -= setWh; } } } } }