Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs @ 15667

Last change on this file since 15667 was 15667, checked in by fholzing, 6 years ago

#2884: Added sorting mechanism (see #2871)

File size: 14.1 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Problems.DataAnalysis {
35  [StorableClass]
36  [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")]
37  public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem {
38    public enum ReplacementMethodEnum {
39      Median,
40      Average,
41      Shuffle,
42      Noise
43    }
44    public enum FactorReplacementMethodEnum {
45      Best,
46      Mode,
47      Shuffle
48    }
49    public enum DataPartitionEnum {
50      Training,
51      Test,
52      All
53    }
54
55    public enum SortingCriteria {
56      ImpactValue,
57      Occurrence,
58      VariableName
59    }
60
61    private const string ReplacementParameterName = "Replacement Method";
62    private const string DataPartitionParameterName = "DataPartition";
63
64    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
65      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
66    }
67    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
68      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
69    }
70
71    public ReplacementMethodEnum ReplacementMethod {
72      get { return ReplacementParameter.Value.Value; }
73      set { ReplacementParameter.Value.Value = value; }
74    }
75    public DataPartitionEnum DataPartition {
76      get { return DataPartitionParameter.Value.Value; }
77      set { DataPartitionParameter.Value.Value = value; }
78    }
79
80
81    [StorableConstructor]
82    private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
83    private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner)
84      : base(original, cloner) { }
85    public override IDeepCloneable Clone(Cloner cloner) {
86      return new ClassificationSolutionVariableImpactsCalculator(this, cloner);
87    }
88
89    public ClassificationSolutionVariableImpactsCalculator()
90      : base() {
91      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
92      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
93    }
94
95    //mkommend: annoying name clash with static method, open to better naming suggestions
96    public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) {
97      return CalculateImpacts(solution, DataPartition, ReplacementMethod);
98    }
99
100    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
101      IClassificationSolution solution,
102      DataPartitionEnum data = DataPartitionEnum.Training,
103      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
104      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
105
106      var problemData = solution.ProblemData;
107      var dataset = problemData.Dataset;
108
109      IEnumerable<int> rows;
110      IEnumerable<double> targetValues;
111      double originalAccuracy;
112
113      OnlineCalculatorError error;
114
115      switch (data) {
116        case DataPartitionEnum.All:
117          rows = problemData.AllIndices;
118          targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList();
119          originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error);
120          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
121          break;
122        case DataPartitionEnum.Training:
123          rows = problemData.TrainingIndices;
124          targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList();
125          originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error);
126          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
127          break;
128        case DataPartitionEnum.Test:
129          rows = problemData.TestIndices;
130          targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList();
131          originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error);
132          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
133          break;
134        default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
135      }
136
137      var impacts = new Dictionary<string, double>();
138      var modifiableDataset = ((Dataset)dataset).ToModifiable();
139
140      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
141      var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
142
143      // calculate impacts for double variables
144      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
145        var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);
146        var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
147        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
148
149        impacts[inputVariable] = originalAccuracy - newAccuracy;
150      }
151
152      // calculate impacts for string variables
153      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
154        if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
155          // try replacing with all possible values and find the best replacement value
156          var smallestImpact = double.PositiveInfinity;
157          foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
158            var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
159              Enumerable.Repeat(repl, dataset.Rows));
160            var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
161            if (error != OnlineCalculatorError.None)
162              throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
163
164            var impact = originalAccuracy - newAccuracy;
165            if (impact < smallestImpact) smallestImpact = impact;
166          }
167          impacts[inputVariable] = smallestImpact;
168        } else {
169          // for replacement methods shuffle and mode
170          // calculate impacts for factor variables
171
172          var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
173            factorReplacementMethod);
174          var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
175          if (error != OnlineCalculatorError.None)
176            throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
177
178          impacts[inputVariable] = originalAccuracy - newAccuracy;
179        }
180      } // foreach
181      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
182    }
183
184    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
185      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
186      double replacementValue;
187      List<double> replacementValues;
188      IRandom rand;
189
190      switch (replacement) {
191        case ReplacementMethodEnum.Median:
192          replacementValue = rows.Select(r => originalValues[r]).Median();
193          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
194          break;
195        case ReplacementMethodEnum.Average:
196          replacementValue = rows.Select(r => originalValues[r]).Average();
197          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
198          break;
199        case ReplacementMethodEnum.Shuffle:
200          // new var has same empirical distribution but the relation to y is broken
201          rand = new FastRandom(31415);
202          // prepare a complete column for the dataset
203          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
204          // shuffle only the selected rows
205          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
206          int i = 0;
207          // update column values
208          foreach (var r in rows) {
209            replacementValues[r] = shuffledValues[i++];
210          }
211          break;
212        case ReplacementMethodEnum.Noise:
213          var avg = rows.Select(r => originalValues[r]).Average();
214          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
215          rand = new FastRandom(31415);
216          // prepare a complete column for the dataset
217          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
218          // update column values
219          foreach (var r in rows) {
220            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
221          }
222          break;
223
224        default:
225          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
226      }
227
228      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
229    }
230
231    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
232      IClassificationModel model, string variable, ModifiableDataset dataset,
233      IEnumerable<int> rows,
234      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
235      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
236      List<string> replacementValues;
237      IRandom rand;
238
239      switch (replacement) {
240        case FactorReplacementMethodEnum.Mode:
241          var mostCommonValue = rows.Select(r => originalValues[r])
242            .GroupBy(v => v)
243            .OrderByDescending(g => g.Count())
244            .First().Key;
245          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
246          break;
247        case FactorReplacementMethodEnum.Shuffle:
248          // new var has same empirical distribution but the relation to y is broken
249          rand = new FastRandom(31415);
250          // prepare a complete column for the dataset
251          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
252          // shuffle only the selected rows
253          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
254          int i = 0;
255          // update column values
256          foreach (var r in rows) {
257            replacementValues[r] = shuffledValues[i++];
258          }
259          break;
260        default:
261          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
262      }
263
264      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
265    }
266
267    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
268      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
269      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
270      dataset.ReplaceVariable(variable, replacementValues.ToList());
271      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
272      var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
273      dataset.ReplaceVariable(variable, originalValues);
274
275      return estimates;
276    }
277    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
278      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
279      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
280      dataset.ReplaceVariable(variable, replacementValues.ToList());
281      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
282      var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
283      dataset.ReplaceVariable(variable, originalValues);
284
285      return estimates;
286    }
287  }
288}
Note: See TracBrowser for help on using the repository browser.