Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/17/19 14:36:59 (5 years ago)
Author:
chaider
Message:

#2971 merged DataAnalysis.Views from trunk to branch

Location:
branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis

  • branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis/3.4

    • Property svn:mergeinfo set to (toggle deleted branches)
      /branches/2839_HiveProjectManagement/HeuristicLab.Problems.DataAnalysis/3.4mergedeligible
      /branches/2915-AbsoluteSymbol/HeuristicLab.Problems.DataAnalysis/3.4mergedeligible
      /branches/2947_ConfigurableIndexedDataTable/HeuristicLab.Problems.DataAnalysis/3.4mergedeligible
      /branches/2965_CancelablePersistence/HeuristicLab.Problems.DataAnalysis/3.4mergedeligible
      /stable/HeuristicLab.Problems.DataAnalysis/3.4mergedeligible
      /trunk/HeuristicLab.Problems.DataAnalysis/3.4mergedeligible
      /branches/2892_LR-prediction-intervals/HeuristicLab.Problems.DataAnalysis/3.415743-16388
      /branches/2904_CalculateImpacts/3.415808-16421
      /branches/2966_interval_calculation/HeuristicLab.Problems.DataAnalysis/3.416320-16406
      /branches/Async/HeuristicLab.Problems.DataAnalysis/3.413329-15286
      /branches/Classification-Extensions/HeuristicLab.Problems.DataAnalysis/3.411606-11761
      /branches/ClassificationModelComparison/HeuristicLab.Problems.DataAnalysis/3.49073-13099
      /branches/CloningRefactoring/HeuristicLab.Problems.DataAnalysis/3.44656-4721
      /branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.45471-5808
      /branches/DataAnalysis SolutionEnsembles/HeuristicLab.Problems.DataAnalysis/3.45815-6180
      /branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis/3.44220,​4226,​4236-4238,​4389,​4458-4459,​4462,​4464
      /branches/DataAnalysisCSVImport/HeuristicLab.Problems.DataAnalysis/3.48713-8875
      /branches/DataPreprocessing/HeuristicLab.Problems.DataAnalysis/3.410085-11101
      /branches/DatasetFeatureCorrelation/HeuristicLab.Problems.DataAnalysis/3.48035-8538
      /branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.46284-6795
      /branches/GP.Symbols (TimeLag, Diff, Integral)/HeuristicLab.Problems.DataAnalysis/3.45060
      /branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Problems.DataAnalysis/3.411570-12508
      /branches/HeuristicLab.Problems.Orienteering/HeuristicLab.Problems.DataAnalysis/3.411130-12721
      /branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Problems.DataAnalysis/3.413819-14091
      /branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.47098-8789
      /branches/LogResidualEvaluator/HeuristicLab.Problems.DataAnalysis/3.410202-10483
      /branches/NET40/sources/HeuristicLab.Problems.DataAnalysis/3.45138-5162
      /branches/ParallelEngine/HeuristicLab.Problems.DataAnalysis/3.45175-5192
      /branches/ProblemInstancesRegressionAndClassification/HeuristicLab.Problems.DataAnalysis/3.47570-7810
      /branches/QAPAlgorithms/HeuristicLab.Problems.DataAnalysis/3.46350-6627
      /branches/Restructure trunk solution/HeuristicLab.Problems.DataAnalysis/3.46828
      /branches/SimplifierViewsProgress/HeuristicLab.Problems.DataAnalysis/3.415318-15370
      /branches/SpectralKernelForGaussianProcesses/HeuristicLab.Problems.DataAnalysis/3.410204-10479
      /branches/Trunk/HeuristicLab.Problems.DataAnalysis/3.46829-6865
      /branches/histogram/HeuristicLab.Problems.DataAnalysis/3.45959-6341
      /branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis/3.414232-14825
  • branches/2971_named_intervals/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs

    r15871 r16536  
    2323
    2424using System;
     25using System.Collections;
    2526using System.Collections.Generic;
    2627using System.Linq;
     
    3637  [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")]
    3738  public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem {
     39    #region Parameters/Properties
    3840    public enum ReplacementMethodEnum {
    3941      Median,
     
    5456
    5557    private const string ReplacementParameterName = "Replacement Method";
     58    private const string FactorReplacementParameterName = "Factor Replacement Method";
    5659    private const string DataPartitionParameterName = "DataPartition";
    5760
    5861    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
    5962      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
     63    }
     64    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter {
     65      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
    6066    }
    6167    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
     
    6773      set { ReplacementParameter.Value.Value = value; }
    6874    }
     75    public FactorReplacementMethodEnum FactorReplacementMethod {
     76      get { return FactorReplacementParameter.Value.Value; }
     77      set { FactorReplacementParameter.Value.Value = value; }
     78    }
    6979    public DataPartitionEnum DataPartition {
    7080      get { return DataPartitionParameter.Value.Value; }
    7181      set { DataPartitionParameter.Value.Value = value; }
    7282    }
    73 
    74 
     83    #endregion
     84
     85    #region Ctor/Cloner
    7586    [StorableConstructor]
    7687    private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
    7788    private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner)
    7889      : base(original, cloner) { }
     90    public ClassificationSolutionVariableImpactsCalculator()
     91      : base() {
     92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle)));
     93      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
     94      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
     95    }
     96
    7997    public override IDeepCloneable Clone(Cloner cloner) {
    8098      return new ClassificationSolutionVariableImpactsCalculator(this, cloner);
    8199    }
    82 
    83     public ClassificationSolutionVariableImpactsCalculator()
    84       : base() {
    85       Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
    86       Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
    87     }
     100    #endregion
    88101
    89102    //mkommend: annoying name clash with static method, open to better naming suggestions
    90103    public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) {
    91       return CalculateImpacts(solution, DataPartition, ReplacementMethod);
     104      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
    92105    }
    93106
    94107    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
    95108      IClassificationSolution solution,
    96       DataPartitionEnum data = DataPartitionEnum.Training,
    97       ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     109      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     110      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     111      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
     112
     113      IEnumerable<int> rows = GetPartitionRows(dataPartition, solution.ProblemData);
     114      IEnumerable<double> estimatedClassValues = solution.GetEstimatedClassValues(rows);
     115      var model = (IClassificationModel)solution.Model.Clone(); //mkommend: clone of model is necessary, because the thresholds for IDiscriminantClassificationModels are updated
     116
     117      return CalculateImpacts(model, solution.ProblemData, estimatedClassValues, rows, replacementMethod, factorReplacementMethod);
     118    }
     119
     120    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
     121     IClassificationModel model,
     122     IClassificationProblemData problemData,
     123     IEnumerable<double> estimatedClassValues,
     124     IEnumerable<int> rows,
     125     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     126     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
     127
     128      //fholzing: try and catch in case a different dataset is loaded, otherwise statement is neglectable
     129      var missingVariables = model.VariablesUsedForPrediction.Except(problemData.Dataset.VariableNames);
     130      if (missingVariables.Any()) {
     131        throw new InvalidOperationException(string.Format("Can not calculate variable impacts, because the model uses inputs missing in the dataset ({0})", string.Join(", ", missingVariables)));
     132      }
     133      IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     134      var originalQuality = CalculateQuality(targetValues, estimatedClassValues);
     135
     136      var impacts = new Dictionary<string, double>();
     137      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
     138      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
     139
     140      foreach (var inputVariable in inputvariables) {
     141        impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData, modifiableDataset, rows, replacementMethod, factorReplacementMethod, targetValues, originalQuality);
     142      }
     143
     144      return impacts.Select(i => Tuple.Create(i.Key, i.Value));
     145    }
     146
     147    public static double CalculateImpact(string variableName,
     148      IClassificationModel model,
     149      IClassificationProblemData problemData,
     150      ModifiableDataset modifiableDataset,
     151      IEnumerable<int> rows,
     152      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     153      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     154      IEnumerable<double> targetValues = null,
     155      double quality = double.NaN) {
     156
     157      if (!model.VariablesUsedForPrediction.Contains(variableName)) { return 0.0; }
     158      if (!problemData.Dataset.VariableNames.Contains(variableName)) {
     159        throw new InvalidOperationException(string.Format("Can not calculate variable impact, because the model uses inputs missing in the dataset ({0})", variableName));
     160      }
     161
     162      if (targetValues == null) {
     163        targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     164      }
     165      if (quality == double.NaN) {
     166        quality = CalculateQuality(model.GetEstimatedClassValues(modifiableDataset, rows), targetValues);
     167      }
     168
     169      IList originalValues = null;
     170      IList replacementValues = GetReplacementValues(modifiableDataset, variableName, model, rows, targetValues, out originalValues, replacementMethod, factorReplacementMethod);
     171
     172      double newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, replacementValues, targetValues);
     173      double impact = quality - newValue;
     174
     175      return impact;
     176    }
     177
     178    private static IList GetReplacementValues(ModifiableDataset modifiableDataset,
     179      string variableName,
     180      IClassificationModel model,
     181      IEnumerable<int> rows,
     182      IEnumerable<double> targetValues,
     183      out IList originalValues,
     184      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    98185      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    99186
    100       var problemData = solution.ProblemData;
    101       var dataset = problemData.Dataset;
    102       var model = (IClassificationModel)solution.Model.Clone(); //mkommend: clone of model is necessary, because the thresholds for IDiscriminantClassificationModels are updated
    103 
    104       IEnumerable<int> rows;
    105       IEnumerable<double> targetValues;
    106       double originalAccuracy;
    107 
    108       OnlineCalculatorError error;
    109 
    110       switch (data) {
    111         case DataPartitionEnum.All:
    112           rows = problemData.AllIndices;
    113           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList();
    114           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error);
    115           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
    116           break;
    117         case DataPartitionEnum.Training:
    118           rows = problemData.TrainingIndices;
    119           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList();
    120           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error);
    121           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
    122           break;
    123         case DataPartitionEnum.Test:
    124           rows = problemData.TestIndices;
    125           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList();
    126           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error);
    127           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
    128           break;
    129         default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
    130       }
    131 
    132       var impacts = new Dictionary<string, double>();
    133       var modifiableDataset = ((Dataset)dataset).ToModifiable();
    134 
    135       var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
    136       var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
    137 
    138       // calculate impacts for double variables
    139       foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
    140         var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows, replacementMethod);
    141         var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    142         if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
    143 
    144         impacts[inputVariable] = originalAccuracy - newAccuracy;
    145       }
    146 
    147       // calculate impacts for string variables
    148       foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
    149         if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
    150           // try replacing with all possible values and find the best replacement value
    151           var smallestImpact = double.PositiveInfinity;
    152           foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
    153             var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows,
    154               Enumerable.Repeat(repl, dataset.Rows));
    155             var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    156             if (error != OnlineCalculatorError.None)
    157               throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
    158 
    159             var impact = originalAccuracy - newAccuracy;
    160             if (impact < smallestImpact) smallestImpact = impact;
    161           }
    162           impacts[inputVariable] = smallestImpact;
    163         } else {
    164           // for replacement methods shuffle and mode
    165           // calculate impacts for factor variables
    166 
    167           var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows,
    168             factorReplacementMethod);
    169           var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    170           if (error != OnlineCalculatorError.None)
    171             throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
    172 
    173           impacts[inputVariable] = originalAccuracy - newAccuracy;
    174         }
    175       } // foreach
    176       return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
    177     }
    178 
    179     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
    180       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
     187      IList replacementValues = null;
     188      if (modifiableDataset.VariableHasType<double>(variableName)) {
     189        originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
     190        replacementValues = GetReplacementValuesForDouble(modifiableDataset, rows, (List<double>)originalValues, replacementMethod);
     191      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
     192        originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
     193        replacementValues = GetReplacementValuesForString(model, modifiableDataset, variableName, rows, (List<string>)originalValues, targetValues, factorReplacementMethod);
     194      } else {
     195        throw new NotSupportedException("Variable not supported");
     196      }
     197
     198      return replacementValues;
     199    }
     200
     201    private static IList GetReplacementValuesForDouble(ModifiableDataset modifiableDataset,
     202      IEnumerable<int> rows,
     203      List<double> originalValues,
     204      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle) {
     205
     206      IRandom random = new FastRandom(31415);
     207      List<double> replacementValues;
    181208      double replacementValue;
    182       List<double> replacementValues;
    183       IRandom rand;
    184 
    185       switch (replacement) {
     209
     210      switch (replacementMethod) {
    186211        case ReplacementMethodEnum.Median:
    187212          replacementValue = rows.Select(r => originalValues[r]).Median();
    188           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
     213          replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
    189214          break;
    190215        case ReplacementMethodEnum.Average:
    191216          replacementValue = rows.Select(r => originalValues[r]).Average();
    192           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
     217          replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
    193218          break;
    194219        case ReplacementMethodEnum.Shuffle:
    195220          // new var has same empirical distribution but the relation to y is broken
    196           rand = new FastRandom(31415);
    197221          // prepare a complete column for the dataset
    198           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
     222          replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
    199223          // shuffle only the selected rows
    200           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
     224          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
    201225          int i = 0;
    202226          // update column values
     
    208232          var avg = rows.Select(r => originalValues[r]).Average();
    209233          var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
    210           rand = new FastRandom(31415);
    211234          // prepare a complete column for the dataset
    212           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
     235          replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
    213236          // update column values
    214237          foreach (var r in rows) {
    215             replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
     238            replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
    216239          }
    217240          break;
    218241
    219242        default:
    220           throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
    221       }
    222 
    223       return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
    224     }
    225 
    226     private static IEnumerable<double> EvaluateModelWithReplacedVariable(
    227       IClassificationModel model, string variable, ModifiableDataset dataset,
    228       IEnumerable<int> rows,
    229       FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
    230       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    231       List<string> replacementValues;
    232       IRandom rand;
    233 
    234       switch (replacement) {
     243          throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
     244      }
     245
     246      return replacementValues;
     247    }
     248
     249    private static IList GetReplacementValuesForString(IClassificationModel model,
     250      ModifiableDataset modifiableDataset,
     251      string variableName,
     252      IEnumerable<int> rows,
     253      List<string> originalValues,
     254      IEnumerable<double> targetValues,
     255      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Shuffle) {
     256
     257      List<string> replacementValues = null;
     258      IRandom random = new FastRandom(31415);
     259
     260      switch (factorReplacementMethod) {
     261        case FactorReplacementMethodEnum.Best:
     262          // try replacing with all possible values and find the best replacement value
     263          var bestQuality = double.NegativeInfinity;
     264          foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
     265            List<string> curReplacementValues = Enumerable.Repeat(repl, modifiableDataset.Rows).ToList();
     266            //fholzing: this result could be used later on (theoretically), but is neglected for better readability/method consistency
     267            var newValue = CalculateQualityForReplacement(model, modifiableDataset, variableName, originalValues, rows, curReplacementValues, targetValues);
     268            var curQuality = newValue;
     269
     270            if (curQuality > bestQuality) {
     271              bestQuality = curQuality;
     272              replacementValues = curReplacementValues;
     273            }
     274          }
     275          break;
    235276        case FactorReplacementMethodEnum.Mode:
    236277          var mostCommonValue = rows.Select(r => originalValues[r])
     
    238279            .OrderByDescending(g => g.Count())
    239280            .First().Key;
    240           replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
     281          replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
    241282          break;
    242283        case FactorReplacementMethodEnum.Shuffle:
    243284          // new var has same empirical distribution but the relation to y is broken
    244           rand = new FastRandom(31415);
    245285          // prepare a complete column for the dataset
    246           replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
     286          replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
    247287          // shuffle only the selected rows
    248           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
     288          var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
    249289          int i = 0;
    250290          // update column values
     
    254294          break;
    255295        default:
    256           throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
    257       }
    258 
    259       return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
    260     }
    261 
    262     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
    263       ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
    264       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    265       dataset.ReplaceVariable(variable, replacementValues.ToList());
    266 
     296          throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
     297      }
     298
     299      return replacementValues;
     300    }
     301
     302    private static double CalculateQualityForReplacement(
     303      IClassificationModel model,
     304      ModifiableDataset modifiableDataset,
     305      string variableName,
     306      IList originalValues,
     307      IEnumerable<int> rows,
     308      IList replacementValues,
     309      IEnumerable<double> targetValues) {
     310
     311      modifiableDataset.ReplaceVariable(variableName, replacementValues);
    267312      var discModel = model as IDiscriminantFunctionClassificationModel;
    268313      if (discModel != null) {
    269         var problemData = new ClassificationProblemData(dataset, dataset.VariableNames, model.TargetVariable);
     314        var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);
    270315        discModel.RecalculateModelParameters(problemData, rows);
    271316      }
    272317
    273318      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    274       var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
    275       dataset.ReplaceVariable(variable, originalValues);
    276 
    277       return estimates;
    278     }
    279     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
    280       ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
    281       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    282       dataset.ReplaceVariable(variable, replacementValues.ToList());
    283 
    284 
    285       var discModel = model as IDiscriminantFunctionClassificationModel;
    286       if (discModel != null) {
    287         var problemData = new ClassificationProblemData(dataset, dataset.VariableNames, model.TargetVariable);
    288         discModel.RecalculateModelParameters(problemData, rows);
    289       }
    290 
    291       //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    292       var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
    293       dataset.ReplaceVariable(variable, originalValues);
    294 
    295       return estimates;
     319      var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
     320      var ret = CalculateQuality(targetValues, estimates);
     321      modifiableDataset.ReplaceVariable(variableName, originalValues);
     322
     323      return ret;
     324    }
     325
     326    public static double CalculateQuality(IEnumerable<double> targetValues, IEnumerable<double> estimatedClassValues) {
     327      OnlineCalculatorError errorState;
     328      var ret = OnlineAccuracyCalculator.Calculate(targetValues, estimatedClassValues, out errorState);
     329      if (errorState != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
     330      return ret;
     331    }
     332
     333    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) {
     334      IEnumerable<int> rows;
     335
     336      switch (dataPartition) {
     337        case DataPartitionEnum.All:
     338          rows = problemData.AllIndices;
     339          break;
     340        case DataPartitionEnum.Test:
     341          rows = problemData.TestIndices;
     342          break;
     343        case DataPartitionEnum.Training:
     344          rows = problemData.TrainingIndices;
     345          break;
     346        default:
     347          throw new NotSupportedException("DataPartition not supported");
     348      }
     349
     350      return rows;
    296351    }
    297352  }
Note: See TracChangeset for help on using the changeset viewer.