#region License Information
/* HeuristicLab
* Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Parameters;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Random;
namespace HeuristicLab.Problems.DataAnalysis {
[StorableClass]
[Item("RegressionSolution Impacts Calculator", "Calculation of the impacts of input variables for any regression solution")]
public sealed class RegressionSolutionVariableImpactsCalculator : ParameterizedNamedItem {
public enum ReplacementMethodEnum {
Median,
Average,
Shuffle,
Noise
}
public enum FactorReplacementMethodEnum {
Best,
Mode,
Shuffle
}
public enum DataPartitionEnum {
Training,
Test,
All
}
private const string ReplacementParameterName = "Replacement Method";
private const string FactorReplacementParameterName = "Factor Replacement Method";
private const string DataPartitionParameterName = "DataPartition";
public IFixedValueParameter> ReplacementParameter {
get { return (IFixedValueParameter>)Parameters[ReplacementParameterName]; }
}
public IFixedValueParameter> FactorReplacementParameter {
get { return (IFixedValueParameter>)Parameters[FactorReplacementParameterName]; }
}
public IFixedValueParameter> DataPartitionParameter {
get { return (IFixedValueParameter>)Parameters[DataPartitionParameterName]; }
}
public ReplacementMethodEnum ReplacementMethod {
get { return ReplacementParameter.Value.Value; }
set { ReplacementParameter.Value.Value = value; }
}
public FactorReplacementMethodEnum FactorReplacementMethod {
get { return FactorReplacementParameter.Value.Value; }
set { FactorReplacementParameter.Value.Value = value; }
}
public DataPartitionEnum DataPartition {
get { return DataPartitionParameter.Value.Value; }
set { DataPartitionParameter.Value.Value = value; }
}
[StorableConstructor]
private RegressionSolutionVariableImpactsCalculator(bool deserializing) : base(deserializing) { }
private RegressionSolutionVariableImpactsCalculator(RegressionSolutionVariableImpactsCalculator original, Cloner cloner)
: base(original, cloner) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new RegressionSolutionVariableImpactsCalculator(this, cloner);
}
public RegressionSolutionVariableImpactsCalculator()
: base() {
Parameters.Add(new FixedValueParameter>(ReplacementParameterName, "The replacement method for variables during impact calculation.", new EnumValue(ReplacementMethodEnum.Median)));
Parameters.Add(new FixedValueParameter>(FactorReplacementParameterName, "The replacement method for factor variables during impact calculation.", new EnumValue(FactorReplacementMethodEnum.Best)));
Parameters.Add(new FixedValueParameter>(DataPartitionParameterName, "The data partition on which the impacts are calculated.", new EnumValue(DataPartitionEnum.Training)));
}
//mkommend: annoying name clash with static method, open to better naming suggestions
public IEnumerable> Calculate(IRegressionSolution solution) {
return CalculateImpacts(solution, ReplacementMethod, FactorReplacementMethod, DataPartition);
}
public static IEnumerable> CalculateImpacts(
IRegressionSolution solution,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
DataPartitionEnum data = DataPartitionEnum.Training) {
return CalculateImpacts(solution.Model, solution.ProblemData, solution.EstimatedValues, replacementMethod, factorReplacementMethod, data);
}
public static IEnumerable> CalculateImpacts(
IRegressionModel model,
IRegressionProblemData problemData,
IEnumerable estimatedValues,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
DataPartitionEnum data = DataPartitionEnum.Training) {
IEnumerable rows;
switch (data) {
case DataPartitionEnum.All:
rows = problemData.AllIndices;
break;
case DataPartitionEnum.Test:
rows = problemData.TestIndices;
break;
case DataPartitionEnum.Training:
rows = problemData.TrainingIndices;
break;
default:
throw new NotSupportedException("DataPartition not supported");
}
return CalculateImpacts(model, problemData, estimatedValues, rows, replacementMethod, factorReplacementMethod);
}
public static double CalculateImpact(string variableName, IRegressionModel model, IRegressionProblemData problemData, IEnumerable estimatedValues, DataPartitionEnum dataPartition, ReplacementMethodEnum replMethod, FactorReplacementMethodEnum factorReplMethod) {
double impact = 0;
IEnumerable rows;
switch (dataPartition) {
case DataPartitionEnum.All:
rows = problemData.AllIndices;
break;
case DataPartitionEnum.Test:
rows = problemData.TestIndices;
break;
case DataPartitionEnum.Training:
rows = problemData.TrainingIndices;
break;
default:
throw new NotSupportedException("DataPartition not supported");
}
OnlineCalculatorError error;
IEnumerable targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v));
IEnumerable estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
var modifiableDataset = ((Dataset)(problemData.Dataset).Clone()).ToModifiable();
// calculate impacts for double variables
if (problemData.Dataset.VariableHasType(variableName)) {
impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, replMethod);
} else if (problemData.Dataset.VariableHasType(variableName)) {
impact = CalculateImpactForFactorVariables(variableName, model, problemData.Dataset, modifiableDataset, rows, targetValuesPartition, originalCalculatorValue, factorReplMethod);
} else {
throw new NotSupportedException("Variable not supported");
}
return impact;
}
public static IEnumerable> CalculateImpacts(
IRegressionModel model,
IRegressionProblemData problemData,
IEnumerable estimatedValues,
IEnumerable rows,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
//Calculate original quality-values (via calculator, default is R²)
OnlineCalculatorError error;
IEnumerable targetValuesPartition = rows.Select(v => problemData.TargetVariableValues.ElementAt(v));
IEnumerable estimatedValuesPartition = rows.Select(v => estimatedValues.ElementAt(v));
var originalCalculatorValue = CalculateVariableImpact(targetValuesPartition, estimatedValuesPartition, out error);
if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation.");
var impacts = new Dictionary();
var inputvariables = new HashSet(problemData.AllowedInputVariables.Union(model.VariablesUsedForPrediction));
var allowedInputVariables = problemData.Dataset.VariableNames.Where(v => inputvariables.Contains(v)).ToList();
foreach (var inputVariable in allowedInputVariables) {
impacts[inputVariable] = CalculateImpact(inputVariable, model, problemData.Dataset, rows, targetValuesPartition, originalCalculatorValue, replacementMethod, factorReplacementMethod);
}
return impacts.OrderByDescending(i => i.Value).Select(i => Tuple.Create(i.Key, i.Value));
}
public static double CalculateImpact(string variableName,
IRegressionSolution solution,
IEnumerable rows,
IEnumerable targetValues,
double originalValue,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best,
DataPartitionEnum data = DataPartitionEnum.Training) {
return CalculateImpact(variableName, solution.Model, solution.ProblemData.Dataset, rows, targetValues, originalValue, replacementMethod, factorReplacementMethod);
}
public static double CalculateImpact(string variableName,
IRegressionModel model,
IDataset dataset,
IEnumerable rows,
IEnumerable targetValues,
double originalValue,
ReplacementMethodEnum replacementMethod = ReplacementMethodEnum.Shuffle,
FactorReplacementMethodEnum factorReplacementMethod = FactorReplacementMethodEnum.Best) {
double impact = 0;
var modifiableDataset = ((Dataset)(dataset).Clone()).ToModifiable();
// calculate impacts for double variables
if (dataset.VariableHasType(variableName)) {
impact = CalculateImpactForNumericalVariables(variableName, model, modifiableDataset, rows, targetValues, originalValue, replacementMethod);
} else if (dataset.VariableHasType(variableName)) {
impact = CalculateImpactForFactorVariables(variableName, model, dataset, modifiableDataset, rows, targetValues, originalValue, factorReplacementMethod);
} else {
throw new NotSupportedException("Variable not supported");
}
return impact;
}
private static double CalculateImpactForNumericalVariables(string variableName,
IRegressionModel model,
ModifiableDataset modifiableDataset,
IEnumerable rows,
IEnumerable targetValues,
double originalValue,
ReplacementMethodEnum replacementMethod) {
OnlineCalculatorError error;
var newEstimates = GetReplacedValuesForNumericalVariables(model, variableName, modifiableDataset, rows, replacementMethod);
var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
if (error != OnlineCalculatorError.None) { throw new InvalidOperationException("Error during calculation with replaced inputs."); }
return originalValue - newValue;
}
private static double CalculateImpactForFactorVariables(string variableName,
IRegressionModel model,
IDataset problemData,
ModifiableDataset modifiableDataset,
IEnumerable rows,
IEnumerable targetValues,
double originalValue,
FactorReplacementMethodEnum factorReplacementMethod) {
OnlineCalculatorError error;
if (factorReplacementMethod == FactorReplacementMethodEnum.Best) {
// try replacing with all possible values and find the best replacement value
var smallestImpact = double.PositiveInfinity;
foreach (var repl in problemData.GetStringValues(variableName, rows).Distinct()) {
var originalValues = modifiableDataset.GetReadOnlyStringValues(variableName).ToList();
var newEstimates = GetReplacedValues(originalValues, model, variableName, modifiableDataset, rows, Enumerable.Repeat(repl, problemData.Rows).ToList());
var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
var curImpact = originalValue - newValue;
if (curImpact < smallestImpact) smallestImpact = curImpact;
}
return smallestImpact;
} else {
// for replacement methods shuffle and mode
// calculate impacts for factor variables
var newEstimates = GetReplacedValuesForFactorVariables(model, variableName, modifiableDataset, rows, factorReplacementMethod);
var newValue = CalculateVariableImpact(targetValues, newEstimates, out error);
if (error != OnlineCalculatorError.None) throw new InvalidOperationException("Error during calculation with replaced inputs.");
return originalValue - newValue;
}
}
private static IEnumerable GetReplacedValuesForNumericalVariables(
IRegressionModel model,
string variable,
ModifiableDataset dataset,
IEnumerable rows,
ReplacementMethodEnum replacement = ReplacementMethodEnum.Shuffle) {
var originalValues = dataset.GetReadOnlyDoubleValues(variable).ToList();
double replacementValue;
List replacementValues;
IRandom rand;
switch (replacement) {
case ReplacementMethodEnum.Median:
replacementValue = rows.Select(r => originalValues[r]).Median();
replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
break;
case ReplacementMethodEnum.Average:
replacementValue = rows.Select(r => originalValues[r]).Average();
replacementValues = Enumerable.Repeat(replacementValue, dataset.Rows).ToList();
break;
case ReplacementMethodEnum.Shuffle:
// new var has same empirical distribution but the relation to y is broken
rand = new FastRandom(31415);
// prepare a complete column for the dataset
replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
// shuffle only the selected rows
var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
int i = 0;
// update column values
foreach (var r in rows) {
replacementValues[r] = shuffledValues[i++];
}
break;
case ReplacementMethodEnum.Noise:
var avg = rows.Select(r => originalValues[r]).Average();
var stdDev = rows.Select(r => originalValues[r]).StandardDeviation();
rand = new FastRandom(31415);
// prepare a complete column for the dataset
replacementValues = Enumerable.Repeat(double.NaN, dataset.Rows).ToList();
// update column values
foreach (var r in rows) {
replacementValues[r] = NormalDistributedRandom.NextDouble(rand, avg, stdDev);
}
break;
default:
throw new ArgumentException(string.Format("ReplacementMethod {0} cannot be handled.", replacement));
}
return GetReplacedValues(originalValues, model, variable, dataset, rows, replacementValues);
}
private static IEnumerable GetReplacedValuesForFactorVariables(
IRegressionModel model,
string variable,
ModifiableDataset dataset,
IEnumerable rows,
FactorReplacementMethodEnum replacement = FactorReplacementMethodEnum.Shuffle) {
var originalValues = dataset.GetReadOnlyStringValues(variable).ToList();
List replacementValues;
IRandom rand;
switch (replacement) {
case FactorReplacementMethodEnum.Mode:
var mostCommonValue = rows.Select(r => originalValues[r])
.GroupBy(v => v)
.OrderByDescending(g => g.Count())
.First().Key;
replacementValues = Enumerable.Repeat(mostCommonValue, dataset.Rows).ToList();
break;
case FactorReplacementMethodEnum.Shuffle:
// new var has same empirical distribution but the relation to y is broken
rand = new FastRandom(31415);
// prepare a complete column for the dataset
replacementValues = Enumerable.Repeat(string.Empty, dataset.Rows).ToList();
// shuffle only the selected rows
var shuffledValues = rows.Select(r => originalValues[r]).Shuffle(rand).ToList();
int i = 0;
// update column values
foreach (var r in rows) {
replacementValues[r] = shuffledValues[i++];
}
break;
default:
throw new ArgumentException(string.Format("FactorReplacementMethod {0} cannot be handled.", replacement));
}
return GetReplacedValues(originalValues, model, variable, dataset, rows, replacementValues);
}
private static IEnumerable GetReplacedValues(
IList originalValues,
IRegressionModel model,
string variable,
ModifiableDataset dataset,
IEnumerable rows,
IList replacementValues) {
dataset.ReplaceVariable(variable, replacementValues);
//mkommend: ToList is used on purpose to avoid lazy evaluation that could result in wrong estimates due to variable replacements
var estimates = model.GetEstimatedValues(dataset, rows).ToList();
dataset.ReplaceVariable(variable, originalValues);
return estimates;
}
private static double CalculateVariableImpact(IEnumerable originalValues, IEnumerable estimatedValues, out OnlineCalculatorError errorState) {
IEnumerator firstEnumerator = originalValues.GetEnumerator();
IEnumerator secondEnumerator = estimatedValues.GetEnumerator();
var calculator = new OnlinePearsonsRSquaredCalculator();
// always move forward both enumerators (do not use short-circuit evaluation!)
while (firstEnumerator.MoveNext() & secondEnumerator.MoveNext()) {
double original = firstEnumerator.Current;
double estimated = secondEnumerator.Current;
calculator.Add(original, estimated);
if (calculator.ErrorState != OnlineCalculatorError.None) break;
}
// check if both enumerators are at the end to make sure both enumerations have the same length
if (calculator.ErrorState == OnlineCalculatorError.None &&
(secondEnumerator.MoveNext() || firstEnumerator.MoveNext())) {
throw new ArgumentException("Number of elements in first and second enumeration doesn't match.");
} else {
errorState = calculator.ErrorState;
return calculator.Value;
}
}
}
}