Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/06/18 17:35:11 (6 years ago)
Author:
fholzing
Message:

#2904: Also applied changes from regression to classification

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs

    r16041 r16055  
    9090    public ClassificationSolutionVariableImpactsCalculator()
    9191      : base() {
    92       Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
     92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle)));
    9393      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
    9494      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
     
    110110      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
    111111      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
    112       return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition);
    113     }
    114 
    115     public static IEnumerable<Tuple<string, double>> CalculateImpacts(
    116       IClassificationModel model,
    117       IClassificationProblemData problemData,
    118       IEnumerable<double> estimatedValues,
    119       ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    120       FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
    121       DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
    122       IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData);
    123       return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
     112
     113      IEnumerable<int> rows = GetPartitionRows(dataPartition, solution.ProblemData);
     114      IEnumerable<double> estimatedClassValues = solution.GetEstimatedClassValues(rows);
     115      return CalculateImpacts(solution.Model, solution.ProblemData, estimatedClassValues, rows, replacementMethod, factorReplacementMethod);
    124116    }
    125117
     
    131123     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    132124     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    133       //Calculate original quality-values (via calculator, default is Accuracy)   
    134       OnlineCalculatorError error;
    135       IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
    136       IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v));
    137       var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
    138       if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
     125
     126      //fholzing: try and catch in case a different dataset is loaded, otherwise statement is neglectable
     127      var missingVariables = model.VariablesUsedForPrediction.Except(problemData.Dataset.VariableNames);
     128      if (missingVariables.Any()) {
     129        throw new InvalidOperationException(string.Format("Can not calculate variable impacts, because the model uses inputs missing in the dataset ({0})", string.Join(", ", missingVariables)));
     130      }
     131      IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     132      var originalQuality = CalculateQuality(targetValues, estimatedClassValues);
    139133
    140134      var impacts = new Dictionary<string, double>();
    141135      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
    142       var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
    143136      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
    144137
    145       foreach (var inputVariable in allowedInputVariables) {
    146         if (model.VariablesUsedForPrediction.Contains(inputVariable)) {
    147           impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
    148         } else {
    149           impacts[inputVariable] = 0;
    150         }
    151       }
    152 
    153       return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
     138      foreach (var inputVariable in inputvariables) {
     139        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData, modifiableDataset, rows, replacementMethod, factorReplacementMethod, targetValues, originalQuality);
     140      }
     141
     142      return impacts.Select(i => Tuple.Create(i.Key, i.Value));
    154143    }
    155144
    156145    public static double CalculateImpact(string variableName,
    157146      IClassificationModel model,
     147      IClassificationProblemData problemData,
    158148      ModifiableDataset modifiableDataset,
    159149      IEnumerable<int> rows,
     150      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     151      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     152      IEnumerable<double> targetValues = null,
     153      double quality = double.NaN) {
     154
     155      if (!model.VariablesUsedForPrediction.Contains(variableName)) { return 0.0; }
     156      if (!problemData.Dataset.VariableNames.Contains(variableName)) {
     157        throw new InvalidOperationException(string.Format("Can not calculate variable impact, because the model uses inputs missing in the dataset ({0})", variableName));
     158      }
     159
     160      if (targetValues == null) {
     161        targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     162      }
     163      if (quality == double.NaN) {
     164        quality = CalculateQuality(model.GetEstimatedClassValues(modifiableDataset, rows), targetValues);
     165      }
     166
     167      IList originalValues = null;
     168      IList replacementValues = GetReplacementValues(modifiableDataset, variableName, model, rows, targetValues, out originalValues, replacementMethod, factorReplacementMethod);
     169
     170      double newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, replacementValues, targetValues);
     171      double impact = quality - newValue;
     172
     173      return impact;
     174    }
     175
     176    private static IList GetReplacementValues(ModifiableDataset modifiableDataset,
     177      string variableName,
     178      IClassificationModel model,
     179      IEnumerable<int> rows,
    160180      IEnumerable<double> targetValues,
    161       double originalValue,
     181      out IList originalValues,
    162182      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    163183      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    164       double impact = 0;
    165       OnlineCalculatorError error;
    166       IRandom random;
    167       double replacementValue;
    168       IEnumerable<double> newEstimates = null;
    169       double newValue = 0;
    170 
     184
     185      IList replacementValues = null;
    171186      if (modifiableDataset.VariableHasType<double>(variableName)) {
    172         #region NumericalVariable
    173         var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
    174         List<double> replacementValues;
    175 
    176         switch (replacementMethod) {
    177           case ReplacementMethodEnum.Median:
    178             replacementValue = rows.Select(r => originalValues[r]).Median();
    179             replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
    180             break;
    181           case ReplacementMethodEnum.Average:
    182             replacementValue = rows.Select(r => originalValues[r]).Average();
    183             replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
    184             break;
    185           case ReplacementMethodEnum.Shuffle:
    186             // new var has same empirical distribution but the relation to y is broken
    187             random = new FastRandom(31415);
    188             // prepare a complete column for the dataset
    189             replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
    190             // shuffle only the selected rows
    191             var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
    192             int i = 0;
    193             // update column values
    194             foreach (var r in rows) {
    195               replacementValues[r] = shuffledValues[i++];
    196             }
    197             break;
    198           case ReplacementMethodEnum.Noise:
    199             var avg = rows.Select(r => originalValues[r]).Average();
    200             var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
    201             random = new FastRandom(31415);
    202             // prepare a complete column for the dataset
    203             replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
    204             // update column values
    205             foreach (var r in rows) {
    206               replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
    207             }
    208             break;
    209 
    210           default:
    211             throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
    212         }
    213 
    214         newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
    215         newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    216         if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
    217 
    218         impact = originalValue - newValue;
    219         #endregion
     187        originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
     188        replacementValues = GetReplacementValuesForDouble(modifiableDataset, rows, (List<double>)originalValues, replacementMethod);
    220189      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
    221         #region FactorVariable
    222         var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
    223         List<string> replacementValues;
    224 
    225         switch (factorReplacementMethod) {
    226           case FactorReplacementMethodEnum.Best:
    227             // try replacing with all possible values and find the best replacement value
    228             var smallestImpact = double.PositiveInfinity;
    229             foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
    230               newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
    231               newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    232               if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    233 
    234               var curImpact = originalValue - newValue;
    235               if (curImpact < smallestImpact) smallestImpact = curImpact;
    236             }
    237             impact = smallestImpact;
    238             break;
    239           case FactorReplacementMethodEnum.Mode:
    240             var mostCommonValue = rows.Select(r => originalValues[r])
    241               .GroupBy(v => v)
    242               .OrderByDescending(g => g.Count())
    243               .First().Key;
    244             replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
    245 
    246             newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
    247             newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    248             if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    249 
    250             impact = originalValue - newValue;
    251             break;
    252           case FactorReplacementMethodEnum.Shuffle:
    253             // new var has same empirical distribution but the relation to y is broken
    254             random = new FastRandom(31415);
    255             // prepare a complete column for the dataset
    256             replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
    257             // shuffle only the selected rows
    258             var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
    259             int i = 0;
    260             // update column values
    261             foreach (var r in rows) {
    262               replacementValues[r] = shuffledValues[i++];
    263             }
    264 
    265             newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
    266             newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    267             if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    268 
    269             impact = originalValue - newValue;
    270             break;
    271           default:
    272             throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
    273         }
    274         #endregion
     190        originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
     191        replacementValues = GetReplacementValuesForString(model, modifiableDataset, variableName, rows, originalValues, targetValues, factorReplacementMethod);
    275192      } else {
    276193        throw new NotSupportedException("Variable not supported");
    277194      }
    278195
    279       return impact;
    280     }
    281 
    282     /// <summary>
    283     /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
    284     /// and changes the value of the model-variables back to the original ones.
    285     /// </summary>
    286     /// <param name="originalValues"></param>
    287     /// <param name="model"></param>
    288     /// <param name="variableName"></param>
    289     /// <param name="modifiableDataset"></param>
    290     /// <param name="rows"></param>
    291     /// <param name="replacementValues"></param>
    292     /// <returns></returns>
    293     private static IEnumerable<double> GetReplacedEstimates(
     196      return replacementValues;
     197    }
     198
     199    private static IList GetReplacementValuesForDouble(ModifiableDataset modifiableDataset,
     200      IEnumerable<int> rows,
     201      List<double> originalValues,
     202      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle) {
     203
     204      IRandom random = new FastRandom(31415);
     205      List<double> replacementValues;
     206      double replacementValue;
     207
     208      switch (replacementMethod) {
     209        case ReplacementMethodEnum.Median:
     210          replacementValue = rows.Select(r => originalValues[r]).Median();
     211          replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     212          break;
     213        case ReplacementMethodEnum.Average:
     214          replacementValue = rows.Select(r => originalValues[r]).Average();
     215          replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     216          break;
     217        case ReplacementMethodEnum.Shuffle:
     218          // new var has same empirical distribution but the relation to y is broken
     219          // prepare a complete column for the dataset
     220          replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     221          // shuffle only the selected rows
     222          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     223          int i = 0;
     224          // update column values
     225          foreach (var r in rows) {
     226            replacementValues[r] = shuffledValues[i++];
     227          }
     228          break;
     229        case ReplacementMethodEnum.Noise:
     230          var avg = rows.Select(r => originalValues[r]).Average();
     231          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
     232          // prepare a complete column for the dataset
     233          replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     234          // update column values
     235          foreach (var r in rows) {
     236            replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
     237          }
     238          break;
     239
     240        default:
     241          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
     242      }
     243
     244      return replacementValues;
     245    }
     246
     247    private static IList GetReplacementValuesForString(IClassificationModel model,
     248      ModifiableDataset modifiableDataset,
     249      string variableName,
     250      IEnumerable<int> rows,
    294251      IList originalValues,
     252      IEnumerable<double> targetValues,
     253      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Shuffle) {
     254
     255      IList replacementValues = null;
     256      IRandom random = new FastRandom(31415);
     257
     258      switch (factorReplacementMethod) {
     259        case FactorReplacementMethodEnum.Best:
     260          // try replacing with all possible values and find the best replacement value
     261          var bestQuality = double.NegativeInfinity;
     262          foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
     263            List<string> curReplacementValues = Enumerable.Repeat(repl, modifiableDataset.Rows).ToList();
     264            //fholzing: this result could be used later on (theoretically), but is neglected for better readability/method consistency
     265            var newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, curReplacementValues, targetValues);
     266            var curQuality = newValue;
     267
     268            if (curQuality > bestQuality) {
     269              bestQuality = curQuality;
     270              replacementValues = curReplacementValues;
     271            }
     272          }
     273          break;
     274        case FactorReplacementMethodEnum.Mode:
     275          var mostCommonValue = rows.Select(r => originalValues[r])
     276            .GroupBy(v => v)
     277            .OrderByDescending(g => g.Count())
     278            .First().Key;
     279          replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
     280          break;
     281        case FactorReplacementMethodEnum.Shuffle:
     282          // new var has same empirical distribution but the relation to y is broken
     283          // prepare a complete column for the dataset
     284          replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
     285          // shuffle only the selected rows
     286          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     287          int i = 0;
     288          // update column values
     289          foreach (var r in rows) {
     290            replacementValues[r] = shuffledValues[i++];
     291          }
     292          break;
     293        default:
     294          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
     295      }
     296
     297      return replacementValues;
     298    }
     299
     300    private static double CalculateQualityForReplacement(
    295301      IClassificationModel model,
     302      ModifiableDataset modifiableDataset,
    296303      string variableName,
    297       ModifiableDataset modifiableDataset,
    298       IEnumerable<int> rows,
    299       IList replacementValues) {
     304      IList originalValues,
     305      IEnumerable<int> rows,
     306      IList replacementValues,
     307      IEnumerable<double> targetValues) {
     308
    300309      modifiableDataset.ReplaceVariable(variableName, replacementValues);
    301 
    302310      var discModel = model as IDiscriminantFunctionClassificationModel;
    303311      if (discModel != null) {
     
    308316      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    309317      var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
     318      var ret = CalculateQuality(targetValues, estimates);
    310319      modifiableDataset.ReplaceVariable(variableName, originalValues);
    311320
    312       return estimates;
    313     }
    314 
    315     /// <summary>
    316     /// Calculates and returns the VariableImpact (calculated via Accuracy).
    317     /// </summary>
    318     /// <param name="targetValues">The actual values</param>
    319     /// <param name="estimatedValues">The calculated/replaced values</param>
    320     /// <param name="errorState"></param>
    321     /// <returns></returns>
    322     public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
    323       //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
    324       //as the code below does. But this way we can easily swap the calculator later on, so the user 
    325       //could choose a Calculator during runtime in future versions.
    326       IOnlineCalculator calculator = new OnlineAccuracyCalculator();
    327       IEnumerator<double> firstEnumerator = targetValues.GetEnumerator();
    328       IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
    329 
    330       // always move forward both enumerators (do not use short-circuit evaluation!)
    331       while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
    332         double original = firstEnumerator.Current;
    333         double estimated = secondEnumerator.Current;
    334         calculator.Add(original, estimated);
    335         if (calculator.ErrorState != OnlineCalculatorError.None) break;
    336       }
    337 
    338       // check if both enumerators are at the end to make sure both enumerations have the same length
    339       if (calculator.ErrorState == OnlineCalculatorError.None &&
    340            (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
    341         throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
    342       } else {
    343         errorState = calculator.ErrorState;
    344         return calculator.Value;
    345       }
    346     }
    347 
    348     /// <summary>
    349     /// Returns a collection of the row-indices for a given DataPartition (training or test)
    350     /// </summary>
    351     /// <param name="dataPartition"></param>
    352     /// <param name="problemData"></param>
    353     /// <returns></returns>
     321      return ret;
     322    }
     323
     324    public static double CalculateQuality(IEnumerable<double> targetValues, IEnumerable<double> estimatedClassValues) {
     325      OnlineCalculatorError errorState;
     326      var ret = OnlineAccuracyCalculator.Calculate(targetValues, estimatedClassValues, out errorState);
     327      if (errorState != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
     328      return ret;
     329    }
     330
    354331    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) {
    355332      IEnumerable<int> rows;
Note: See TracChangeset for help on using the changeset viewer.