#region License Information /* HeuristicLab * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Random; namespace HeuristicLab.Problems.DataAnalysis { [StorableClass] [Item("ClassificationSolution Impacts Calculator", "Calculation of the impacts of input variables for any classification solution")] public sealed class ClassificationSolutionVariableImpactsCalculator : ParameterizedNamedItem { #region Parameters/Properties public enum ReplacementMethodEnum { Median, Average, Shuffle, Noise } public enum FactorReplacementMethodEnum { Best, Mode, Shuffle } public enum DataPartitionEnum { Training, Test, All } private const string ReplacementParameterName = "Replacement Method"; private const string FactorReplacementParameterName = "Factor Replacement Method"; private const string DataPartitionParameterName = "DataPartition"; public IFixedValueParameter> ReplacementParameter { get { return (IFixedValueParameter>)Parameters[ReplacementParameterName]; } } public IFixedValueParameter> FactorReplacementParameter { get { return (IFixedValueParameter>)Parameters[FactorReplacementParameterName]; } } public IFixedValueParameter> DataPartitionParameter { get { return (IFixedValueParameter>)Parameters[DataPartitionParameterName]; } } public ReplacementMethodEnum ReplacementMethod { get { return ReplacementParameter.Value.Value; } set { ReplacementParameter.Value.Value = value; } } public FactorReplacementMethodEnum FactorReplacementMethod { get { return FactorReplacementParameter.Value.Value; } set { FactorReplacementParameter.Value.Value = value; } } public DataPartitionEnum DataPartition { get { return DataPartitionParameter.Value.Value; } set { DataPartitionParameter.Value.Value = value; } } #endregion #region Ctor/Cloner [StorableConstructor] private ClassificationSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { } private ClassificationSolutionVariableImpactsCalculator(ClassificationSolutionVariableImpactsCalculator original, Cloner cloner) : base(original, cloner) { } public ClassificationSolutionVariableImpactsCalculator() : base() { Parameters.Add(new FixedValueParameter>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue(ReplacementMethodEnum.Median))); Parameters.Add(new FixedValueParameter>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue(FactorReplacementMethodEnum.Best))); Parameters.Add(new FixedValueParameter>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue(DataPartitionEnum.Training))); } public override IDeepCloneable Clone(Cloner cloner) { return new ClassificationSolutionVariableImpactsCalculator(this, cloner); } #endregion //mkommend: annoying name clash with static method, open to better naming suggestions public IEnumerable> Calculate(IClassificationSolution solution) { return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition); } public static IEnumerable> CalculateImpacts( IClassificationSolution solution, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, DataPartitionEnum dataPartition = DataPartitionEnum.Training) { return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedClassValues, replacementMethod, factorReplacementMethod, dataPartition); } public static IEnumerable> CalculateImpacts( IClassificationModel model, IClassificationProblemData problemData, IEnumerable estimatedValues, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, DataPartitionEnum dataPartition = DataPartitionEnum.Training) { IEnumerable rows = GetPartitionRows(dataPartition, problemData); return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod); } public static IEnumerable> CalculateImpacts( IClassificationModel model, IClassificationProblemData problemData, IEnumerable estimatedClassValues, IEnumerable rows, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { //Calculate original quality-values (via calculator, default is Accuracy) OnlineCalculatorError error; IEnumerable targetValuesPartition = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows); IEnumerable estimatedValuesPartition = rows.Select(v => estimatedClassValues.ElementAt(v)); var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error); if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation."); var impacts = new Dictionary(); var inputvariables = new HashSet(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction)); var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList(); var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable(); foreach (var inputVariable in allowedInputVariables) { if (model.VariablesUsedForPrediction.Contains(inputVariable)) { impacts[inputVariable] = CalculateImpact(inputVariable, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod); } else { impacts[inputVariable] = 0; } } return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); } public static double CalculateImpact(string variableName, IClassificationModel model, ModifiableDataset modifiableDataset, IEnumerable rows, IEnumerable targetValues, double originalValue, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { double impact = 0; OnlineCalculatorError error; IRandom random; double replacementValue; IEnumerable newEstimates = null; double newValue = 0; if (modifiableDataset.VariableHasType(variableName)) { #region NumericalVariable var originalValues = modifiableDataset.GetReadOnlyDoubleValues(variableName).ToList(); List replacementValues; switch (replacementMethod) { case ReplacementMethodEnum.Median: replacementValue = rows.Select(r => originalValues[r]).Median(); replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); break; case ReplacementMethodEnum.Average: replacementValue = rows.Select(r => originalValues[r]).Average(); replacementValues = Enumerable.Repeat(replacementValue, modifiableDataset.Rows).ToList(); break; case ReplacementMethodEnum.Shuffle: // new var has same empirical distribution but the relation to y is broken random = new FastRandom(31415); // prepare a complete column for the dataset replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); // shuffle only the selected rows var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); int i = 0; // update column values foreach (var r in rows) { replacementValues[r] = shuffledValues[i++]; } break; case ReplacementMethodEnum.Noise: var avg = rows.Select(r => originalValues[r]).Average(); var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); random = new FastRandom(31415); // prepare a complete column for the dataset replacementValues = Enumerable.Repeat(double.NaN, modifiableDataset.Rows).ToList(); // update column values foreach (var r in rows) { replacementValues[r] = NormalDistributedRandom.NextDouble(random, avg, stdDev); } break; default: throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacementMethod)); } newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); newValue = CalculateVariableImpact(targetValues, newEstimates, out error); if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } impact = originalValue - newValue; #endregion } else if (modifiableDataset.VariableHasType(variableName)) { #region FactorVariable var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); List replacementValues; switch (factorReplacementMethod) { case FactorReplacementMethodEnum.Best: // try replacing with all possible values and find the best replacement value var smallestImpact = double.PositiveInfinity; foreach (var repl in modifiableDataset.GetStringValues(variableName, rows).Distinct()) { newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, modifiableDataset.Rows).ToList()); newValue = CalculateVariableImpact(targetValues, newEstimates, out error); if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); var curImpact = originalValue - newValue; if (curImpact < smallestImpact) smallestImpact = curImpact; } impact = smallestImpact; break; case FactorReplacementMethodEnum.Mode: var mostCommonValue = rows.Select(r => originalValues[r]) .GroupBy(v => v) .OrderByDescending(g => g.Count()) .First().Key; replacementValues = Enumerable.Repeat(mostCommonValue, modifiableDataset.Rows).ToList(); newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); newValue = CalculateVariableImpact(targetValues, newEstimates, out error); if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); impact = originalValue - newValue; break; case FactorReplacementMethodEnum.Shuffle: // new var has same empirical distribution but the relation to y is broken random = new FastRandom(31415); // prepare a complete column for the dataset replacementValues = Enumerable.Repeat(string.Empty, modifiableDataset.Rows).ToList(); // shuffle only the selected rows var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(random).ToList(); int i = 0; // update column values foreach (var r in rows) { replacementValues[r] = shuffledValues[i++]; } newEstimates = GetReplacedEstimates(originalValues, model, variableName, modifiableDataset, rows, replacementValues); newValue = CalculateVariableImpact(targetValues, newEstimates, out error); if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); impact = originalValue - newValue; break; default: throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", factorReplacementMethod)); } #endregion } else { throw new NotSupportedException("Variable not supported"); } return impact; } /// /// Replaces the values of the original model-variables with the replacement variables, calculates the new estimated values /// and changes the value of the model-variables back to the original ones. /// /// /// /// /// /// /// /// private static IEnumerable GetReplacedEstimates( IList originalValues, IClassificationModel model, string variableName, ModifiableDataset modifiableDataset, IEnumerable rows, IList replacementValues) { modifiableDataset.ReplaceVariable(variableName, replacementValues); var discModel = model as IDiscriminantFunctionClassificationModel; if (discModel != null) { var problemData = new ClassificationProblemData(modifiableDataset, modifiableDataset.VariableNames, model.TargetVariable); discModel.RecalculateModelParameters(problemData, rows); } //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements var estimates = model.GetEstimatedClassValues(modifiableDataset, rows).ToList(); modifiableDataset.ReplaceVariable(variableName, originalValues); return estimates; } /// /// Calculates and returns the VariableImpact (calculated via Accuracy). /// /// The actual values /// The calculated/replaced values /// /// public static double CalculateVariableImpact(IEnumerable targetValues, IEnumerable estimatedValues, out OnlineCalculatorError errorState) { //Theoretically, all calculators implement a static Calculate-Method which provides the same functionality //as the code below does. But this way we can easily swap the calculator later on, so the user //could choose a Calculator during runtime in future versions. IOnlineCalculator calculator = new OnlineAccuracyCalculator(); IEnumerator firstEnumerator = targetValues.GetEnumerator(); IEnumerator secondEnumerator = estimatedValues.GetEnumerator(); // always move forward both enumerators (do not use short-circuit evaluation!) while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) { double original = firstEnumerator.Current; double estimated = secondEnumerator.Current; calculator.Add(original, estimated); if (calculator.ErrorState != OnlineCalculatorError.None) break; } // check if both enumerators are at the end to make sure both enumerations have the same length if (calculator.ErrorState == OnlineCalculatorError.None && (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) { throw new ArgumentException("Number of elements in first and second enumeration doesn't match."); } else { errorState = calculator.ErrorState; return calculator.Value; } } /// /// Returns a collection of the row-indices for a given DataPartition (training or test) /// /// /// /// public static IEnumerable GetPartitionRows(DataPartitionEnum dataPartition, IClassificationProblemData problemData) { IEnumerable rows; switch (dataPartition) { case DataPartitionEnum.All: rows = problemData.AllIndices; break; case DataPartitionEnum.Test: rows = problemData.TestIndices; break; case DataPartitionEnum.Training: rows = problemData.TrainingIndices; break; default: throw new NotSupportedException("DataPartition not supported"); } return rows; } } }