1 | #region License Information
|
---|
2 | /* HeuristicLab
|
---|
3 | * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
4 | *
|
---|
5 | * This file is part of HeuristicLab.
|
---|
6 | *
|
---|
7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
8 | * it under the terms of the GNU General Public License as published by
|
---|
9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
10 | * (at your option) any later version.
|
---|
11 | *
|
---|
12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
15 | * GNU General Public License for more details.
|
---|
16 | *
|
---|
17 | * You should have received a copy of the GNU General Public License
|
---|
18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
19 | */
|
---|
20 | #endregion
|
---|
21 |
|
---|
22 | using System;
|
---|
23 | using System.Collections.Generic;
|
---|
24 | using System.Linq;
|
---|
25 | using System.Threading;
|
---|
26 | using HeuristicLab.Algorithms.DataAnalysis;
|
---|
27 | using HeuristicLab.Analysis;
|
---|
28 | using HeuristicLab.Common;
|
---|
29 | using HeuristicLab.Core;
|
---|
30 | using HeuristicLab.Data;
|
---|
31 | using HeuristicLab.Optimization;
|
---|
32 | using HeuristicLab.Parameters;
|
---|
33 | using HEAL.Attic;
|
---|
34 |
|
---|
35 | namespace HeuristicLab.Problems.DataAnalysis.FeatureSelection {
|
---|
36 | [Item("VariableImpactBasedFeatureSelectionAlgorithm", "")]
|
---|
37 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 9999)]
|
---|
38 | [StorableType("EB47CA07-6F01-4FC1-9351-54ACB1F2CF24")]
|
---|
39 | public class VariableImpactBasedFeatureSelectionAlgorithm : BasicAlgorithm {
|
---|
40 |
|
---|
41 | public override bool SupportsPause {
|
---|
42 | get { return false; }
|
---|
43 | }
|
---|
44 |
|
---|
45 | #region Problem Type
|
---|
46 | public override Type ProblemType {
|
---|
47 | get { return typeof(IRegressionProblem); }
|
---|
48 | }
|
---|
49 | public new IRegressionProblem Problem {
|
---|
50 | get { return (IRegressionProblem)base.Problem; }
|
---|
51 | set { base.Problem = value; }
|
---|
52 | }
|
---|
53 | #endregion
|
---|
54 |
|
---|
55 | #region Parameter Properties
|
---|
56 | private ValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>> AlgorithmParameter {
|
---|
57 | get { return (ValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>>)Parameters["Algorithm"]; }
|
---|
58 | }
|
---|
59 |
|
---|
60 | private FixedValueParameter<PercentValue> FeaturesDropParameter {
|
---|
61 | get { return (FixedValueParameter<PercentValue>)Parameters["FeaturesDrop"]; }
|
---|
62 | }
|
---|
63 | #endregion
|
---|
64 |
|
---|
65 | #region Results Parameter
|
---|
66 |
|
---|
67 | #endregion
|
---|
68 |
|
---|
69 | #region Constructor, Cloning & Persistence
|
---|
70 | public VariableImpactBasedFeatureSelectionAlgorithm() {
|
---|
71 | Parameters.Add(new ValueParameter<FixedDataAnalysisAlgorithm<IRegressionProblem>>("Algorithm", new RandomForestRegression()));
|
---|
72 | Parameters.Add(new FixedValueParameter<PercentValue>("FeaturesDrop", new PercentValue(0.2)));
|
---|
73 |
|
---|
74 | // ToDo: Use ResultParameters
|
---|
75 | //Parameters.Add(new ResultParameter<>());
|
---|
76 |
|
---|
77 | Problem = new RegressionProblem();
|
---|
78 | }
|
---|
79 |
|
---|
80 | [StorableConstructor]
|
---|
81 | protected VariableImpactBasedFeatureSelectionAlgorithm(StorableConstructorFlag _)
|
---|
82 | : base(_) { }
|
---|
83 |
|
---|
84 | public VariableImpactBasedFeatureSelectionAlgorithm(VariableImpactBasedFeatureSelectionAlgorithm original, Cloner cloner)
|
---|
85 | : base(original, cloner) { }
|
---|
86 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
87 | return new VariableImpactBasedFeatureSelectionAlgorithm(this, cloner);
|
---|
88 | }
|
---|
89 | #endregion
|
---|
90 |
|
---|
91 | protected override void Run(CancellationToken cancellationToken) {
|
---|
92 | var clonedAlgorithm = (FixedDataAnalysisAlgorithm<IRegressionProblem>)AlgorithmParameter.Value.Clone();
|
---|
93 | var clonedProblem = (IRegressionProblem)Problem.Clone();
|
---|
94 | double featureDrop = FeaturesDropParameter.Value.Value;
|
---|
95 |
|
---|
96 | SetupAlgorithm(clonedAlgorithm, clonedProblem);
|
---|
97 | Results.Add(new Result("Algorithm", clonedAlgorithm));
|
---|
98 |
|
---|
99 | var remainingFeatures = clonedProblem.ProblemData.InputVariables.CheckedItems.Select(x => x.Value.Value).ToList();
|
---|
100 |
|
---|
101 | var variableImpactsDataTable = new DataTable("VariableImpacts");
|
---|
102 | foreach (var variable in clonedProblem.ProblemData.InputVariables) {
|
---|
103 | variableImpactsDataTable.Rows.Add(new DataRow(variable.Value));
|
---|
104 | }
|
---|
105 | Results.Add(new Result("VariableImpacts", variableImpactsDataTable));
|
---|
106 |
|
---|
107 | var selectedVariablesResult = new StringMatrix(remainingFeatures.Count, 1) {
|
---|
108 | ColumnNames = new[] { "StartUp" }
|
---|
109 | };
|
---|
110 | for (int i = 0; i < remainingFeatures.Count; i++) {
|
---|
111 | selectedVariablesResult[i, 0] = remainingFeatures[i];
|
---|
112 | }
|
---|
113 | Results.Add(new Result("SelectedFeatures", selectedVariablesResult));
|
---|
114 |
|
---|
115 | var qualitiesResult = new DataTable("SolutionQualities");
|
---|
116 | qualitiesResult.Rows.Add(new DataRow("MAE Training"));
|
---|
117 | qualitiesResult.Rows.Add(new DataRow("MAE Test"));
|
---|
118 | qualitiesResult.Rows.Add(new DataRow("R² Training") { VisualProperties = { SecondYAxis = true } });
|
---|
119 | qualitiesResult.Rows.Add(new DataRow("R² Test") { VisualProperties = { SecondYAxis = true } });
|
---|
120 | Results.Add(new Result("Qualities", qualitiesResult));
|
---|
121 |
|
---|
122 |
|
---|
123 | int iteration = 0;
|
---|
124 | while (remainingFeatures.Any()) {
|
---|
125 | clonedAlgorithm.Start(cancellationToken);
|
---|
126 |
|
---|
127 | int numberOfRemainingVariables = (int)(remainingFeatures.Count * (1.0 - featureDrop)); // floor to avoid getting stuck
|
---|
128 | var variableImpacts = GetVariableImpacts(clonedAlgorithm.Results).ToDictionary(x => x.Item1, x => x.Item2);
|
---|
129 |
|
---|
130 | remainingFeatures = variableImpacts
|
---|
131 | .OrderByDescending(x => x.Value)
|
---|
132 | .Take(numberOfRemainingVariables)
|
---|
133 | .Select(x => x.Key)
|
---|
134 | .ToList();
|
---|
135 |
|
---|
136 | foreach (var row in variableImpactsDataTable.Rows) {
|
---|
137 | row.Values.Add(variableImpacts.ContainsKey(row.Name) ? variableImpacts[row.Name] : double.NaN);
|
---|
138 | }
|
---|
139 |
|
---|
140 | ((IStringConvertibleMatrix)selectedVariablesResult).Columns++;
|
---|
141 | selectedVariablesResult.ColumnNames = selectedVariablesResult.ColumnNames.Select(c => c.Replace("Column", "Iteration"));
|
---|
142 | //selectedVariablesResult.ColumnNames = selectedVariablesResult.ColumnNames.Concat(new[] { $"Iteration {iteration}" });
|
---|
143 | for (int i = 0; i < remainingFeatures.Count; i++) {
|
---|
144 | selectedVariablesResult[i, selectedVariablesResult.Columns - 1] = remainingFeatures[i];
|
---|
145 | }
|
---|
146 |
|
---|
147 | var solution = clonedAlgorithm.Results.Select(r => r.Value).OfType<IRegressionSolution>().Single();
|
---|
148 | qualitiesResult.Rows["MAE Training"].Values.Add(solution.TrainingMeanAbsoluteError);
|
---|
149 | qualitiesResult.Rows["MAE Test"].Values.Add(solution.TestMeanAbsoluteError);
|
---|
150 | qualitiesResult.Rows["R² Training"].Values.Add(solution.TrainingRSquared);
|
---|
151 | qualitiesResult.Rows["R² Test"].Values.Add(solution.TestRSquared);
|
---|
152 |
|
---|
153 |
|
---|
154 | UpdateSelectedInputs(clonedProblem, remainingFeatures);
|
---|
155 |
|
---|
156 | iteration++;
|
---|
157 | }
|
---|
158 | }
|
---|
159 |
|
---|
160 | private static void SetupAlgorithm(FixedDataAnalysisAlgorithm<IRegressionProblem> algorithm, IRegressionProblem problem) {
|
---|
161 | algorithm.Problem = problem;
|
---|
162 | algorithm.Prepare(clearRuns: true);
|
---|
163 | }
|
---|
164 |
|
---|
165 | private static IEnumerable<Tuple<string, double>> GetVariableImpacts(ResultCollection results) {
|
---|
166 | //var solution = (IRegressionSolution)results["Random forest regression solution"].Value;
|
---|
167 | var solution = results.Select(r => r.Value).OfType<IRegressionSolution>().Single();
|
---|
168 | return RegressionSolutionVariableImpactsCalculator.CalculateImpacts(
|
---|
169 | solution,
|
---|
170 | replacementMethod: RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle,
|
---|
171 | factorReplacementMethod: RegressionSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Shuffle,
|
---|
172 | dataPartition: RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training);
|
---|
173 | }
|
---|
174 |
|
---|
175 | private void UpdateSelectedInputs(IRegressionProblem problem, List<string> remainingFeatures) {
|
---|
176 | foreach (var inputFeature in problem.ProblemData.InputVariables) {
|
---|
177 | bool isRemaining = remainingFeatures.Contains(inputFeature.Value);
|
---|
178 | problem.ProblemData.InputVariables.SetItemCheckedState(inputFeature, isRemaining);
|
---|
179 | }
|
---|
180 | }
|
---|
181 | }
|
---|
182 | } |
---|