Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/28/18 17:41:20 (7 years ago)
Author:
mkommend
Message:

#2910: Added recalculation of thresholds for IDiscriminantClassificationModels during impact calculation.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolutionVariableImpactsCalculator.cs

    r15674 r15871  
    100100      var problemData = solution.ProblemData;
    101101      var dataset = problemData.Dataset;
     102      var model = (IClassificationModel)solution.Model.Clone(); //mkommend: clone of model is necessary, because the thresholds for IDiscriminantClassificationModels are updated
    102103
    103104      IEnumerable<int> rows;
     
    137138      // calculate impacts for double variables
    138139      foreach (var inputVariable in allowedInputVariables.Where(problemData.Dataset.VariableHasType<double>)) {
    139         var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows, replacementMethod);
     140        var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows, replacementMethod);
    140141        var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
    141142        if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during R² calculation with replaced inputs.");
     
    150151          var smallestImpact = double.PositiveInfinity;
    151152          foreach (var repl in problemData.Dataset.GetStringValues(inputVariable, rows).Distinct()) {
    152             var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
     153            var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows,
    153154              Enumerable.Repeat(repl, dataset.Rows));
    154155            var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
     
    164165          // calculate impacts for factor variables
    165166
    166           var newEstimates = EvaluateModelWithReplacedVariable(solution.Model, inputVariable, modifiableDataset, rows,
     167          var newEstimates = EvaluateModelWithReplacedVariable(model, inputVariable, modifiableDataset, rows,
    167168            factorReplacementMethod);
    168169          var newAccuracy = OnlineAccuracyCalculator.Calculate(targetValues, newEstimates, out error);
     
    263264      var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
    264265      dataset.ReplaceVariable(variable, replacementValues.ToList());
     266
     267      var discModel = model as IDiscriminantFunctionClassificationModel;
     268      if (discModel != null) {
     269        var problemData = new ClassificationProblemData(dataset, dataset.VariableNames, model.TargetVariable);
     270        discModel.RecalculateModelParameters(problemData, rows);
     271      }
     272
    265273      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    266274      var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
     
    273281      var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
    274282      dataset.ReplaceVariable(variable, replacementValues.ToList());
     283
     284
     285      var discModel = model as IDiscriminantFunctionClassificationModel;
     286      if (discModel != null) {
     287        var problemData = new ClassificationProblemData(dataset, dataset.VariableNames, model.TargetVariable);
     288        discModel.RecalculateModelParameters(problemData, rows);
     289      }
     290
    275291      //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
    276292      var estimates = model.GetEstimatedClassValues(dataset, rows).ToList();
Note: See TracChangeset for help on using the changeset viewer.