source: branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 16036

Last change on this file since 16036 was 16036, checked in by fholzing, 3 years ago

#2904: Streamlined the variableimpactcalculator code on both Regression and Classification. Taken over the regression-code for classification with some minor adaptations.

File size: 17.9 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections;
26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
37  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
38  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    #region Parameters/Properties
40    public enum ReplacementMethodEnum {
41      Median,
42      Average,
43      Shuffle,
44      Noise
45    }
46    public enum FactorReplacementMethodEnum {
47      Best,
48      Mode,
49      Shuffle
50    }
51    public enum DataPartitionEnum {
52      Training,
53      Test,
54      All
55    }
56
57    private const string ReplacementParameterName = "Replacement Method";
58    private const string FactorReplacementParameterName = "Factor Replacement Method";
59    private const string DataPartitionParameterName = "DataPartition";
60
61    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
62      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
63    }
64    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter {
65      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
66    }
67    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
68      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
69    }
70
71    public ReplacementMethodEnum ReplacementMethod {
72      get { return ReplacementParameter.Value.Value; }
73      set { ReplacementParameter.Value.Value = value; }
74    }
75    public FactorReplacementMethodEnum FactorReplacementMethod {
76      get { return FactorReplacementParameter.Value.Value; }
77      set { FactorReplacementParameter.Value.Value = value; }
78    }
79    public DataPartitionEnum DataPartition {
80      get { return DataPartitionParameter.Value.Value; }
81      set { DataPartitionParameter.Value.Value = value; }
82    }
83    #endregion
84
85    #region Ctor/Cloner
86    [StorableConstructor]
87    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
88    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
89      : base(original, cloner) { }
90    public RegressionSolutionVariableImpactsCalculator()
91      : base() {
92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
93      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
94      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
95    }
96
97    public override IDeepCloneable Clone(Cloner cloner) {
98      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
99    }
100    #endregion
101
102    //mkommend: annoying name clash with static method, open to better naming suggestions
103    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
104      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
105    }
106
107    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
108      IRegressionSolution solution,
109      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
110      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
111      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
112      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, replacementMethod, factorReplacementMethod, dataPartition);
113    }
114
115    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
116      IRegressionModel model,
117      IRegressionProblemData problemData,
118      IEnumerable<double> estimatedValues,
119      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
120      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
121      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
122      IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData);
123      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
124    }
125
126    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
127     IRegressionModel model,
128     IRegressionProblemData problemData,
129     IEnumerable<double> estimatedValues,
130     IEnumerable<int> rows,
131     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
132     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
133      //Calculate original quality-values (via calculator, default is R²)
134      OnlineCalculatorError error;
135      IEnumerable<double> targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v));
136      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
137      var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
138      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
139
140      var impacts = new Dictionary<string, double>();
141      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
142      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
143      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
144
145      foreach (var inputVariable in allowedInputVariables) {
146        impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
147      }
148
149      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
150    }
151
152    public static double CalculateImpact(string variableName,
153      IRegressionModel model,
154      ModifiableDataset modifiableDataset,
155      IEnumerable<int> rows,
156      IEnumerable<double> targetValues,
157      double originalValue,
158      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
159      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
160      double impact = 0;
161      OnlineCalculatorError error;
162      IRandom random;
163      double replacementValue;
164      IEnumerable<double> newEstimates = null;
165      double newValue = 0;
166
167      if (modifiableDataset.VariableHasType<double>(variableName)) {
168        #region NumericalVariable
169        var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
170        List<double> replacementValues;
171
172        switch (replacementMethod) {
173          case ReplacementMethodEnum.Median:
174            replacementValue = rows.Select(r => originalValues[r]).Median();
175            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
176            break;
177          case ReplacementMethodEnum.Average:
178            replacementValue = rows.Select(r => originalValues[r]).Average();
179            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
180            break;
181          case ReplacementMethodEnum.Shuffle:
182            // new var has same empirical distribution but the relation to y is broken
183            random = new FastRandom(31415);
184            // prepare a complete column for the dataset
185            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
186            // shuffle only the selected rows
187            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
188            int i = 0;
189            // update column values
190            foreach (var r in rows) {
191              replacementValues[r] = shuffledValues[i++];
192            }
193            break;
194          case ReplacementMethodEnum.Noise:
195            var avg = rows.Select(r => originalValues[r]).Average();
196            var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
197            random = new FastRandom(31415);
198            // prepare a complete column for the dataset
199            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
200            // update column values
201            foreach (var r in rows) {
202              replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
203            }
204            break;
205
206          default:
207            throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
208        }
209
210        newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
211        newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
212        if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
213
214        impact = originalValue - newValue;
215        #endregion
216      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
217        #region FactorVariable
218        var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
219        List<string> replacementValues;
220
221        switch (factorReplacementMethod) {
222          case FactorReplacementMethodEnum.Best:
223            // try replacing with all possible values and find the best replacement value
224            var smallestImpact = double.PositiveInfinity;
225            foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
226              newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
227              newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
228              if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
229
230              var curImpact = originalValue - newValue;
231              if (curImpact < smallestImpact) smallestImpact = curImpact;
232            }
233            impact = smallestImpact;
234            break;
235          case FactorReplacementMethodEnum.Mode:
236            var mostCommonValue = rows.Select(r => originalValues[r])
237              .GroupBy(v => v)
238              .OrderByDescending(g => g.Count())
239              .First().Key;
240            replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
241
242            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
243            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
244            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
245
246            impact = originalValue - newValue;
247            break;
248          case FactorReplacementMethodEnum.Shuffle:
249            // new var has same empirical distribution but the relation to y is broken
250            random = new FastRandom(31415);
251            // prepare a complete column for the dataset
252            replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
253            // shuffle only the selected rows
254            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
255            int i = 0;
256            // update column values
257            foreach (var r in rows) {
258              replacementValues[r] = shuffledValues[i++];
259            }
260
261            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
262            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
263            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
264
265            impact = originalValue - newValue;
266            break;
267          default:
268            throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
269        }
270        #endregion
271      } else {
272        throw new NotSupportedException("Variable not supported");
273      }
274
275      return impact;
276    }
277
278    /// <summary>
279    /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
280    /// and changes the value of the model-variables back to the original ones.
281    /// </summary>
282    /// <param name="originalValues"></param>
283    /// <param name="model"></param>
284    /// <param name="variableName"></param>
285    /// <param name="modifiableDataset"></param>
286    /// <param name="rows"></param>
287    /// <param name="replacementValues"></param>
288    /// <returns></returns>
289    private static IEnumerable<double> GetReplacedEstimates(
290      IList originalValues,
291      IRegressionModel model,
292      string variableName,
293      ModifiableDataset modifiableDataset,
294      IEnumerable<int> rows,
295      IList replacementValues) {
296      modifiableDataset.ReplaceVariable(variableName, replacementValues);
297      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
298      var estimates = model.GetEstimatedValues(modifiableDataset, rows).ToList();
299      modifiableDataset.ReplaceVariable(variableName, originalValues);
300
301      return estimates;
302    }
303
304    /// <summary>
305    /// Calculates and returns the VariableImpact (calculated via Pearsons R²).
306    /// </summary>
307    /// <param name="targetValues">The actual values</param>
308    /// <param name="estimatedValues">The calculated/replaced values</param>
309    /// <param name="errorState"></param>
310    /// <returns></returns>
311    public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
312      //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
313      //as the code below does. But this way we can easily swap the calculator later on, so the user 
314      //could choose a Calculator during runtime in future versions.
315      IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator();
316      IEnumerator<double> firstEnumerator = targetValues.GetEnumerator();
317      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
318
319      // always move forward both enumerators (do not use short-circuit evaluation!)
320      while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
321        double original = firstEnumerator.Current;
322        double estimated = secondEnumerator.Current;
323        calculator.Add(original, estimated);
324        if (calculator.ErrorState != OnlineCalculatorError.None) break;
325      }
326
327      // check if both enumerators are at the end to make sure both enumerations have the same length
328      if (calculator.ErrorState == OnlineCalculatorError.None &&
329           (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
330        throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
331      } else {
332        errorState = calculator.ErrorState;
333        return calculator.Value;
334      }
335    }
336
337    /// <summary>
338    /// Returns a collection of the row-indices for a given DataPartition (training or test)
339    /// </summary>
340    /// <param name="dataPartition"></param>
341    /// <param name="problemData"></param>
342    /// <returns></returns>
343    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IRegressionProblemData problemData) {
344      IEnumerable<int> rows;
345
346      switch (dataPartition) {
347        case DataPartitionEnum.All:
348          rows = problemData.AllIndices;
349          break;
350        case DataPartitionEnum.Test:
351          rows = problemData.TestIndices;
352          break;
353        case DataPartitionEnum.Training:
354          rows = problemData.TrainingIndices;
355          break;
356        default:
357          throw new NotSupportedException("DataPartition not supported");
358      }
359
360      return rows;
361    }
362  }
363}
Note: See TracBrowser for help on using the repository browser.