Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/01/18 14:14:49 (6 years ago)
Author:
fholzing
Message:

#2904: Restyled ClassificationSolutionVariableImpactsCalculator to be nearly identical to it's Regression-pendant

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2904_CalculateImpacts/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs

    r16036 r16037  
    9090    public ClassificationSolutionVariableImpactsCalculator()
    9191      : base() {
    92       Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Shuffle)));
     92      Parameters.Add(new FixedValueParameter<EnumValue<ReplacementMethodEnum>>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue<ReplacementMethodEnum>(ReplacementMethodEnum.Median)));
     93      Parameters.Add(new FixedValueParameter<EnumValue<FactorReplacementMethodEnum>>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue<FactorReplacementMethodEnum>(FactorReplacementMethodEnum.Best)));
    9394      Parameters.Add(new FixedValueParameter<EnumValue<DataPartitionEnum>>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue<DataPartitionEnum>(DataPartitionEnum.Training)));
    9495    }
     
    122123      return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
    123124    }
    124 
    125125
    126126    public static IEnumerable<Tuple<string, double>> CalculateImpacts(
     
    149149      return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
    150150    }
    151 
    152151
    153152    public static double CalculateImpact(string variableName,
     
    170169        var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
    171170        List<double> replacementValues;
    172         IRandom rand;
    173171
    174172        switch (replacementMethod) {
     
    183181          case ReplacementMethodEnum.Shuffle:
    184182            // new var has same empirical distribution but the relation to y is broken
    185             rand = new FastRandom(31415);
     183            random = new FastRandom(31415);
    186184            // prepare a complete column for the dataset
    187185            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
    188186            // shuffle only the selected rows
    189             var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
     187            var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
    190188            int i = 0;
    191189            // update column values
     
    197195            var avg = rows.Select(r => originalValues[r]).Average();
    198196            var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
    199             rand = new FastRandom(31415);
     197            random = new FastRandom(31415);
    200198            // prepare a complete column for the dataset
    201199            replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
    202200            // update column values
    203201            foreach (var r in rows) {
    204               replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
     202              replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
    205203            }
    206204            break;
     
    276274
    277275      return impact;
     276    }
     277
     278    /// <summary>
     279    /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
     280    /// and changes the value of the model-variables back to the original ones.
     281    /// </summary>
     282    /// <param name="originalValues"></param>
     283    /// <param name="model"></param>
     284    /// <param name="variableName"></param>
     285    /// <param name="modifiableDataset"></param>
     286    /// <param name="rows"></param>
     287    /// <param name="replacementValues"></param>
     288    /// <returns></returns>
     289    private static IEnumerable<double> GetReplacedEstimates(
     290      IList originalValues,
     291      IClassificationModel model,
     292      string variableName,
     293      ModifiableDataset modifiableDataset,
     294      IEnumerable<int> rows,
     295      IList replacementValues) {
     296      modifiableDataset.ReplaceVariable(variableName, replacementValues);
     297
     298      var discModel = model as IDiscriminantFunctionClassificationModel;
     299      if (discModel != null) {
     300        var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);
     301        discModel.RecalculateModelParameters(problemData, rows);
     302      }
     303
     304      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
     305      var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
     306      modifiableDataset.ReplaceVariable(variableName, originalValues);
     307
     308      return estimates;
    278309    }
    279310
     
    312343
    313344    /// <summary>
    314     /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
    315     /// and changes the value of the model-variables back to the original ones.
    316     /// </summary>
    317     /// <param name="originalValues"></param>
    318     /// <param name="model"></param>
    319     /// <param name="variableName"></param>
    320     /// <param name="modifiableDataset"></param>
    321     /// <param name="rows"></param>
    322     /// <param name="replacementValues"></param>
    323     /// <returns></returns>
    324     private static IEnumerable<double> GetReplacedEstimates(
    325      IList originalValues,
    326      IClassificationModel model,
    327      string variableName,
    328      ModifiableDataset modifiableDataset,
    329      IEnumerable<int> rows,
    330      IList replacementValues) {
    331       modifiableDataset.ReplaceVariable(variableName, replacementValues);
    332 
    333       var discModel = model as IDiscriminantFunctionClassificationModel;
    334       if (discModel != null) {
    335         var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);
    336         discModel.RecalculateModelParameters(problemData, rows);
    337       }
    338 
    339       //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    340       var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
    341       modifiableDataset.ReplaceVariable(variableName, originalValues);
    342 
    343       return estimates;
    344     }
    345 
    346 
    347     /// <summary>
    348345    /// Returns a collection of the row-indices for a given DataPartition (training or test)
    349346    /// </summary>
     
    370367      return rows;
    371368    }
    372 
    373369  }
    374370}
Note: See TracChangeset for help on using the changeset viewer.