Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
02/26/18 14:03:27 (7 years ago)
Author:
fholzing
Message:

#2904: Splitted the Method CalculateImpact into smaller parts

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs

    r15802 r15815  
    5656    private const string DataPartitionParameterName = "DataPartition";
    5757
    58     public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
     58    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter
     59    {
    5960      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
    6061    }
    61     public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
     62    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter
     63    {
    6264      get { return (IFixedValueParameter<EnumValue<DataPartitionEnum>>)Parameters[DataPartitionParameterName]; }
    6365    }
    6466
    65     public ReplacementMethodEnum ReplacementMethod {
     67    public ReplacementMethodEnum ReplacementMethod
     68    {
    6669      get { return ReplacementParameter.Value.Value; }
    6770      set { ReplacementParameter.Value.Value = value; }
    6871    }
    69     public DataPartitionEnum DataPartition {
     72    public DataPartitionEnum DataPartition
     73    {
    7074      get { return DataPartitionParameter.Value.Value; }
    7175      set { DataPartitionParameter.Value.Value = value; }
     
    9296    }
    9397
     98    private static void PrepareData(DataPartitionEnum partition,
     99      IRegressionSolution solution,
     100      out IEnumerable<int> rows,
     101      out IEnumerable<double> targetValues,
     102      out double originalR2) {
     103      OnlineCalculatorError error;
     104
     105      switch (partition) {
     106        case DataPartitionEnum.All:
     107          rows = solution.ProblemData.AllIndices;
     108          targetValues = solution.ProblemData.TargetVariableValues.ToList();
     109          originalR2 = OnlinePearsonsRCalculator.Calculate(solution.ProblemData.TargetVariableValues, solution.EstimatedValues, out error);
     110          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation.");
     111          originalR2 = originalR2 * originalR2;
     112          break;
     113        case DataPartitionEnum.Training:
     114          rows = solution.ProblemData.TrainingIndices;
     115          targetValues = solution.ProblemData.TargetVariableTrainingValues.ToList();
     116          originalR2 = solution.TrainingRSquared;
     117          break;
     118        case DataPartitionEnum.Test:
     119          rows = solution.ProblemData.TestIndices;
     120          targetValues = solution.ProblemData.TargetVariableTestValues.ToList();
     121          originalR2 = solution.TestRSquared;
     122          break;
     123        default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", partition));
     124      }
     125    }
     126
     127    private static double CalculateImpactForDouble(string variableName,
     128      IRegressionSolution solution,
     129      ModifiableDataset modifiableDataset,
     130      IEnumerable<int> rows,
     131      IEnumerable<double> targetValues,
     132      double originalR2,
     133      ReplacementMethodEnum replacementMethod) {
     134      OnlineCalculatorError error;
     135      var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, replacementMethod);
     136      var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
     137      if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during R² calculation with replaced inputs."); }
     138      return originalR2 - (newR2 * newR2);
     139    }
     140
     141    private static double CalculateImpactForString(string variableName,
     142      IRegressionSolution solution,
     143      ModifiableDataset modifiableDataset,
     144      IEnumerable<int> rows,
     145      IEnumerable<double> targetValues,
     146      double originalR2,
     147      FactorReplacementMethodEnum factorReplacementMethod) {
     148
     149      OnlineCalculatorError error;
     150      if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
     151        // try replacing with all possible values and find the best replacement value
     152        var smallestImpact = double.PositiveInfinity;
     153        foreach (var repl in solution.ProblemData.Dataset.GetStringValues(variableName, rows).Distinct()) {
     154          var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, solution.ProblemData.Dataset.Rows));
     155          var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
     156          if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
     157
     158          var curImpact = originalR2 - (newR2 * newR2);
     159          if (curImpact < smallestImpact) smallestImpact = curImpact;
     160        }
     161        return smallestImpact;
     162      } else {
     163        // for replacement methods shuffle and mode
     164        // calculate impacts for factor variables
     165        var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, variableName, modifiableDataset, rows, factorReplacementMethod);
     166        var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
     167        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
     168
     169        return originalR2 - (newR2 * newR2);
     170      }
     171    }
     172    public static double CalculateImpact(string variableName,
     173      IRegressionSolution solution,
     174      DataPartitionEnum data = DataPartitionEnum.Training,
     175      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     176      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     177      IEnumerable<int> rows = null,
     178      IEnumerable<double> targetValues = null,
     179    double originalR2 = -1) {
     180
     181      double impact = 0;
     182      var modifiableDataset = ((Dataset)solution.ProblemData.Dataset).ToModifiable();
     183
     184      // calculate impacts for double variables
     185      if (solution.ProblemData.Dataset.VariableHasType<double>(variableName)) {
     186        impact = CalculateImpactForDouble(variableName, solution, modifiableDataset, rows, targetValues, originalR2, replacementMethod);
     187      } else if (solution.ProblemData.Dataset.VariableHasType<string>(variableName)) {
     188        impact = CalculateImpactForString(variableName, solution, modifiableDataset, rows, targetValues, originalR2, factorReplacementMethod);
     189      } else {
     190        throw new NotSupportedException("Variable not supported");
     191      }
     192      return impact;
     193    }
     194
    94195    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
    95196      IRegressionSolution solution,
     
    99200      Func<double, string, bool> progressCallback = null) {
    100201
    101       var problemData = solution.ProblemData;
    102       var dataset = problemData.Dataset;
    103 
    104202      IEnumerable<int> rows;
    105203      IEnumerable<double> targetValues;
    106204      double originalR2 = -1;
    107205
    108       OnlineCalculatorError error;
    109 
    110       switch (data) {
    111         case DataPartitionEnum.All:
    112           rows = solution.ProblemData.AllIndices;
    113           targetValues = problemData.TargetVariableValues.ToList();
    114           originalR2 = OnlinePearsonsRCalculator.Calculate(problemData.TargetVariableValues, solution.EstimatedValues, out error);
    115           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation.");
    116           originalR2 = originalR2 * originalR2;
    117           break;
    118         case DataPartitionEnum.Training:
    119           rows = problemData.TrainingIndices;
    120           targetValues = problemData.TargetVariableTrainingValues.ToList();
    121           originalR2 = solution.TrainingRSquared;
    122           break;
    123         case DataPartitionEnum.Test:
    124           rows = problemData.TestIndices;
    125           targetValues = problemData.TargetVariableTestValues.ToList();
    126           originalR2 = solution.TestRSquared;
    127           break;
    128         default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
    129       }
     206      PrepareData(data, solution, out rows, out targetValues, out originalR2);
    130207
    131208      var impacts = new Dictionary<string, double>();
    132       var modifiableDataset = ((Dataset)dataset).ToModifiable();
    133 
    134       var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
    135       var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
     209      var inputvariables = new HashSet<string>(solution.ProblemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
     210      var allowedInputVariables = solution.ProblemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
    136211
    137212      int curIdx = 0;
    138       int count = allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>).Count();
     213      int count = allowedInputVariables.Where(solution.ProblemData.Dataset.VariableHasType<double>).Count();
    139214      // calculate impacts for double variables
    140       foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
     215      foreach (var inputVariable in allowedInputVariables) {
    141216        //Report the current progress in percent. If the callback returns true, it means the execution shall be stopped
    142217        if (progressCallback != null) {
     
    144219          if (progressCallback((double)curIdx / count, string.Format("Calculating impact for variable {0} ({1} of {2})", inputVariable, curIdx, count))) { return null; }
    145220        }
    146         var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);
    147         var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
    148         if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
    149 
    150         newR2 = newR2 * newR2;
    151         var impact = originalR2 - newR2;
    152         impacts[inputVariable] = impact;
    153       }
    154 
    155       // calculate impacts for string variables
    156       foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
    157         if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
    158           // try replacing with all possible values and find the best replacement value
    159           var smallestImpact = double.PositiveInfinity;
    160           foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
    161             var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
    162               Enumerable.Repeat(repl, dataset.Rows));
    163             var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
    164             if (error != OnlineCalculatorError.None)
    165               throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
    166 
    167             newR2 = newR2 * newR2;
    168             var impact = originalR2 - newR2;
    169             if (impact < smallestImpact) smallestImpact = impact;
    170           }
    171           impacts[inputVariable] = smallestImpact;
    172         } else {
    173           // for replacement methods shuffle and mode
    174           // calculate impacts for factor variables
    175 
    176           var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
    177             factorReplacementMethod);
    178           var newR2 = OnlinePearsonsRCalculator.Calculate(targetValues, newEstimates, out error);
    179           if (error != OnlineCalculatorError.None)
    180             throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
    181 
    182           newR2 = newR2 * newR2;
    183           var impact = originalR2 - newR2;
    184           impacts[inputVariable] = impact;
    185         }
    186       } // foreach
     221        impacts[inputVariable] = CalculateImpact(inputVariable, solution, data, replacementMethod, factorReplacementMethod, rows, targetValues, originalR2);
     222      }
     223
    187224      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
    188225    }
    189 
    190226
    191227    private static IEnumerable<double> EvaluateModelWithReplacedVariable(IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
Note: See TracChangeset for help on using the changeset viewer.