Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/06/18 16:16:51 (6 years ago)
Author:
fholzing
Message:

#2904: Refactored Methods for better readability. Fixed styling mistakes.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs

    r16041 r16051  
    9090    public RegressionSolutionVariableImpactsCalculator()
    9191      : base() {
    92       Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
     92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle)));
    9393      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
    9494      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
     
    110110      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
    111111      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
    112       IEnumerable<int> rows = (GetPartitionRows(dataPartition, solution.ProblemData));
     112
     113      IEnumerable<int> rows = GetPartitionRows(dataPartition, solution.ProblemData);
    113114      IEnumerable<double> estimatedValues = solution.GetEstimatedValues(rows);
    114115      return CalculateImpacts(solution.Model, solution.ProblemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
     
    122123     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    123124     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    124       //Calculate original quality-values (via calculator, default is R²)
    125       OnlineCalculatorError error;
     125
     126      //fholzing: try and catch in case a different dataset is loaded, otherwise statement is neglectable
     127      var missingVariables = model.VariablesUsedForPrediction.Except(problemData.Dataset.VariableNames);
     128      if (missingVariables.Any()) {
     129        throw new InvalidOperationException(string.Format("Can not calculate variable impacts, because the model uses inputs missing in the dataset ({0})", string.Join(", ", missingVariables)));
     130      }
    126131      IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
    127       var originalCalculatorValue = CalculateVariableImpact(targetValues, estimatedValues, out error);
    128       if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
     132      var originalQuality = CalculateQuality(targetValues, estimatedValues);
    129133
    130134      var impacts = new Dictionary<string, double>();
    131135      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
    132       var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
    133136      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
    134137
    135 
    136       foreach (var inputVariable in allowedInputVariables) {
    137         if (model.VariablesUsedForPrediction.Contains(inputVariable)) {
    138           impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValues, originalCalculatorValue, replacementMethod, factorReplacementMethod);
    139         } else {
    140           impacts[inputVariable] = 0;
    141         }
    142       }
    143 
    144       return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
     138      foreach (var inputVariable in inputvariables) {
     139        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData, modifiableDataset, rows, replacementMethod, factorReplacementMethod, targetValues, originalQuality);
     140      }
     141
     142      return impacts.Select(i => Tuple.Create(i.Key, i.Value));
    145143    }
    146144
    147145    public static double CalculateImpact(string variableName,
    148146      IRegressionModel model,
     147      IRegressionProblemData problemData,
    149148      ModifiableDataset modifiableDataset,
    150149      IEnumerable<int> rows,
     150      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     151      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     152      IEnumerable<double> targetValues = null,
     153      double quality = double.NaN) {
     154
     155      if (!model.VariablesUsedForPrediction.Contains(variableName)) { return 0.0; }
     156      if (!problemData.Dataset.VariableNames.Contains(variableName)) {
     157        throw new InvalidOperationException(string.Format("Can not calculate variable impact, because the model uses inputs missing in the dataset ({0})", variableName));
     158      }
     159
     160      if (targetValues == null) {
     161        targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     162      }
     163      if (quality == double.NaN) {
     164        quality = CalculateQuality(model.GetEstimatedValues(modifiableDataset, rows), targetValues);
     165      }
     166
     167      IList originalValues = null;
     168      IList replacementValues = GetReplacementValues(modifiableDataset, variableName, model, rows, targetValues, out originalValues, replacementMethod, factorReplacementMethod);
     169
     170      double newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, replacementValues, targetValues);
     171      double impact = quality - newValue;
     172
     173      return impact;
     174    }
     175
     176    private static IList GetReplacementValues(ModifiableDataset modifiableDataset,
     177      string variableName,
     178      IRegressionModel model,
     179      IEnumerable<int> rows,
    151180      IEnumerable<double> targetValues,
    152       double originalValue,
     181      out IList originalValues,
    153182      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    154183      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    155       double impact = 0;
    156       OnlineCalculatorError error;
    157       IRandom random;
    158       double replacementValue;
    159       IEnumerable<double> newEstimates = null;
    160       double newValue = 0;
    161 
     184
     185      IList replacementValues = null;
    162186      if (modifiableDataset.VariableHasType<double>(variableName)) {
    163         #region NumericalVariable
    164         var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
    165         List<double> replacementValues;
    166 
    167         switch (replacementMethod) {
    168           case ReplacementMethodEnum.Median:
    169             replacementValue = rows.Select(r => originalValues[r]).Median();
    170             replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
    171             break;
    172           case ReplacementMethodEnum.Average:
    173             replacementValue = rows.Select(r => originalValues[r]).Average();
    174             replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
    175             break;
    176           case ReplacementMethodEnum.Shuffle:
    177             // new var has same empirical distribution but the relation to y is broken
    178             random = new FastRandom(31415);
    179             // prepare a complete column for the dataset
    180             replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
    181             // shuffle only the selected rows
    182             var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
    183             int i = 0;
    184             // update column values
    185             foreach (var r in rows) {
    186               replacementValues[r] = shuffledValues[i++];
    187             }
    188             break;
    189           case ReplacementMethodEnum.Noise:
    190             var avg = rows.Select(r => originalValues[r]).Average();
    191             var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
    192             random = new FastRandom(31415);
    193             // prepare a complete column for the dataset
    194             replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
    195             // update column values
    196             foreach (var r in rows) {
    197               replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
    198             }
    199             break;
    200 
    201           default:
    202             throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
    203         }
    204 
    205         newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
    206         newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    207         if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
    208 
    209         impact = originalValue - newValue;
    210         #endregion
     187        originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
     188        replacementValues = GetReplacementValuesForDouble(modifiableDataset, rows, (List<double>)originalValues, replacementMethod);
    211189      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
    212         #region FactorVariable
    213         var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
    214         List<string> replacementValues;
    215 
    216         switch (factorReplacementMethod) {
    217           case FactorReplacementMethodEnum.Best:
    218             // try replacing with all possible values and find the best replacement value
    219             var smallestImpact = double.PositiveInfinity;
    220             foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
    221               newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
    222               newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    223               if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    224 
    225               var curImpact = originalValue - newValue;
    226               if (curImpact < smallestImpact) smallestImpact = curImpact;
    227             }
    228             impact = smallestImpact;
    229             break;
    230           case FactorReplacementMethodEnum.Mode:
    231             var mostCommonValue = rows.Select(r => originalValues[r])
    232               .GroupBy(v => v)
    233               .OrderByDescending(g => g.Count())
    234               .First().Key;
    235             replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
    236 
    237             newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
    238             newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    239             if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    240 
    241             impact = originalValue - newValue;
    242             break;
    243           case FactorReplacementMethodEnum.Shuffle:
    244             // new var has same empirical distribution but the relation to y is broken
    245             random = new FastRandom(31415);
    246             // prepare a complete column for the dataset
    247             replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
    248             // shuffle only the selected rows
    249             var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
    250             int i = 0;
    251             // update column values
    252             foreach (var r in rows) {
    253               replacementValues[r] = shuffledValues[i++];
    254             }
    255 
    256             newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
    257             newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    258             if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    259 
    260             impact = originalValue - newValue;
    261             break;
    262           default:
    263             throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
    264         }
    265         #endregion
     190        originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
     191        replacementValues = GetReplacementValuesForString(model, modifiableDataset, variableName, rows, originalValues, targetValues, factorReplacementMethod);
    266192      } else {
    267193        throw new NotSupportedException("Variable not supported");
    268194      }
    269195
    270       return impact;
    271     }
    272 
    273     /// <summary>
    274     /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
    275     /// and changes the value of the model-variables back to the original ones.
    276     /// </summary>
    277     /// <param name="originalValues"></param>
    278     /// <param name="model"></param>
    279     /// <param name="variableName"></param>
    280     /// <param name="modifiableDataset"></param>
    281     /// <param name="rows"></param>
    282     /// <param name="replacementValues"></param>
    283     /// <returns></returns>
    284     private static IEnumerable<double> GetReplacedEstimates(
     196      return replacementValues;
     197    }
     198
     199    private static IList GetReplacementValuesForDouble(ModifiableDataset modifiableDataset,
     200      IEnumerable<int> rows,
     201      List<double> originalValues,
     202      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle) {
     203
     204      IRandom random = new FastRandom(31415);
     205      List<double> replacementValues;
     206      double replacementValue;
     207
     208      switch (replacementMethod) {
     209        case ReplacementMethodEnum.Median:
     210          replacementValue = rows.Select(r => originalValues[r]).Median();
     211          replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     212          break;
     213        case ReplacementMethodEnum.Average:
     214          replacementValue = rows.Select(r => originalValues[r]).Average();
     215          replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     216          break;
     217        case ReplacementMethodEnum.Shuffle:
     218          // new var has same empirical distribution but the relation to y is broken
     219          // prepare a complete column for the dataset
     220          replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     221          // shuffle only the selected rows
     222          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     223          int i = 0;
     224          // update column values
     225          foreach (var r in rows) {
     226            replacementValues[r] = shuffledValues[i++];
     227          }
     228          break;
     229        case ReplacementMethodEnum.Noise:
     230          var avg = rows.Select(r => originalValues[r]).Average();
     231          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
     232          // prepare a complete column for the dataset
     233          replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     234          // update column values
     235          foreach (var r in rows) {
     236            replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
     237          }
     238          break;
     239
     240        default:
     241          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
     242      }
     243
     244      return replacementValues;
     245    }
     246
     247    private static IList GetReplacementValuesForString(IRegressionModel model,
     248      ModifiableDataset modifiableDataset,
     249      string variableName,
     250      IEnumerable<int> rows,
    285251      IList originalValues,
     252      IEnumerable<double> targetValues,
     253      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Shuffle) {
     254
     255      IList replacementValues = null;
     256      IRandom random = new FastRandom(31415);
     257
     258      switch (factorReplacementMethod) {
     259        case FactorReplacementMethodEnum.Best:
     260          // try replacing with all possible values and find the best replacement value
     261          var bestQuality = double.NegativeInfinity;
     262          foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
     263            List<string> curReplacementValues = Enumerable.Repeat(repl, modifiableDataset.Rows).ToList();
     264            //fholzing: this result could be used later on (theoretically), but is neglected for better readability/method consistency
     265            var newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, replacementValues, targetValues);
     266            var curQuality = newValue;
     267
     268            if (curQuality > bestQuality) {
     269              bestQuality = curQuality;
     270              replacementValues = curReplacementValues;
     271            }
     272          }
     273          break;
     274        case FactorReplacementMethodEnum.Mode:
     275          var mostCommonValue = rows.Select(r => originalValues[r])
     276            .GroupBy(v => v)
     277            .OrderByDescending(g => g.Count())
     278            .First().Key;
     279          replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
     280          break;
     281        case FactorReplacementMethodEnum.Shuffle:
     282          // new var has same empirical distribution but the relation to y is broken
     283          // prepare a complete column for the dataset
     284          replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
     285          // shuffle only the selected rows
     286          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     287          int i = 0;
     288          // update column values
     289          foreach (var r in rows) {
     290            replacementValues[r] = shuffledValues[i++];
     291          }
     292          break;
     293        default:
     294          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
     295      }
     296
     297      return replacementValues;
     298    }
     299
     300    private static double CalculateQualityForReplacement(
    286301      IRegressionModel model,
     302      ModifiableDataset modifiableDataset,
    287303      string variableName,
    288       ModifiableDataset modifiableDataset,
    289       IEnumerable<int> rows,
    290       IList replacementValues) {
     304      IList originalValues,
     305      IEnumerable<int> rows,
     306      IList replacementValues,
     307      IEnumerable<double> targetValues) {
     308
    291309      modifiableDataset.ReplaceVariable(variableName, replacementValues);
    292310      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    293311      var estimates = model.GetEstimatedValues(modifiableDataset, rows).ToList();
     312      var ret = CalculateQuality(targetValues, estimates);
    294313      modifiableDataset.ReplaceVariable(variableName, originalValues);
    295314
    296       return estimates;
    297     }
    298 
    299     /// <summary>
    300     /// Calculates and returns the VariableImpact (calculated via Pearsons R²).
    301     /// </summary>
    302     /// <param name="targetValues">The actual values</param>
    303     /// <param name="estimatedValues">The calculated/replaced values</param>
    304     /// <param name="errorState"></param>
    305     /// <returns></returns>
    306     public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
    307       //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
    308       //as the code below does. But this way we can easily swap the calculator later on, so the user 
    309       //could choose a Calculator during runtime in future versions.
    310       IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator();
    311       IEnumerator<double> firstEnumerator = targetValues.GetEnumerator();
    312       IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
    313 
    314       // always move forward both enumerators (do not use short-circuit evaluation!)
    315       while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
    316         double original = firstEnumerator.Current;
    317         double estimated = secondEnumerator.Current;
    318         calculator.Add(original, estimated);
    319         if (calculator.ErrorState != OnlineCalculatorError.None) break;
    320       }
    321 
    322       // check if both enumerators are at the end to make sure both enumerations have the same length
    323       if (calculator.ErrorState == OnlineCalculatorError.None &&
    324            (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
    325         throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
    326       } else {
    327         errorState = calculator.ErrorState;
    328         return calculator.Value;
    329       }
    330     }
    331 
    332     /// <summary>
    333     /// Returns a collection of the row-indices for a given DataPartition (training or test)
    334     /// </summary>
    335     /// <param name="dataPartition"></param>
    336     /// <param name="problemData"></param>
    337     /// <returns></returns>
     315      return ret;
     316    }
     317
     318    public static double CalculateQuality(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues) {
     319      OnlineCalculatorError errorState;
     320      var ret = OnlinePearsonsRCalculator.Calculate(targetValues, estimatedValues, out errorState);
     321      if (errorState != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
     322      return ret * ret;
     323    }
     324
    338325    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IRegressionProblemData problemData) {
    339326      IEnumerable<int> rows;
Note: See TracChangeset for help on using the changeset viewer.