Ignore:
Timestamp:
08/01/18 14:01:08 (3 years ago)
Author:
fholzing
Message:

#2904: Streamlined the variableimpactcalculator code on both Regression and Classification. Taken over the regression-code for classification with some minor adaptations.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs

    r16035 r16036  
    100100    #endregion
    101101
    102     #region Public Methods/Wrappers
    103102    //mkommend: annoying name clash with static method, open to better naming suggestions
    104103    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
     
    159158      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    160159      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    161 
    162160      double impact = 0;
    163 
    164       // calculate impacts for double variables
     161      OnlineCalculatorError error;
     162      IRandom random;
     163      double replacementValue;
     164      IEnumerable<double> newEstimates = null;
     165      double newValue = 0;
     166
    165167      if (modifiableDataset.VariableHasType<double>(variableName)) {
    166         impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
     168        #region NumericalVariable
     169        var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
     170        List<double> replacementValues;
     171
     172        switch (replacementMethod) {
     173          case ReplacementMethodEnum.Median:
     174            replacementValue = rows.Select(r => originalValues[r]).Median();
     175            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     176            break;
     177          case ReplacementMethodEnum.Average:
     178            replacementValue = rows.Select(r => originalValues[r]).Average();
     179            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     180            break;
     181          case ReplacementMethodEnum.Shuffle:
     182            // new var has same empirical distribution but the relation to y is broken
     183            random = new FastRandom(31415);
     184            // prepare a complete column for the dataset
     185            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     186            // shuffle only the selected rows
     187            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     188            int i = 0;
     189            // update column values
     190            foreach (var r in rows) {
     191              replacementValues[r] = shuffledValues[i++];
     192            }
     193            break;
     194          case ReplacementMethodEnum.Noise:
     195            var avg = rows.Select(r => originalValues[r]).Average();
     196            var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
     197            random = new FastRandom(31415);
     198            // prepare a complete column for the dataset
     199            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     200            // update column values
     201            foreach (var r in rows) {
     202              replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
     203            }
     204            break;
     205
     206          default:
     207            throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
     208        }
     209
     210        newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     211        newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     212        if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
     213
     214        impact = originalValue - newValue;
     215        #endregion
    167216      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
    168         impact = CalculateImpactForFactorVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
     217        #region FactorVariable
     218        var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
     219        List<string> replacementValues;
     220
     221        switch (factorReplacementMethod) {
     222          case FactorReplacementMethodEnum.Best:
     223            // try replacing with all possible values and find the best replacement value
     224            var smallestImpact = double.PositiveInfinity;
     225            foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
     226              newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
     227              newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     228              if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     229
     230              var curImpact = originalValue - newValue;
     231              if (curImpact < smallestImpact) smallestImpact = curImpact;
     232            }
     233            impact = smallestImpact;
     234            break;
     235          case FactorReplacementMethodEnum.Mode:
     236            var mostCommonValue = rows.Select(r => originalValues[r])
     237              .GroupBy(v => v)
     238              .OrderByDescending(g => g.Count())
     239              .First().Key;
     240            replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
     241
     242            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     243            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     244            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     245
     246            impact = originalValue - newValue;
     247            break;
     248          case FactorReplacementMethodEnum.Shuffle:
     249            // new var has same empirical distribution but the relation to y is broken
     250            random = new FastRandom(31415);
     251            // prepare a complete column for the dataset
     252            replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
     253            // shuffle only the selected rows
     254            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     255            int i = 0;
     256            // update column values
     257            foreach (var r in rows) {
     258              replacementValues[r] = shuffledValues[i++];
     259            }
     260
     261            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     262            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     263            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     264
     265            impact = originalValue - newValue;
     266            break;
     267          default:
     268            throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
     269        }
     270        #endregion
    169271      } else {
    170272        throw new NotSupportedException("Variable not supported");
    171273      }
     274
    172275      return impact;
    173276    }
    174     #endregion
    175 
    176     private static double CalculateImpactForNumericalVariables(string variableName,
    177       IRegressionModel model,
    178       ModifiableDataset modifiableDataset,
    179       IEnumerable<int> rows,
    180       IEnumerable<double> targetValues,
    181       double originalValue,
    182       ReplacementMethodEnum replacementMethod) {
    183       OnlineCalculatorError error;
    184       var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod);
    185       var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    186       if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
    187       return originalValue - newValue;
    188     }
    189 
    190     private static double CalculateImpactForFactorVariables(string variableName,
    191       IRegressionModel model,
    192       ModifiableDataset modifiableDataset,
    193       IEnumerable<int> rows,
    194       IEnumerable<double> targetValues,
    195       double originalValue,
    196       FactorReplacementMethodEnum factorReplacementMethod) {
    197 
    198       OnlineCalculatorError error;
    199       if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
    200         // try replacing with all possible values and find the best replacement value
    201         var smallestImpact = double.PositiveInfinity;
    202         foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
    203           var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
    204           var newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
    205           var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    206           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    207 
    208           var curImpact = originalValue - newValue;
    209           if (curImpact < smallestImpact) smallestImpact = curImpact;
    210         }
    211         return smallestImpact;
    212       } else {
    213         // for replacement methods shuffle and mode
    214         // calculate impacts for factor variables
    215         var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod);
    216         var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    217         if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    218 
    219         return originalValue - newValue;
    220       }
    221     }
    222 
    223     private static IEnumerable<double> GetReplacedValuesForNumericalVariables(
    224       IRegressionModel model,
    225       string variable,
    226       ModifiableDataset dataset,
    227       IEnumerable<int> rows,
    228       ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) {
    229       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    230       double replacementValue;
    231       List<double> replacementValues;
    232       IRandom rand;
    233 
    234       switch (replacement) {
    235         case ReplacementMethodEnum.Median:
    236           replacementValue = rows.Select(r => originalValues[r]).Median();
    237           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
    238           break;
    239         case ReplacementMethodEnum.Average:
    240           replacementValue = rows.Select(r => originalValues[r]).Average();
    241           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
    242           break;
    243         case ReplacementMethodEnum.Shuffle:
    244           // new var has same empirical distribution but the relation to y is broken
    245           rand = new FastRandom(31415);
    246           // prepare a complete column for the dataset
    247           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
    248           // shuffle only the selected rows
    249           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
    250           int i = 0;
    251           // update column values
    252           foreach (var r in rows) {
    253             replacementValues[r] = shuffledValues[i++];
    254           }
    255           break;
    256         case ReplacementMethodEnum.Noise:
    257           var avg = rows.Select(r => originalValues[r]).Average();
    258           var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
    259           rand = new FastRandom(31415);
    260           // prepare a complete column for the dataset
    261           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
    262           // update column values
    263           foreach (var r in rows) {
    264             replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
    265           }
    266           break;
    267 
    268         default:
    269           throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
    270       }
    271 
    272       return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues);
    273     }
    274 
    275     private static IEnumerable<double> GetReplacedValuesForFactorVariables(
    276       IRegressionModel model,
    277       string variable,
    278       ModifiableDataset dataset,
    279       IEnumerable<int> rows,
    280       FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
    281       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    282       List<string> replacementValues;
    283       IRandom rand;
    284 
    285       switch (replacement) {
    286         case FactorReplacementMethodEnum.Mode:
    287           var mostCommonValue = rows.Select(r => originalValues[r])
    288             .GroupBy(v => v)
    289             .OrderByDescending(g => g.Count())
    290             .First().Key;
    291           replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
    292           break;
    293         case FactorReplacementMethodEnum.Shuffle:
    294           // new var has same empirical distribution but the relation to y is broken
    295           rand = new FastRandom(31415);
    296           // prepare a complete column for the dataset
    297           replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
    298           // shuffle only the selected rows
    299           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
    300           int i = 0;
    301           // update column values
    302           foreach (var r in rows) {
    303             replacementValues[r] = shuffledValues[i++];
    304           }
    305           break;
    306         default:
    307           throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
    308       }
    309 
    310       return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues);
    311     }
    312 
     277
     278    /// <summary>
     279    /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
     280    /// and changes the value of the model-variables back to the original ones.
     281    /// </summary>
     282    /// <param name="originalValues"></param>
     283    /// <param name="model"></param>
     284    /// <param name="variableName"></param>
     285    /// <param name="modifiableDataset"></param>
     286    /// <param name="rows"></param>
     287    /// <param name="replacementValues"></param>
     288    /// <returns></returns>
    313289    private static IEnumerable<double> GetReplacedEstimates(
    314290      IList originalValues,
    315291      IRegressionModel model,
    316       string variable,
    317       ModifiableDataset dataset,
     292      string variableName,
     293      ModifiableDataset modifiableDataset,
    318294      IEnumerable<int> rows,
    319295      IList replacementValues) {
    320       dataset.ReplaceVariable(variable, replacementValues);
     296      modifiableDataset.ReplaceVariable(variableName, replacementValues);
    321297      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    322       var estimates = model.GetEstimatedValues(dataset, rows).ToList();
    323       dataset.ReplaceVariable(variable, originalValues);
     298      var estimates = model.GetEstimatedValues(modifiableDataset, rows).ToList();
     299      modifiableDataset.ReplaceVariable(variableName, originalValues);
    324300
    325301      return estimates;
    326302    }
    327303
    328     public static double CalculateVariableImpact(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
    329       IEnumerator<double> firstEnumerator = originalValues.GetEnumerator();
     304    /// <summary>
     305    /// Calculates and returns the VariableImpact (calculated via Pearsons R²).
     306    /// </summary>
     307    /// <param name="targetValues">The actual values</param>
     308    /// <param name="estimatedValues">The calculated/replaced values</param>
     309    /// <param name="errorState"></param>
     310    /// <returns></returns>
     311    public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
     312      //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
     313      //as the code below does. But this way we can easily swap the calculator later on, so the user 
     314      //could choose a Calculator during runtime in future versions.
     315      IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator();
     316      IEnumerator<double> firstEnumerator = targetValues.GetEnumerator();
    330317      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
    331       var calculator = new OnlinePearsonsRSquaredCalculator();
    332318
    333319      // always move forward both enumerators (do not use short-circuit evaluation!)
     
    349335    }
    350336
     337    /// <summary>
     338    /// Returns a collection of the row-indices for a given DataPartition (training or test)
     339    /// </summary>
     340    /// <param name="dataPartition"></param>
     341    /// <param name="problemData"></param>
     342    /// <returns></returns>
    351343    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IRegressionProblemData problemData) {
    352344      IEnumerable<int> rows;
Note: See TracChangeset for help on using the changeset viewer.