#region License Information
/* HeuristicLab
* Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Parameters;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Random;
namespace HeuristicLab.Problems.DataAnalysis {
[StorableClass]
[Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")]
public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem {
#region Parameters/Properties
public enum ReplacementMethodEnum {
Median,
Average,
Shuffle,
Noise
}
public enum FactorReplacementMethodEnum {
Best,
Mode,
Shuffle
}
public enum DataPartitionEnum {
Training,
Test,
All
}
private const string ReplacementParameterName = "Replacement Method";
private const string FactorReplacementParameterName = "Factor Replacement Method";
private const string DataPartitionParameterName = "DataPartition";
public IFixedValueParameter> ReplacementParameter {
get { return (IFixedValueParameter>)Parameters[ReplacementParameterName]; }
}
public IFixedValueParameter> FactorReplacementParameter {
get { return (IFixedValueParameter>)Parameters[FactorReplacementParameterName]; }
}
public IFixedValueParameter> DataPartitionParameter {
get { return (IFixedValueParameter>)Parameters[DataPartitionParameterName]; }
}
public ReplacementMethodEnum ReplacementMethod {
get { return ReplacementParameter.Value.Value; }
set { ReplacementParameter.Value.Value = value; }
}
public FactorReplacementMethodEnum FactorReplacementMethod {
get { return FactorReplacementParameter.Value.Value; }
set { FactorReplacementParameter.Value.Value = value; }
}
public DataPartitionEnum DataPartition {
get { return DataPartitionParameter.Value.Value; }
set { DataPartitionParameter.Value.Value = value; }
}
#endregion
#region Ctor/Cloner
[StorableConstructor]
private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner)
: base(original, cloner) { }
public ClassificationSolutionVariableImpactsCalculator()
: base() {
Parameters.Add(new FixedValueParameter>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue(ReplacementMethodEnum.Median)));
Parameters.Add(new FixedValueParameter>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue(FactorReplacementMethodEnum.Best)));
Parameters.Add(new FixedValueParameter>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue(DataPartitionEnum.Training)));
}
public override IDeepCloneable Clone(Cloner cloner) {
return new ClassificationSolutionVariableImpactsCalculator(this, cloner);
}
#endregion
//mkommend: annoying name clash with static method, open to better naming suggestions
public IEnumerable> Calculate(IClassificationSolution solution) {
return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
}
public static IEnumerable> CalculateImpacts(
IClassificationSolution solution,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition);
}
public static IEnumerable> CalculateImpacts(
IClassificationModel model,
IClassificationProblemData problemData,
IEnumerable estimatedValues,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
DataPartitionEnum dataPartition = DataPartitionEnum.Training) {
IEnumerable rows = GetPartitionRows(dataPartition, problemData);
return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
}
public static IEnumerable> CalculateImpacts(
IClassificationModel model,
IClassificationProblemData problemData,
IEnumerable estimatedClassValues,
IEnumerable rows,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
//Calculate original quality-values (via calculator, default is Accuracy)
OnlineCalculatorError error;
IEnumerable targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
IEnumerable estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v));
var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
var impacts = new Dictionary();
var inputvariables = new HashSet(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
foreach (var inputVariable in allowedInputVariables) {
if (model.VariablesUsedForPrediction.Contains(inputVariable)) {
impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
} else {
impacts[inputVariable] = 0;
}
}
return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
}
public static double CalculateImpact(string variableName,
IClassificationModel model,
ModifiableDataset modifiableDataset,
IEnumerable rows,
IEnumerable targetValues,
double originalValue,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
double impact = 0;
OnlineCalculatorError error;
IRandom random;
double replacementValue;
IEnumerable newEstimates = null;
double newValue = 0;
if (modifiableDataset.VariableHasType(variableName)) {
#region NumericalVariable
var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList();
List replacementValues;
switch (replacementMethod) {
case ReplacementMethodEnum.Median:
replacementValue = rows.Select(r => originalValues[r]).Median();
replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
break;
case ReplacementMethodEnum.Average:
replacementValue = rows.Select(r => originalValues[r]).Average();
replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList();
break;
case ReplacementMethodEnum.Shuffle:
// new var has same empirical distribution but the relation to y is broken
random = new FastRandom(31415);
// prepare a complete column for the dataset
replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
// shuffle only the selected rows
var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
int i = 0;
// update column values
foreach (var r in rows) {
replacementValues[r] = shuffledValues[i++];
}
break;
case ReplacementMethodEnum.Noise:
var avg = rows.Select(r => originalValues[r]).Average();
var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
random = new FastRandom(31415);
// prepare a complete column for the dataset
replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList();
// update column values
foreach (var r in rows) {
replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev);
}
break;
default:
throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod));
}
newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
impact = originalValue - newValue;
#endregion
} else if (modifiableDataset.VariableHasType(variableName)) {
#region FactorVariable
var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
List replacementValues;
switch (factorReplacementMethod) {
case FactorReplacementMethodEnum.Best:
// try replacing with all possible values and find the best replacement value
var smallestImpact = double.PositiveInfinity;
foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) {
newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList());
newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
var curImpact = originalValue - newValue;
if (curImpact < smallestImpact) smallestImpact = curImpact;
}
impact = smallestImpact;
break;
case FactorReplacementMethodEnum.Mode:
var mostCommonValue = rows.Select(r => originalValues[r])
.GroupBy(v => v)
.OrderByDescending(g => g.Count())
.First().Key;
replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList();
newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
impact = originalValue - newValue;
break;
case FactorReplacementMethodEnum.Shuffle:
// new var has same empirical distribution but the relation to y is broken
random = new FastRandom(31415);
// prepare a complete column for the dataset
replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList();
// shuffle only the selected rows
var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList();
int i = 0;
// update column values
foreach (var r in rows) {
replacementValues[r] = shuffledValues[i++];
}
newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues);
newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
impact = originalValue - newValue;
break;
default:
throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod));
}
#endregion
} else {
throw new NotSupportedException("Variable not supported");
}
return impact;
}
///
/// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values
/// and changes the value of the model-variables back to the original ones.
///
///
///
///
///
///
///
///
private static IEnumerable GetReplacedEstimates(
IList originalValues,
IClassificationModel model,
string variableName,
ModifiableDataset modifiableDataset,
IEnumerable rows,
IList replacementValues) {
modifiableDataset.ReplaceVariable(variableName, replacementValues);
var discModel = model as IDiscriminantFunctionClassificationModel;
if (discModel != null) {
var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable);
discModel.RecalculateModelParameters(problemData, rows);
}
//mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList();
modifiableDataset.ReplaceVariable(variableName, originalValues);
return estimates;
}
///
/// Calculates and returns the VariableImpact (calculated via Accuracy).
///
/// The actual values
/// The calculated/replaced values
///
///
public static double CalculateVariableImpact(IEnumerable targetValues, IEnumerable estimatedValues, out OnlineCalculatorError errorState) {
//Theoretically, all calculators implement a static Calculate-Method which provides the same functionality
//as the code below does. But this way we can easily swap the calculator later on, so the user
//could choose a Calculator during runtime in future versions.
IOnlineCalculator calculator = new OnlineAccuracyCalculator();
IEnumerator firstEnumerator = targetValues.GetEnumerator();
IEnumerator secondEnumerator = estimatedValues.GetEnumerator();
// always move forward both enumerators (do not use short-circuit evaluation!)
while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
double original = firstEnumerator.Current;
double estimated = secondEnumerator.Current;
calculator.Add(original, estimated);
if (calculator.ErrorState != OnlineCalculatorError.None) break;
}
// check if both enumerators are at the end to make sure both enumerations have the same length
if (calculator.ErrorState == OnlineCalculatorError.None &&
(secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
} else {
errorState = calculator.ErrorState;
return calculator.Value;
}
}
///
/// Returns a collection of the row-indices for a given DataPartition (training or test)
///
///
///
///
public static IEnumerable GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) {
IEnumerable rows;
switch (dataPartition) {
case DataPartitionEnum.All:
rows = problemData.AllIndices;
break;
case DataPartitionEnum.Test:
rows = problemData.TestIndices;
break;
case DataPartitionEnum.Training:
rows = problemData.TrainingIndices;
break;
default:
throw new NotSupportedException("DataPartition not supported");
}
return rows;
}
}
}