Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2990_VariableImpactBasedFeatureSelection/HeuristicLab.Algorithms.DataAnalysis/3.4/FeatureSelection/VariableImpactBasedFeatureSelectionAlgorithm.cs @ 16705

Last change on this file since 16705 was 16705, checked in by pfleck, 5 years ago

#2990: Implemented first version with RF and percentage-based feature elimination.

File size: 7.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Algorithms.DataAnalysis;
27using HeuristicLab.Analysis;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Optimization;
32using HeuristicLab.Parameters;
33using HEAL.Attic;
34
35namespace HeuristicLab.Problems.DataAnalysis.FeatureSelection {
36  [Item("VariableImpactBasedFeatureSelectionAlgorithm", "")]
37  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 9999)]
38  [StorableType("EB47CA07-6F01-4FC1-9351-54ACB1F2CF24")]
39  public class VariableImpactBasedFeatureSelectionAlgorithm : BasicAlgorithm {
40
41    public override bool SupportsPause {
42      get { return false; }
43    }
44
45    #region Problem Type
46    public override Type ProblemType {
47      get { return typeof(IRegressionProblem); }
48    }
49    public new IRegressionProblem Problem {
50      get { return (IRegressionProblem)base.Problem; }
51      set { base.Problem = value; }
52    }
53    #endregion
54
55    #region Parameter Properties
56    private ValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>> AlgorithmParameter {
57      get { return (ValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>>)Parameters["Algorithm"]; }
58    }
59
60    private FixedValueParameter<PercentValue> FeaturesDropParameter {
61      get { return (FixedValueParameter<PercentValue>)Parameters["FeaturesDrop"]; }
62    }
63    #endregion
64
65    #region Results Parameter
66
67    #endregion
68
69    #region Constructor, Cloning & Persistence
70    public VariableImpactBasedFeatureSelectionAlgorithm() {
71      Parameters.Add(new ValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>>("Algorithm", new RandomForestRegression()));
72      Parameters.Add(new FixedValueParameter<PercentValue>("FeaturesDrop", new PercentValue(0.2)));
73
74      // ToDo: Use ResultParameters
75      //Parameters.Add(new ResultParameter<>());
76
77      Problem = new RegressionProblem();
78    }
79
80    [StorableConstructor]
81    protected VariableImpactBasedFeatureSelectionAlgorithm(StorableConstructorFlag _)
82      : base(_) { }
83
84    public VariableImpactBasedFeatureSelectionAlgorithm(VariableImpactBasedFeatureSelectionAlgorithm original, Cloner cloner)
85      : base(original, cloner) { }
86    public override IDeepCloneable Clone(Cloner cloner) {
87      return new VariableImpactBasedFeatureSelectionAlgorithm(this, cloner);
88    }
89    #endregion
90
91    protected override void Run(CancellationToken cancellationToken) {
92      var clonedAlgorithm = (FixedDataAnalysisAlgorithm<IRegressionProblem>)AlgorithmParameter.Value.Clone();
93      var clonedProblem = (IRegressionProblem)Problem.Clone();
94      double featureDrop = FeaturesDropParameter.Value.Value;
95
96      SetupAlgorithm(clonedAlgorithm, clonedProblem);
97      Results.Add(new Result("Algorithm", clonedAlgorithm));
98
99      var remainingFeatures = clonedProblem.ProblemData.InputVariables.CheckedItems.Select(x => x.Value.Value).ToList();
100
101      var variableImpactsDataTable = new DataTable("VariableImpacts");
102      foreach (var variable in clonedProblem.ProblemData.InputVariables) {
103        variableImpactsDataTable.Rows.Add(new DataRow(variable.Value));
104      }
105      Results.Add(new Result("VariableImpacts", variableImpactsDataTable));
106
107      var selectedVariablesResult = new StringMatrix(remainingFeatures.Count, 1) {
108        ColumnNames = new[] { "StartUp" }
109      };
110      for (int i = 0; i < remainingFeatures.Count; i++) {
111        selectedVariablesResult[i, 0] = remainingFeatures[i];
112      }
113      Results.Add(new Result("SelectedFeatures", selectedVariablesResult));
114
115      var qualitiesResult = new DataTable("SolutionQualities");
116      qualitiesResult.Rows.Add(new DataRow("MAE Training"));
117      qualitiesResult.Rows.Add(new DataRow("MAE Test"));
118      qualitiesResult.Rows.Add(new DataRow("R² Training") { VisualProperties = { SecondYAxis = true } });
119      qualitiesResult.Rows.Add(new DataRow("R² Test") { VisualProperties = { SecondYAxis = true } });
120      Results.Add(new Result("Qualities", qualitiesResult));
121
122
123      int iteration = 0;
124      while (remainingFeatures.Any()) {
125        clonedAlgorithm.Start(cancellationToken);
126
127        int numberOfRemainingVariables = (int)(remainingFeatures.Count * (1.0 - featureDrop)); // floor to avoid getting stuck
128        var variableImpacts = GetVariableImpacts(clonedAlgorithm.Results).ToDictionary(x => x.Item1, x => x.Item2);
129
130        remainingFeatures = variableImpacts
131          .OrderByDescending(x => x.Value)
132          .Take(numberOfRemainingVariables)
133          .Select(x => x.Key)
134          .ToList();
135
136        foreach (var row in variableImpactsDataTable.Rows) {
137          row.Values.Add(variableImpacts.ContainsKey(row.Name) ? variableImpacts[row.Name] : double.NaN);
138        }
139
140        ((IStringConvertibleMatrix)selectedVariablesResult).Columns++;
141        selectedVariablesResult.ColumnNames = selectedVariablesResult.ColumnNames.Select(c => c.Replace("Column", "Iteration"));
142        //selectedVariablesResult.ColumnNames = selectedVariablesResult.ColumnNames.Concat(new[] { $"Iteration {iteration}" });
143        for (int i = 0; i < remainingFeatures.Count; i++) {
144          selectedVariablesResult[i, selectedVariablesResult.Columns - 1] = remainingFeatures[i];
145        }
146
147        var solution = clonedAlgorithm.Results.Select(r => r.Value).OfType<IRegressionSolution>().Single();
148        qualitiesResult.Rows["MAE Training"].Values.Add(solution.TrainingMeanAbsoluteError);
149        qualitiesResult.Rows["MAE Test"].Values.Add(solution.TestMeanAbsoluteError);
150        qualitiesResult.Rows["R² Training"].Values.Add(solution.TrainingRSquared);
151        qualitiesResult.Rows["R² Test"].Values.Add(solution.TestRSquared);
152
153
154        UpdateSelectedInputs(clonedProblem, remainingFeatures);
155
156        iteration++;
157      }
158    }
159
160    private static void SetupAlgorithm(FixedDataAnalysisAlgorithm<IRegressionProblem> algorithm, IRegressionProblem problem) {
161      algorithm.Problem = problem;
162      algorithm.Prepare(clearRuns: true);
163    }
164
165    private static IEnumerable<Tuple<string, double>> GetVariableImpacts(ResultCollection results) {
166      //var solution = (IRegressionSolution)results["Random forest regression solution"].Value;
167      var solution = results.Select(r => r.Value).OfType<IRegressionSolution>().Single();
168      return RegressionSolutionVariableImpactsCalculator.CalculateImpacts(
169        solution,
170        replacementMethod: RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle,
171        factorReplacementMethod: RegressionSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Shuffle,
172        dataPartition: RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training);
173    }
174
175    private void UpdateSelectedInputs(IRegressionProblem problem, List<string> remainingFeatures) {
176      foreach (var inputFeature in problem.ProblemData.InputVariables) {
177        bool isRemaining = remainingFeatures.Contains(inputFeature.Value);
178        problem.ProblemData.InputVariables.SetItemCheckedState(inputFeature, isRemaining);
179      }
180    }
181  }
182}
Note: See TracBrowser for help on using the repository browser.