Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/08/18 10:44:51 (6 years ago)
Author:
fholzing
Message:

#2904: Refactored RegressionSolutionVariableImpactsCalculator. We don't dependent on the solution anymore. The impact can be calculated for a single variable. The calculator can be chosen.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs

    r15816 r15831  
    2323
    2424using System;
     25using System.Collections;
    2526using System.Collections.Generic;
    2627using System.Linq;
     
    9697    }
    9798
    98     private static void PrepareData(DataPartitionEnum partition,
     99    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
    99100      IRegressionSolution solution,
    100       out IEnumerable<int> rows,
     101      DataPartitionEnum data = DataPartitionEnum.Training,
     102      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     103      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     104      Func<double, string, bool> progressCallback = null) {
     105      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, data, replacementMethod, factorReplacementMethod, progressCallback);
     106    }
     107
     108    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
     109      IRegressionModel model,
     110      IRegressionProblemData problemData,
     111      IEnumerable<double> estimatedValues,
     112      DataPartitionEnum data = DataPartitionEnum.Training,
     113      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     114      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     115      Func<double, string, bool> progressCallback = null,
     116      IOnlineCalculator calculator = null) {
     117      //PearsonsRSquared is the default calculator
     118      if (calculator == null) { calculator = new OnlinePearsonsRSquaredCalculator(); }
     119      IEnumerable<int> rows;
     120
     121      switch (data) {
     122        case DataPartitionEnum.All:
     123          rows = problemData.AllIndices;
     124          break;
     125        case DataPartitionEnum.Test:
     126          rows = problemData.TestIndices;
     127          break;
     128        case DataPartitionEnum.Training:
     129          rows = problemData.TrainingIndices;
     130          break;
     131        default:
     132          throw new NotSupportedException("DataPartition not supported");
     133      }
     134
     135      return CalculateImpacts(model, problemData, estimatedValues, rows, calculator, replacementMethod, factorReplacementMethod, progressCallback);
     136    }
     137
     138    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
     139     IRegressionModel model,
     140     IRegressionProblemData problemData,
     141     IEnumerable<double> estimatedValues,
     142     IEnumerable<int> rows,
     143     IOnlineCalculator calculator,
     144     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     145     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     146     Func<double, string, bool> progressCallback = null) {
     147
     148      IEnumerable<double> targetValues;
     149      double originalValue = -1;
     150
     151      PrepareData(rows, problemData, estimatedValues, out targetValues, out originalValue, calculator);
     152
     153      var impacts = new Dictionary<string, double>();
     154      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
     155      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
     156
     157      int curIdx = 0;
     158      int count = allowedInputVariables
     159        .Where(v => problemData.Dataset.VariableHasType<double>(v) || problemData.Dataset.VariableHasType<string>(v))
     160        .Count();
     161
     162      foreach (var inputVariable in allowedInputVariables) {
     163        //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
     164        if (progressCallback != null) {
     165          curIdx++;
     166          if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
     167        }
     168        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValues, originalValue, calculator, replacementMethod, factorReplacementMethod);
     169      }
     170
     171      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
     172    }
     173
     174    public static double CalculateImpact(string variableName,
     175      IRegressionSolution solution,
     176      IEnumerable<int> rows,
     177      IEnumerable<double> targetValues,
     178      double originalValue,
     179      IOnlineCalculator calculator,
     180      DataPartitionEnum data = DataPartitionEnum.Training,
     181      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     182      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
     183      return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, calculator, replacementMethod, factorReplacementMethod);
     184    }
     185
     186    public static double CalculateImpact(string variableName,
     187      IRegressionModel model,
     188      IDataset dataset,
     189      IEnumerable<int> rows,
     190      IEnumerable<double> targetValues,
     191      double originalValue,
     192      IOnlineCalculator calculator,
     193      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     194      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
     195
     196      double impact = 0;
     197      var modifiableDataset = ((Dataset)dataset).ToModifiable();
     198
     199      // calculate impacts for double variables
     200      if (dataset.VariableHasType<double>(variableName)) {
     201        impact = CalculateImpactForDouble(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod, calculator);
     202      } else if (dataset.VariableHasType<string>(variableName)) {
     203        impact = CalculateImpactForString(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod, calculator);
     204      } else {
     205        throw new NotSupportedException("Variable not supported");
     206      }
     207      return impact;
     208    }
     209
     210    private static void PrepareData(IEnumerable<int> rows,
     211      IRegressionProblemData problemData,
     212      IEnumerable<double> estimatedValues,
    101213      out IEnumerable<double> targetValues,
    102       out double originalR2) {
     214      out double originalValue,
     215      IOnlineCalculator calculator) {
    103216      OnlineCalculatorError error;
    104217
    105       switch (partition) {
    106         case DataPartitionEnum.All:
    107           rows = solution.ProblemData.AllIndices;
    108           targetValues = solution.ProblemData.TargetVariableValues.ToList();
    109           originalR2 = OnlinePearsonsRCalculator.Calculate(solution.ProblemData.TargetVariableValues, solution.EstimatedValues, out error);
    110           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation.");
    111           originalR2 = originalR2 * originalR2;
    112           break;
    113         case DataPartitionEnum.Training:
    114           rows = solution.ProblemData.TrainingIndices;
    115           targetValues = solution.ProblemData.TargetVariableTrainingValues.ToList();
    116           originalR2 = solution.TrainingRSquared;
    117           break;
    118         case DataPartitionEnum.Test:
    119           rows = solution.ProblemData.TestIndices;
    120           targetValues = solution.ProblemData.TargetVariableTestValues.ToList();
    121           originalR2 = solution.TestRSquared;
    122           break;
    123         default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", partition));
    124       }
     218      var targetVariableValueList = problemData.TargetVariableValues.ToList();
     219      targetValues = rows.Select(v => targetVariableValueList.ElementAt(v));
     220      var estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
     221      originalValue = calculator.CalculateValue(targetValues, estimatedValuesPartition, out error);
     222
     223      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
    125224    }
    126225
    127226    private static double CalculateImpactForDouble(string variableName,
    128       IRegressionSolution solution,
     227      IRegressionModel model,
    129228      ModifiableDataset modifiableDataset,
    130229      IEnumerable<int> rows,
    131230      IEnumerable<double> targetValues,
    132       double originalR2,
    133       ReplacementMethodEnum replacementMethod) {
     231      double originalValue,
     232      ReplacementMethodEnum replacementMethod,
     233      IOnlineCalculator calculator) {
    134234      OnlineCalculatorError error;
    135       var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, replacementMethod);
    136       var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
    137       if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
    138       return originalR2 - (newR2 * newR2);
     235      var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, replacementMethod);
     236      var newValue = calculator.CalculateValue(targetValues, newEstimates, out error);
     237      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
     238      return originalValue - newValue;
    139239    }
    140240
    141241    private static double CalculateImpactForString(string variableName,
    142       IRegressionSolution solution,
     242      IRegressionModel model,
     243      IDataset problemData,
    143244      ModifiableDataset modifiableDataset,
    144245      IEnumerable<int> rows,
    145246      IEnumerable<double> targetValues,
    146       double originalR2,
    147       FactorReplacementMethodEnum factorReplacementMethod) {
     247      double originalValue,
     248      FactorReplacementMethodEnum factorReplacementMethod,
     249      IOnlineCalculator calculator) {
    148250
    149251      OnlineCalculatorError error;
     
    151253        // try replacing with all possible values and find the best replacement value
    152254        var smallestImpact = double.PositiveInfinity;
    153         foreach (var repl in solution.ProblemData.Dataset.GetStringValues(variableName, rows).Distinct()) {
    154           var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, solution.ProblemData.Dataset.Rows));
    155           var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
    156           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
    157 
    158           var curImpact = originalR2 - (newR2 * newR2);
     255        foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) {
     256          var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
     257          var newEstimates = EvaluateModelWithReplacedVariable(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList());
     258          var newValue = calculator.CalculateValue(targetValues, newEstimates, out error);
     259          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     260
     261          var curImpact = originalValue - newValue;
    159262          if (curImpact < smallestImpact) smallestImpact = curImpact;
    160263        }
     
    163266        // for replacement methods shuffle and mode
    164267        // calculate impacts for factor variables
    165         var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, factorReplacementMethod);
    166         var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
    167         if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
    168 
    169         return originalR2 - (newR2 * newR2);
    170       }
    171     }
    172     public static double CalculateImpact(string variableName,
    173       IRegressionSolution solution,
    174       IEnumerable<int> rows,
    175       IEnumerable<double> targetValues,
    176       double originalR2,
    177       DataPartitionEnum data = DataPartitionEnum.Training,
    178       ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
    179       FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    180 
    181       double impact = 0;
    182       var modifiableDataset = ((Dataset)solution.ProblemData.Dataset).ToModifiable();
    183 
    184       // calculate impacts for double variables
    185       if (solution.ProblemData.Dataset.VariableHasType<double>(variableName)) {
    186         impact = CalculateImpactForDouble(variableName, solution, modifiableDataset, rows, targetValues, originalR2, replacementMethod);
    187       } else if (solution.ProblemData.Dataset.VariableHasType<string>(variableName)) {
    188         impact = CalculateImpactForString(variableName, solution, modifiableDataset, rows, targetValues, originalR2, factorReplacementMethod);
    189       } else {
    190         throw new NotSupportedException("Variable not supported");
    191       }
    192       return impact;
    193     }
    194 
    195     public static IEnumerable<Tuple<string, double>> CalculateImpacts(
    196       IRegressionSolution solution,
    197       DataPartitionEnum data = DataPartitionEnum.Training,
    198       ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
    199       FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
    200       Func<double, string, bool> progressCallback = null) {
    201 
    202       IEnumerable<int> rows;
    203       IEnumerable<double> targetValues;
    204       double originalR2 = -1;
    205 
    206       PrepareData(data, solution, out rows, out targetValues, out originalR2);
    207 
    208       var impacts = new Dictionary<string, double>();
    209       var inputvariables = new HashSet<string>(solution.ProblemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
    210       var allowedInputVariables = solution.ProblemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
    211 
    212       int curIdx = 0;
    213       int count = allowedInputVariables.Where(solution.ProblemData.Dataset.VariableHasType<double>).Count();
    214       // calculate impacts for double variables
    215       foreach (var inputVariable in allowedInputVariables) {
    216         //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
    217         if (progressCallback != null) {
    218           curIdx++;
    219           if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
    220         }
    221         impacts[inputVariable] = CalculateImpact(inputVariable, solution, rows, targetValues, originalR2, data, replacementMethod, factorReplacementMethod);
    222       }
    223 
    224       return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
     268        var newEstimates = EvaluateModelWithReplacedVariable(model, variableName, modifiableDataset, rows, factorReplacementMethod);
     269        var newValue = calculator.CalculateValue(targetValues, newEstimates, out error);
     270        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     271
     272        return originalValue - newValue;
     273      }
    225274    }
    226275
     
    269318      }
    270319
    271       return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
     320      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
    272321    }
    273322
     
    305354      }
    306355
    307       return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
    308     }
    309 
    310     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
    311       ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
    312       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    313       dataset.ReplaceVariable(variable, replacementValues.ToList());
     356      return EvaluateModelWithReplacedVariable(originalValues, model, variable, dataset, rows, replacementValues);
     357    }
     358
     359    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IList originalValues, IRegressionModel model, string variable,
     360      ModifiableDataset dataset, IEnumerable<int> rows, IList replacementValues) {
     361      dataset.ReplaceVariable(variable, replacementValues);
    314362      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    315363      var estimates = model.GetEstimatedValues(dataset, rows).ToList();
     
    318366      return estimates;
    319367    }
    320     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable,
    321       ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
    322       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    323       dataset.ReplaceVariable(variable, replacementValues.ToList());
    324       //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    325       var estimates = model.GetEstimatedValues(dataset, rows).ToList();
    326       dataset.ReplaceVariable(variable, originalValues);
    327 
    328       return estimates;
    329     }
    330368  }
    331369}
Note: See TracChangeset for help on using the changeset viewer.