Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 16031

Last change on this file since 16031 was 16031, checked in by fholzing, 6 years ago

#2904: Changed formatting (adhering to the HL-standard) and renamed variables/methods for better comprehensibility

File size: 19.2 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections;
26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
37  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
38  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    public enum ReplacementMethodEnum {
40      Median,
41      Average,
42      Shuffle,
43      Noise
44    }
45    public enum FactorReplacementMethodEnum {
46      Best,
47      Mode,
48      Shuffle
49    }
50    public enum DataPartitionEnum {
51      Training,
52      Test,
53      All
54    }
55
56    private const string ReplacementParameterName = "Replacement Method";
57    private const string FactorReplacementParameterName = "Factor Replacement Method";
58    private const string DataPartitionParameterName = "DataPartition";
59
60    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
61      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
62    }
63    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter {
64      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
65    }
66    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
67      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
68    }
69
70    public ReplacementMethodEnum ReplacementMethod {
71      get { return ReplacementParameter.Value.Value; }
72      set { ReplacementParameter.Value.Value = value; }
73    }
74    public FactorReplacementMethodEnum FactorReplacementMethod {
75      get { return FactorReplacementParameter.Value.Value; }
76      set { FactorReplacementParameter.Value.Value = value; }
77    }
78    public DataPartitionEnum DataPartition {
79      get { return DataPartitionParameter.Value.Value; }
80      set { DataPartitionParameter.Value.Value = value; }
81    }
82
83
84    [StorableConstructor]
85    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
86    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
87      : base(original, cloner) { }
88    public override IDeepCloneable Clone(Cloner cloner) {
89      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
90    }
91
92    public RegressionSolutionVariableImpactsCalculator()
93      : base() {
94      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
95      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
96      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
97    }
98
99    //mkommend: annoying name clash with static method, open to better naming suggestions
100    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
101      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
102    }
103
104    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
105      IRegressionSolution solution,
106      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
107      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
108      DataPartitionEnum data = DataPartitionEnum.Training,
109      Func<double, string, bool> progressCallback = null) {
110      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, replacementMethod, factorReplacementMethod, data, progressCallback);
111    }
112
113    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
114      IRegressionModel model,
115      IRegressionProblemData problemData,
116      IEnumerable<double> estimatedValues,
117      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
118      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
119      DataPartitionEnum data = DataPartitionEnum.Training,
120      Func<double, string, bool> progressCallback = null) {
121      IEnumerable<int> rows;
122
123      switch (data) {
124        case DataPartitionEnum.All:
125          rows = problemData.AllIndices;
126          break;
127        case DataPartitionEnum.Test:
128          rows = problemData.TestIndices;
129          break;
130        case DataPartitionEnum.Training:
131          rows = problemData.TrainingIndices;
132          break;
133        default:
134          throw new NotSupportedException("DataPartition not supported");
135      }
136
137      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod, progressCallback);
138    }
139
140    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
141     IRegressionModel model,
142     IRegressionProblemData problemData,
143     IEnumerable<double> estimatedValues,
144     IEnumerable<int> rows,
145     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
146     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
147     Func<double, string, bool> progressCallback = null) {
148      //Calculate original quality-values (via calculator, default is R²)
149      OnlineCalculatorError error;
150      IEnumerable<double> targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v));
151      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
152      var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
153      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
154
155      var impacts = new Dictionary<string, double>();
156      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
157      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
158
159      int curIdx = 0;
160      int count = allowedInputVariables.Count(v => problemData.Dataset.VariableHasType<double>(v) || problemData.Dataset.VariableHasType<string>(v));
161
162      foreach (var inputVariable in allowedInputVariables) {
163        //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
164        if (progressCallback != null) {
165          curIdx++;
166          if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
167        }
168        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
169      }
170
171      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
172    }
173
174    public static double CalculateImpact(string variableName,
175      IRegressionSolution solution,
176      IEnumerable<int> rows,
177      IEnumerable<double> targetValues,
178      double originalValue,
179      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
180      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
181      DataPartitionEnum data = DataPartitionEnum.Training) {
182      return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod);
183    }
184
185    public static double CalculateImpact(string variableName,
186      IRegressionModel model,
187      IDataset dataset,
188      IEnumerable<int> rows,
189      IEnumerable<double> targetValues,
190      double originalValue,
191      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
192      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
193
194      double impact = 0;
195      var modifiableDataset = ((Dataset)(dataset).Clone()).ToModifiable();
196
197      // calculate impacts for double variables
198      if (dataset.VariableHasType<double>(variableName)) {
199        impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
200      } else if (dataset.VariableHasType<string>(variableName)) {
201        impact = CalculateImpactForFactorVariables(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
202      } else {
203        throw new NotSupportedException("Variable not supported");
204      }
205      return impact;
206    }
207
208    private static void PrepareData(IEnumerable<int> rows,
209      IRegressionProblemData problemData,
210      IEnumerable<double> estimatedValues,
211      out IEnumerable<double> targetValues,
212      out double originalValue) {
213      OnlineCalculatorError error;
214
215      var targetVariableValueList = problemData.TargetVariableValues.ToList();
216      targetValues = rows.Select(v => targetVariableValueList.ElementAt(v));
217      var estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
218      originalValue = CalculateVariableImpact(targetValues, estimatedValuesPartition, out error);
219
220      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
221    }
222
223    private static double CalculateImpactForNumericalVariables(string variableName,
224      IRegressionModel model,
225      ModifiableDataset modifiableDataset,
226      IEnumerable<int> rows,
227      IEnumerable<double> targetValues,
228      double originalValue,
229      ReplacementMethodEnum replacementMethod) {
230      OnlineCalculatorError error;
231      var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod);
232      var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
233      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
234      return originalValue - newValue;
235    }
236
237    private static double CalculateImpactForFactorVariables(string variableName,
238      IRegressionModel model,
239      IDataset problemData,
240      ModifiableDataset modifiableDataset,
241      IEnumerable<int> rows,
242      IEnumerable<double> targetValues,
243      double originalValue,
244      FactorReplacementMethodEnum factorReplacementMethod) {
245
246      OnlineCalculatorError error;
247      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
248        // try replacing with all possible values and find the best replacement value
249        var smallestImpact = double.PositiveInfinity;
250        foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) {
251          var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
252          var newEstimates = GetReplacedValues(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList());
253          var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
254          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
255
256          var curImpact = originalValue - newValue;
257          if (curImpact < smallestImpact) smallestImpact = curImpact;
258        }
259        return smallestImpact;
260      } else {
261        // for replacement methods shuffle and mode
262        // calculate impacts for factor variables
263        var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod);
264        var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
265        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
266
267        return originalValue - newValue;
268      }
269    }
270
271    private static IEnumerable<double> GetReplacedValuesForNumericalVariables(
272      IRegressionModel model,
273      string variable,
274      ModifiableDataset dataset,
275      IEnumerable<int> rows,
276      ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) {
277      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
278      double replacementValue;
279      List<double> replacementValues;
280      IRandom rand;
281
282      switch (replacement) {
283        case ReplacementMethodEnum.Median:
284          replacementValue = rows.Select(r => originalValues[r]).Median();
285          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
286          break;
287        case ReplacementMethodEnum.Average:
288          replacementValue = rows.Select(r => originalValues[r]).Average();
289          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
290          break;
291        case ReplacementMethodEnum.Shuffle:
292          // new var has same empirical distribution but the relation to y is broken
293          rand = new FastRandom(31415);
294          // prepare a complete column for the dataset
295          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
296          // shuffle only the selected rows
297          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
298          int i = 0;
299          // update column values
300          foreach (var r in rows) {
301            replacementValues[r] = shuffledValues[i++];
302          }
303          break;
304        case ReplacementMethodEnum.Noise:
305          var avg = rows.Select(r => originalValues[r]).Average();
306          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
307          rand = new FastRandom(31415);
308          // prepare a complete column for the dataset
309          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
310          // update column values
311          foreach (var r in rows) {
312            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
313          }
314          break;
315
316        default:
317          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
318      }
319
320      return GetReplacedValues(originalValues, model, variable, dataset, rows, replacementValues);
321    }
322
323    private static IEnumerable<double> GetReplacedValuesForFactorVariables(
324      IRegressionModel model,
325      string variable,
326      ModifiableDataset dataset,
327      IEnumerable<int> rows,
328      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
329      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
330      List<string> replacementValues;
331      IRandom rand;
332
333      switch (replacement) {
334        case FactorReplacementMethodEnum.Mode:
335          var mostCommonValue = rows.Select(r => originalValues[r])
336            .GroupBy(v => v)
337            .OrderByDescending(g => g.Count())
338            .First().Key;
339          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
340          break;
341        case FactorReplacementMethodEnum.Shuffle:
342          // new var has same empirical distribution but the relation to y is broken
343          rand = new FastRandom(31415);
344          // prepare a complete column for the dataset
345          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
346          // shuffle only the selected rows
347          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
348          int i = 0;
349          // update column values
350          foreach (var r in rows) {
351            replacementValues[r] = shuffledValues[i++];
352          }
353          break;
354        default:
355          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
356      }
357
358      return GetReplacedValues(originalValues, model, variable, dataset, rows, replacementValues);
359    }
360
361    private static IEnumerable<double> GetReplacedValues(
362      IList originalValues,
363      IRegressionModel model,
364      string variable,
365      ModifiableDataset dataset,
366      IEnumerable<int> rows,
367      IList replacementValues) {
368      dataset.ReplaceVariable(variable, replacementValues);
369      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
370      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
371      dataset.ReplaceVariable(variable, originalValues);
372
373      return estimates;
374    }
375
376    private static double CalculateVariableImpact(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
377      IEnumerator<double> firstEnumerator = originalValues.GetEnumerator();
378      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
379      var calculator = new OnlinePearsonsRSquaredCalculator();
380
381      // always move forward both enumerators (do not use short-circuit evaluation!)
382      while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
383        double original = firstEnumerator.Current;
384        double estimated = secondEnumerator.Current;
385        calculator.Add(original, estimated);
386        if (calculator.ErrorState != OnlineCalculatorError.None) break;
387      }
388
389      // check if both enumerators are at the end to make sure both enumerations have the same length
390      if (calculator.ErrorState == OnlineCalculatorError.None &&
391           (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
392        throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
393      } else {
394        errorState = calculator.ErrorState;
395        return calculator.Value;
396      }
397    }
398  }
399}
Note: See TracBrowser for help on using the repository browser.