Changeset 16036


Ignore:
Timestamp:
08/01/18 14:01:08 (13 months ago)
Author:
fholzing
Message:

#2904: Streamlined the variableimpactcalculator code on both Regression and Classification. Taken over the regression-code for classification with some minor adaptations.

Location:
branches/2904_CalculateImpacts
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs

    r15674 r16036  
    2323
    2424using System;
     25using System.Collections;
    2526using System.Collections.Generic;
    2627using System.Linq;
     
    3637  [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")]
    3738  public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem {
     39    #region Parameters/Properties
    3840    public enum ReplacementMethodEnum {
    3941      Median,
     
    5456
    5557    private const string ReplacementParameterName = "Replacement Method";
     58    private const string FactorReplacementParameterName = "Factor Replacement Method";
    5659    private const string DataPartitionParameterName = "DataPartition";
    5760
    5861    public IFixedValueParameter<EnumValue<ReplacementMethodEnum>> ReplacementParameter {
    5962      get { return (IFixedValueParameter<EnumValue<ReplacementMethodEnum>>)Parameters[ReplacementParameterName]; }
     63    }
     64    public IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>> FactorReplacementParameter {
     65      get { return (IFixedValueParameter<EnumValue<FactorReplacementMethodEnum>>)Parameters[FactorReplacementParameterName]; }
    6066    }
    6167    public IFixedValueParameter<EnumValue<DataPartitionEnum>> DataPartitionParameter {
     
    6773      set { ReplacementParameter.Value.Value = value; }
    6874    }
     75    public FactorReplacementMethodEnum FactorReplacementMethod {
     76      get { return FactorReplacementParameter.Value.Value; }
     77      set { FactorReplacementParameter.Value.Value = value; }
     78    }
    6979    public DataPartitionEnum DataPartition {
    7080      get { return DataPartitionParameter.Value.Value; }
    7181      set { DataPartitionParameter.Value.Value = value; }
    7282    }
    73 
    74 
     83    #endregion
     84
     85    #region Ctor/Cloner
    7586    [StorableConstructor]
    7687    private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
    7788    private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner)
    7889      : base(original, cloner) { }
     90    public ClassificationSolutionVariableImpactsCalculator()
     91      : base() {
     92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle)));
     93      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
     94    }
     95
    7996    public override IDeepCloneable Clone(Cloner cloner) {
    8097      return new ClassificationSolutionVariableImpactsCalculator(this, cloner);
    8198    }
    82 
    83     public ClassificationSolutionVariableImpactsCalculator()
    84       : base() {
    85       Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
    86       Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
    87     }
     99    #endregion
    88100
    89101    //mkommend: annoying name clash with static method, open to better naming suggestions
    90102    public IEnumerable<Tuple<string, double>> Calculate(IClassificationSolution solution) {
    91       return CalculateImpacts(solution, DataPartition, ReplacementMethod);
     103      return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
    92104    }
    93105
    94106    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
    95107      IClassificationSolution solution,
    96       DataPartitionEnum data = DataPartitionEnum.Training,
    97       ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Median,
     108      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     109      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     110      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
     111      return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition);
     112    }
     113
     114    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
     115      IClassificationModel model,
     116      IClassificationProblemData problemData,
     117      IEnumerable<double> estimatedValues,
     118      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     119      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
     120      DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
     121      IEnumerable<int> rows = GetPartitionRows(dataPartition, problemData);
     122      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
     123    }
     124
     125
     126    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
     127     IClassificationModel model,
     128     IClassificationProblemData problemData,
     129     IEnumerable<double> estimatedClassValues,
     130     IEnumerable<int> rows,
     131     ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
     132     FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
     133      //Calculate original quality-values (via calculator, default is Accuracy)
     134      OnlineCalculatorError error;
     135      IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     136      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v));
     137      var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
     138      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
     139
     140      var impacts = new Dictionary<string, double>();
     141      var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
     142      var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
     143      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
     144
     145      foreach (var inputVariable in allowedInputVariables) {
     146        impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
     147      }
     148
     149      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
     150    }
     151
     152
     153    public static double CalculateImpact(string variableName,
     154      IClassificationModel model,
     155      ModifiableDataset modifiableDataset,
     156      IEnumerable<int> rows,
     157      IEnumerable<double> targetValues,
     158      double originalValue,
     159      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    98160      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    99 
    100       var problemData = solution.ProblemData;
    101       var dataset = problemData.Dataset;
    102 
     161      double impact = 0;
     162      OnlineCalculatorError error;
     163      IRandom random;
     164      double replacementValue;
     165      IEnumerable<double> newEstimates = null;
     166      double newValue = 0;
     167
     168      if (modifiableDataset.VariableHasType<double>(variableName)) {
     169        #region NumericalVariable
     170        var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
     171        List<double> replacementValues;
     172        IRandom rand;
     173
     174        switch (replacementMethod) {
     175          case ReplacementMethodEnum.Median:
     176            replacementValue = rows.Select(r => originalValues[r]).Median();
     177            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     178            break;
     179          case ReplacementMethodEnum.Average:
     180            replacementValue = rows.Select(r => originalValues[r]).Average();
     181            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     182            break;
     183          case ReplacementMethodEnum.Shuffle:
     184            // new var has same empirical distribution but the relation to y is broken
     185            rand = new FastRandom(31415);
     186            // prepare a complete column for the dataset
     187            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     188            // shuffle only the selected rows
     189            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
     190            int i = 0;
     191            // update column values
     192            foreach (var r in rows) {
     193              replacementValues[r] = shuffledValues[i++];
     194            }
     195            break;
     196          case ReplacementMethodEnum.Noise:
     197            var avg = rows.Select(r => originalValues[r]).Average();
     198            var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
     199            rand = new FastRandom(31415);
     200            // prepare a complete column for the dataset
     201            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     202            // update column values
     203            foreach (var r in rows) {
     204              replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
     205            }
     206            break;
     207
     208          default:
     209            throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
     210        }
     211
     212        newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     213        newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     214        if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
     215
     216        impact = originalValue - newValue;
     217        #endregion
     218      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
     219        #region FactorVariable
     220        var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
     221        List<string> replacementValues;
     222
     223        switch (factorReplacementMethod) {
     224          case FactorReplacementMethodEnum.Best:
     225            // try replacing with all possible values and find the best replacement value
     226            var smallestImpact = double.PositiveInfinity;
     227            foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
     228              newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
     229              newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     230              if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     231
     232              var curImpact = originalValue - newValue;
     233              if (curImpact < smallestImpact) smallestImpact = curImpact;
     234            }
     235            impact = smallestImpact;
     236            break;
     237          case FactorReplacementMethodEnum.Mode:
     238            var mostCommonValue = rows.Select(r => originalValues[r])
     239              .GroupBy(v => v)
     240              .OrderByDescending(g => g.Count())
     241              .First().Key;
     242            replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
     243
     244            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     245            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     246            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     247
     248            impact = originalValue - newValue;
     249            break;
     250          case FactorReplacementMethodEnum.Shuffle:
     251            // new var has same empirical distribution but the relation to y is broken
     252            random = new FastRandom(31415);
     253            // prepare a complete column for the dataset
     254            replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
     255            // shuffle only the selected rows
     256            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     257            int i = 0;
     258            // update column values
     259            foreach (var r in rows) {
     260              replacementValues[r] = shuffledValues[i++];
     261            }
     262
     263            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     264            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     265            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     266
     267            impact = originalValue - newValue;
     268            break;
     269          default:
     270            throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
     271        }
     272        #endregion
     273      } else {
     274        throw new NotSupportedException("Variable not supported");
     275      }
     276
     277      return impact;
     278    }
     279
     280    /// <summary>
     281    /// Calculates and returns the VariableImpact (calculated via Accuracy).
     282    /// </summary>
     283    /// <param name="targetValues">The actual values</param>
     284    /// <param name="estimatedValues">The calculated/replaced values</param>
     285    /// <param name="errorState"></param>
     286    /// <returns></returns>
     287    public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
     288      //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
     289      //as the code below does. But this way we can easily swap the calculator later on, so the user 
     290      //could choose a Calculator during runtime in future versions.
     291      IOnlineCalculator calculator = new OnlineAccuracyCalculator();
     292      IEnumerator<double> firstEnumerator = targetValues.GetEnumerator();
     293      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
     294
     295      // always move forward both enumerators (do not use short-circuit evaluation!)
     296      while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
     297        double original = firstEnumerator.Current;
     298        double estimated = secondEnumerator.Current;
     299        calculator.Add(original, estimated);
     300        if (calculator.ErrorState != OnlineCalculatorError.None) break;
     301      }
     302
     303      // check if both enumerators are at the end to make sure both enumerations have the same length
     304      if (calculator.ErrorState == OnlineCalculatorError.None &&
     305           (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
     306        throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
     307      } else {
     308        errorState = calculator.ErrorState;
     309        return calculator.Value;
     310      }
     311    }
     312
     313    /// <summary>
     314    /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
     315    /// and changes the value of the model-variables back to the original ones.
     316    /// </summary>
     317    /// <param name="originalValues"></param>
     318    /// <param name="model"></param>
     319    /// <param name="variableName"></param>
     320    /// <param name="modifiableDataset"></param>
     321    /// <param name="rows"></param>
     322    /// <param name="replacementValues"></param>
     323    /// <returns></returns>
     324    private static IEnumerable<double> GetReplacedEstimates(
     325     IList originalValues,
     326     IClassificationModel model,
     327     string variableName,
     328     ModifiableDataset modifiableDataset,
     329     IEnumerable<int> rows,
     330     IList replacementValues) {
     331      modifiableDataset.ReplaceVariable(variableName, replacementValues);
     332
     333      var discModel = model as IDiscriminantFunctionClassificationModel;
     334      if (discModel != null) {
     335        var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);
     336        discModel.RecalculateModelParameters(problemData, rows);
     337      }
     338
     339      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
     340      var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
     341      modifiableDataset.ReplaceVariable(variableName, originalValues);
     342
     343      return estimates;
     344    }
     345
     346
     347    /// <summary>
     348    /// Returns a collection of the row-indices for a given DataPartition (training or test)
     349    /// </summary>
     350    /// <param name="dataPartition"></param>
     351    /// <param name="problemData"></param>
     352    /// <returns></returns>
     353    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) {
    103354      IEnumerable<int> rows;
    104       IEnumerable<double> targetValues;
    105       double originalAccuracy;
    106 
    107       OnlineCalculatorError error;
    108 
    109       switch (data) {
     355
     356      switch (dataPartition) {
    110357        case DataPartitionEnum.All:
    111358          rows = problemData.AllIndices;
    112           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.AllIndices).ToList();
    113           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedClassValues, out error);
    114           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
     359          break;
     360        case DataPartitionEnum.Test:
     361          rows = problemData.TestIndices;
    115362          break;
    116363        case DataPartitionEnum.Training:
    117364          rows = problemData.TrainingIndices;
    118           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToList();
    119           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTrainingClassValues, out error);
    120           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
    121           break;
    122         case DataPartitionEnum.Test:
    123           rows = problemData.TestIndices;
    124           targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToList();
    125           originalAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, solution.EstimatedTestClassValues, out error);
    126           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during accuracy calculation.");
    127           break;
    128         default: throw new ArgumentException(string.Format("DataPartition {0} cannot be handled.", data));
    129       }
    130 
    131       var impacts = new Dictionary<string, double>();
    132       var modifiableDataset = ((Dataset)dataset).ToModifiable();
    133 
    134       var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(solution.Model.VariablesUsedForPrediction));
    135       var allowedInputVariables = dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
    136 
    137       // calculate impacts for double variables
    138       foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
    139         var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);
    140         var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    141         if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
    142 
    143         impacts[inputVariable] = originalAccuracy - newAccuracy;
    144       }
    145 
    146       // calculate impacts for string variables
    147       foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<string>)) {
    148         if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
    149           // try replacing with all possible values and find the best replacement value
    150           var smallestImpact = double.PositiveInfinity;
    151           foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
    152             var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
    153               Enumerable.Repeat(repl, dataset.Rows));
    154             var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    155             if (error != OnlineCalculatorError.None)
    156               throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
    157 
    158             var impact = originalAccuracy - newAccuracy;
    159             if (impact < smallestImpact) smallestImpact = impact;
    160           }
    161           impacts[inputVariable] = smallestImpact;
    162         } else {
    163           // for replacement methods shuffle and mode
    164           // calculate impacts for factor variables
    165 
    166           var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
    167             factorReplacementMethod);
    168           var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    169           if (error != OnlineCalculatorError.None)
    170             throw new InvalidOperationException("Error during accuracy calculation with replaced inputs.");
    171 
    172           impacts[inputVariable] = originalAccuracy - newAccuracy;
    173         }
    174       } // foreach
    175       return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
    176     }
    177 
    178     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable, ModifiableDataset dataset, IEnumerable<int> rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Median) {
    179       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    180       double replacementValue;
    181       List<double> replacementValues;
    182       IRandom rand;
    183 
    184       switch (replacement) {
    185         case ReplacementMethodEnum.Median:
    186           replacementValue = rows.Select(r => originalValues[r]).Median();
    187           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
    188           break;
    189         case ReplacementMethodEnum.Average:
    190           replacementValue = rows.Select(r => originalValues[r]).Average();
    191           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
    192           break;
    193         case ReplacementMethodEnum.Shuffle:
    194           // new var has same empirical distribution but the relation to y is broken
    195           rand = new FastRandom(31415);
    196           // prepare a complete column for the dataset
    197           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
    198           // shuffle only the selected rows
    199           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
    200           int i = 0;
    201           // update column values
    202           foreach (var r in rows) {
    203             replacementValues[r] = shuffledValues[i++];
    204           }
    205           break;
    206         case ReplacementMethodEnum.Noise:
    207           var avg = rows.Select(r => originalValues[r]).Average();
    208           var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
    209           rand = new FastRandom(31415);
    210           // prepare a complete column for the dataset
    211           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
    212           // update column values
    213           foreach (var r in rows) {
    214             replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
    215           }
    216           break;
    217 
    218         default:
    219           throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
    220       }
    221 
    222       return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
    223     }
    224 
    225     private static IEnumerable<double> EvaluateModelWithReplacedVariable(
    226       IClassificationModel model, string variable, ModifiableDataset dataset,
    227       IEnumerable<int> rows,
    228       FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
    229       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    230       List<string> replacementValues;
    231       IRandom rand;
    232 
    233       switch (replacement) {
    234         case FactorReplacementMethodEnum.Mode:
    235           var mostCommonValue = rows.Select(r => originalValues[r])
    236             .GroupBy(v => v)
    237             .OrderByDescending(g => g.Count())
    238             .First().Key;
    239           replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
    240           break;
    241         case FactorReplacementMethodEnum.Shuffle:
    242           // new var has same empirical distribution but the relation to y is broken
    243           rand = new FastRandom(31415);
    244           // prepare a complete column for the dataset
    245           replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
    246           // shuffle only the selected rows
    247           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
    248           int i = 0;
    249           // update column values
    250           foreach (var r in rows) {
    251             replacementValues[r] = shuffledValues[i++];
    252           }
    253365          break;
    254366        default:
    255           throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
    256       }
    257 
    258       return EvaluateModelWithReplacedVariable(model, variable, dataset, rows, replacementValues);
    259     }
    260 
    261     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
    262       ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<double> replacementValues) {
    263       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    264       dataset.ReplaceVariable(variable, replacementValues.ToList());
    265       //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    266       var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
    267       dataset.ReplaceVariable(variable, originalValues);
    268 
    269       return estimates;
    270     }
    271     private static IEnumerable<double> EvaluateModelWithReplacedVariable(IClassificationModel model, string variable,
    272       ModifiableDataset dataset, IEnumerable<int> rows, IEnumerable<string> replacementValues) {
    273       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    274       dataset.ReplaceVariable(variable, replacementValues.ToList());
    275       //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    276       var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
    277       dataset.ReplaceVariable(variable, originalValues);
    278 
    279       return estimates;
    280     }
     367          throw new NotSupportedException("DataPartition not supported");
     368      }
     369
     370      return rows;
     371    }
     372
    281373  }
    282374}
  • branches/2904_CalculateImpacts/3.4/Implementation/Regression/RegressionSolutionVariableImpactsCalculator.cs

    r16035 r16036  
    100100    #endregion
    101101
    102     #region Public Methods/Wrappers
    103102    //mkommend: annoying name clash with static method, open to better naming suggestions
    104103    public IEnumerable<Tuple<string, double>> Calculate(IRegressionSolution solution) {
     
    159158      ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
    160159      FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
    161 
    162160      double impact = 0;
    163 
    164       // calculate impacts for double variables
     161      OnlineCalculatorError error;
     162      IRandom random;
     163      double replacementValue;
     164      IEnumerable<double> newEstimates = null;
     165      double newValue = 0;
     166
    165167      if (modifiableDataset.VariableHasType<double>(variableName)) {
    166         impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
     168        #region NumericalVariable
     169        var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
     170        List<double> replacementValues;
     171
     172        switch (replacementMethod) {
     173          case ReplacementMethodEnum.Median:
     174            replacementValue = rows.Select(r => originalValues[r]).Median();
     175            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     176            break;
     177          case ReplacementMethodEnum.Average:
     178            replacementValue = rows.Select(r => originalValues[r]).Average();
     179            replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
     180            break;
     181          case ReplacementMethodEnum.Shuffle:
     182            // new var has same empirical distribution but the relation to y is broken
     183            random = new FastRandom(31415);
     184            // prepare a complete column for the dataset
     185            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     186            // shuffle only the selected rows
     187            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     188            int i = 0;
     189            // update column values
     190            foreach (var r in rows) {
     191              replacementValues[r] = shuffledValues[i++];
     192            }
     193            break;
     194          case ReplacementMethodEnum.Noise:
     195            var avg = rows.Select(r => originalValues[r]).Average();
     196            var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
     197            random = new FastRandom(31415);
     198            // prepare a complete column for the dataset
     199            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
     200            // update column values
     201            foreach (var r in rows) {
     202              replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
     203            }
     204            break;
     205
     206          default:
     207            throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
     208        }
     209
     210        newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     211        newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     212        if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
     213
     214        impact = originalValue - newValue;
     215        #endregion
    167216      } else if (modifiableDataset.VariableHasType<string>(variableName)) {
    168         impact = CalculateImpactForFactorVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
     217        #region FactorVariable
     218        var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
     219        List<string> replacementValues;
     220
     221        switch (factorReplacementMethod) {
     222          case FactorReplacementMethodEnum.Best:
     223            // try replacing with all possible values and find the best replacement value
     224            var smallestImpact = double.PositiveInfinity;
     225            foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
     226              newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
     227              newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     228              if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     229
     230              var curImpact = originalValue - newValue;
     231              if (curImpact < smallestImpact) smallestImpact = curImpact;
     232            }
     233            impact = smallestImpact;
     234            break;
     235          case FactorReplacementMethodEnum.Mode:
     236            var mostCommonValue = rows.Select(r => originalValues[r])
     237              .GroupBy(v => v)
     238              .OrderByDescending(g => g.Count())
     239              .First().Key;
     240            replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
     241
     242            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     243            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     244            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     245
     246            impact = originalValue - newValue;
     247            break;
     248          case FactorReplacementMethodEnum.Shuffle:
     249            // new var has same empirical distribution but the relation to y is broken
     250            random = new FastRandom(31415);
     251            // prepare a complete column for the dataset
     252            replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
     253            // shuffle only the selected rows
     254            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
     255            int i = 0;
     256            // update column values
     257            foreach (var r in rows) {
     258              replacementValues[r] = shuffledValues[i++];
     259            }
     260
     261            newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
     262            newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
     263            if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
     264
     265            impact = originalValue - newValue;
     266            break;
     267          default:
     268            throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
     269        }
     270        #endregion
    169271      } else {
    170272        throw new NotSupportedException("Variable not supported");
    171273      }
     274
    172275      return impact;
    173276    }
    174     #endregion
    175 
    176     private static double CalculateImpactForNumericalVariables(string variableName,
    177       IRegressionModel model,
    178       ModifiableDataset modifiableDataset,
    179       IEnumerable<int> rows,
    180       IEnumerable<double> targetValues,
    181       double originalValue,
    182       ReplacementMethodEnum replacementMethod) {
    183       OnlineCalculatorError error;
    184       var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod);
    185       var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    186       if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
    187       return originalValue - newValue;
    188     }
    189 
    190     private static double CalculateImpactForFactorVariables(string variableName,
    191       IRegressionModel model,
    192       ModifiableDataset modifiableDataset,
    193       IEnumerable<int> rows,
    194       IEnumerable<double> targetValues,
    195       double originalValue,
    196       FactorReplacementMethodEnum factorReplacementMethod) {
    197 
    198       OnlineCalculatorError error;
    199       if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
    200         // try replacing with all possible values and find the best replacement value
    201         var smallestImpact = double.PositiveInfinity;
    202         foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
    203           var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
    204           var newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
    205           var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    206           if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    207 
    208           var curImpact = originalValue - newValue;
    209           if (curImpact < smallestImpact) smallestImpact = curImpact;
    210         }
    211         return smallestImpact;
    212       } else {
    213         // for replacement methods shuffle and mode
    214         // calculate impacts for factor variables
    215         var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod);
    216         var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
    217         if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
    218 
    219         return originalValue - newValue;
    220       }
    221     }
    222 
    223     private static IEnumerable<double> GetReplacedValuesForNumericalVariables(
    224       IRegressionModel model,
    225       string variable,
    226       ModifiableDataset dataset,
    227       IEnumerable<int> rows,
    228       ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) {
    229       var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    230       double replacementValue;
    231       List<double> replacementValues;
    232       IRandom rand;
    233 
    234       switch (replacement) {
    235         case ReplacementMethodEnum.Median:
    236           replacementValue = rows.Select(r => originalValues[r]).Median();
    237           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
    238           break;
    239         case ReplacementMethodEnum.Average:
    240           replacementValue = rows.Select(r => originalValues[r]).Average();
    241           replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
    242           break;
    243         case ReplacementMethodEnum.Shuffle:
    244           // new var has same empirical distribution but the relation to y is broken
    245           rand = new FastRandom(31415);
    246           // prepare a complete column for the dataset
    247           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
    248           // shuffle only the selected rows
    249           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
    250           int i = 0;
    251           // update column values
    252           foreach (var r in rows) {
    253             replacementValues[r] = shuffledValues[i++];
    254           }
    255           break;
    256         case ReplacementMethodEnum.Noise:
    257           var avg = rows.Select(r => originalValues[r]).Average();
    258           var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
    259           rand = new FastRandom(31415);
    260           // prepare a complete column for the dataset
    261           replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
    262           // update column values
    263           foreach (var r in rows) {
    264             replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
    265           }
    266           break;
    267 
    268         default:
    269           throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
    270       }
    271 
    272       return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues);
    273     }
    274 
    275     private static IEnumerable<double> GetReplacedValuesForFactorVariables(
    276       IRegressionModel model,
    277       string variable,
    278       ModifiableDataset dataset,
    279       IEnumerable<int> rows,
    280       FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
    281       var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    282       List<string> replacementValues;
    283       IRandom rand;
    284 
    285       switch (replacement) {
    286         case FactorReplacementMethodEnum.Mode:
    287           var mostCommonValue = rows.Select(r => originalValues[r])
    288             .GroupBy(v => v)
    289             .OrderByDescending(g => g.Count())
    290             .First().Key;
    291           replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
    292           break;
    293         case FactorReplacementMethodEnum.Shuffle:
    294           // new var has same empirical distribution but the relation to y is broken
    295           rand = new FastRandom(31415);
    296           // prepare a complete column for the dataset
    297           replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
    298           // shuffle only the selected rows
    299           var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
    300           int i = 0;
    301           // update column values
    302           foreach (var r in rows) {
    303             replacementValues[r] = shuffledValues[i++];
    304           }
    305           break;
    306         default:
    307           throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
    308       }
    309 
    310       return GetReplacedEstimates(originalValues, model, variable, dataset, rows, replacementValues);
    311     }
    312 
     277
     278    /// <summary>
     279    /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
     280    /// and changes the value of the model-variables back to the original ones.
     281    /// </summary>
     282    /// <param name="originalValues"></param>
     283    /// <param name="model"></param>
     284    /// <param name="variableName"></param>
     285    /// <param name="modifiableDataset"></param>
     286    /// <param name="rows"></param>
     287    /// <param name="replacementValues"></param>
     288    /// <returns></returns>
    313289    private static IEnumerable<double> GetReplacedEstimates(
    314290      IList originalValues,
    315291      IRegressionModel model,
    316       string variable,
    317       ModifiableDataset dataset,
     292      string variableName,
     293      ModifiableDataset modifiableDataset,
    318294      IEnumerable<int> rows,
    319295      IList replacementValues) {
    320       dataset.ReplaceVariable(variable, replacementValues);
     296      modifiableDataset.ReplaceVariable(variableName, replacementValues);
    321297      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    322       var estimates = model.GetEstimatedValues(dataset, rows).ToList();
    323       dataset.ReplaceVariable(variable, originalValues);
     298      var estimates = model.GetEstimatedValues(modifiableDataset, rows).ToList();
     299      modifiableDataset.ReplaceVariable(variableName, originalValues);
    324300
    325301      return estimates;
    326302    }
    327303
    328     public static double CalculateVariableImpact(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
    329       IEnumerator<double> firstEnumerator = originalValues.GetEnumerator();
     304    /// <summary>
     305    /// Calculates and returns the VariableImpact (calculated via Pearsons R²).
     306    /// </summary>
     307    /// <param name="targetValues">The actual values</param>
     308    /// <param name="estimatedValues">The calculated/replaced values</param>
     309    /// <param name="errorState"></param>
     310    /// <returns></returns>
     311    public static double CalculateVariableImpact(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, out OnlineCalculatorError errorState) {
     312      //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
     313      //as the code below does. But this way we can easily swap the calculator later on, so the user 
     314      //could choose a Calculator during runtime in future versions.
     315      IOnlineCalculator calculator = new OnlinePearsonsRSquaredCalculator();
     316      IEnumerator<double> firstEnumerator = targetValues.GetEnumerator();
    330317      IEnumerator<double> secondEnumerator = estimatedValues.GetEnumerator();
    331       var calculator = new OnlinePearsonsRSquaredCalculator();
    332318
    333319      // always move forward both enumerators (do not use short-circuit evaluation!)
     
    349335    }
    350336
     337    /// <summary>
     338    /// Returns a collection of the row-indices for a given DataPartition (training or test)
     339    /// </summary>
     340    /// <param name="dataPartition"></param>
     341    /// <param name="problemData"></param>
     342    /// <returns></returns>
    351343    public static IEnumerable<int> GetPartitionRows(DataPartitionEnum dataPartition, IRegressionProblemData problemData) {
    352344      IEnumerable<int> rows;
  • branches/2904_CalculateImpacts/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationSolutionVariableImpactsView.Designer.cs

    r15753 r16036  
    1919 */
    2020#endregion
     21
     22
    2123namespace HeuristicLab.Problems.DataAnalysis.Views {
    2224  partial class ClassificationSolutionVariableImpactsView {
     
    4446    /// </summary>
    4547    private void InitializeComponent() {
    46       this.variableImactsArrayView = new HeuristicLab.Data.Views.StringConvertibleArrayView();
    47       this.dataPartitionComboBox = new System.Windows.Forms.ComboBox();
    48       this.dataPartitionLabel = new System.Windows.Forms.Label();
    49       this.numericVarReplacementLabel = new System.Windows.Forms.Label();
    50       this.replacementComboBox = new System.Windows.Forms.ComboBox();
    51       this.factorVarReplacementLabel = new System.Windows.Forms.Label();
    52       this.factorVarReplComboBox = new System.Windows.Forms.ComboBox();
    5348      this.ascendingCheckBox = new System.Windows.Forms.CheckBox();
    5449      this.sortByLabel = new System.Windows.Forms.Label();
    5550      this.sortByComboBox = new System.Windows.Forms.ComboBox();
    56       this.backgroundWorker = new System.ComponentModel.BackgroundWorker();
     51      this.factorVarReplComboBox = new System.Windows.Forms.ComboBox();
     52      this.factorVarReplacementLabel = new System.Windows.Forms.Label();
     53      this.replacementComboBox = new System.Windows.Forms.ComboBox();
     54      this.numericVarReplacementLabel = new System.Windows.Forms.Label();
     55      this.dataPartitionLabel = new System.Windows.Forms.Label();
     56      this.dataPartitionComboBox = new System.Windows.Forms.ComboBox();
     57      this.variableImactsArrayView = new HeuristicLab.Data.Views.StringConvertibleArrayView();
    5758      this.SuspendLayout();
    5859      //
    59       // variableImactsArrayView
    60       //
    61       this.variableImactsArrayView.Anchor = ((System.Windows.Forms.AnchorStyles)((((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Bottom)
    62             | System.Windows.Forms.AnchorStyles.Left)
    63             | System.Windows.Forms.AnchorStyles.Right)));
    64       this.variableImactsArrayView.Caption = "StringConvertibleArray View";
    65       this.variableImactsArrayView.Content = null;
    66       this.variableImactsArrayView.Location = new System.Drawing.Point(3, 84);
    67       this.variableImactsArrayView.Name = "variableImactsArrayView";
    68       this.variableImactsArrayView.ReadOnly = true;
    69       this.variableImactsArrayView.Size = new System.Drawing.Size(662, 278);
    70       this.variableImactsArrayView.TabIndex = 2;
    71       //
    72       // dataPartitionComboBox
    73       //
    74       this.dataPartitionComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList;
    75       this.dataPartitionComboBox.FormattingEnabled = true;
    76       this.dataPartitionComboBox.Items.AddRange(new object[] {
    77             HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.Training,
    78             HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.Test,
    79             HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.All});
    80       this.dataPartitionComboBox.Location = new System.Drawing.Point(197, 3);
    81       this.dataPartitionComboBox.Name = "dataPartitionComboBox";
    82       this.dataPartitionComboBox.Size = new System.Drawing.Size(121, 21);
    83       this.dataPartitionComboBox.TabIndex = 1;
    84       this.dataPartitionComboBox.SelectedIndexChanged += new System.EventHandler(this.dataPartitionComboBox_SelectedIndexChanged);
    85       //
    86       // dataPartitionLabel
    87       //
    88       this.dataPartitionLabel.AutoSize = true;
    89       this.dataPartitionLabel.Location = new System.Drawing.Point(3, 6);
    90       this.dataPartitionLabel.Name = "dataPartitionLabel";
    91       this.dataPartitionLabel.Size = new System.Drawing.Size(73, 13);
    92       this.dataPartitionLabel.TabIndex = 0;
    93       this.dataPartitionLabel.Text = "Data partition:";
    94       //
    95       // numericVarReplacementLabel
    96       //
    97       this.numericVarReplacementLabel.AutoSize = true;
    98       this.numericVarReplacementLabel.Location = new System.Drawing.Point(3, 33);
    99       this.numericVarReplacementLabel.Name = "numericVarReplacementLabel";
    100       this.numericVarReplacementLabel.Size = new System.Drawing.Size(173, 13);
    101       this.numericVarReplacementLabel.TabIndex = 2;
    102       this.numericVarReplacementLabel.Text = "Replacement for numeric variables:";
     60      // ascendingCheckBox
     61      //
     62      this.ascendingCheckBox.AutoSize = true;
     63      this.ascendingCheckBox.Location = new System.Drawing.Point(534, 6);
     64      this.ascendingCheckBox.Name = "ascendingCheckBox";
     65      this.ascendingCheckBox.Size = new System.Drawing.Size(76, 17);
     66      this.ascendingCheckBox.TabIndex = 7;
     67      this.ascendingCheckBox.Text = "Ascending";
     68      this.ascendingCheckBox.UseVisualStyleBackColor = true;
     69      this.ascendingCheckBox.CheckedChanged += new System.EventHandler(this.ascendingCheckBox_CheckedChanged);
     70      //
     71      // sortByLabel
     72      //
     73      this.sortByLabel.AutoSize = true;
     74      this.sortByLabel.Location = new System.Drawing.Point(324, 6);
     75      this.sortByLabel.Name = "sortByLabel";
     76      this.sortByLabel.Size = new System.Drawing.Size(77, 13);
     77      this.sortByLabel.TabIndex = 4;
     78      this.sortByLabel.Text = "Sorting criteria:";
     79      //
     80      // sortByComboBox
     81      //
     82      this.sortByComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList;
     83      this.sortByComboBox.FormattingEnabled = true;
     84      this.sortByComboBox.Items.AddRange(new object[] {
     85            HeuristicLab.Problems.DataAnalysis.Views.ClassificationSolutionVariableImpactsView.SortingCriteria.ImpactValue,
     86            HeuristicLab.Problems.DataAnalysis.Views.ClassificationSolutionVariableImpactsView.SortingCriteria.Occurrence,
     87            HeuristicLab.Problems.DataAnalysis.Views.ClassificationSolutionVariableImpactsView.SortingCriteria.VariableName});
     88      this.sortByComboBox.Location = new System.Drawing.Point(407, 3);
     89      this.sortByComboBox.Name = "sortByComboBox";
     90      this.sortByComboBox.Size = new System.Drawing.Size(121, 21);
     91      this.sortByComboBox.TabIndex = 5;
     92      this.sortByComboBox.SelectedIndexChanged += new System.EventHandler(this.sortByComboBox_SelectedIndexChanged);
     93      //
     94      // factorVarReplComboBox
     95      //
     96      this.factorVarReplComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList;
     97      this.factorVarReplComboBox.FormattingEnabled = true;
     98      this.factorVarReplComboBox.Items.AddRange(new object[] {
     99            HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Best,
     100            HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Mode,
     101            HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Shuffle});
     102      this.factorVarReplComboBox.Location = new System.Drawing.Point(197, 57);
     103      this.factorVarReplComboBox.Name = "factorVarReplComboBox";
     104      this.factorVarReplComboBox.Size = new System.Drawing.Size(121, 21);
     105      this.factorVarReplComboBox.TabIndex = 1;
     106      this.factorVarReplComboBox.SelectedIndexChanged += new System.EventHandler(this.replacementComboBox_SelectedIndexChanged);
     107      //
     108      // factorVarReplacementLabel
     109      //
     110      this.factorVarReplacementLabel.AutoSize = true;
     111      this.factorVarReplacementLabel.Location = new System.Drawing.Point(3, 60);
     112      this.factorVarReplacementLabel.Name = "factorVarReplacementLabel";
     113      this.factorVarReplacementLabel.Size = new System.Drawing.Size(188, 13);
     114      this.factorVarReplacementLabel.TabIndex = 0;
     115      this.factorVarReplacementLabel.Text = "Replacement for categorical variables:";
    103116      //
    104117      // replacementComboBox
     
    117130      this.replacementComboBox.SelectedIndexChanged += new System.EventHandler(this.replacementComboBox_SelectedIndexChanged);
    118131      //
    119       // factorVarReplacementLabel
    120       //
    121       this.factorVarReplacementLabel.AutoSize = true;
    122       this.factorVarReplacementLabel.Location = new System.Drawing.Point(3, 60);
    123       this.factorVarReplacementLabel.Name = "factorVarReplacementLabel";
    124       this.factorVarReplacementLabel.Size = new System.Drawing.Size(188, 13);
    125       this.factorVarReplacementLabel.TabIndex = 0;
    126       this.factorVarReplacementLabel.Text = "Replacement for categorical variables:";
    127       //
    128       // factorVarReplComboBox
    129       //
    130       this.factorVarReplComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList;
    131       this.factorVarReplComboBox.FormattingEnabled = true;
    132       this.factorVarReplComboBox.Items.AddRange(new object[] {
    133             HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Best,
    134             HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Mode,
    135             HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum.Shuffle});
    136       this.factorVarReplComboBox.Location = new System.Drawing.Point(197, 57);
    137       this.factorVarReplComboBox.Name = "factorVarReplComboBox";
    138       this.factorVarReplComboBox.Size = new System.Drawing.Size(121, 21);
    139       this.factorVarReplComboBox.TabIndex = 1;
    140       this.factorVarReplComboBox.SelectedIndexChanged += new System.EventHandler(this.replacementComboBox_SelectedIndexChanged);
    141       //
    142       // ascendingCheckBox
    143       //
    144       this.ascendingCheckBox.AutoSize = true;
    145       this.ascendingCheckBox.Location = new System.Drawing.Point(534, 6);
    146       this.ascendingCheckBox.Name = "ascendingCheckBox";
    147       this.ascendingCheckBox.Size = new System.Drawing.Size(76, 17);
    148       this.ascendingCheckBox.TabIndex = 10;
    149       this.ascendingCheckBox.Text = "Ascending";
    150       this.ascendingCheckBox.UseVisualStyleBackColor = true;
    151       this.ascendingCheckBox.CheckedChanged += new System.EventHandler(this.ascendingCheckBox_CheckedChanged);
    152       //
    153       // sortByLabel
    154       //
    155       this.sortByLabel.AutoSize = true;
    156       this.sortByLabel.Location = new System.Drawing.Point(324, 6);
    157       this.sortByLabel.Name = "sortByLabel";
    158       this.sortByLabel.Size = new System.Drawing.Size(77, 13);
    159       this.sortByLabel.TabIndex = 8;
    160       this.sortByLabel.Text = "Sorting criteria:";
    161       //
    162       // sortByComboBox
    163       //
    164       this.sortByComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList;
    165       this.sortByComboBox.FormattingEnabled = true;
    166       this.sortByComboBox.Location = new System.Drawing.Point(407, 3);
    167       this.sortByComboBox.Name = "sortByComboBox";
    168       this.sortByComboBox.Size = new System.Drawing.Size(121, 21);
    169       this.sortByComboBox.TabIndex = 9;
    170       this.sortByComboBox.SelectedIndexChanged += new System.EventHandler(this.sortByComboBox_SelectedIndexChanged);
     132      // numericVarReplacementLabel
     133      //
     134      this.numericVarReplacementLabel.AutoSize = true;
     135      this.numericVarReplacementLabel.Location = new System.Drawing.Point(3, 33);
     136      this.numericVarReplacementLabel.Name = "numericVarReplacementLabel";
     137      this.numericVarReplacementLabel.Size = new System.Drawing.Size(173, 13);
     138      this.numericVarReplacementLabel.TabIndex = 2;
     139      this.numericVarReplacementLabel.Text = "Replacement for numeric variables:";
     140      //
     141      // dataPartitionLabel
     142      //
     143      this.dataPartitionLabel.AutoSize = true;
     144      this.dataPartitionLabel.Location = new System.Drawing.Point(3, 6);
     145      this.dataPartitionLabel.Name = "dataPartitionLabel";
     146      this.dataPartitionLabel.Size = new System.Drawing.Size(73, 13);
     147      this.dataPartitionLabel.TabIndex = 0;
     148      this.dataPartitionLabel.Text = "Data partition:";
     149      //
     150      // dataPartitionComboBox
     151      //
     152      this.dataPartitionComboBox.DropDownStyle = System.Windows.Forms.ComboBoxStyle.DropDownList;
     153      this.dataPartitionComboBox.FormattingEnabled = true;
     154      this.dataPartitionComboBox.Items.AddRange(new object[] {
     155            HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.Training,
     156            HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.Test,
     157            HeuristicLab.Problems.DataAnalysis.ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum.All});
     158      this.dataPartitionComboBox.Location = new System.Drawing.Point(197, 3);
     159      this.dataPartitionComboBox.Name = "dataPartitionComboBox";
     160      this.dataPartitionComboBox.Size = new System.Drawing.Size(121, 21);
     161      this.dataPartitionComboBox.TabIndex = 1;
     162      this.dataPartitionComboBox.SelectedIndexChanged += new System.EventHandler(this.dataPartitionComboBox_SelectedIndexChanged);
     163      //
     164      // variableImactsArrayView
     165      //
     166      this.variableImactsArrayView.Anchor = ((System.Windows.Forms.AnchorStyles)((((System.Windows.Forms.AnchorStyles.Top | System.Windows.Forms.AnchorStyles.Bottom)
     167            | System.Windows.Forms.AnchorStyles.Left)
     168            | System.Windows.Forms.AnchorStyles.Right)));
     169      this.variableImactsArrayView.Caption = "StringConvertibleArray View";
     170      this.variableImactsArrayView.Content = null;
     171      this.variableImactsArrayView.Location = new System.Drawing.Point(3, 84);
     172      this.variableImactsArrayView.Name = "variableImactsArrayView";
     173      this.variableImactsArrayView.ReadOnly = true;
     174      this.variableImactsArrayView.Size = new System.Drawing.Size(706, 278);
     175      this.variableImactsArrayView.TabIndex = 2;
    171176      //
    172177      // ClassificationSolutionVariableImpactsView
     
    185190      this.Controls.Add(this.variableImactsArrayView);
    186191      this.Name = "ClassificationSolutionVariableImpactsView";
    187       this.Size = new System.Drawing.Size(668, 365);
     192      this.Size = new System.Drawing.Size(712, 365);
    188193      this.VisibleChanged += new System.EventHandler(this.ClassificationSolutionVariableImpactsView_VisibleChanged);
    189194      this.ResumeLayout(false);
     
    201206    private System.Windows.Forms.Label factorVarReplacementLabel;
    202207    private System.Windows.Forms.ComboBox factorVarReplComboBox;
    203     private System.Windows.Forms.CheckBox ascendingCheckBox;
    204208    private System.Windows.Forms.Label sortByLabel;
    205209    private System.Windows.Forms.ComboBox sortByComboBox;
    206     private System.ComponentModel.BackgroundWorker backgroundWorker;
     210    private System.Windows.Forms.CheckBox ascendingCheckBox;
    207211  }
    208212}
  • branches/2904_CalculateImpacts/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationSolutionVariableImpactsView.cs

    r15753 r16036  
    3333  [Content(typeof(IClassificationSolution))]
    3434  public partial class ClassificationSolutionVariableImpactsView : DataAnalysisSolutionEvaluationView {
    35     #region Nested Types
    3635    private enum SortingCriteria {
    3736      ImpactValue,
     
    3938      VariableName
    4039    }
    41     #endregion
    42 
    43     #region Fields
    44     private Dictionary<string, double> rawVariableImpacts = new Dictionary<string, double>();
    45     private Thread thread;
    46     #endregion
    47 
    48     #region Getter/Setter
     40    private CancellationTokenSource cancellationToken = new CancellationTokenSource();
     41    private List<Tuple<string, double>> rawVariableImpacts = new List<Tuple<string, double>>();
     42
    4943    public new IClassificationSolution Content {
    5044      get { return (IClassificationSolution)base.Content; }
     
    5347      }
    5448    }
    55     #endregion
    56 
    57     #region Ctor
     49
    5850    public ClassificationSolutionVariableImpactsView()
    5951      : base() {
    6052      InitializeComponent();
    6153
    62       //Little workaround. If you fill the ComboBox-Items in the other partial class, the UI-Designer will moan.
    63       this.sortByComboBox.Items.AddRange(Enum.GetValues(typeof(SortingCriteria)).Cast<object>().ToArray());
    64       this.sortByComboBox.SelectedItem = SortingCriteria.ImpactValue;
    65 
    6654      //Set the default values
    6755      this.dataPartitionComboBox.SelectedIndex = 0;
    68       this.replacementComboBox.SelectedIndex = 0;
     56      this.replacementComboBox.SelectedIndex = 3;
    6957      this.factorVarReplComboBox.SelectedIndex = 0;
    70     }
    71     #endregion
    72 
    73     #region Events
     58      this.sortByComboBox.SelectedItem = SortingCriteria.ImpactValue;
     59    }
     60
    7461    protected override void RegisterContentEvents() {
    7562      base.RegisterContentEvents();
     
    7764      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
    7865    }
    79 
    8066    protected override void DeregisterContentEvents() {
    8167      base.DeregisterContentEvents();
     
    8773      OnContentChanged();
    8874    }
    89 
    9075    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
    9176      OnContentChanged();
    9277    }
    93 
    9478    protected override void OnContentChanged() {
    9579      base.OnContentChanged();
     
    10084      }
    10185    }
    102 
    10386    private void ClassificationSolutionVariableImpactsView_VisibleChanged(object sender, EventArgs e) {
    104       if (thread == null) { return; }
    105 
    106       if (thread.IsAlive) { thread.Abort(); }
    107       thread = null;
    108     }
    109 
     87      cancellationToken.Cancel();
     88    }
    11089
    11190    private void dataPartitionComboBox_SelectedIndexChanged(object sender, EventArgs e) {
    11291      UpdateVariableImpact();
    11392    }
    114 
    11593    private void replacementComboBox_SelectedIndexChanged(object sender, EventArgs e) {
    11694      UpdateVariableImpact();
    11795    }
    118 
    11996    private void sortByComboBox_SelectedIndexChanged(object sender, EventArgs e) {
    12097      //Update the default ordering (asc,desc), but remove the eventHandler beforehand (otherwise the data would be ordered twice)
    12198      ascendingCheckBox.CheckedChanged -= ascendingCheckBox_CheckedChanged;
    122       switch ((SortingCriteria)sortByComboBox.SelectedItem) {
    123         case SortingCriteria.ImpactValue:
    124           ascendingCheckBox.Checked = false;
    125           break;
    126         case SortingCriteria.Occurrence:
    127           ascendingCheckBox.Checked = true;
    128           break;
    129         case SortingCriteria.VariableName:
    130           ascendingCheckBox.Checked = true;
    131           break;
    132         default:
    133           throw new NotImplementedException("Ordering for selected SortingCriteria not implemented");
    134       }
     99      ascendingCheckBox.Checked = (SortingCriteria)sortByComboBox.SelectedItem != SortingCriteria.ImpactValue;
    135100      ascendingCheckBox.CheckedChanged += ascendingCheckBox_CheckedChanged;
    136101
    137       UpdateDataOrdering();
    138     }
    139 
     102      UpdateOrdering();
     103    }
    140104    private void ascendingCheckBox_CheckedChanged(object sender, EventArgs e) {
    141       UpdateDataOrdering();
    142     }
    143 
    144     #endregion
    145 
    146     #region Helper Methods   
    147     private void UpdateVariableImpact() {
     105      UpdateOrdering();
     106    }
     107
     108    private async void UpdateVariableImpact() {
    148109      //Check if the selection is valid
    149110      if (Content == null) { return; }
     
    152113      if (factorVarReplComboBox.SelectedIndex < 0) { return; }
    153114
     115      IProgress progress;
     116
    154117      //Prepare arguments
    155118      var mainForm = (MainForm.WindowsForms.MainForm)MainFormManager.MainForm;
     
    159122
    160123      variableImactsArrayView.Caption = Content.Name + " Variable Impacts";
    161 
    162       mainForm.AddOperationProgressToView(this, "Calculating variable impacts for " + Content.Name);
    163 
    164       Task.Factory.StartNew(() => {
    165         thread = Thread.CurrentThread;
    166         //Remember the original ordering of the variables
    167         var impacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(Content, dataPartition, replMethod, factorReplMethod);
     124      progress = mainForm.AddOperationProgressToView(this, "Calculating variable impacts for " + Content.Name);
     125      progress.ProgressValue = 0;
     126
     127      cancellationToken = new CancellationTokenSource();
     128
     129      try {
    168130        var problemData = Content.ProblemData;
    169131        var inputvariables = new HashSet<string>(problemData.AllowedInputVariables.Union(Content.Model.VariablesUsedForPrediction));
    170         var originalVariableOrdering = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).Where(problemData.Dataset.VariableHasType<double>).ToList();
     132        //Remember the original ordering of the variables
     133        var originalVariableOrdering = problemData.Dataset.VariableNames
     134          .Where(v => inputvariables.Contains(v))
     135          .Where(v => problemData.Dataset.VariableHasType<double>(v) || problemData.Dataset.VariableHasType<string>(v))
     136          .ToList();
     137
     138        List<Tuple<string, double>> impacts = null;
     139
     140        await Task.Run(() => { impacts = CalculateVariableImpacts(originalVariableOrdering, (IClassificationModel)Content.Model.Clone(), problemData, Content.EstimatedClassValues, dataPartition, replMethod, factorReplMethod, cancellationToken.Token, progress); });
     141        if (impacts == null) { return; }
    171142
    172143        rawVariableImpacts.Clear();
    173         originalVariableOrdering.ForEach(v => rawVariableImpacts.Add(v, impacts.First(vv => vv.Item1 == v).Item2));
    174       }).ContinueWith((o) => {
    175         UpdateDataOrdering();
    176         mainForm.RemoveOperationProgressFromView(this);
    177         thread = null;
    178       }, TaskScheduler.FromCurrentSynchronizationContext());
     144        originalVariableOrdering.ForEach(v => rawVariableImpacts.Add(new Tuple<string, double>(v, impacts.First(vv => vv.Item1 == v).Item2)));
     145        UpdateOrdering();
     146      }
     147      finally {
     148        ((MainForm.WindowsForms.MainForm)MainFormManager.MainForm).RemoveOperationProgressFromView(this);
     149      }
     150    }
     151
     152    private List<Tuple<string, double>> CalculateVariableImpacts(List<string> originalVariableOrdering,
     153      IClassificationModel model,
     154      IClassificationProblemData problemData,
     155      IEnumerable<double> estimatedValues,
     156      ClassificationSolutionVariableImpactsCalculator.DataPartitionEnum dataPartition,
     157      ClassificationSolutionVariableImpactsCalculator.ReplacementMethodEnum replMethod,
     158      ClassificationSolutionVariableImpactsCalculator.FactorReplacementMethodEnum factorReplMethod,
     159      CancellationToken token,
     160      IProgress progress) {
     161      List<Tuple<string, double>> impacts = new List<Tuple<string, double>>();
     162      int count = originalVariableOrdering.Count;
     163      int i = 0;
     164      var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
     165      IEnumerable<int> rows = ClassificationSolutionVariableImpactsCalculator.GetPartitionRows(dataPartition, problemData);
     166
     167      //Calculate original quality-values (via calculator, default is R²)
     168      OnlineCalculatorError error;
     169      IEnumerable<double> targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     170      IEnumerable<double> estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
     171      var originalCalculatorValue = ClassificationSolutionVariableImpactsCalculator.CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
     172      if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
     173
     174      foreach (var variableName in originalVariableOrdering) {
     175        if (cancellationToken.Token.IsCancellationRequested) { return null; }
     176        progress.ProgressValue = (double)++i / count;
     177        progress.Status = string.Format("Calculating impact for variable {0} ({1} of {2})", variableName, i, count);
     178
     179        double impact = ClassificationSolutionVariableImpactsCalculator.CalculateImpact(variableName, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replMethod, factorReplMethod);
     180        impacts.Add(new Tuple<string, double>(variableName, impact));
     181      }
     182
     183      return impacts;
    179184    }
    180185
     
    183188    /// The default is "Descending" by "VariableImpact" (as in previous versions)
    184189    /// </summary>
    185     private void UpdateDataOrdering() {
     190    private void UpdateOrdering() {
    186191      //Check if valid sortingCriteria is selected and data exists
    187192      if (sortByComboBox.SelectedIndex == -1) { return; }
     
    192197      bool ascending = ascendingCheckBox.Checked;
    193198
    194       IEnumerable<KeyValuePair<string, double>> orderedEntries = null;
     199      IEnumerable<Tuple<string, double>> orderedEntries = null;
    195200
    196201      //Sort accordingly
    197202      switch (selectedItem) {
    198203        case SortingCriteria.ImpactValue:
    199           orderedEntries = rawVariableImpacts.OrderBy(v => v.Value);
     204          orderedEntries = rawVariableImpacts.OrderBy(v => v.Item2);
    200205          break;
    201206        case SortingCriteria.Occurrence:
     
    203208          break;
    204209        case SortingCriteria.VariableName:
    205           orderedEntries = rawVariableImpacts.OrderBy(v => v.Key, new NaturalStringComparer());
     210          orderedEntries = rawVariableImpacts.OrderBy(v => v.Item1, new NaturalStringComparer());
    206211          break;
    207212        default:
     
    212217
    213218      //Write the data back
    214       var impactArray = new DoubleArray(orderedEntries.Select(i => i.Value).ToArray()) {
    215         ElementNames = orderedEntries.Select(i => i.Key)
     219      var impactArray = new DoubleArray(orderedEntries.Select(i => i.Item2).ToArray()) {
     220        ElementNames = orderedEntries.Select(i => i.Item1)
    216221      };
    217222
     
    221226      }
    222227    }
    223     #endregion 
    224228  }
    225229}
Note: See TracChangeset for help on using the changeset viewer.