source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 15665

Last change on this file since 15665 was 15665, checked in by fholzing, 4 years ago

#2871: Implemented review-issues

File size: 13.7 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Problems.DataAnalysis {
35  [StorableClass]
36  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
37  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
38    public enum ReplacementMethodEnum {
39      Median,
40      Average,
41      Shuffle,
42      Noise
43    }
44    public enum FactorReplacementMethodEnum {
45      Best,
46      Mode,
47      Shuffle
48    }
49    public enum DataPartitionEnum {
50      Training,
51      Test,
52      All
53    }
54    public enum SortingCriteria {
55      ImpactValue,
56      Occurrence,
57      VariableName
58    }
59
60    private const string ReplacementParameterName = "Replacement Method";
61    private const string DataPartitionParameterName = "DataPartition";
62
63    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
64      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
65    }
66    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
67      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
68    }
69
70    public ReplacementMethodEnum ReplacementMethod {
71      get { return ReplacementParameter.Value.Value; }
72      set { ReplacementParameter.Value.Value = value; }
73    }
74    public DataPartitionEnum DataPartition {
75      get { return DataPartitionParameter.Value.Value; }
76      set { DataPartitionParameter.Value.Value = value; }
77    }
78
79
80    [StorableConstructor]
81    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
82    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
83      : base(original, cloner) { }
84    public override IDeepCloneable Clone(Cloner cloner) {
85      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
86    }
87
88    public RegressionSolutionVariableImpactsCalculator()
89      : base() {
90      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
91      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
92    }
93
94    //mkommend: annoying name clash with static method, open to better naming suggestions
95    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
96      return CalculateImpacts(solution, DataPartition, ReplacementMethod);
97    }
98
99    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
100      IRegressionSolution solution,
101      DataPartitionEnum data = DataPartitionEnum.Training,
102      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
103      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
104
105      var problemData = solution.ProblemData;
106      var dataset = problemData.Dataset;
107
108      IEnumerable<int> rows;
109      IEnumerable<double> targetValues;
110      double originalR2 = -1;
111
112      OnlineCalculatorError error;
113
114      switch (data) {
115        case DataPartitionEnum.All:
116          rows = solution.ProblemData.AllIndices;
117          targetValues = problemData.TargetVariableValues.ToList();
118          originalR2 = OnlinePearsonsRCalculator.Calculate(problemData.TargetVariableValues, solution.EstimatedValues, out error);
119          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation.");
120          originalR2 = originalR2 * originalR2;
121          break;
122        case DataPartitionEnum.Training:
123          rows = problemData.TrainingIndices;
124          targetValues = problemData.TargetVariableTrainingValues.ToList();
125          originalR2 = solution.TrainingRSquared;
126          break;
127        case DataPartitionEnum.Test:
128          rows = problemData.TestIndices;
129          targetValues = problemData.TargetVariableTestValues.ToList();
130          originalR2 = solution.TestRSquared;
131          break;
132        default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
133      }
134
135      var impacts = new Dictionary<string, double>();
136      var modifiableDataset = ((Dataset)dataset).ToModifiable();
137
138      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
139      var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
140
141      // calculate impacts for double variables
142      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
143        var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);
144        var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
145        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
146
147        newR2 = newR2 * newR2;
148        var impact = originalR2 - newR2;
149        impacts[inputVariable] = impact;
150      }
151
152      // calculate impacts for string variables
153      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
154        if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
155          // try replacing with all possible values and find the best replacement value
156          var smallestImpact = double.PositiveInfinity;
157          foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
158            var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
159              Enumerable.Repeat(repl, dataset.Rows));
160            var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
161            if (error != OnlineCalculatorError.None)
162              throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
163
164            newR2 = newR2 * newR2;
165            var impact = originalR2 - newR2;
166            if (impact < smallestImpact) smallestImpact = impact;
167          }
168          impacts[inputVariable] = smallestImpact;
169        } else {
170          // for replacement methods shuffle and mode
171          // calculate impacts for factor variables
172
173          var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
174            factorReplacementMethod);
175          var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
176          if (error != OnlineCalculatorError.None)
177            throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
178
179          newR2 = newR2 * newR2;
180          var impact = originalR2 - newR2;
181          impacts[inputVariable] = impact;
182        }
183      } // foreach
184      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
185    }
186
187    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
188      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
189      double replacementValue;
190      List<double> replacementValues;
191      IRandom rand;
192
193      switch (replacement) {
194        case ReplacementMethodEnum.Median:
195          replacementValue = rows.Select(r => originalValues[r]).Median();
196          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
197          break;
198        case ReplacementMethodEnum.Average:
199          replacementValue = rows.Select(r => originalValues[r]).Average();
200          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
201          break;
202        case ReplacementMethodEnum.Shuffle:
203          // new var has same empirical distribution but the relation to y is broken
204          rand = new FastRandom(31415);
205          // prepare a complete column for the dataset
206          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
207          // shuffle only the selected rows
208          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
209          int i = 0;
210          // update column values
211          foreach (var r in rows) {
212            replacementValues[r] = shuffledValues[i++];
213          }
214          break;
215        case ReplacementMethodEnum.Noise:
216          var avg = rows.Select(r => originalValues[r]).Average();
217          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
218          rand = new FastRandom(31415);
219          // prepare a complete column for the dataset
220          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
221          // update column values
222          foreach (var r in rows) {
223            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
224          }
225          break;
226
227        default:
228          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
229      }
230
231      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
232    }
233
234    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
235      IRegressionModel model, string variable, ModifiableDataset dataset,
236      IEnumerable<int> rows,
237      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
238      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
239      List<string> replacementValues;
240      IRandom rand;
241
242      switch (replacement) {
243        case FactorReplacementMethodEnum.Mode:
244          var mostCommonValue = rows.Select(r => originalValues[r])
245            .GroupBy(v => v)
246            .OrderByDescending(g => g.Count())
247            .First().Key;
248          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
249          break;
250        case FactorReplacementMethodEnum.Shuffle:
251          // new var has same empirical distribution but the relation to y is broken
252          rand = new FastRandom(31415);
253          // prepare a complete column for the dataset
254          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
255          // shuffle only the selected rows
256          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
257          int i = 0;
258          // update column values
259          foreach (var r in rows) {
260            replacementValues[r] = shuffledValues[i++];
261          }
262          break;
263        default:
264          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
265      }
266
267      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
268    }
269
270    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
271      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
272      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
273      dataset.ReplaceVariable(variable, replacementValues.ToList());
274      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
275      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
276      dataset.ReplaceVariable(variable, originalValues);
277
278      return estimates;
279    }
280    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
281      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
282      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
283      dataset.ReplaceVariable(variable, replacementValues.ToList());
284      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
285      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
286      dataset.ReplaceVariable(variable, originalValues);
287
288      return estimates;
289    }
290  }
291}
Note: See TracBrowser for help on using the repository browser.