Free cookie consent management tool by TermsFeed Policy Generator

source: branches/PersistenceReintegration/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 15234

Last change on this file since 15234 was 15018, checked in by gkronber, 8 years ago

#2520 introduced StorableConstructorFlag type for StorableConstructors

File size: 13.8 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Problems.DataAnalysis {
35  [StorableType("de95788b-0353-4996-b307-c66432460ed2")]
36  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
37  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
38    [StorableType("0bf10277-e9e2-45e1-bd14-36691f5ec384")]
39    public enum ReplacementMethodEnum {
40      Median,
41      Average,
42      Shuffle,
43      Noise
44    }
45    [StorableType("713485d4-ce6c-4066-8a1b-6c809456fde1")]
46    public enum FactorReplacementMethodEnum {
47      Best,
48      Mode,
49      Shuffle
50    }
51    [StorableType("d8dac633-f199-4fb5-b7e6-92ddfcf4be94")]
52    public enum DataPartitionEnum {
53      Training,
54      Test,
55      All
56    }
57
58    private const string ReplacementParameterName = "Replacement Method";
59    private const string DataPartitionParameterName = "DataPartition";
60
61    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
62      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
63    }
64    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
65      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
66    }
67
68    public ReplacementMethodEnum ReplacementMethod {
69      get { return ReplacementParameter.Value.Value; }
70      set { ReplacementParameter.Value.Value = value; }
71    }
72    public DataPartitionEnum DataPartition {
73      get { return DataPartitionParameter.Value.Value; }
74      set { DataPartitionParameter.Value.Value = value; }
75    }
76
77
78    [StorableConstructor]
79    private RegressionSolutionVariableImpactsCalculator(StorableConstructorFlag deserializing) : base(deserializing) { }
80    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
81      : base(original, cloner) { }
82    public override IDeepCloneable Clone(Cloner cloner) {
83      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
84    }
85
86    public RegressionSolutionVariableImpactsCalculator()
87      : base() {
88      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
89      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
90    }
91
92    //mkommend: annoying name clash with static method, open to better naming suggestions
93    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
94      return CalculateImpacts(solution, DataPartition, ReplacementMethod);
95    }
96
97    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
98      IRegressionSolution solution,
99      DataPartitionEnum data = DataPartitionEnum.Training,
100      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
101      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
102
103      var problemData = solution.ProblemData;
104      var dataset = problemData.Dataset;
105
106      IEnumerable<int> rows;
107      IEnumerable<double> targetValues;
108      double originalR2 = -1;
109
110      OnlineCalculatorError error;
111
112      switch (data) {
113        case DataPartitionEnum.All:
114          rows = solution.ProblemData.AllIndices;
115          targetValues = problemData.TargetVariableValues.ToList();
116          originalR2 = OnlinePearsonsRCalculator.Calculate(problemData.TargetVariableValues, solution.EstimatedValues, out error);
117          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation.");
118          originalR2 = originalR2 * originalR2;
119          break;
120        case DataPartitionEnum.Training:
121          rows = problemData.TrainingIndices;
122          targetValues = problemData.TargetVariableTrainingValues.ToList();
123          originalR2 = solution.TrainingRSquared;
124          break;
125        case DataPartitionEnum.Test:
126          rows = problemData.TestIndices;
127          targetValues = problemData.TargetVariableTestValues.ToList();
128          originalR2 = solution.TestRSquared;
129          break;
130        default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
131      }
132
133      var impacts = new Dictionary<string, double>();
134      var modifiableDataset = ((Dataset)dataset).ToModifiable();
135
136      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
137      var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
138
139      // calculate impacts for double variables
140      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
141        var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);
142        var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
143        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
144
145        newR2 = newR2 * newR2;
146        var impact = originalR2 - newR2;
147        impacts[inputVariable] = impact;
148      }
149
150      // calculate impacts for string variables
151      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
152        if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
153          // try replacing with all possible values and find the best replacement value
154          var smallestImpact = double.PositiveInfinity;
155          foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
156            var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
157              Enumerable.Repeat(repl, dataset.Rows));
158            var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
159            if (error != OnlineCalculatorError.None)
160              throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
161
162            newR2 = newR2 * newR2;
163            var impact = originalR2 - newR2;
164            if (impact < smallestImpact) smallestImpact = impact;
165          }
166          impacts[inputVariable] = smallestImpact;
167        } else {
168          // for replacement methods shuffle and mode
169          // calculate impacts for factor variables
170
171          var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
172            factorReplacementMethod);
173          var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
174          if (error != OnlineCalculatorError.None)
175            throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
176
177          newR2 = newR2 * newR2;
178          var impact = originalR2 - newR2;
179          impacts[inputVariable] = impact;
180        }
181      } // foreach
182      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
183    }
184
185    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
186      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
187      double replacementValue;
188      List<double> replacementValues;
189      IRandom rand;
190
191      switch (replacement) {
192        case ReplacementMethodEnum.Median:
193          replacementValue = rows.Select(r => originalValues[r]).Median();
194          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
195          break;
196        case ReplacementMethodEnum.Average:
197          replacementValue = rows.Select(r => originalValues[r]).Average();
198          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
199          break;
200        case ReplacementMethodEnum.Shuffle:
201          // new var has same empirical distribution but the relation to y is broken
202          rand = new FastRandom(31415);
203          // prepare a complete column for the dataset
204          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
205          // shuffle only the selected rows
206          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
207          int i = 0;
208          // update column values
209          foreach (var r in rows) {
210            replacementValues[r] = shuffledValues[i++];
211          }
212          break;
213        case ReplacementMethodEnum.Noise:
214          var avg = rows.Select(r => originalValues[r]).Average();
215          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
216          rand = new FastRandom(31415);
217          // prepare a complete column for the dataset
218          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
219          // update column values
220          foreach (var r in rows) {
221            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
222          }
223          break;
224
225        default:
226          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
227      }
228
229      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
230    }
231
232    private static IEnumerable<double> EvaluateModelWithReplacedVariable(
233      IRegressionModel model, string variable, ModifiableDataset dataset,
234      IEnumerable<int> rows,
235      FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
236      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
237      List<string> replacementValues;
238      IRandom rand;
239
240      switch (replacement) {
241        case FactorReplacementMethodEnum.Mode:
242          var mostCommonValue = rows.Select(r => originalValues[r])
243            .GroupBy(v => v)
244            .OrderByDescending(g => g.Count())
245            .First().Key;
246          replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
247          break;
248        case FactorReplacementMethodEnum.Shuffle:
249          // new var has same empirical distribution but the relation to y is broken
250          rand = new FastRandom(31415);
251          // prepare a complete column for the dataset
252          replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
253          // shuffle only the selected rows
254          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
255          int i = 0;
256          // update column values
257          foreach (var r in rows) {
258            replacementValues[r] = shuffledValues[i++];
259          }
260          break;
261        default:
262          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
263      }
264
265      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
266    }
267
268    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
269      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
270      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
271      dataset.ReplaceVariable(variable, replacementValues.ToList());
272      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
273      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
274      dataset.ReplaceVariable(variable, originalValues);
275
276      return estimates;
277    }
278    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
279      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
280      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
281      dataset.ReplaceVariable(variable, replacementValues.ToList());
282      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
283      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
284      dataset.ReplaceVariable(variable, originalValues);
285
286      return estimates;
287    }
288  }
289}
Note: See TracBrowser for help on using the repository browser.