Ignore:
Timestamp:
08/01/18 14:01:08 (3 years ago)
Author:
fholzing
Message:

#2904: Streamlined the variableimpactcalculator code on both Regression and Classification. Taken over the regression-code for classification with some minor adaptations.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs

    r15674 r16036  
    2323
    2424using System;
     25using System.Collections;
    2526using System.Collections.Generic;
    2627using System.Linq;
     
    3637  [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")]
    3738  public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem {
     39    #region Parameters/Properties
    3840    public enum ReplacementMethodEnum {
    3941      Median,
     
    5456
    5557    private const string ReplacementParameterName = "Replacement Method";
     58    private const string FactorReplacementParameterName = "Factor Replacement Method";
    5659    private const string DataPartitionParameterName = "DataPartition";
    5760
    5861    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
    5962      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
     63    }
     64    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter {
     65      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
    6066    }
    6167    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
     
    6773      set { ReplacementParameter.Value.Value = value; }
    6874    }
     75    public FactorReplacementMethodEnum FactorReplacementMethod {
     76      get { return FactorReplacementParameter.Value.Value; }
     77      set { FactorReplacementParameter.Value.Value = value; }
     78    }
    6979    public DataPartitionEnum DataPartition {
    7080      get { return DataPartitionParameter.Value.Value; }
    7181      set { DataPartitionParameter.Value.Value = value; }
    7282    }
    73 
    74 
     83    #endregion
     84
     85    #region Ctor/Cloner
    7586    [StorableConstructor]
    7687    private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
    7788    private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner)
    7889      : base(original, cloner) { }
     90    public ClassificationSolutionVariableImpactsCalculator()
     91      : base() {
     92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle)));
     93      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
     94    }
     95
    7996    public override IDeepCloneable Clone(Cloner cloner) {
    8097      return new ClassificationSolutionVariableImpactsCalculator(this, cloner);
    8198    }
    82 
    83     public ClassificationSolutionVariableImpactsCalculator()
    84       : base() {
    85       Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
    86       Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
    87     }
     99    #endregion
    88100
    89101    //mkommend: annoying name clash with static method, open to better naming suggestions
    90102    public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) {
    91       return CalculateImpacts(solution, DataPartition, ReplacementMethod);
     103      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
    92104    }
    93105
    94106    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
    95107      IClassificationSolution solution,
    96       DataPartitionEnum data = DataPartitionEnum.Training,
    97       ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     108      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     109      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     110      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
     111      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition);
     112    }
     113
     114    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
     115      IClassificationModel model,
     116      IClassificationProblemData problemData,
     117      IEnumerable<double> estimatedValues,
     118      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     119      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     120      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
     121      IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData);
     122      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
     123    }
     124
     125
     126    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
     127     IClassificationModel model,
     128     IClassificationProblemData problemData,
     129     IEnumerable<double> estimatedClassValues,
     130     IEnumerable<int> rows,
     131     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     132     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
     133      //Calculate original quality-values (via calculator, default is Accuracy)
     134      OnlineCalculatorError error;
     135      IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     136      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v));
     137      var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
     138      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
     139
     140      var impacts = new Dictionary<string, double>();
     141      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
     142      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
     143      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
     144
     145      foreach (var inputVariable in allowedInputVariables) {
     146        impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
     147      }
     148
     149      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
     150    }
     151
     152
     153    public static double CalculateImpact(string variableName,
     154      IClassificationModel model,
     155      ModifiableDataset modifiableDataset,
     156      IEnumerable<int> rows,
     157      IEnumerable<double> targetValues,
     158      double originalValue,
     159      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    98160      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    99 
    100       var problemData = solution.ProblemData;
    101       var dataset = problemData.Dataset;
    102 
     161      double impact = 0;
     162      OnlineCalculatorError error;
     163      IRandom random;
     164      double replacementValue;
     165      IEnumerable<double> newEstimates = null;
     166      double newValue = 0;
     167
     168      if (modifiableDataset.VariableHasType<double>(variableName)) {
     169        #region NumericalVariable
     170        var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
     171        List<double> replacementValues;
     172        IRandom rand;
     173
     174        switch (replacementMethod) {
     175          case ReplacementMethodEnum.Median:
     176            replacementValue = rows.Select(r => originalValues[r]).Median();
     177            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     178            break;
     179          case ReplacementMethodEnum.Average:
     180            replacementValue = rows.Select(r => originalValues[r]).Average();
     181            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     182            break;
     183          case ReplacementMethodEnum.Shuffle:
     184            // new var has same empirical distribution but the relation to y is broken
     185            rand = new FastRandom(31415);
     186            // prepare a complete column for the dataset
     187            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     188            // shuffle only the selected rows
     189            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
     190            int i = 0;
     191            // update column values
     192            foreach (var r in rows) {
     193              replacementValues[r] = shuffledValues[i++];
     194            }
     195            break;
     196          case ReplacementMethodEnum.Noise:
     197            var avg = rows.Select(r => originalValues[r]).Average();
     198            var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
     199            rand = new FastRandom(31415);
     200            // prepare a complete column for the dataset
     201            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     202            // update column values
     203            foreach (var r in rows) {
     204              replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
     205            }
     206            break;
     207
     208          default:
     209            throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
     210        }
     211
     212        newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     213        newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     214        if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
     215
     216        impact = originalValue - newValue;
     217        #endregion
     218      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
     219        #region FactorVariable
     220        var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
     221        List<string> replacementValues;
     222
     223        switch (factorReplacementMethod) {
     224          case FactorReplacementMethodEnum.Best:
     225            // try replacing with all possible values and find the best replacement value
     226            var smallestImpact = double.PositiveInfinity;
     227            foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
     228              newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
     229              newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     230              if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     231
     232              var curImpact = originalValue - newValue;
     233              if (curImpact < smallestImpact) smallestImpact = curImpact;
     234            }
     235            impact = smallestImpact;
     236            break;
     237          case FactorReplacementMethodEnum.Mode:
     238            var mostCommonValue = rows.Select(r => originalValues[r])
     239              .GroupBy(v => v)
     240              .OrderByDescending(g => g.Count())
     241              .First().Key;
     242            replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
     243
     244            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     245            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     246            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     247
     248            impact = originalValue - newValue;
     249            break;
     250          case FactorReplacementMethodEnum.Shuffle:
     251            // new var has same empirical distribution but the relation to y is broken
     252            random = new FastRandom(31415);
     253            // prepare a complete column for the dataset
     254            replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
     255            // shuffle only the selected rows
     256            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     257            int i = 0;
     258            // update column values
     259            foreach (var r in rows) {
     260              replacementValues[r] = shuffledValues[i++];
     261            }
     262
     263            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     264            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     265            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     266
     267            impact = originalValue - newValue;
     268            break;
     269          default:
     270            throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
     271        }
     272        #endregion
     273      } else {
     274        throw new NotSupportedException("Variable not supported");
     275      }
     276
     277      return impact;
     278    }
     279
     280    /// <summary>
     281    /// Calculates and returns the VariableImpact (calculated via Accuracy).
     282    /// </summary>
     283    /// <param name="targetValues">The actual values</param>
     284    /// <param name="estimatedValues">The calculated/replaced values</param>
     285    /// <param name="errorState"></param>
     286    /// <returns></returns>
     287    public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
     288      //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
     289      //as the code below does. But this way we can easily swap the calculator later on, so the user 
     290      //could choose a Calculator during runtime in future versions.
     291      IOnlineCalculator calculator = new OnlineAccuracyCalculator();
     292      IEnumerator<double> firstEnumerator = targetValues.GetEnumerator();
     293      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
     294
     295      // always move forward both enumerators (do not use short-circuit evaluation!)
     296      while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
     297        double original = firstEnumerator.Current;
     298        double estimated = secondEnumerator.Current;
     299        calculator.Add(original, estimated);
     300        if (calculator.ErrorState != OnlineCalculatorError.None) break;
     301      }
     302
     303      // check if both enumerators are at the end to make sure both enumerations have the same length
     304      if (calculator.ErrorState == OnlineCalculatorError.None &&
     305           (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
     306        throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
     307      } else {
     308        errorState = calculator.ErrorState;
     309        return calculator.Value;
     310      }
     311    }
     312
     313    /// <summary>
     314    /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
     315    /// and changes the value of the model-variables back to the original ones.
     316    /// </summary>
     317    /// <param name="originalValues"></param>
     318    /// <param name="model"></param>
     319    /// <param name="variableName"></param>
     320    /// <param name="modifiableDataset"></param>
     321    /// <param name="rows"></param>
     322    /// <param name="replacementValues"></param>
     323    /// <returns></returns>
     324    private static IEnumerable<double> GetReplacedEstimates(
     325     IList originalValues,
     326     IClassificationModel model,
     327     string variableName,
     328     ModifiableDataset modifiableDataset,
     329     IEnumerable<int> rows,
     330     IList replacementValues) {
     331      modifiableDataset.ReplaceVariable(variableName, replacementValues);
     332
     333      var discModel = model as IDiscriminantFunctionClassificationModel;
     334      if (discModel != null) {
     335        var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);
     336        discModel.RecalculateModelParameters(problemData, rows);
     337      }
     338
     339      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
     340      var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
     341      modifiableDataset.ReplaceVariable(variableName, originalValues);
     342
     343      return estimates;
     344    }
     345
     346
     347    /// <summary>
     348    /// Returns a collection of the row-indices for a given DataPartition (training or test)
     349    /// </summary>
     350    /// <param name="dataPartition"></param>
     351    /// <param name="problemData"></param>
     352    /// <returns></returns>
     353    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) {
    103354      IEnumerable<int> rows;
    104       IEnumerable<double> targetValues;
    105       double originalAccuracy;
    106 
    107       OnlineCalculatorError error;
    108 
    109       switch (data) {
     355
     356      switch (dataPartition) {
    110357        case DataPartitionEnum.All:
    111358          rows = problemData.AllIndices;
    112           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList();
    113           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error);
    114           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
     359          break;
     360        case DataPartitionEnum.Test:
     361          rows = problemData.TestIndices;
    115362          break;
    116363        case DataPartitionEnum.Training:
    117364          rows = problemData.TrainingIndices;
    118           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList();
    119           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error);
    120           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
    121           break;
    122         case DataPartitionEnum.Test:
    123           rows = problemData.TestIndices;
    124           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList();
    125           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error);
    126           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
    127           break;
    128         default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
    129       }
    130 
    131       var impacts = new Dictionary<string, double>();
    132       var modifiableDataset = ((Dataset)dataset).ToModifiable();
    133 
    134       var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
    135       var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
    136 
    137       // calculate impacts for double variables
    138       foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
    139         var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);
    140         var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    141         if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
    142 
    143         impacts[inputVariable] = originalAccuracy - newAccuracy;
    144       }
    145 
    146       // calculate impacts for string variables
    147       foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
    148         if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
    149           // try replacing with all possible values and find the best replacement value
    150           var smallestImpact = double.PositiveInfinity;
    151           foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
    152             var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
    153               Enumerable.Repeat(repl, dataset.Rows));
    154             var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    155             if (error != OnlineCalculatorError.None)
    156               throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
    157 
    158             var impact = originalAccuracy - newAccuracy;
    159             if (impact < smallestImpact) smallestImpact = impact;
    160           }
    161           impacts[inputVariable] = smallestImpact;
    162         } else {
    163           // for replacement methods shuffle and mode
    164           // calculate impacts for factor variables
    165 
    166           var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
    167             factorReplacementMethod);
    168           var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    169           if (error != OnlineCalculatorError.None)
    170             throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
    171 
    172           impacts[inputVariable] = originalAccuracy - newAccuracy;
    173         }
    174       } // foreach
    175       return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
    176     }
    177 
    178     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
    179       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    180       double replacementValue;
    181       List<double> replacementValues;
    182       IRandom rand;
    183 
    184       switch (replacement) {
    185         case ReplacementMethodEnum.Median:
    186           replacementValue = rows.Select(r => originalValues[r]).Median();
    187           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
    188           break;
    189         case ReplacementMethodEnum.Average:
    190           replacementValue = rows.Select(r => originalValues[r]).Average();
    191           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
    192           break;
    193         case ReplacementMethodEnum.Shuffle:
    194           // new var has same empirical distribution but the relation to y is broken
    195           rand = new FastRandom(31415);
    196           // prepare a complete column for the dataset
    197           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
    198           // shuffle only the selected rows
    199           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
    200           int i = 0;
    201           // update column values
    202           foreach (var r in rows) {
    203             replacementValues[r] = shuffledValues[i++];
    204           }
    205           break;
    206         case ReplacementMethodEnum.Noise:
    207           var avg = rows.Select(r => originalValues[r]).Average();
    208           var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
    209           rand = new FastRandom(31415);
    210           // prepare a complete column for the dataset
    211           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
    212           // update column values
    213           foreach (var r in rows) {
    214             replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
    215           }
    216           break;
    217 
    218         default:
    219           throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
    220       }
    221 
    222       return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
    223     }
    224 
    225     private static IEnumerable<double> EvaluateModelWithReplacedVariable(
    226       IClassificationModel model, string variable, ModifiableDataset dataset,
    227       IEnumerable<int> rows,
    228       FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
    229       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    230       List<string> replacementValues;
    231       IRandom rand;
    232 
    233       switch (replacement) {
    234         case FactorReplacementMethodEnum.Mode:
    235           var mostCommonValue = rows.Select(r => originalValues[r])
    236             .GroupBy(v => v)
    237             .OrderByDescending(g => g.Count())
    238             .First().Key;
    239           replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
    240           break;
    241         case FactorReplacementMethodEnum.Shuffle:
    242           // new var has same empirical distribution but the relation to y is broken
    243           rand = new FastRandom(31415);
    244           // prepare a complete column for the dataset
    245           replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
    246           // shuffle only the selected rows
    247           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
    248           int i = 0;
    249           // update column values
    250           foreach (var r in rows) {
    251             replacementValues[r] = shuffledValues[i++];
    252           }
    253365          break;
    254366        default:
    255           throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
    256       }
    257 
    258       return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
    259     }
    260 
    261     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
    262       ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
    263       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    264       dataset.ReplaceVariable(variable, replacementValues.ToList());
    265       //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    266       var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
    267       dataset.ReplaceVariable(variable, originalValues);
    268 
    269       return estimates;
    270     }
    271     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
    272       ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
    273       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    274       dataset.ReplaceVariable(variable, replacementValues.ToList());
    275       //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    276       var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
    277       dataset.ReplaceVariable(variable, originalValues);
    278 
    279       return estimates;
    280     }
     367          throw new NotSupportedException("DataPartition not supported");
     368      }
     369
     370      return rows;
     371    }
     372
    281373  }
    282374}
Note: See TracChangeset for help on using the changeset viewer.