Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs @ 16041

Last change on this file since 16041 was 16041, checked in by fholzing, 6 years ago

#2904: Removed ElementAt

File size: 18.4 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections;
26using System.Collections.Generic;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Problems.DataAnalysis {
36  [StorableClass]
37  [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")]
38  public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem {
39    #region Parameters/Properties
40    public enum ReplacementMethodEnum {
41      Median,
42      Average,
43      Shuffle,
44      Noise
45    }
46    public enum FactorReplacementMethodEnum {
47      Best,
48      Mode,
49      Shuffle
50    }
51    public enum DataPartitionEnum {
52      Training,
53      Test,
54      All
55    }
56
57    private const string ReplacementParameterName = "Replacement Method";
58    private const string FactorReplacementParameterName = "Factor Replacement Method";
59    private const string DataPartitionParameterName = "DataPartition";
60
61    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
62      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
63    }
64    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter {
65      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
66    }
67    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
68      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
69    }
70
71    public ReplacementMethodEnum ReplacementMethod {
72      get { return ReplacementParameter.Value.Value; }
73      set { ReplacementParameter.Value.Value = value; }
74    }
75    public FactorReplacementMethodEnum FactorReplacementMethod {
76      get { return FactorReplacementParameter.Value.Value; }
77      set { FactorReplacementParameter.Value.Value = value; }
78    }
79    public DataPartitionEnum DataPartition {
80      get { return DataPartitionParameter.Value.Value; }
81      set { DataPartitionParameter.Value.Value = value; }
82    }
83    #endregion
84
85    #region Ctor/Cloner
86    [StorableConstructor]
87    private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
88    private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner)
89      : base(original, cloner) { }
90    public ClassificationSolutionVariableImpactsCalculator()
91      : base() {
92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
93      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
94      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
95    }
96
97    public override IDeepCloneable Clone(Cloner cloner) {
98      return new ClassificationSolutionVariableImpactsCalculator(this, cloner);
99    }
100    #endregion
101
102    //mkommend: annoying name clash with static method, open to better naming suggestions
103    public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) {
104      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
105    }
106
107    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
108      IClassificationSolution solution,
109      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
110      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
111      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
112      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition);
113    }
114
115    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
116      IClassificationModel model,
117      IClassificationProblemData problemData,
118      IEnumerable<double> estimatedValues,
119      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
120      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
121      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
122      IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData);
123      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
124    }
125
126    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
127     IClassificationModel model,
128     IClassificationProblemData problemData,
129     IEnumerable<double> estimatedClassValues,
130     IEnumerable<int> rows,
131     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
132     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
133      //Calculate original quality-values (via calculator, default is Accuracy)   
134      OnlineCalculatorError error;
135      IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
136      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v));
137      var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
138      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
139
140      var impacts = new Dictionary<string, double>();
141      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
142      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
143      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
144
145      foreach (var inputVariable in allowedInputVariables) {
146        if (model.VariablesUsedForPrediction.Contains(inputVariable)) {
147          impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
148        } else {
149          impacts[inputVariable] = 0;
150        }
151      }
152
153      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
154    }
155
156    public static double CalculateImpact(string variableName,
157      IClassificationModel model,
158      ModifiableDataset modifiableDataset,
159      IEnumerable<int> rows,
160      IEnumerable<double> targetValues,
161      double originalValue,
162      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
163      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
164      double impact = 0;
165      OnlineCalculatorError error;
166      IRandom random;
167      double replacementValue;
168      IEnumerable<double> newEstimates = null;
169      double newValue = 0;
170
171      if (modifiableDataset.VariableHasType<double>(variableName)) {
172        #region NumericalVariable
173        var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
174        List<double> replacementValues;
175
176        switch (replacementMethod) {
177          case ReplacementMethodEnum.Median:
178            replacementValue = rows.Select(r => originalValues[r]).Median();
179            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
180            break;
181          case ReplacementMethodEnum.Average:
182            replacementValue = rows.Select(r => originalValues[r]).Average();
183            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
184            break;
185          case ReplacementMethodEnum.Shuffle:
186            // new var has same empirical distribution but the relation to y is broken
187            random = new FastRandom(31415);
188            // prepare a complete column for the dataset
189            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
190            // shuffle only the selected rows
191            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
192            int i = 0;
193            // update column values
194            foreach (var r in rows) {
195              replacementValues[r] = shuffledValues[i++];
196            }
197            break;
198          case ReplacementMethodEnum.Noise:
199            var avg = rows.Select(r => originalValues[r]).Average();
200            var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
201            random = new FastRandom(31415);
202            // prepare a complete column for the dataset
203            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
204            // update column values
205            foreach (var r in rows) {
206              replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
207            }
208            break;
209
210          default:
211            throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
212        }
213
214        newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
215        newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
216        if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
217
218        impact = originalValue - newValue;
219        #endregion
220      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
221        #region FactorVariable
222        var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
223        List<string> replacementValues;
224
225        switch (factorReplacementMethod) {
226          case FactorReplacementMethodEnum.Best:
227            // try replacing with all possible values and find the best replacement value
228            var smallestImpact = double.PositiveInfinity;
229            foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
230              newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
231              newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
232              if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
233
234              var curImpact = originalValue - newValue;
235              if (curImpact < smallestImpact) smallestImpact = curImpact;
236            }
237            impact = smallestImpact;
238            break;
239          case FactorReplacementMethodEnum.Mode:
240            var mostCommonValue = rows.Select(r => originalValues[r])
241              .GroupBy(v => v)
242              .OrderByDescending(g => g.Count())
243              .First().Key;
244            replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
245
246            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
247            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
248            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
249
250            impact = originalValue - newValue;
251            break;
252          case FactorReplacementMethodEnum.Shuffle:
253            // new var has same empirical distribution but the relation to y is broken
254            random = new FastRandom(31415);
255            // prepare a complete column for the dataset
256            replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
257            // shuffle only the selected rows
258            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
259            int i = 0;
260            // update column values
261            foreach (var r in rows) {
262              replacementValues[r] = shuffledValues[i++];
263            }
264
265            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
266            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
267            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
268
269            impact = originalValue - newValue;
270            break;
271          default:
272            throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
273        }
274        #endregion
275      } else {
276        throw new NotSupportedException("Variable not supported");
277      }
278
279      return impact;
280    }
281
282    /// <summary>
283    /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
284    /// and changes the value of the model-variables back to the original ones.
285    /// </summary>
286    /// <param name="originalValues"></param>
287    /// <param name="model"></param>
288    /// <param name="variableName"></param>
289    /// <param name="modifiableDataset"></param>
290    /// <param name="rows"></param>
291    /// <param name="replacementValues"></param>
292    /// <returns></returns>
293    private static IEnumerable<double> GetReplacedEstimates(
294      IList originalValues,
295      IClassificationModel model,
296      string variableName,
297      ModifiableDataset modifiableDataset,
298      IEnumerable<int> rows,
299      IList replacementValues) {
300      modifiableDataset.ReplaceVariable(variableName, replacementValues);
301
302      var discModel = model as IDiscriminantFunctionClassificationModel;
303      if (discModel != null) {
304        var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);
305        discModel.RecalculateModelParameters(problemData, rows);
306      }
307
308      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
309      var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
310      modifiableDataset.ReplaceVariable(variableName, originalValues);
311
312      return estimates;
313    }
314
315    /// <summary>
316    /// Calculates and returns the VariableImpact (calculated via Accuracy).
317    /// </summary>
318    /// <param name="targetValues">The actual values</param>
319    /// <param name="estimatedValues">The calculated/replaced values</param>
320    /// <param name="errorState"></param>
321    /// <returns></returns>
322    public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
323      //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
324      //as the code below does. But this way we can easily swap the calculator later on, so the user 
325      //could choose a Calculator during runtime in future versions.
326      IOnlineCalculator calculator = new OnlineAccuracyCalculator();
327      IEnumerator<double> firstEnumerator = targetValues.GetEnumerator();
328      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
329
330      // always move forward both enumerators (do not use short-circuit evaluation!)
331      while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
332        double original = firstEnumerator.Current;
333        double estimated = secondEnumerator.Current;
334        calculator.Add(original, estimated);
335        if (calculator.ErrorState != OnlineCalculatorError.None) break;
336      }
337
338      // check if both enumerators are at the end to make sure both enumerations have the same length
339      if (calculator.ErrorState == OnlineCalculatorError.None &&
340           (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
341        throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
342      } else {
343        errorState = calculator.ErrorState;
344        return calculator.Value;
345      }
346    }
347
348    /// <summary>
349    /// Returns a collection of the row-indices for a given DataPartition (training or test)
350    /// </summary>
351    /// <param name="dataPartition"></param>
352    /// <param name="problemData"></param>
353    /// <returns></returns>
354    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) {
355      IEnumerable<int> rows;
356
357      switch (dataPartition) {
358        case DataPartitionEnum.All:
359          rows = problemData.AllIndices;
360          break;
361        case DataPartitionEnum.Test:
362          rows = problemData.TestIndices;
363          break;
364        case DataPartitionEnum.Training:
365          rows = problemData.TrainingIndices;
366          break;
367        default:
368          throw new NotSupportedException("DataPartition not supported");
369      }
370
371      return rows;
372    }
373  }
374}
Note: See TracBrowser for help on using the repository browser.