Free cookie consent management tool by TermsFeed Policy Generator

source: branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs @ 14498

Last change on this file since 14498 was 14498, checked in by gkronber, 7 years ago

#2650: merged r14457:14494 from trunk to branch (resolving conflicts)

File size: 10.9 KB
Line 
1#region License Information
2
3/* HeuristicLab
4 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Random;
33
34namespace HeuristicLab.Problems.DataAnalysis {
35  [StorableClass]
36  [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
37  public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
38    public enum ReplacementMethodEnum {
39      Median,
40      Average,
41      Shuffle,
42      Noise
43    }
44
45    public enum DataPartitionEnum {
46      Training,
47      Test,
48      All
49    }
50
51    private const string ReplacementParameterName = "Replacement Method";
52    private const string DataPartitionParameterName = "DataPartition";
53
54    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
55      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
56    }
57    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
58      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
59    }
60
61    public ReplacementMethodEnum ReplacementMethod {
62      get { return ReplacementParameter.Value.Value; }
63      set { ReplacementParameter.Value.Value = value; }
64    }
65    public DataPartitionEnum DataPartition {
66      get { return DataPartitionParameter.Value.Value; }
67      set { DataPartitionParameter.Value.Value = value; }
68    }
69
70
71    [StorableConstructor]
72    private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
73    private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
74      : base(original, cloner) { }
75    public override IDeepCloneable Clone(Cloner cloner) {
76      return new RegressionSolutionVariableImpactsCalculator(this, cloner);
77    }
78
79    public RegressionSolutionVariableImpactsCalculator()
80      : base() {
81      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
82      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
83    }
84
85    //mkommend: annoying name clash with static method, open to better naming suggestions
86    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
87      return CalculateImpacts(solution, DataPartition, ReplacementMethod);
88    }
89
90    public static IEnumerable<Tuple<string, double>> CalculateImpacts(IRegressionSolution solution,
91      DataPartitionEnum data = DataPartitionEnum.Training,
92      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median) {
93
94      var problemData = solution.ProblemData;
95      var dataset = problemData.Dataset;
96
97      IEnumerable<int> rows;
98      IEnumerable<double> targetValues;
99      double originalR2 = -1;
100
101      OnlineCalculatorError error;
102
103      switch (data) {
104        case DataPartitionEnum.All:
105          rows = solution.ProblemData.AllIndices;
106          targetValues = problemData.TargetVariableValues.ToList();
107          originalR2 = OnlinePearsonsRCalculator.Calculate(problemData.TargetVariableValues, solution.EstimatedValues, out error);
108          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation.");
109          originalR2 = originalR2 * originalR2;
110          break;
111        case DataPartitionEnum.Training:
112          rows = problemData.TrainingIndices;
113          targetValues = problemData.TargetVariableTrainingValues.ToList();
114          originalR2 = solution.TrainingRSquared;
115          break;
116        case DataPartitionEnum.Test:
117          rows = problemData.TestIndices;
118          targetValues = problemData.TargetVariableTestValues.ToList();
119          originalR2 = solution.TestRSquared;
120          break;
121        default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
122      }
123
124      var impacts = new Dictionary<string, double>();
125      var modifiableDataset = ((Dataset)dataset).ToModifiable();
126
127      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
128      var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
129
130      // calculate impacts for double variables
131      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
132        var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);
133        var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
134        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
135
136        newR2 = newR2 * newR2;
137        var impact = originalR2 - newR2;
138        impacts[inputVariable] = impact;
139      }
140      // calculate impacts for factor variables
141      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
142        var smallestImpact = double.PositiveInfinity;
143        foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
144          var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, Enumerable.Repeat(repl, dataset.Rows));
145          var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
146          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
147
148          newR2 = newR2 * newR2;
149          var impact = originalR2 - newR2;
150          if (impact < smallestImpact) smallestImpact = impact;
151        }
152        impacts[inputVariable] = smallestImpact;
153      }
154      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
155    }
156
157    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
158      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
159      double replacementValue;
160      List<double> replacementValues;
161      IRandom rand;
162
163      switch (replacement) {
164        case ReplacementMethodEnum.Median:
165          replacementValue = rows.Select(r => originalValues[r]).Median();
166          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
167          break;
168        case ReplacementMethodEnum.Average:
169          replacementValue = rows.Select(r => originalValues[r]).Average();
170          replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
171          break;
172        case ReplacementMethodEnum.Shuffle:
173          // new var has same empirical distribution but the relation to y is broken
174          rand = new FastRandom(31415);
175          // prepare a complete column for the dataset
176          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
177          // shuffle only the selected rows
178          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
179          int i = 0;
180          // update column values
181          foreach (var r in rows) {
182            replacementValues[r] = shuffledValues[i++];
183          }
184          break;
185        case ReplacementMethodEnum.Noise:
186          var avg = rows.Select(r => originalValues[r]).Average();
187          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
188          rand = new FastRandom(31415);
189          // prepare a complete column for the dataset
190          replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
191          // update column values
192          foreach (var r in rows) {
193            replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
194          }
195          break;
196
197        default:
198          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
199      }
200
201      return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
202    }
203
204    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
205      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
206      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
207      dataset.ReplaceVariable(variable, replacementValues.ToList());
208      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
209      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
210      dataset.ReplaceVariable(variable, originalValues);
211
212      return estimates;
213    }
214    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
215      ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
216      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
217      dataset.ReplaceVariable(variable, replacementValues.ToList());
218      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
219      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
220      dataset.ReplaceVariable(variable, originalValues);
221
222      return estimates;
223    }
224  }
225}
Note: See TracBrowser for help on using the repository browser.