source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 16034

Last change on this file since 16034 was 16034, checked in by fholzing, 23 months ago

#2904: Removed callback, adapted both view and calculator.

File size: 19.8 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections;
26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
37  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
38  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    public enum ReplacementMethodEnum {
40      Median,
41      Average,
42      Shuffle,
43      Noise
44    }
45    public enum FactorReplacementMethodEnum {
46      Best,
47      Mode,
48      Shuffle
49    }
50    public enum DataPartitionEnum {
51      Training,
52      Test,
53      All
54    }
55
56    private const string ReplacementParameterName = "Replacement Method";
57    private const string FactorReplacementParameterName = "Factor Replacement Method";
58    private const string DataPartitionParameterName = "DataPartition";
59
60    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
61      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
62    }
63    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter {
64      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
65    }
66    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
67      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
68    }
69
70    public ReplacementMethodEnum ReplacementMethod {
71      get { return ReplacementParameter.Value.Value; }
72      set { ReplacementParameter.Value.Value = value; }
73    }
74    public FactorReplacementMethodEnum FactorReplacementMethod {
75      get { return FactorReplacementParameter.Value.Value; }
76      set { FactorReplacementParameter.Value.Value = value; }
77    }
78    public DataPartitionEnum DataPartition {
79      get { return DataPartitionParameter.Value.Value; }
80      set { DataPartitionParameter.Value.Value = value; }
81    }
82
83
84    [StorableConstructor]
85    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
86    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
87      : base(original, cloner) { }
88    public override IDeepCloneable Clone(Cloner cloner) {
89      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
90    }
91
92    public RegressionSolutionVariableImpactsCalculator()
93      : base() {
94      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
95      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
96      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
97    }
98
99    //mkommend: annoying name clash with static method, open to better naming suggestions
100    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
101      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
102    }
103
104    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
105      IRegressionSolution solution,
106      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
107      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
108      DataPartitionEnum data = DataPartitionEnum.Training) {
109      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, replacementMethod, factorReplacementMethod, data);
110    }
111
112    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
113      IRegressionModel model,
114      IRegressionProblemData problemData,
115      IEnumerable<double> estimatedValues,
116      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
117      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
118      DataPartitionEnum data = DataPartitionEnum.Training) {
119      IEnumerable<int> rows;
120
121      switch (data) {
122        case DataPartitionEnum.All:
123          rows = problemData.AllIndices;
124          break;
125        case DataPartitionEnum.Test:
126          rows = problemData.TestIndices;
127          break;
128        case DataPartitionEnum.Training:
129          rows = problemData.TrainingIndices;
130          break;
131        default:
132          throw new NotSupportedException("DataPartition not supported");
133      }
134
135      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
136    }
137
138    public static double CalculateImpact(string variableName, IRegressionModel model, IRegressionProblemData problemData, IEnumerable<double> estimatedValues, DataPartitionEnum dataPartition, ReplacementMethodEnum replMethod, FactorReplacementMethodEnum factorReplMethod) {
139      double impact = 0;
140
141      IEnumerable<int> rows;
142      switch (dataPartition) {
143        case DataPartitionEnum.All:
144          rows = problemData.AllIndices;
145          break;
146        case DataPartitionEnum.Test:
147          rows = problemData.TestIndices;
148          break;
149        case DataPartitionEnum.Training:
150          rows = problemData.TrainingIndices;
151          break;
152        default:
153          throw new NotSupportedException("DataPartition not supported");
154      }
155
156      OnlineCalculatorError error;
157      IEnumerable<double> targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v));
158      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
159      var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
160      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
161
162
163      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
164
165      // calculate impacts for double variables
166      if (problemData.Dataset.VariableHasType<double>(variableName)) {
167        impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replMethod);
168      } else if (problemData.Dataset.VariableHasType<string>(variableName)) {
169        impact = CalculateImpactForFactorVariables(variableName, model, problemData.Dataset, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, factorReplMethod);
170      } else {
171        throw new NotSupportedException("Variable not supported");
172      }
173      return impact;
174    }
175
176    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
177     IRegressionModel model,
178     IRegressionProblemData problemData,
179     IEnumerable<double> estimatedValues,
180     IEnumerable<int> rows,
181     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
182     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
183      //Calculate original quality-values (via calculator, default is R²)
184      OnlineCalculatorError error;
185      IEnumerable<double> targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v));
186      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
187      var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
188      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
189
190      var impacts = new Dictionary<string, double>();
191      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
192      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
193
194      foreach (var inputVariable in allowedInputVariables) {
195        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
196      }
197
198      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
199    }
200
201    public static double CalculateImpact(string variableName,
202      IRegressionSolution solution,
203      IEnumerable<int> rows,
204      IEnumerable<double> targetValues,
205      double originalValue,
206      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
207      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
208      DataPartitionEnum data = DataPartitionEnum.Training) {
209      return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod);
210    }
211
212    public static double CalculateImpact(string variableName,
213      IRegressionModel model,
214      IDataset dataset,
215      IEnumerable<int> rows,
216      IEnumerable<double> targetValues,
217      double originalValue,
218      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
219      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
220
221      double impact = 0;
222      var modifiableDataset = ((Dataset)(dataset).Clone()).ToModifiable();
223
224      // calculate impacts for double variables
225      if (dataset.VariableHasType<double>(variableName)) {
226        impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
227      } else if (dataset.VariableHasType<string>(variableName)) {
228        impact = CalculateImpactForFactorVariables(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
229      } else {
230        throw new NotSupportedException("Variable not supported");
231      }
232      return impact;
233    }
234
235    private static double CalculateImpactForNumericalVariables(string variableName,
236      IRegressionModel model,
237      ModifiableDataset modifiableDataset,
238      IEnumerable<int> rows,
239      IEnumerable<double> targetValues,
240      double originalValue,
241      ReplacementMethodEnum replacementMethod) {
242      OnlineCalculatorError error;
243      var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod);
244      var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
245      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
246      return originalValue - newValue;
247    }
248
249    private static double CalculateImpactForFactorVariables(string variableName,
250      IRegressionModel model,
251      IDataset problemData,
252      ModifiableDataset modifiableDataset,
253      IEnumerable<int> rows,
254      IEnumerable<double> targetValues,
255      double originalValue,
256      FactorReplacementMethodEnum factorReplacementMethod) {
257
258      OnlineCalculatorError error;
259      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
260        // try replacing with all possible values and find the best replacement value
261        var smallestImpact = double.PositiveInfinity;
262        foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) {
263          var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
264          var newEstimates = GetReplacedValues(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList());
265          var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
266          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
267
268          var curImpact = originalValue - newValue;
269          if (curImpact < smallestImpact) smallestImpact = curImpact;
270        }
271        return smallestImpact;
272      } else {
273        // for replacement methods shuffle and mode
274        // calculate impacts for factor variables
275        var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod);
276        var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
277        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
278
279        return originalValue - newValue;
280      }
281    }
282
283    private static IEnumerable<double> GetReplacedValuesForNumericalVariables(
284      IRegressionModel model,
285      string variable,
286      ModifiableDataset dataset,
287      IEnumerable<int> rows,
288      ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) {
289      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
290      double replacementValue;
291      List<double> replacementValues;
292      IRandom rand;
293
294      switch (replacement) {
295        case ReplacementMethodEnum.Median:
296          replacementValue = rows.Select(r => originalValues[r]).Median();
297          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
298          break;
299        case ReplacementMethodEnum.Average:
300          replacementValue = rows.Select(r => originalValues[r]).Average();
301          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
302          break;
303        case ReplacementMethodEnum.Shuffle:
304          // new var has same empirical distribution but the relation to y is broken
305          rand = new FastRandom(31415);
306          // prepare a complete column for the dataset
307          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
308          // shuffle only the selected rows
309          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
310          int i = 0;
311          // update column values
312          foreach (var r in rows) {
313            replacementValues[r] = shuffledValues[i++];
314          }
315          break;
316        case ReplacementMethodEnum.Noise:
317          var avg = rows.Select(r => originalValues[r]).Average();
318          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
319          rand = new FastRandom(31415);
320          // prepare a complete column for the dataset
321          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
322          // update column values
323          foreach (var r in rows) {
324            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
325          }
326          break;
327
328        default:
329          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
330      }
331
332      return GetReplacedValues(originalValues, model, variable, dataset, rows, replacementValues);
333    }
334
335    private static IEnumerable<double> GetReplacedValuesForFactorVariables(
336      IRegressionModel model,
337      string variable,
338      ModifiableDataset dataset,
339      IEnumerable<int> rows,
340      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
341      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
342      List<string> replacementValues;
343      IRandom rand;
344
345      switch (replacement) {
346        case FactorReplacementMethodEnum.Mode:
347          var mostCommonValue = rows.Select(r => originalValues[r])
348            .GroupBy(v => v)
349            .OrderByDescending(g => g.Count())
350            .First().Key;
351          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
352          break;
353        case FactorReplacementMethodEnum.Shuffle:
354          // new var has same empirical distribution but the relation to y is broken
355          rand = new FastRandom(31415);
356          // prepare a complete column for the dataset
357          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
358          // shuffle only the selected rows
359          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
360          int i = 0;
361          // update column values
362          foreach (var r in rows) {
363            replacementValues[r] = shuffledValues[i++];
364          }
365          break;
366        default:
367          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
368      }
369
370      return GetReplacedValues(originalValues, model, variable, dataset, rows, replacementValues);
371    }
372
373    private static IEnumerable<double> GetReplacedValues(
374      IList originalValues,
375      IRegressionModel model,
376      string variable,
377      ModifiableDataset dataset,
378      IEnumerable<int> rows,
379      IList replacementValues) {
380      dataset.ReplaceVariable(variable, replacementValues);
381      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
382      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
383      dataset.ReplaceVariable(variable, originalValues);
384
385      return estimates;
386    }
387
388    private static double CalculateVariableImpact(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
389      IEnumerator<double> firstEnumerator = originalValues.GetEnumerator();
390      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
391      var calculator = new OnlinePearsonsRSquaredCalculator();
392
393      // always move forward both enumerators (do not use short-circuit evaluation!)
394      while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
395        double original = firstEnumerator.Current;
396        double estimated = secondEnumerator.Current;
397        calculator.Add(original, estimated);
398        if (calculator.ErrorState != OnlineCalculatorError.None) break;
399      }
400
401      // check if both enumerators are at the end to make sure both enumerations have the same length
402      if (calculator.ErrorState == OnlineCalculatorError.None &&
403           (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
404        throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
405      } else {
406        errorState = calculator.ErrorState;
407        return calculator.Value;
408      }
409    }
410  }
411}
Note: See TracBrowser for help on using the repository browser.