source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 16035

Last change on this file since 16035 was 16035, checked in by fholzing, 22 months ago

#2904: Better method-ordering, variable-naming and cleaned up some code not necessary anymore.

File size: 17.6 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections;
26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
37  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
38  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    #region Parameters/Properties
40    public enum ReplacementMethodEnum {
41      Median,
42      Average,
43      Shuffle,
44      Noise
45    }
46    public enum FactorReplacementMethodEnum {
47      Best,
48      Mode,
49      Shuffle
50    }
51    public enum DataPartitionEnum {
52      Training,
53      Test,
54      All
55    }
56
57    private const string ReplacementParameterName = "Replacement Method";
58    private const string FactorReplacementParameterName = "Factor Replacement Method";
59    private const string DataPartitionParameterName = "DataPartition";
60
61    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
62      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
63    }
64    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter {
65      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
66    }
67    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
68      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
69    }
70
71    public ReplacementMethodEnum ReplacementMethod {
72      get { return ReplacementParameter.Value.Value; }
73      set { ReplacementParameter.Value.Value = value; }
74    }
75    public FactorReplacementMethodEnum FactorReplacementMethod {
76      get { return FactorReplacementParameter.Value.Value; }
77      set { FactorReplacementParameter.Value.Value = value; }
78    }
79    public DataPartitionEnum DataPartition {
80      get { return DataPartitionParameter.Value.Value; }
81      set { DataPartitionParameter.Value.Value = value; }
82    }
83    #endregion
84
85    #region Ctor/Cloner
86    [StorableConstructor]
87    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
88    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
89      : base(original, cloner) { }
90    public RegressionSolutionVariableImpactsCalculator()
91      : base() {
92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
93      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
94      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
95    }
96
97    public override IDeepCloneable Clone(Cloner cloner) {
98      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
99    }
100    #endregion
101
102    #region Public Methods/Wrappers
103    //mkommend: annoying name clash with static method, open to better naming suggestions
104    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
105      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
106    }
107
108    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
109      IRegressionSolution solution,
110      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
111      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
112      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
113      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, replacementMethod, factorReplacementMethod, dataPartition);
114    }
115
116    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
117      IRegressionModel model,
118      IRegressionProblemData problemData,
119      IEnumerable<double> estimatedValues,
120      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
121      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
122      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
123      IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData);
124      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
125    }
126
127    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
128     IRegressionModel model,
129     IRegressionProblemData problemData,
130     IEnumerable<double> estimatedValues,
131     IEnumerable<int> rows,
132     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
133     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
134      //Calculate original quality-values (via calculator, default is R²)
135      OnlineCalculatorError error;
136      IEnumerable<double> targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v));
137      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
138      var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
139      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
140
141      var impacts = new Dictionary<string, double>();
142      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
143      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
144      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
145
146      foreach (var inputVariable in allowedInputVariables) {
147        impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
148      }
149
150      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
151    }
152
153    public static double CalculateImpact(string variableName,
154      IRegressionModel model,
155      ModifiableDataset modifiableDataset,
156      IEnumerable<int> rows,
157      IEnumerable<double> targetValues,
158      double originalValue,
159      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
160      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
161
162      double impact = 0;
163
164      // calculate impacts for double variables
165      if (modifiableDataset.VariableHasType<double>(variableName)) {
166        impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
167      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
168        impact = CalculateImpactForFactorVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
169      } else {
170        throw new NotSupportedException("Variable not supported");
171      }
172      return impact;
173    }
174    #endregion
175
176    private static double CalculateImpactForNumericalVariables(string variableName,
177      IRegressionModel model,
178      ModifiableDataset modifiableDataset,
179      IEnumerable<int> rows,
180      IEnumerable<double> targetValues,
181      double originalValue,
182      ReplacementMethodEnum replacementMethod) {
183      OnlineCalculatorError error;
184      var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod);
185      var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
186      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
187      return originalValue - newValue;
188    }
189
190    private static double CalculateImpactForFactorVariables(string variableName,
191      IRegressionModel model,
192      ModifiableDataset modifiableDataset,
193      IEnumerable<int> rows,
194      IEnumerable<double> targetValues,
195      double originalValue,
196      FactorReplacementMethodEnum factorReplacementMethod) {
197
198      OnlineCalculatorError error;
199      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
200        // try replacing with all possible values and find the best replacement value
201        var smallestImpact = double.PositiveInfinity;
202        foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
203          var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
204          var newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
205          var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
206          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
207
208          var curImpact = originalValue - newValue;
209          if (curImpact < smallestImpact) smallestImpact = curImpact;
210        }
211        return smallestImpact;
212      } else {
213        // for replacement methods shuffle and mode
214        // calculate impacts for factor variables
215        var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod);
216        var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
217        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
218
219        return originalValue - newValue;
220      }
221    }
222
223    private static IEnumerable<double> GetReplacedValuesForNumericalVariables(
224      IRegressionModel model,
225      string variable,
226      ModifiableDataset dataset,
227      IEnumerable<int> rows,
228      ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) {
229      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
230      double replacementValue;
231      List<double> replacementValues;
232      IRandom rand;
233
234      switch (replacement) {
235        case ReplacementMethodEnum.Median:
236          replacementValue = rows.Select(r => originalValues[r]).Median();
237          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
238          break;
239        case ReplacementMethodEnum.Average:
240          replacementValue = rows.Select(r => originalValues[r]).Average();
241          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
242          break;
243        case ReplacementMethodEnum.Shuffle:
244          // new var has same empirical distribution but the relation to y is broken
245          rand = new FastRandom(31415);
246          // prepare a complete column for the dataset
247          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
248          // shuffle only the selected rows
249          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
250          int i = 0;
251          // update column values
252          foreach (var r in rows) {
253            replacementValues[r] = shuffledValues[i++];
254          }
255          break;
256        case ReplacementMethodEnum.Noise:
257          var avg = rows.Select(r => originalValues[r]).Average();
258          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
259          rand = new FastRandom(31415);
260          // prepare a complete column for the dataset
261          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
262          // update column values
263          foreach (var r in rows) {
264            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
265          }
266          break;
267
268        default:
269          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
270      }
271
272      return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues);
273    }
274
275    private static IEnumerable<double> GetReplacedValuesForFactorVariables(
276      IRegressionModel model,
277      string variable,
278      ModifiableDataset dataset,
279      IEnumerable<int> rows,
280      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
281      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
282      List<string> replacementValues;
283      IRandom rand;
284
285      switch (replacement) {
286        case FactorReplacementMethodEnum.Mode:
287          var mostCommonValue = rows.Select(r => originalValues[r])
288            .GroupBy(v => v)
289            .OrderByDescending(g => g.Count())
290            .First().Key;
291          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
292          break;
293        case FactorReplacementMethodEnum.Shuffle:
294          // new var has same empirical distribution but the relation to y is broken
295          rand = new FastRandom(31415);
296          // prepare a complete column for the dataset
297          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
298          // shuffle only the selected rows
299          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
300          int i = 0;
301          // update column values
302          foreach (var r in rows) {
303            replacementValues[r] = shuffledValues[i++];
304          }
305          break;
306        default:
307          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
308      }
309
310      return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues);
311    }
312
313    private static IEnumerable<double> GetReplacedEstimates(
314      IList originalValues,
315      IRegressionModel model,
316      string variable,
317      ModifiableDataset dataset,
318      IEnumerable<int> rows,
319      IList replacementValues) {
320      dataset.ReplaceVariable(variable, replacementValues);
321      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
322      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
323      dataset.ReplaceVariable(variable, originalValues);
324
325      return estimates;
326    }
327
328    public static double CalculateVariableImpact(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
329      IEnumerator<double> firstEnumerator = originalValues.GetEnumerator();
330      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
331      var calculator = new OnlinePearsonsRSquaredCalculator();
332
333      // always move forward both enumerators (do not use short-circuit evaluation!)
334      while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
335        double original = firstEnumerator.Current;
336        double estimated = secondEnumerator.Current;
337        calculator.Add(original, estimated);
338        if (calculator.ErrorState != OnlineCalculatorError.None) break;
339      }
340
341      // check if both enumerators are at the end to make sure both enumerations have the same length
342      if (calculator.ErrorState == OnlineCalculatorError.None &&
343           (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
344        throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
345      } else {
346        errorState = calculator.ErrorState;
347        return calculator.Value;
348      }
349    }
350
351    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IRegressionProblemData problemData) {
352      IEnumerable<int> rows;
353
354      switch (dataPartition) {
355        case DataPartitionEnum.All:
356          rows = problemData.AllIndices;
357          break;
358        case DataPartitionEnum.Test:
359          rows = problemData.TestIndices;
360          break;
361        case DataPartitionEnum.Training:
362          rows = problemData.TrainingIndices;
363          break;
364        default:
365          throw new NotSupportedException("DataPartition not supported");
366      }
367
368      return rows;
369    }
370  }
371}
Note: See TracBrowser for help on using the repository browser.