source: trunk/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs @ 15871

Last change on this file since 15871 was 15871, checked in by mkommend, 17 months ago

#2910: Added recalculation of thresholds for IDiscriminantClassificationModels during impact calculation.

File size: 14.7 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Problems.DataAnalysis {
35  [StorableClass]
36  [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")]
37  public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem {
38    public enum ReplacementMethodEnum {
39      Median,
40      Average,
41      Shuffle,
42      Noise
43    }
44    public enum FactorReplacementMethodEnum {
45      Best,
46      Mode,
47      Shuffle
48    }
49    public enum DataPartitionEnum {
50      Training,
51      Test,
52      All
53    }
54
55    private const string ReplacementParameterName = "Replacement Method";
56    private const string DataPartitionParameterName = "DataPartition";
57
58    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
59      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
60    }
61    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
62      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
63    }
64
65    public ReplacementMethodEnum ReplacementMethod {
66      get { return ReplacementParameter.Value.Value; }
67      set { ReplacementParameter.Value.Value = value; }
68    }
69    public DataPartitionEnum DataPartition {
70      get { return DataPartitionParameter.Value.Value; }
71      set { DataPartitionParameter.Value.Value = value; }
72    }
73
74
75    [StorableConstructor]
76    private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
77    private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner)
78      : base(original, cloner) { }
79    public override IDeepCloneable Clone(Cloner cloner) {
80      return new ClassificationSolutionVariableImpactsCalculator(this, cloner);
81    }
82
83    public ClassificationSolutionVariableImpactsCalculator()
84      : base() {
85      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
86      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
87    }
88
89    //mkommend: annoying name clash with static method, open to better naming suggestions
90    public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) {
91      return CalculateImpacts(solution, DataPartition, ReplacementMethod);
92    }
93
94    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
95      IClassificationSolution solution,
96      DataPartitionEnum data = DataPartitionEnum.Training,
97      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
98      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
99
100      var problemData = solution.ProblemData;
101      var dataset = problemData.Dataset;
102      var model = (IClassificationModel)solution.Model.Clone(); //mkommend: clone of model is necessary, because the thresholds for IDiscriminantClassificationModels are updated
103
104      IEnumerable<int> rows;
105      IEnumerable<double> targetValues;
106      double originalAccuracy;
107
108      OnlineCalculatorError error;
109
110      switch (data) {
111        case DataPartitionEnum.All:
112          rows = problemData.AllIndices;
113          targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList();
114          originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error);
115          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
116          break;
117        case DataPartitionEnum.Training:
118          rows = problemData.TrainingIndices;
119          targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList();
120          originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error);
121          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
122          break;
123        case DataPartitionEnum.Test:
124          rows = problemData.TestIndices;
125          targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList();
126          originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error);
127          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
128          break;
129        default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
130      }
131
132      var impacts = new Dictionary<string, double>();
133      var modifiableDataset = ((Dataset)dataset).ToModifiable();
134
135      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
136      var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
137
138      // calculate impacts for double variables
139      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
140        var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows, replacementMethod);
141        var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
142        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
143
144        impacts[inputVariable] = originalAccuracy - newAccuracy;
145      }
146
147      // calculate impacts for string variables
148      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
149        if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
150          // try replacing with all possible values and find the best replacement value
151          var smallestImpact = double.PositiveInfinity;
152          foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
153            var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows,
154              Enumerable.Repeat(repl, dataset.Rows));
155            var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
156            if (error != OnlineCalculatorError.None)
157              throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
158
159            var impact = originalAccuracy - newAccuracy;
160            if (impact < smallestImpact) smallestImpact = impact;
161          }
162          impacts[inputVariable] = smallestImpact;
163        } else {
164          // for replacement methods shuffle and mode
165          // calculate impacts for factor variables
166
167          var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows,
168            factorReplacementMethod);
169          var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
170          if (error != OnlineCalculatorError.None)
171            throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
172
173          impacts[inputVariable] = originalAccuracy - newAccuracy;
174        }
175      } // foreach
176      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
177    }
178
179    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
180      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
181      double replacementValue;
182      List<double> replacementValues;
183      IRandom rand;
184
185      switch (replacement) {
186        case ReplacementMethodEnum.Median:
187          replacementValue = rows.Select(r => originalValues[r]).Median();
188          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
189          break;
190        case ReplacementMethodEnum.Average:
191          replacementValue = rows.Select(r => originalValues[r]).Average();
192          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
193          break;
194        case ReplacementMethodEnum.Shuffle:
195          // new var has same empirical distribution but the relation to y is broken
196          rand = new FastRandom(31415);
197          // prepare a complete column for the dataset
198          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
199          // shuffle only the selected rows
200          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
201          int i = 0;
202          // update column values
203          foreach (var r in rows) {
204            replacementValues[r] = shuffledValues[i++];
205          }
206          break;
207        case ReplacementMethodEnum.Noise:
208          var avg = rows.Select(r => originalValues[r]).Average();
209          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
210          rand = new FastRandom(31415);
211          // prepare a complete column for the dataset
212          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
213          // update column values
214          foreach (var r in rows) {
215            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
216          }
217          break;
218
219        default:
220          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
221      }
222
223      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
224    }
225
226    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
227      IClassificationModel model, string variable, ModifiableDataset dataset,
228      IEnumerable<int> rows,
229      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
230      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
231      List<string> replacementValues;
232      IRandom rand;
233
234      switch (replacement) {
235        case FactorReplacementMethodEnum.Mode:
236          var mostCommonValue = rows.Select(r => originalValues[r])
237            .GroupBy(v => v)
238            .OrderByDescending(g => g.Count())
239            .First().Key;
240          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
241          break;
242        case FactorReplacementMethodEnum.Shuffle:
243          // new var has same empirical distribution but the relation to y is broken
244          rand = new FastRandom(31415);
245          // prepare a complete column for the dataset
246          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
247          // shuffle only the selected rows
248          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
249          int i = 0;
250          // update column values
251          foreach (var r in rows) {
252            replacementValues[r] = shuffledValues[i++];
253          }
254          break;
255        default:
256          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
257      }
258
259      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
260    }
261
262    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
263      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
264      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
265      dataset.ReplaceVariable(variable, replacementValues.ToList());
266
267      var discModel = model as IDiscriminantFunctionClassificationModel;
268      if (discModel != null) {
269        var problemData = new ClassificationProblemData(dataset, dataset.VariableNames, model.TargetVariable);
270        discModel.RecalculateModelParameters(problemData, rows);
271      }
272
273      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
274      var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
275      dataset.ReplaceVariable(variable, originalValues);
276
277      return estimates;
278    }
279    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
280      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
281      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
282      dataset.ReplaceVariable(variable, replacementValues.ToList());
283
284
285      var discModel = model as IDiscriminantFunctionClassificationModel;
286      if (discModel != null) {
287        var problemData = new ClassificationProblemData(dataset, dataset.VariableNames, model.TargetVariable);
288        discModel.RecalculateModelParameters(problemData, rows);
289      }
290
291      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
292      var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
293      dataset.ReplaceVariable(variable, originalValues);
294
295      return estimates;
296    }
297  }
298}
Note: See TracBrowser for help on using the repository browser.