#region License Information /* HeuristicLab * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Random; namespace HeuristicLab.Problems.DataAnalysis { [StorableClass] [Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")] public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem { public enum ReplacementMethodEnum { Median, Average, Shuffle, Noise } public enum FactorReplacementMethodEnum { Best, Mode, Shuffle } public enum DataPartitionEnum { Training, Test, All } private const string ReplacementParameterName = "Replacement Method"; private const string FactorReplacementParameterName = "Factor Replacement Method"; private const string DataPartitionParameterName = "DataPartition"; public IFixedValueParameter> ReplacementParameter { get { return (IFixedValueParameter>)Parameters[ReplacementParameterName]; } } public IFixedValueParameter> FactorReplacementParameter { get { return (IFixedValueParameter>)Parameters[FactorReplacementParameterName]; } } public IFixedValueParameter> DataPartitionParameter { get { return (IFixedValueParameter>)Parameters[DataPartitionParameterName]; } } public ReplacementMethodEnum ReplacementMethod { get { return ReplacementParameter.Value.Value; } set { ReplacementParameter.Value.Value = value; } } public FactorReplacementMethodEnum FactorReplacementMethod { get { return FactorReplacementParameter.Value.Value; } set { FactorReplacementParameter.Value.Value = value; } } public DataPartitionEnum DataPartition { get { return DataPartitionParameter.Value.Value; } set { DataPartitionParameter.Value.Value = value; } } [StorableConstructor] private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { } private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner) : base(original, cloner) { } public override IDeepCloneable Clone(Cloner cloner) { return new RegressionSolutionVariableImpactsCalculator(this, cloner); } public RegressionSolutionVariableImpactsCalculator() : base() { Parameters.Add(new FixedValueParameter>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue(ReplacementMethodEnum.Median))); Parameters.Add(new FixedValueParameter>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue(FactorReplacementMethodEnum.Best))); Parameters.Add(new FixedValueParameter>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue(DataPartitionEnum.Training))); } //mkommend: annoying name clash with static method, open to better naming suggestions public IEnumerable> Calculate(IRegressionSolution solution) { return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition); } public static IEnumerable> CalculateImpacts( IRegressionSolution solution, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, DataPartitionEnum data = DataPartitionEnum.Training) { return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, replacementMethod, factorReplacementMethod, data); } public static IEnumerable> CalculateImpacts( IRegressionModel model, IRegressionProblemData problemData, IEnumerable estimatedValues, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, DataPartitionEnum data = DataPartitionEnum.Training) { IEnumerable rows; switch (data) { case DataPartitionEnum.All: rows = problemData.AllIndices; break; case DataPartitionEnum.Test: rows = problemData.TestIndices; break; case DataPartitionEnum.Training: rows = problemData.TrainingIndices; break; default: throw new NotSupportedException("DataPartition not supported"); } return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod); } public static double CalculateImpact(string variableName, IRegressionModel model, IRegressionProblemData problemData, IEnumerable estimatedValues, DataPartitionEnum dataPartition, ReplacementMethodEnum replMethod, FactorReplacementMethodEnum factorReplMethod) { double impact = 0; IEnumerable rows; switch (dataPartition) { case DataPartitionEnum.All: rows = problemData.AllIndices; break; case DataPartitionEnum.Test: rows = problemData.TestIndices; break; case DataPartitionEnum.Training: rows = problemData.TrainingIndices; break; default: throw new NotSupportedException("DataPartition not supported"); } OnlineCalculatorError error; IEnumerable targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v)); IEnumerable estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v)); var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error); if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation."); var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable(); // calculate impacts for double variables if (problemData.Dataset.VariableHasType(variableName)) { impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replMethod); } else if (problemData.Dataset.VariableHasType(variableName)) { impact = CalculateImpactForFactorVariables(variableName, model, problemData.Dataset, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, factorReplMethod); } else { throw new NotSupportedException("Variable not supported"); } return impact; } public static IEnumerable> CalculateImpacts( IRegressionModel model, IRegressionProblemData problemData, IEnumerable estimatedValues, IEnumerable rows, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { //Calculate original quality-values (via calculator, default is R²) OnlineCalculatorError error; IEnumerable targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v)); IEnumerable estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v)); var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error); if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation."); var impacts = new Dictionary(); var inputvariables = new HashSet(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction)); var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList(); foreach (var inputVariable in allowedInputVariables) { impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod); } return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value)); } public static double CalculateImpact(string variableName, IRegressionSolution solution, IEnumerable rows, IEnumerable targetValues, double originalValue, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best, DataPartitionEnum data = DataPartitionEnum.Training) { return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod); } public static double CalculateImpact(string variableName, IRegressionModel model, IDataset dataset, IEnumerable rows, IEnumerable targetValues, double originalValue, ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle, FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) { double impact = 0; var modifiableDataset = ((Dataset)(dataset).Clone()).ToModifiable(); // calculate impacts for double variables if (dataset.VariableHasType(variableName)) { impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod); } else if (dataset.VariableHasType(variableName)) { impact = CalculateImpactForFactorVariables(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod); } else { throw new NotSupportedException("Variable not supported"); } return impact; } private static double CalculateImpactForNumericalVariables(string variableName, IRegressionModel model, ModifiableDataset modifiableDataset, IEnumerable rows, IEnumerable targetValues, double originalValue, ReplacementMethodEnum replacementMethod) { OnlineCalculatorError error; var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod); var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); } return originalValue - newValue; } private static double CalculateImpactForFactorVariables(string variableName, IRegressionModel model, IDataset problemData, ModifiableDataset modifiableDataset, IEnumerable rows, IEnumerable targetValues, double originalValue, FactorReplacementMethodEnum factorReplacementMethod) { OnlineCalculatorError error; if (factorReplacementMethod == FactorReplacementMethodEnum.Best) { // try replacing with all possible values and find the best replacement value var smallestImpact = double.PositiveInfinity; foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) { var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList(); var newEstimates = GetReplacedValues(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList()); var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); var curImpact = originalValue - newValue; if (curImpact < smallestImpact) smallestImpact = curImpact; } return smallestImpact; } else { // for replacement methods shuffle and mode // calculate impacts for factor variables var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod); var newValue = CalculateVariableImpact(targetValues, newEstimates, out error); if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs."); return originalValue - newValue; } } private static IEnumerable GetReplacedValuesForNumericalVariables( IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable rows, ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) { var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList(); double replacementValue; List replacementValues; IRandom rand; switch (replacement) { case ReplacementMethodEnum.Median: replacementValue = rows.Select(r => originalValues[r]).Median(); replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList(); break; case ReplacementMethodEnum.Average: replacementValue = rows.Select(r => originalValues[r]).Average(); replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList(); break; case ReplacementMethodEnum.Shuffle: // new var has same empirical distribution but the relation to y is broken rand = new FastRandom(31415); // prepare a complete column for the dataset replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList(); // shuffle only the selected rows var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); int i = 0; // update column values foreach (var r in rows) { replacementValues[r] = shuffledValues[i++]; } break; case ReplacementMethodEnum.Noise: var avg = rows.Select(r => originalValues[r]).Average(); var stdDev = rows.Select(r => originalValues[r]).StandardDeviation(); rand = new FastRandom(31415); // prepare a complete column for the dataset replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList(); // update column values foreach (var r in rows) { replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev); } break; default: throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement)); } return GetReplacedValues(originalValues, model, variable, dataset, rows, replacementValues); } private static IEnumerable GetReplacedValuesForFactorVariables( IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable rows, FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) { var originalValues = dataset.GetReadOnlyStringValues(variable).ToList(); List replacementValues; IRandom rand; switch (replacement) { case FactorReplacementMethodEnum.Mode: var mostCommonValue = rows.Select(r => originalValues[r]) .GroupBy(v => v) .OrderByDescending(g => g.Count()) .First().Key; replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList(); break; case FactorReplacementMethodEnum.Shuffle: // new var has same empirical distribution but the relation to y is broken rand = new FastRandom(31415); // prepare a complete column for the dataset replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList(); // shuffle only the selected rows var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList(); int i = 0; // update column values foreach (var r in rows) { replacementValues[r] = shuffledValues[i++]; } break; default: throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement)); } return GetReplacedValues(originalValues, model, variable, dataset, rows, replacementValues); } private static IEnumerable GetReplacedValues( IList originalValues, IRegressionModel model, string variable, ModifiableDataset dataset, IEnumerable rows, IList replacementValues) { dataset.ReplaceVariable(variable, replacementValues); //mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements var estimates = model.GetEstimatedValues(dataset, rows).ToList(); dataset.ReplaceVariable(variable, originalValues); return estimates; } private static double CalculateVariableImpact(IEnumerable originalValues, IEnumerable estimatedValues, out OnlineCalculatorError errorState) { IEnumerator firstEnumerator = originalValues.GetEnumerator(); IEnumerator secondEnumerator = estimatedValues.GetEnumerator(); var calculator = new OnlinePearsonsRSquaredCalculator(); // always move forward both enumerators (do not use short-circuit evaluation!) while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) { double original = firstEnumerator.Current; double estimated = secondEnumerator.Current; calculator.Add(original, estimated); if (calculator.ErrorState != OnlineCalculatorError.None) break; } // check if both enumerators are at the end to make sure both enumerations have the same length if (calculator.ErrorState == OnlineCalculatorError.None && (secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) { throw new ArgumentException("Number of elements in first and second enumeration doesn't match."); } else { errorState = calculator.ErrorState; return calculator.Value; } } } }