Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 16020

Last change on this file since 16020 was 16020, checked in by fholzing, 6 years ago

#2904: Removed static calculator-variable, Changed efault ReplacementMethod from Median to Shuffle, Adapted Calculation-Method adhering to the OnlineCalculators, Re-Added the condition for counting the input-parameters

File size: 18.8 KB
RevLine 
[13766]1#region License Information
2
3/* HeuristicLab
[15583]4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[13766]5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
[15831]25using System.Collections;
[13766]26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[13986]33using HeuristicLab.Random;
[13766]34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
[13985]37  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
[13766]38  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    public enum ReplacementMethodEnum {
40      Median,
[13986]41      Average,
42      Shuffle,
43      Noise
[13766]44    }
[14826]45    public enum FactorReplacementMethodEnum {
46      Best,
47      Mode,
48      Shuffle
49    }
[13766]50    public enum DataPartitionEnum {
51      Training,
52      Test,
53      All
54    }
[15796]55
[13766]56    private const string ReplacementParameterName = "Replacement Method";
[16017]57    private const string FactorReplacementParameterName = "Factor Replacement Method";
[13766]58    private const string DataPartitionParameterName = "DataPartition";
59
[15815]60    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter
61    {
[13766]62      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
63    }
[16017]64    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter
65    {
66      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
67    }
[15815]68    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter
69    {
[13766]70      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
71    }
72
[15815]73    public ReplacementMethodEnum ReplacementMethod
74    {
[13766]75      get { return ReplacementParameter.Value.Value; }
76      set { ReplacementParameter.Value.Value = value; }
77    }
[16017]78    public FactorReplacementMethodEnum FactorReplacementMethod
79    {
80      get { return FactorReplacementParameter.Value.Value; }
81      set { FactorReplacementParameter.Value.Value = value; }
82    }
[15815]83    public DataPartitionEnum DataPartition
84    {
[13766]85      get { return DataPartitionParameter.Value.Value; }
86      set { DataPartitionParameter.Value.Value = value; }
87    }
88
89
90    [StorableConstructor]
91    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
92    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
93      : base(original, cloner) { }
94    public override IDeepCloneable Clone(Cloner cloner) {
95      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
96    }
97
98    public RegressionSolutionVariableImpactsCalculator()
99      : base() {
100      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
[16017]101      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
[13985]102      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
[13766]103    }
104
105    //mkommend: annoying name clash with static method, open to better naming suggestions
106    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
[16018]107      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
[13766]108    }
109
[15831]110    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
[14826]111      IRegressionSolution solution,
[16020]112      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
[15831]113      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
[16018]114      DataPartitionEnum data = DataPartitionEnum.Training,
[15831]115      Func<double, string, bool> progressCallback = null) {
[16018]116      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, replacementMethod, factorReplacementMethod, data, progressCallback);
[15831]117    }
[13766]118
[15831]119    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
120      IRegressionModel model,
121      IRegressionProblemData problemData,
122      IEnumerable<double> estimatedValues,
[16020]123      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
[15831]124      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
[16018]125      DataPartitionEnum data = DataPartitionEnum.Training,
[16001]126      Func<double, string, bool> progressCallback = null) {
[15831]127      IEnumerable<int> rows;
128
129      switch (data) {
[13766]130        case DataPartitionEnum.All:
[15831]131          rows = problemData.AllIndices;
[13766]132          break;
[15831]133        case DataPartitionEnum.Test:
134          rows = problemData.TestIndices;
135          break;
[13766]136        case DataPartitionEnum.Training:
[15831]137          rows = problemData.TrainingIndices;
[13766]138          break;
[15831]139        default:
140          throw new NotSupportedException("DataPartition not supported");
[13766]141      }
142
[16001]143      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod, progressCallback);
[15815]144    }
[13766]145
[15831]146    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
147     IRegressionModel model,
148     IRegressionProblemData problemData,
149     IEnumerable<double> estimatedValues,
150     IEnumerable<int> rows,
[16020]151     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
[15831]152     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
153     Func<double, string, bool> progressCallback = null) {
[14463]154
[15831]155      IEnumerable<double> targetValues;
[16020]156      double originalCalculatorValue = -1;
[15815]157
[16020]158      PrepareData(rows, problemData, estimatedValues, out targetValues, out originalCalculatorValue);
[15831]159
160      var impacts = new Dictionary<string, double>();
161      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
162      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
163
164      int curIdx = 0;
[16020]165      int count = allowedInputVariables.Count(v => problemData.Dataset.VariableHasType<double>(v) || problemData.Dataset.VariableHasType<string>(v));
[15831]166
167      foreach (var inputVariable in allowedInputVariables) {
168        //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
169        if (progressCallback != null) {
170          curIdx++;
171          if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
[15796]172        }
[16020]173        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValues, originalCalculatorValue, replacementMethod, factorReplacementMethod);
[15831]174      }
[13766]175
[15831]176      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
[15815]177    }
[15831]178
[15815]179    public static double CalculateImpact(string variableName,
180      IRegressionSolution solution,
[15816]181      IEnumerable<int> rows,
182      IEnumerable<double> targetValues,
[15831]183      double originalValue,
[16020]184      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
[16018]185      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
186      DataPartitionEnum data = DataPartitionEnum.Training) {
[16001]187      return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod);
[15831]188    }
[14826]189
[15831]190    public static double CalculateImpact(string variableName,
191      IRegressionModel model,
192      IDataset dataset,
193      IEnumerable<int> rows,
194      IEnumerable<double> targetValues,
195      double originalValue,
[16020]196      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
[15831]197      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
198
[15815]199      double impact = 0;
[16016]200      var modifiableDataset = ((Dataset)(dataset).Clone()).ToModifiable();
[14826]201
[15815]202      // calculate impacts for double variables
[15831]203      if (dataset.VariableHasType<double>(variableName)) {
[16001]204        impact = CalculateImpactForDouble(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
[15831]205      } else if (dataset.VariableHasType<string>(variableName)) {
[16001]206        impact = CalculateImpactForString(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
[15815]207      } else {
208        throw new NotSupportedException("Variable not supported");
209      }
210      return impact;
211    }
[14826]212
[15831]213    private static void PrepareData(IEnumerable<int> rows,
214      IRegressionProblemData problemData,
215      IEnumerable<double> estimatedValues,
216      out IEnumerable<double> targetValues,
[16001]217      out double originalValue) {
[15831]218      OnlineCalculatorError error;
[14826]219
[15831]220      var targetVariableValueList = problemData.TargetVariableValues.ToList();
221      targetValues = rows.Select(v => targetVariableValueList.ElementAt(v));
222      var estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
[16001]223      originalValue = CalculateValue(targetValues, estimatedValuesPartition, out error);
[15815]224
[15831]225      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
226    }
[15815]227
[15831]228    private static double CalculateImpactForDouble(string variableName,
229      IRegressionModel model,
230      ModifiableDataset modifiableDataset,
231      IEnumerable<int> rows,
232      IEnumerable<double> targetValues,
233      double originalValue,
[16001]234      ReplacementMethodEnum replacementMethod) {
[15831]235      OnlineCalculatorError error;
236      var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, replacementMethod);
[16001]237      var newValue = CalculateValue(targetValues, newEstimates, out error);
[15831]238      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
239      return originalValue - newValue;
240    }
[15815]241
[15831]242    private static double CalculateImpactForString(string variableName,
243      IRegressionModel model,
244      IDataset problemData,
245      ModifiableDataset modifiableDataset,
246      IEnumerable<int> rows,
247      IEnumerable<double> targetValues,
248      double originalValue,
[16001]249      FactorReplacementMethodEnum factorReplacementMethod) {
[15831]250
251      OnlineCalculatorError error;
252      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
253        // try replacing with all possible values and find the best replacement value
254        var smallestImpact = double.PositiveInfinity;
255        foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) {
256          var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
257          var newEstimates = EvaluateModelWithReplacedVariable(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList());
[16001]258          var newValue = CalculateValue(targetValues, newEstimates, out error);
[15831]259          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
260
261          var curImpact = originalValue - newValue;
262          if (curImpact < smallestImpact) smallestImpact = curImpact;
[14826]263        }
[15831]264        return smallestImpact;
265      } else {
266        // for replacement methods shuffle and mode
267        // calculate impacts for factor variables
268        var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, factorReplacementMethod);
[16001]269        var newValue = CalculateValue(targetValues, newEstimates, out error);
[15831]270        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
271
272        return originalValue - newValue;
[15815]273      }
[13766]274    }
275
[16020]276    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) {
[13766]277      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
278      double replacementValue;
[13986]279      List<double> replacementValues;
280      IRandom rand;
[13766]281
282      switch (replacement) {
283        case ReplacementMethodEnum.Median:
284          replacementValue = rows.Select(r => originalValues[r]).Median();
[13986]285          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
[13766]286          break;
287        case ReplacementMethodEnum.Average:
288          replacementValue = rows.Select(r => originalValues[r]).Average();
[13986]289          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
[13766]290          break;
[13986]291        case ReplacementMethodEnum.Shuffle:
292          // new var has same empirical distribution but the relation to y is broken
293          rand = new FastRandom(31415);
[14348]294          // prepare a complete column for the dataset
295          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
296          // shuffle only the selected rows
297          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
298          int i = 0;
299          // update column values
300          foreach (var r in rows) {
301            replacementValues[r] = shuffledValues[i++];
302          }
[13986]303          break;
304        case ReplacementMethodEnum.Noise:
305          var avg = rows.Select(r => originalValues[r]).Average();
306          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
307          rand = new FastRandom(31415);
[14348]308          // prepare a complete column for the dataset
309          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
310          // update column values
311          foreach (var r in rows) {
312            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
313          }
[13986]314          break;
315
[13766]316        default:
317          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
318      }
319
[15831]320      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
[14826]321    }
322
323    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
324      IRegressionModel model, string variable, ModifiableDataset dataset,
325      IEnumerable<int> rows,
[16017]326      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Best) {
[14826]327      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
328      List<string> replacementValues;
329      IRandom rand;
330
331      switch (replacement) {
332        case FactorReplacementMethodEnum.Mode:
333          var mostCommonValue = rows.Select(r => originalValues[r])
334            .GroupBy(v => v)
335            .OrderByDescending(g => g.Count())
336            .First().Key;
337          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
338          break;
339        case FactorReplacementMethodEnum.Shuffle:
340          // new var has same empirical distribution but the relation to y is broken
341          rand = new FastRandom(31415);
342          // prepare a complete column for the dataset
343          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
344          // shuffle only the selected rows
345          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
346          int i = 0;
347          // update column values
348          foreach (var r in rows) {
349            replacementValues[r] = shuffledValues[i++];
350          }
351          break;
352        default:
353          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
354      }
355
[15831]356      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
[14826]357    }
358
[15831]359    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IList originalValues, IRegressionModel model, string variable,
360      ModifiableDataset dataset, IEnumerable<int> rows, IList replacementValues) {
361      dataset.ReplaceVariable(variable, replacementValues);
[13766]362      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
363      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
364      dataset.ReplaceVariable(variable, originalValues);
365
366      return estimates;
367    }
[16001]368
[16020]369    private static double CalculateValue(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
370      IEnumerator<double> firstEnumerator = originalValues.GetEnumerator();
371      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
372      var calculator = new OnlinePearsonsRSquaredCalculator();
[16016]373
[16020]374      // always move forward both enumerators (do not use short-circuit evaluation!)
375      while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
376        double original = firstEnumerator.Current;
377        double estimated = secondEnumerator.Current;
378        calculator.Add(original, estimated);
379        if (calculator.ErrorState != OnlineCalculatorError.None) break;
[16016]380      }
381
[16020]382      // check if both enumerators are at the end to make sure both enumerations have the same length
383      if (calculator.ErrorState == OnlineCalculatorError.None &&
384           (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
385        throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
386      } else {
387        errorState = calculator.ErrorState;
388        return calculator.Value;
[16001]389      }
390    }
[13766]391  }
392}
Note: See TracBrowser for help on using the repository browser.