source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 16018

Last change on this file since 16018 was 16018, checked in by fholzing, 20 months ago

#2904: Unified order of IFixedValueParameter. Use FactorReplacementMethod as default for .Calculate method

File size: 18.6 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections;
26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
37  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
38  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    private static IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator();
40
41    public enum ReplacementMethodEnum {
42      Median,
43      Average,
44      Shuffle,
45      Noise
46    }
47    public enum FactorReplacementMethodEnum {
48      Best,
49      Mode,
50      Shuffle
51    }
52    public enum DataPartitionEnum {
53      Training,
54      Test,
55      All
56    }
57
58    private const string ReplacementParameterName = "Replacement Method";
59    private const string FactorReplacementParameterName = "Factor Replacement Method";
60    private const string DataPartitionParameterName = "DataPartition";
61
62    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter
63    {
64      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
65    }
66    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter
67    {
68      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
69    }
70    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter
71    {
72      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
73    }
74
75    public ReplacementMethodEnum ReplacementMethod
76    {
77      get { return ReplacementParameter.Value.Value; }
78      set { ReplacementParameter.Value.Value = value; }
79    }
80    public FactorReplacementMethodEnum FactorReplacementMethod
81    {
82      get { return FactorReplacementParameter.Value.Value; }
83      set { FactorReplacementParameter.Value.Value = value; }
84    }
85    public DataPartitionEnum DataPartition
86    {
87      get { return DataPartitionParameter.Value.Value; }
88      set { DataPartitionParameter.Value.Value = value; }
89    }
90
91
92    [StorableConstructor]
93    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
94    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
95      : base(original, cloner) { }
96    public override IDeepCloneable Clone(Cloner cloner) {
97      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
98    }
99
100    public RegressionSolutionVariableImpactsCalculator()
101      : base() {
102      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
103      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
104      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
105    }
106
107    //mkommend: annoying name clash with static method, open to better naming suggestions
108    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
109      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
110    }
111
112    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
113      IRegressionSolution solution,
114      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
115      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
116      DataPartitionEnum data = DataPartitionEnum.Training,
117      Func<double, string, bool> progressCallback = null) {
118      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, replacementMethod, factorReplacementMethod, data, progressCallback);
119    }
120
121    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
122      IRegressionModel model,
123      IRegressionProblemData problemData,
124      IEnumerable<double> estimatedValues,
125      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
126      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
127      DataPartitionEnum data = DataPartitionEnum.Training,
128      Func<double, string, bool> progressCallback = null) {
129      IEnumerable<int> rows;
130
131      switch (data) {
132        case DataPartitionEnum.All:
133          rows = problemData.AllIndices;
134          break;
135        case DataPartitionEnum.Test:
136          rows = problemData.TestIndices;
137          break;
138        case DataPartitionEnum.Training:
139          rows = problemData.TrainingIndices;
140          break;
141        default:
142          throw new NotSupportedException("DataPartition not supported");
143      }
144
145      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod, progressCallback);
146    }
147
148    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
149     IRegressionModel model,
150     IRegressionProblemData problemData,
151     IEnumerable<double> estimatedValues,
152     IEnumerable<int> rows,
153     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
154     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
155     Func<double, string, bool> progressCallback = null) {
156
157      IEnumerable<double> targetValues;
158      double originalValue = -1;
159
160      PrepareData(rows, problemData, estimatedValues, out targetValues, out originalValue);
161
162      var impacts = new Dictionary<string, double>();
163      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
164      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
165
166      int curIdx = 0;
167      int count = allowedInputVariables.Count();
168
169      foreach (var inputVariable in allowedInputVariables) {
170        //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
171        if (progressCallback != null) {
172          curIdx++;
173          if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
174        }
175        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod);
176      }
177
178      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
179    }
180
181    public static double CalculateImpact(string variableName,
182      IRegressionSolution solution,
183      IEnumerable<int> rows,
184      IEnumerable<double> targetValues,
185      double originalValue,
186      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
187      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
188      DataPartitionEnum data = DataPartitionEnum.Training) {
189      return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod);
190    }
191
192    public static double CalculateImpact(string variableName,
193      IRegressionModel model,
194      IDataset dataset,
195      IEnumerable<int> rows,
196      IEnumerable<double> targetValues,
197      double originalValue,
198      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
199      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
200
201      double impact = 0;
202      var modifiableDataset = ((Dataset)(dataset).Clone()).ToModifiable();
203
204      // calculate impacts for double variables
205      if (dataset.VariableHasType<double>(variableName)) {
206        impact = CalculateImpactForDouble(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
207      } else if (dataset.VariableHasType<string>(variableName)) {
208        impact = CalculateImpactForString(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
209      } else {
210        throw new NotSupportedException("Variable not supported");
211      }
212      return impact;
213    }
214
215    private static void PrepareData(IEnumerable<int> rows,
216      IRegressionProblemData problemData,
217      IEnumerable<double> estimatedValues,
218      out IEnumerable<double> targetValues,
219      out double originalValue) {
220      OnlineCalculatorError error;
221
222      var targetVariableValueList = problemData.TargetVariableValues.ToList();
223      targetValues = rows.Select(v => targetVariableValueList.ElementAt(v));
224      var estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
225      originalValue = CalculateValue(targetValues, estimatedValuesPartition, out error);
226
227      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
228    }
229
230    private static double CalculateImpactForDouble(string variableName,
231      IRegressionModel model,
232      ModifiableDataset modifiableDataset,
233      IEnumerable<int> rows,
234      IEnumerable<double> targetValues,
235      double originalValue,
236      ReplacementMethodEnum replacementMethod) {
237      OnlineCalculatorError error;
238      var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, replacementMethod);
239      var newValue = CalculateValue(targetValues, newEstimates, out error);
240      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
241      return originalValue - newValue;
242    }
243
244    private static double CalculateImpactForString(string variableName,
245      IRegressionModel model,
246      IDataset problemData,
247      ModifiableDataset modifiableDataset,
248      IEnumerable<int> rows,
249      IEnumerable<double> targetValues,
250      double originalValue,
251      FactorReplacementMethodEnum factorReplacementMethod) {
252
253      OnlineCalculatorError error;
254      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
255        // try replacing with all possible values and find the best replacement value
256        var smallestImpact = double.PositiveInfinity;
257        foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) {
258          var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
259          var newEstimates = EvaluateModelWithReplacedVariable(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList());
260          var newValue = CalculateValue(targetValues, newEstimates, out error);
261          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
262
263          var curImpact = originalValue - newValue;
264          if (curImpact < smallestImpact) smallestImpact = curImpact;
265        }
266        return smallestImpact;
267      } else {
268        // for replacement methods shuffle and mode
269        // calculate impacts for factor variables
270        var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, factorReplacementMethod);
271        var newValue = CalculateValue(targetValues, newEstimates, out error);
272        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
273
274        return originalValue - newValue;
275      }
276    }
277
278    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
279      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
280      double replacementValue;
281      List<double> replacementValues;
282      IRandom rand;
283
284      switch (replacement) {
285        case ReplacementMethodEnum.Median:
286          replacementValue = rows.Select(r => originalValues[r]).Median();
287          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
288          break;
289        case ReplacementMethodEnum.Average:
290          replacementValue = rows.Select(r => originalValues[r]).Average();
291          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
292          break;
293        case ReplacementMethodEnum.Shuffle:
294          // new var has same empirical distribution but the relation to y is broken
295          rand = new FastRandom(31415);
296          // prepare a complete column for the dataset
297          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
298          // shuffle only the selected rows
299          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
300          int i = 0;
301          // update column values
302          foreach (var r in rows) {
303            replacementValues[r] = shuffledValues[i++];
304          }
305          break;
306        case ReplacementMethodEnum.Noise:
307          var avg = rows.Select(r => originalValues[r]).Average();
308          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
309          rand = new FastRandom(31415);
310          // prepare a complete column for the dataset
311          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
312          // update column values
313          foreach (var r in rows) {
314            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
315          }
316          break;
317
318        default:
319          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
320      }
321
322      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
323    }
324
325    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
326      IRegressionModel model, string variable, ModifiableDataset dataset,
327      IEnumerable<int> rows,
328      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Best) {
329      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
330      List<string> replacementValues;
331      IRandom rand;
332
333      switch (replacement) {
334        case FactorReplacementMethodEnum.Mode:
335          var mostCommonValue = rows.Select(r => originalValues[r])
336            .GroupBy(v => v)
337            .OrderByDescending(g => g.Count())
338            .First().Key;
339          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
340          break;
341        case FactorReplacementMethodEnum.Shuffle:
342          // new var has same empirical distribution but the relation to y is broken
343          rand = new FastRandom(31415);
344          // prepare a complete column for the dataset
345          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
346          // shuffle only the selected rows
347          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
348          int i = 0;
349          // update column values
350          foreach (var r in rows) {
351            replacementValues[r] = shuffledValues[i++];
352          }
353          break;
354        default:
355          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
356      }
357
358      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
359    }
360
361    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IList originalValues, IRegressionModel model, string variable,
362      ModifiableDataset dataset, IEnumerable<int> rows, IList replacementValues) {
363      dataset.ReplaceVariable(variable, replacementValues);
364      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
365      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
366      dataset.ReplaceVariable(variable, originalValues);
367
368      return estimates;
369    }
370
371    private static double CalculateValue(IEnumerable<double> targets, IEnumerable<double> estimates, out OnlineCalculatorError error) {
372      calculator.Reset();
373
374      var targetsEnumerator = targets.GetEnumerator();
375      var estimatesEnumerator = estimates.GetEnumerator();
376
377      bool targetsHasNextValue = targetsEnumerator.MoveNext();
378      bool estimatesHasNextValue = estimatesEnumerator.MoveNext();
379
380      while (targetsHasNextValue && estimatesHasNextValue) {
381        calculator.Add(targetsEnumerator.Current, estimatesEnumerator.Current);
382        targetsHasNextValue = targetsEnumerator.MoveNext();
383        estimatesHasNextValue = estimatesEnumerator.MoveNext();
384      }
385
386      //Check if there is an equal quantity of targets and estimates
387      if (targetsHasNextValue != estimatesHasNextValue) {
388        throw new ArgumentException(string.Format("Targets and Estimates must be of equal length ({0},{1})", targets.Count(), estimates.Count()));
389
390      }
391
392      error = calculator.ErrorState;
393      return calculator.Value;
394    }
395  }
396}
Note: See TracBrowser for help on using the repository browser.