Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 15831

Last change on this file since 15831 was 15831, checked in by fholzing, 6 years ago

#2904: Refactored RegressionSolutionVariableImpactsCalculator. We don't dependent on the solution anymore. The impact can be calculated for a single variable. The calculator can be chosen.

File size: 17.2 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections;
26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
37  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
38  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    public enum ReplacementMethodEnum {
40      Median,
41      Average,
42      Shuffle,
43      Noise
44    }
45    public enum FactorReplacementMethodEnum {
46      Best,
47      Mode,
48      Shuffle
49    }
50    public enum DataPartitionEnum {
51      Training,
52      Test,
53      All
54    }
55
56    private const string ReplacementParameterName = "Replacement Method";
57    private const string DataPartitionParameterName = "DataPartition";
58
59    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter
60    {
61      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
62    }
63    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter
64    {
65      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
66    }
67
68    public ReplacementMethodEnum ReplacementMethod
69    {
70      get { return ReplacementParameter.Value.Value; }
71      set { ReplacementParameter.Value.Value = value; }
72    }
73    public DataPartitionEnum DataPartition
74    {
75      get { return DataPartitionParameter.Value.Value; }
76      set { DataPartitionParameter.Value.Value = value; }
77    }
78
79
80    [StorableConstructor]
81    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
82    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
83      : base(original, cloner) { }
84    public override IDeepCloneable Clone(Cloner cloner) {
85      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
86    }
87
88    public RegressionSolutionVariableImpactsCalculator()
89      : base() {
90      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
91      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
92    }
93
94    //mkommend: annoying name clash with static method, open to better naming suggestions
95    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
96      return CalculateImpacts(solution, DataPartition, ReplacementMethod);
97    }
98
99    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
100      IRegressionSolution solution,
101      DataPartitionEnum data = DataPartitionEnum.Training,
102      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
103      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
104      Func<double, string, bool> progressCallback = null) {
105      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, data, replacementMethod, factorReplacementMethod, progressCallback);
106    }
107
108    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
109      IRegressionModel model,
110      IRegressionProblemData problemData,
111      IEnumerable<double> estimatedValues,
112      DataPartitionEnum data = DataPartitionEnum.Training,
113      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
114      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
115      Func<double, string, bool> progressCallback = null,
116      IOnlineCalculator calculator = null) {
117      //PearsonsRSquared is the default calculator
118      if (calculator == null) { calculator = new OnlinePearsonsRSquaredCalculator(); }
119      IEnumerable<int> rows;
120
121      switch (data) {
122        case DataPartitionEnum.All:
123          rows = problemData.AllIndices;
124          break;
125        case DataPartitionEnum.Test:
126          rows = problemData.TestIndices;
127          break;
128        case DataPartitionEnum.Training:
129          rows = problemData.TrainingIndices;
130          break;
131        default:
132          throw new NotSupportedException("DataPartition not supported");
133      }
134
135      return CalculateImpacts(model, problemData, estimatedValues, rows, calculator, replacementMethod, factorReplacementMethod, progressCallback);
136    }
137
138    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
139     IRegressionModel model,
140     IRegressionProblemData problemData,
141     IEnumerable<double> estimatedValues,
142     IEnumerable<int> rows,
143     IOnlineCalculator calculator,
144     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
145     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
146     Func<double, string, bool> progressCallback = null) {
147
148      IEnumerable<double> targetValues;
149      double originalValue = -1;
150
151      PrepareData(rows, problemData, estimatedValues, out targetValues, out originalValue, calculator);
152
153      var impacts = new Dictionary<string, double>();
154      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
155      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
156
157      int curIdx = 0;
158      int count = allowedInputVariables
159        .Where(v => problemData.Dataset.VariableHasType<double>(v) || problemData.Dataset.VariableHasType<string>(v))
160        .Count();
161
162      foreach (var inputVariable in allowedInputVariables) {
163        //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
164        if (progressCallback != null) {
165          curIdx++;
166          if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
167        }
168        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValues, originalValue, calculator, replacementMethod, factorReplacementMethod);
169      }
170
171      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
172    }
173
174    public static double CalculateImpact(string variableName,
175      IRegressionSolution solution,
176      IEnumerable<int> rows,
177      IEnumerable<double> targetValues,
178      double originalValue,
179      IOnlineCalculator calculator,
180      DataPartitionEnum data = DataPartitionEnum.Training,
181      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
182      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
183      return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, calculator, replacementMethod, factorReplacementMethod);
184    }
185
186    public static double CalculateImpact(string variableName,
187      IRegressionModel model,
188      IDataset dataset,
189      IEnumerable<int> rows,
190      IEnumerable<double> targetValues,
191      double originalValue,
192      IOnlineCalculator calculator,
193      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
194      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
195
196      double impact = 0;
197      var modifiableDataset = ((Dataset)dataset).ToModifiable();
198
199      // calculate impacts for double variables
200      if (dataset.VariableHasType<double>(variableName)) {
201        impact = CalculateImpactForDouble(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod, calculator);
202      } else if (dataset.VariableHasType<string>(variableName)) {
203        impact = CalculateImpactForString(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod, calculator);
204      } else {
205        throw new NotSupportedException("Variable not supported");
206      }
207      return impact;
208    }
209
210    private static void PrepareData(IEnumerable<int> rows,
211      IRegressionProblemData problemData,
212      IEnumerable<double> estimatedValues,
213      out IEnumerable<double> targetValues,
214      out double originalValue,
215      IOnlineCalculator calculator) {
216      OnlineCalculatorError error;
217
218      var targetVariableValueList = problemData.TargetVariableValues.ToList();
219      targetValues = rows.Select(v => targetVariableValueList.ElementAt(v));
220      var estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
221      originalValue = calculator.CalculateValue(targetValues, estimatedValuesPartition, out error);
222
223      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
224    }
225
226    private static double CalculateImpactForDouble(string variableName,
227      IRegressionModel model,
228      ModifiableDataset modifiableDataset,
229      IEnumerable<int> rows,
230      IEnumerable<double> targetValues,
231      double originalValue,
232      ReplacementMethodEnum replacementMethod,
233      IOnlineCalculator calculator) {
234      OnlineCalculatorError error;
235      var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, replacementMethod);
236      var newValue = calculator.CalculateValue(targetValues, newEstimates, out error);
237      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
238      return originalValue - newValue;
239    }
240
241    private static double CalculateImpactForString(string variableName,
242      IRegressionModel model,
243      IDataset problemData,
244      ModifiableDataset modifiableDataset,
245      IEnumerable<int> rows,
246      IEnumerable<double> targetValues,
247      double originalValue,
248      FactorReplacementMethodEnum factorReplacementMethod,
249      IOnlineCalculator calculator) {
250
251      OnlineCalculatorError error;
252      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
253        // try replacing with all possible values and find the best replacement value
254        var smallestImpact = double.PositiveInfinity;
255        foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) {
256          var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
257          var newEstimates = EvaluateModelWithReplacedVariable(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList());
258          var newValue = calculator.CalculateValue(targetValues, newEstimates, out error);
259          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
260
261          var curImpact = originalValue - newValue;
262          if (curImpact < smallestImpact) smallestImpact = curImpact;
263        }
264        return smallestImpact;
265      } else {
266        // for replacement methods shuffle and mode
267        // calculate impacts for factor variables
268        var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, factorReplacementMethod);
269        var newValue = calculator.CalculateValue(targetValues, newEstimates, out error);
270        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
271
272        return originalValue - newValue;
273      }
274    }
275
276    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
277      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
278      double replacementValue;
279      List<double> replacementValues;
280      IRandom rand;
281
282      switch (replacement) {
283        case ReplacementMethodEnum.Median:
284          replacementValue = rows.Select(r => originalValues[r]).Median();
285          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
286          break;
287        case ReplacementMethodEnum.Average:
288          replacementValue = rows.Select(r => originalValues[r]).Average();
289          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
290          break;
291        case ReplacementMethodEnum.Shuffle:
292          // new var has same empirical distribution but the relation to y is broken
293          rand = new FastRandom(31415);
294          // prepare a complete column for the dataset
295          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
296          // shuffle only the selected rows
297          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
298          int i = 0;
299          // update column values
300          foreach (var r in rows) {
301            replacementValues[r] = shuffledValues[i++];
302          }
303          break;
304        case ReplacementMethodEnum.Noise:
305          var avg = rows.Select(r => originalValues[r]).Average();
306          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
307          rand = new FastRandom(31415);
308          // prepare a complete column for the dataset
309          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
310          // update column values
311          foreach (var r in rows) {
312            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
313          }
314          break;
315
316        default:
317          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
318      }
319
320      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
321    }
322
323    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
324      IRegressionModel model, string variable, ModifiableDataset dataset,
325      IEnumerable<int> rows,
326      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
327      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
328      List<string> replacementValues;
329      IRandom rand;
330
331      switch (replacement) {
332        case FactorReplacementMethodEnum.Mode:
333          var mostCommonValue = rows.Select(r => originalValues[r])
334            .GroupBy(v => v)
335            .OrderByDescending(g => g.Count())
336            .First().Key;
337          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
338          break;
339        case FactorReplacementMethodEnum.Shuffle:
340          // new var has same empirical distribution but the relation to y is broken
341          rand = new FastRandom(31415);
342          // prepare a complete column for the dataset
343          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
344          // shuffle only the selected rows
345          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
346          int i = 0;
347          // update column values
348          foreach (var r in rows) {
349            replacementValues[r] = shuffledValues[i++];
350          }
351          break;
352        default:
353          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
354      }
355
356      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
357    }
358
359    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IList originalValues, IRegressionModel model, string variable,
360      ModifiableDataset dataset, IEnumerable<int> rows, IList replacementValues) {
361      dataset.ReplaceVariable(variable, replacementValues);
362      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
363      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
364      dataset.ReplaceVariable(variable, originalValues);
365
366      return estimates;
367    }
368  }
369}
Note: See TracBrowser for help on using the repository browser.