Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 16016

Last change on this file since 16016 was 16016, checked in by fholzing, 6 years ago

#2904: Removed unnecessary where-condition (.Where(...).Count()); Cloned the dataset before the .ToModifiable() call; Adapted the CalculateValue-Method for a single loop through the IEnumerables for the target and estimated values

File size: 18.0 KB
RevLine 
[13766]1#region License Information
2
3/* HeuristicLab
[15583]4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[13766]5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
[15831]25using System.Collections;
[13766]26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[13986]33using HeuristicLab.Random;
[13766]34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
[13985]37  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
[13766]38  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    public enum ReplacementMethodEnum {
40      Median,
[13986]41      Average,
42      Shuffle,
43      Noise
[13766]44    }
[14826]45    public enum FactorReplacementMethodEnum {
46      Best,
47      Mode,
48      Shuffle
49    }
[13766]50    public enum DataPartitionEnum {
51      Training,
52      Test,
53      All
54    }
[15796]55
[16001]56    //The PerasonsR²-Calculator is the default, but can be overwritten to any other IOnlineCalculator
57    //Just remember to reset it after you're done
58    private static IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator();
[16002]59    public static IOnlineCalculator Calculator
[16001]60    {
61      get { return calculator; }
62      set { calculator = value; }
63    }
64
[13766]65    private const string ReplacementParameterName = "Replacement Method";
66    private const string DataPartitionParameterName = "DataPartition";
67
[15815]68    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter
69    {
[13766]70      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
71    }
[15815]72    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter
73    {
[13766]74      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
75    }
76
[15815]77    public ReplacementMethodEnum ReplacementMethod
78    {
[13766]79      get { return ReplacementParameter.Value.Value; }
80      set { ReplacementParameter.Value.Value = value; }
81    }
[15815]82    public DataPartitionEnum DataPartition
83    {
[13766]84      get { return DataPartitionParameter.Value.Value; }
85      set { DataPartitionParameter.Value.Value = value; }
86    }
87
88
89    [StorableConstructor]
90    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
91    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
92      : base(original, cloner) { }
93    public override IDeepCloneable Clone(Cloner cloner) {
94      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
95    }
96
97    public RegressionSolutionVariableImpactsCalculator()
98      : base() {
99      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
[13985]100      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
[13766]101    }
102
103    //mkommend: annoying name clash with static method, open to better naming suggestions
104    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
105      return CalculateImpacts(solution, DataPartition, ReplacementMethod);
106    }
107
[15831]108    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
[14826]109      IRegressionSolution solution,
[15831]110      DataPartitionEnum data = DataPartitionEnum.Training,
111      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
112      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
113      Func<double, string, bool> progressCallback = null) {
114      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, data, replacementMethod, factorReplacementMethod, progressCallback);
115    }
[13766]116
[15831]117    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
118      IRegressionModel model,
119      IRegressionProblemData problemData,
120      IEnumerable<double> estimatedValues,
121      DataPartitionEnum data = DataPartitionEnum.Training,
122      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
123      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
[16001]124      Func<double, string, bool> progressCallback = null) {
[15831]125      IEnumerable<int> rows;
126
127      switch (data) {
[13766]128        case DataPartitionEnum.All:
[15831]129          rows = problemData.AllIndices;
[13766]130          break;
[15831]131        case DataPartitionEnum.Test:
132          rows = problemData.TestIndices;
133          break;
[13766]134        case DataPartitionEnum.Training:
[15831]135          rows = problemData.TrainingIndices;
[13766]136          break;
[15831]137        default:
138          throw new NotSupportedException("DataPartition not supported");
[13766]139      }
140
[16001]141      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod, progressCallback);
[15815]142    }
[13766]143
[15831]144    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
145     IRegressionModel model,
146     IRegressionProblemData problemData,
147     IEnumerable<double> estimatedValues,
148     IEnumerable<int> rows,
149     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
150     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
151     Func<double, string, bool> progressCallback = null) {
[14463]152
[15831]153      IEnumerable<double> targetValues;
154      double originalValue = -1;
[15815]155
[16001]156      PrepareData(rows, problemData, estimatedValues, out targetValues, out originalValue);
[15831]157
158      var impacts = new Dictionary<string, double>();
159      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
160      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
161
162      int curIdx = 0;
[16016]163      int count = allowedInputVariables.Count();
[15831]164
165      foreach (var inputVariable in allowedInputVariables) {
166        //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
167        if (progressCallback != null) {
168          curIdx++;
169          if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
[15796]170        }
[16001]171        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod);
[15831]172      }
[13766]173
[15831]174      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
[15815]175    }
[15831]176
[15815]177    public static double CalculateImpact(string variableName,
178      IRegressionSolution solution,
[15816]179      IEnumerable<int> rows,
180      IEnumerable<double> targetValues,
[15831]181      double originalValue,
[15815]182      DataPartitionEnum data = DataPartitionEnum.Training,
183      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
[15816]184      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
[16001]185      return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod);
[15831]186    }
[14826]187
[15831]188    public static double CalculateImpact(string variableName,
189      IRegressionModel model,
190      IDataset dataset,
191      IEnumerable<int> rows,
192      IEnumerable<double> targetValues,
193      double originalValue,
194      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
195      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
196
[15815]197      double impact = 0;
[16016]198      var modifiableDataset = ((Dataset)(dataset).Clone()).ToModifiable();
[14826]199
[15815]200      // calculate impacts for double variables
[15831]201      if (dataset.VariableHasType<double>(variableName)) {
[16001]202        impact = CalculateImpactForDouble(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
[15831]203      } else if (dataset.VariableHasType<string>(variableName)) {
[16001]204        impact = CalculateImpactForString(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
[15815]205      } else {
206        throw new NotSupportedException("Variable not supported");
207      }
208      return impact;
209    }
[14826]210
[15831]211    private static void PrepareData(IEnumerable<int> rows,
212      IRegressionProblemData problemData,
213      IEnumerable<double> estimatedValues,
214      out IEnumerable<double> targetValues,
[16001]215      out double originalValue) {
[15831]216      OnlineCalculatorError error;
[14826]217
[15831]218      var targetVariableValueList = problemData.TargetVariableValues.ToList();
219      targetValues = rows.Select(v => targetVariableValueList.ElementAt(v));
220      var estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
[16001]221      originalValue = CalculateValue(targetValues, estimatedValuesPartition, out error);
[15815]222
[15831]223      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
224    }
[15815]225
[15831]226    private static double CalculateImpactForDouble(string variableName,
227      IRegressionModel model,
228      ModifiableDataset modifiableDataset,
229      IEnumerable<int> rows,
230      IEnumerable<double> targetValues,
231      double originalValue,
[16001]232      ReplacementMethodEnum replacementMethod) {
[15831]233      OnlineCalculatorError error;
234      var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, replacementMethod);
[16001]235      var newValue = CalculateValue(targetValues, newEstimates, out error);
[15831]236      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
237      return originalValue - newValue;
238    }
[15815]239
[15831]240    private static double CalculateImpactForString(string variableName,
241      IRegressionModel model,
242      IDataset problemData,
243      ModifiableDataset modifiableDataset,
244      IEnumerable<int> rows,
245      IEnumerable<double> targetValues,
246      double originalValue,
[16001]247      FactorReplacementMethodEnum factorReplacementMethod) {
[15831]248
249      OnlineCalculatorError error;
250      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
251        // try replacing with all possible values and find the best replacement value
252        var smallestImpact = double.PositiveInfinity;
253        foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) {
254          var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
255          var newEstimates = EvaluateModelWithReplacedVariable(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList());
[16001]256          var newValue = CalculateValue(targetValues, newEstimates, out error);
[15831]257          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
258
259          var curImpact = originalValue - newValue;
260          if (curImpact < smallestImpact) smallestImpact = curImpact;
[14826]261        }
[15831]262        return smallestImpact;
263      } else {
264        // for replacement methods shuffle and mode
265        // calculate impacts for factor variables
266        var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, factorReplacementMethod);
[16001]267        var newValue = CalculateValue(targetValues, newEstimates, out error);
[15831]268        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
269
270        return originalValue - newValue;
[15815]271      }
[13766]272    }
273
274    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
275      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
276      double replacementValue;
[13986]277      List<double> replacementValues;
278      IRandom rand;
[13766]279
280      switch (replacement) {
281        case ReplacementMethodEnum.Median:
282          replacementValue = rows.Select(r => originalValues[r]).Median();
[13986]283          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
[13766]284          break;
285        case ReplacementMethodEnum.Average:
286          replacementValue = rows.Select(r => originalValues[r]).Average();
[13986]287          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
[13766]288          break;
[13986]289        case ReplacementMethodEnum.Shuffle:
290          // new var has same empirical distribution but the relation to y is broken
291          rand = new FastRandom(31415);
[14348]292          // prepare a complete column for the dataset
293          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
294          // shuffle only the selected rows
295          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
296          int i = 0;
297          // update column values
298          foreach (var r in rows) {
299            replacementValues[r] = shuffledValues[i++];
300          }
[13986]301          break;
302        case ReplacementMethodEnum.Noise:
303          var avg = rows.Select(r => originalValues[r]).Average();
304          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
305          rand = new FastRandom(31415);
[14348]306          // prepare a complete column for the dataset
307          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
308          // update column values
309          foreach (var r in rows) {
310            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
311          }
[13986]312          break;
313
[13766]314        default:
315          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
316      }
317
[15831]318      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
[14826]319    }
320
321    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
322      IRegressionModel model, string variable, ModifiableDataset dataset,
323      IEnumerable<int> rows,
324      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
325      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
326      List<string> replacementValues;
327      IRandom rand;
328
329      switch (replacement) {
330        case FactorReplacementMethodEnum.Mode:
331          var mostCommonValue = rows.Select(r => originalValues[r])
332            .GroupBy(v => v)
333            .OrderByDescending(g => g.Count())
334            .First().Key;
335          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
336          break;
337        case FactorReplacementMethodEnum.Shuffle:
338          // new var has same empirical distribution but the relation to y is broken
339          rand = new FastRandom(31415);
340          // prepare a complete column for the dataset
341          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
342          // shuffle only the selected rows
343          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
344          int i = 0;
345          // update column values
346          foreach (var r in rows) {
347            replacementValues[r] = shuffledValues[i++];
348          }
349          break;
350        default:
351          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
352      }
353
[15831]354      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
[14826]355    }
356
[15831]357    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IList originalValues, IRegressionModel model, string variable,
358      ModifiableDataset dataset, IEnumerable<int> rows, IList replacementValues) {
359      dataset.ReplaceVariable(variable, replacementValues);
[13766]360      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
361      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
362      dataset.ReplaceVariable(variable, originalValues);
363
364      return estimates;
365    }
[16001]366
367    private static double CalculateValue(IEnumerable<double> targets, IEnumerable<double> estimates, out OnlineCalculatorError error) {
368      calculator.Reset();
[16016]369
370      var targetsEnumerator = targets.GetEnumerator();
371      var estimatesEnumerator = estimates.GetEnumerator();
372
373      bool targetsHasNextValue = targetsEnumerator.MoveNext();
374      bool estimatesHasNextValue = estimatesEnumerator.MoveNext();
375
376      while (targetsHasNextValue && estimatesHasNextValue) {
377        calculator.Add(targetsEnumerator.Current, estimatesEnumerator.Current);
378        targetsHasNextValue = targetsEnumerator.MoveNext();
379        estimatesHasNextValue = estimatesEnumerator.MoveNext();
380      }
381
382      //Check if there is an equal quantity of targets and estimates
383      if (targetsHasNextValue != estimatesHasNextValue) {
[16001]384        throw new ArgumentException(string.Format("Targets and Estimates must be of equal length ({0},{1})", targets.Count(), estimates.Count()));
[16016]385
[16001]386      }
[16016]387
[16001]388      error = calculator.ErrorState;
389      return calculator.Value;
390    }
[13766]391  }
392}
Note: See TracBrowser for help on using the repository browser.