Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 15816

Last change on this file since 15816 was 15816, checked in by fholzing, 6 years ago

#2904: Cleaned up the method signature (some optional parameters aren't so optional)

File size: 15.9 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Problems.DataAnalysis {
35  [StorableClass]
36  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
37  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
38    public enum ReplacementMethodEnum {
39      Median,
40      Average,
41      Shuffle,
42      Noise
43    }
44    public enum FactorReplacementMethodEnum {
45      Best,
46      Mode,
47      Shuffle
48    }
49    public enum DataPartitionEnum {
50      Training,
51      Test,
52      All
53    }
54
55    private const string ReplacementParameterName = "Replacement Method";
56    private const string DataPartitionParameterName = "DataPartition";
57
58    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter
59    {
60      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
61    }
62    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter
63    {
64      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
65    }
66
67    public ReplacementMethodEnum ReplacementMethod
68    {
69      get { return ReplacementParameter.Value.Value; }
70      set { ReplacementParameter.Value.Value = value; }
71    }
72    public DataPartitionEnum DataPartition
73    {
74      get { return DataPartitionParameter.Value.Value; }
75      set { DataPartitionParameter.Value.Value = value; }
76    }
77
78
79    [StorableConstructor]
80    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
81    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
82      : base(original, cloner) { }
83    public override IDeepCloneable Clone(Cloner cloner) {
84      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
85    }
86
87    public RegressionSolutionVariableImpactsCalculator()
88      : base() {
89      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
90      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
91    }
92
93    //mkommend: annoying name clash with static method, open to better naming suggestions
94    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
95      return CalculateImpacts(solution, DataPartition, ReplacementMethod);
96    }
97
98    private static void PrepareData(DataPartitionEnum partition,
99      IRegressionSolution solution,
100      out IEnumerable<int> rows,
101      out IEnumerable<double> targetValues,
102      out double originalR2) {
103      OnlineCalculatorError error;
104
105      switch (partition) {
106        case DataPartitionEnum.All:
107          rows = solution.ProblemData.AllIndices;
108          targetValues = solution.ProblemData.TargetVariableValues.ToList();
109          originalR2 = OnlinePearsonsRCalculator.Calculate(solution.ProblemData.TargetVariableValues, solution.EstimatedValues, out error);
110          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation.");
111          originalR2 = originalR2 * originalR2;
112          break;
113        case DataPartitionEnum.Training:
114          rows = solution.ProblemData.TrainingIndices;
115          targetValues = solution.ProblemData.TargetVariableTrainingValues.ToList();
116          originalR2 = solution.TrainingRSquared;
117          break;
118        case DataPartitionEnum.Test:
119          rows = solution.ProblemData.TestIndices;
120          targetValues = solution.ProblemData.TargetVariableTestValues.ToList();
121          originalR2 = solution.TestRSquared;
122          break;
123        default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", partition));
124      }
125    }
126
127    private static double CalculateImpactForDouble(string variableName,
128      IRegressionSolution solution,
129      ModifiableDataset modifiableDataset,
130      IEnumerable<int> rows,
131      IEnumerable<double> targetValues,
132      double originalR2,
133      ReplacementMethodEnum replacementMethod) {
134      OnlineCalculatorError error;
135      var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, replacementMethod);
136      var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
137      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during R² calculation with replaced inputs."); }
138      return originalR2 - (newR2 * newR2);
139    }
140
141    private static double CalculateImpactForString(string variableName,
142      IRegressionSolution solution,
143      ModifiableDataset modifiableDataset,
144      IEnumerable<int> rows,
145      IEnumerable<double> targetValues,
146      double originalR2,
147      FactorReplacementMethodEnum factorReplacementMethod) {
148
149      OnlineCalculatorError error;
150      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
151        // try replacing with all possible values and find the best replacement value
152        var smallestImpact = double.PositiveInfinity;
153        foreach (var repl in solution.ProblemData.Dataset.GetStringValues(variableName, rows).Distinct()) {
154          var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, solution.ProblemData.Dataset.Rows));
155          var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
156          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
157
158          var curImpact = originalR2 - (newR2 * newR2);
159          if (curImpact < smallestImpact) smallestImpact = curImpact;
160        }
161        return smallestImpact;
162      } else {
163        // for replacement methods shuffle and mode
164        // calculate impacts for factor variables
165        var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, factorReplacementMethod);
166        var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
167        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
168
169        return originalR2 - (newR2 * newR2);
170      }
171    }
172    public static double CalculateImpact(string variableName,
173      IRegressionSolution solution,
174      IEnumerable<int> rows,
175      IEnumerable<double> targetValues,
176      double originalR2,
177      DataPartitionEnum data = DataPartitionEnum.Training,
178      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
179      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
180
181      double impact = 0;
182      var modifiableDataset = ((Dataset)solution.ProblemData.Dataset).ToModifiable();
183
184      // calculate impacts for double variables
185      if (solution.ProblemData.Dataset.VariableHasType<double>(variableName)) {
186        impact = CalculateImpactForDouble(variableName, solution, modifiableDataset, rows, targetValues, originalR2, replacementMethod);
187      } else if (solution.ProblemData.Dataset.VariableHasType<string>(variableName)) {
188        impact = CalculateImpactForString(variableName, solution, modifiableDataset, rows, targetValues, originalR2, factorReplacementMethod);
189      } else {
190        throw new NotSupportedException("Variable not supported");
191      }
192      return impact;
193    }
194
195    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
196      IRegressionSolution solution,
197      DataPartitionEnum data = DataPartitionEnum.Training,
198      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
199      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
200      Func<double, string, bool> progressCallback = null) {
201
202      IEnumerable<int> rows;
203      IEnumerable<double> targetValues;
204      double originalR2 = -1;
205
206      PrepareData(data, solution, out rows, out targetValues, out originalR2);
207
208      var impacts = new Dictionary<string, double>();
209      var inputvariables = new HashSet<string>(solution.ProblemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
210      var allowedInputVariables = solution.ProblemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
211
212      int curIdx = 0;
213      int count = allowedInputVariables.Where(solution.ProblemData.Dataset.VariableHasType<double>).Count();
214      // calculate impacts for double variables
215      foreach (var inputVariable in allowedInputVariables) {
216        //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
217        if (progressCallback != null) {
218          curIdx++;
219          if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
220        }
221        impacts[inputVariable] = CalculateImpact(inputVariable, solution, rows, targetValues, originalR2, data, replacementMethod, factorReplacementMethod);
222      }
223
224      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
225    }
226
227    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
228      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
229      double replacementValue;
230      List<double> replacementValues;
231      IRandom rand;
232
233      switch (replacement) {
234        case ReplacementMethodEnum.Median:
235          replacementValue = rows.Select(r => originalValues[r]).Median();
236          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
237          break;
238        case ReplacementMethodEnum.Average:
239          replacementValue = rows.Select(r => originalValues[r]).Average();
240          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
241          break;
242        case ReplacementMethodEnum.Shuffle:
243          // new var has same empirical distribution but the relation to y is broken
244          rand = new FastRandom(31415);
245          // prepare a complete column for the dataset
246          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
247          // shuffle only the selected rows
248          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
249          int i = 0;
250          // update column values
251          foreach (var r in rows) {
252            replacementValues[r] = shuffledValues[i++];
253          }
254          break;
255        case ReplacementMethodEnum.Noise:
256          var avg = rows.Select(r => originalValues[r]).Average();
257          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
258          rand = new FastRandom(31415);
259          // prepare a complete column for the dataset
260          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
261          // update column values
262          foreach (var r in rows) {
263            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
264          }
265          break;
266
267        default:
268          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
269      }
270
271      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
272    }
273
274    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
275      IRegressionModel model, string variable, ModifiableDataset dataset,
276      IEnumerable<int> rows,
277      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
278      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
279      List<string> replacementValues;
280      IRandom rand;
281
282      switch (replacement) {
283        case FactorReplacementMethodEnum.Mode:
284          var mostCommonValue = rows.Select(r => originalValues[r])
285            .GroupBy(v => v)
286            .OrderByDescending(g => g.Count())
287            .First().Key;
288          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
289          break;
290        case FactorReplacementMethodEnum.Shuffle:
291          // new var has same empirical distribution but the relation to y is broken
292          rand = new FastRandom(31415);
293          // prepare a complete column for the dataset
294          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
295          // shuffle only the selected rows
296          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
297          int i = 0;
298          // update column values
299          foreach (var r in rows) {
300            replacementValues[r] = shuffledValues[i++];
301          }
302          break;
303        default:
304          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
305      }
306
307      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
308    }
309
310    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
311      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
312      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
313      dataset.ReplaceVariable(variable, replacementValues.ToList());
314      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
315      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
316      dataset.ReplaceVariable(variable, originalValues);
317
318      return estimates;
319    }
320    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
321      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
322      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
323      dataset.ReplaceVariable(variable, replacementValues.ToList());
324      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
325      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
326      dataset.ReplaceVariable(variable, originalValues);
327
328      return estimates;
329    }
330  }
331}
Note: See TracBrowser for help on using the repository browser.