Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2789_MathNetNumerics-Exploration/HeuristicLab.Algorithms.DataAnalysis.Experimental/ForwardSelection.cs @ 16966

Last change on this file since 16966 was 14998, checked in by gkronber, 7 years ago

#2789 added forward selection algorithm and algorithm to calculate all LR combinations

File size: 7.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Analysis;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis;
34
35namespace HeuristicLab.Algorithms.DataAnalysis.Experimental {
36  /// <summary>
37  /// Forward selection meta-algorithm.
38  /// </summary>
39  [Item("Forward Selection", "Meta-algorithm that performs feature selection for a given base algorithm using greedy forward selection.")]
40  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 999)]
41  [StorableClass]
42  public sealed class ForwardsSelectionAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
43    public IFixedValueParameter<IntValue> MaximumInputsParameter {
44      get { return (IFixedValueParameter<IntValue>)Parameters["Maximum Inputs"]; }
45    }
46    public int MaximumInputs {
47      get { return MaximumInputsParameter.Value.Value; }
48      set { MaximumInputsParameter.Value.Value = value; }
49    }
50
51    public IValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>> AlgorithmParameter {
52      get { return (IValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>>)Parameters["Algorithm"]; }
53    }
54
55    public FixedDataAnalysisAlgorithm<IRegressionProblem> Algorithm {
56      get { return AlgorithmParameter.Value; }
57      set { AlgorithmParameter.Value = value; }
58    }
59
60
61    [StorableConstructor]
62    private ForwardsSelectionAlgorithm(bool deserializing) : base(deserializing) { }
63    [StorableHook(HookType.AfterDeserialization)]
64    private void AfterDeserialization() {
65      RegisterEventHandlers();
66    }
67
68    private ForwardsSelectionAlgorithm(ForwardsSelectionAlgorithm original, Cloner cloner)
69      : base(original, cloner) {
70      RegisterEventHandlers();
71    }
72    public override IDeepCloneable Clone(Cloner cloner) {
73      return new ForwardsSelectionAlgorithm(this, cloner);
74    }
75
76    public ForwardsSelectionAlgorithm()
77      : base() {
78      Parameters.Add(new FixedValueParameter<IntValue>("Maximum Inputs", "The maximum number of input variables used in the models.", new IntValue(1)));
79      Parameters.Add(new ValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>>("Algorithm", "The base algorithm for modeling", new LinearRegression()));
80
81      Problem = new RegressionProblem();
82      RegisterEventHandlers();
83    }
84
85    private void RegisterEventHandlers() {
86      Problem.ProblemDataChanged += (o, e) => { MaximumInputs = Problem.ProblemData.InputVariables.CheckedItems
87        .Select(t => t.Value)
88        .Where(v => Problem.ProblemData.Dataset.VariableHasType<double>(v.Value))
89        .Count();
90      };
91    }
92    protected override void OnProblemChanged() {
93      base.OnProblemChanged();
94      MaximumInputs = Problem.ProblemData.InputVariables.CheckedItems
95        .Select(t => t.Value)
96        .Where(v => Problem.ProblemData.Dataset.VariableHasType<double>(v.Value))
97        .Count();
98    }
99
100    protected override void Run(CancellationToken cancellationToken) {
101      InitResults();
102      var problemClone = (IRegressionProblem)Problem.Clone();
103      var problemDataClone = (IRegressionProblemData)problemClone.ProblemData;
104      var allowedInputVariables = problemDataClone.InputVariables.CheckedItems.Select(t=>t.Value)
105        .Where(v => problemDataClone.Dataset.VariableHasType<double>(v.Value))
106        .ToList();
107      foreach (var variable in problemDataClone.InputVariables)
108        problemDataClone.InputVariables.SetItemCheckedState(variable, false);
109
110      var alg = Algorithm;
111      alg.Problem = problemClone;
112
113      for (int inputs = 1; inputs <= MaximumInputs; inputs++) {
114        var bestRMSE = double.MaxValue;
115        IRegressionSolution bestSolution = null;
116        StringValue bestInput = null;
117        foreach (var inputVar in allowedInputVariables) {
118          if (cancellationToken.IsCancellationRequested) {
119            cancellationToken.ThrowIfCancellationRequested();
120          }
121
122          problemDataClone.InputVariables.SetItemCheckedState(inputVar, true);
123          var solution = RunAlg(alg);
124          if (solution != null && solution.TrainingRootMeanSquaredError < bestRMSE) {
125            bestRMSE = solution.TrainingRootMeanSquaredError;
126            bestSolution = solution;
127            bestInput = inputVar;
128          }
129          problemDataClone.InputVariables.SetItemCheckedState(inputVar, false);
130        }
131
132        allowedInputVariables.Remove(bestInput);
133        problemDataClone.InputVariables.SetItemCheckedState(bestInput, true);
134
135        bestSolution.Name = inputs.ToString() + " " + bestSolution.Name;
136        Results["Current solution"].Value = bestSolution;
137        ((ItemList<IRegressionSolution>)Results["All Solutions"].Value).Add(bestSolution);
138        ((IntValue)Results["Number of variables"].Value).Value = inputs;
139        ((DataTable)Results["RMSE table"].Value).Rows["RMSE (train)"].Values.Add(bestSolution.TrainingRootMeanSquaredError);
140        ((DataTable)Results["RMSE table"].Value).Rows["RMSE (test)"].Values.Add(bestSolution.TestRootMeanSquaredError);
141
142        if (cancellationToken.IsCancellationRequested) {
143          cancellationToken.ThrowIfCancellationRequested();
144        }
145      }
146    }
147
148    private void InitResults() {
149      Results.Add(new Result("Current solution", typeof(IRegressionSolution)));
150      Results.Add(new Result("All Solutions", new ItemList<IRegressionSolution>()));
151      Results.Add(new Result("Number of variables", new IntValue(0)));
152      var rmseTable = new DataTable("RMSE table");
153      var trainingRmseRow = new DataRow("RMSE (train)");
154      var testRmseRow = new DataRow("RMSE (test)");
155      rmseTable.Rows.Add(trainingRmseRow);
156      rmseTable.Rows.Add(testRmseRow);
157      Results.Add(new Result("RMSE table", rmseTable));
158    }
159
160    private IRegressionSolution RunAlg(FixedDataAnalysisAlgorithm<IRegressionProblem> alg) {
161      using (var wh = new AutoResetEvent(false)) {
162        EventHandler<EventArgs<Exception>> setWhForException = (sender, args) => { wh.Set(); };
163        EventHandler setWh = (sender, args) => { wh.Set(); };
164        try {
165          alg.ExceptionOccurred += setWhForException;
166          alg.Stopped += setWh;
167          alg.Prepare(true);
168          alg.Start();
169
170          wh.WaitOne();
171
172          return alg.Results.Select(r => r.Value).OfType<IRegressionSolution>().FirstOrDefault();
173        } finally {
174          alg.ExceptionOccurred -= setWhForException;
175          alg.Stopped -= setWh;
176        }
177      }
178    }
179  }
180}
Note: See TracBrowser for help on using the repository browser.