using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using HeuristicLab.Algorithms.DataAnalysis;
using HeuristicLab.Common;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Problems.DataAnalysis.Symbolic;
using HeuristicLab.Problems.DataAnalysis.Symbolic.Classification;
using HeuristicLab.Problems.Instances.DataAnalysis;
using HeuristicLab.Random;
using Microsoft.VisualStudio.TestTools.UnitTesting;
namespace HeuristicLab.Problems.DataAnalysis.Tests {
[TestClass()]
public class ClassificationVariableImpactCalculationTest {
private TestContext testContextInstance;
///
///Gets or sets the test context which provides
///information about and functionality for the current test run.
///
public TestContext TestContext {
get { return testContextInstance; }
set { testContextInstance = value; }
}
[TestMethod]
[TestCategory("Problems.DataAnalysis")]
[TestProperty("Time", "short")]
public void ConstantModelVariableImpactTest() {
IClassificationProblemData problemData = LoadIrisProblem();
IClassificationModel model = new ConstantModel(5, "y");
IClassificationSolution solution = new ClassificationSolution(model, problemData);
Dictionary expectedImpacts = GetExpectedValuesForConstantModel();
CheckDefaultAsserts(solution, expectedImpacts);
}
[TestMethod]
[TestCategory("Problems.DataAnalysis")]
[TestProperty("Time", "short")]
public void KNNIrisVariableImpactTest() {
IClassificationProblemData problemData = LoadIrisProblem();
IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
Dictionary expectedImpacts = GetExpectedValuesForIrisKNNModel();
CheckDefaultAsserts(solution, expectedImpacts);
}
[TestMethod]
[TestCategory("Problems.DataAnalysis")]
[TestProperty("Time", "short")]
public void LDAIrisVariableImpactTest() {
IClassificationProblemData problemData = LoadIrisProblem();
IClassificationSolution solution = LinearDiscriminantAnalysis.CreateLinearDiscriminantAnalysisSolution(problemData);
ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
Dictionary expectedImpacts = GetExpectedValuesForIrisLDAModel();
CheckDefaultAsserts(solution, expectedImpacts);
}
[TestMethod]
[TestCategory("Problems.DataAnalysis")]
[TestProperty("Time", "short")]
public void CustomModelVariableImpactTest() {
IClassificationProblemData problemData = CreateDefaultProblem();
ISymbolicExpressionTree tree = CreateCustomExpressionTree();
var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
IClassificationSolution solution = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
Dictionary expectedImpacts = GetExpectedValuesForCustomProblem();
CheckDefaultAsserts(solution, expectedImpacts);
}
[TestMethod]
[TestCategory("Problems.DataAnalysis")]
[TestProperty("Time", "short")]
public void CustomModelVariableImpactNoInfluenceTest() {
IClassificationProblemData problemData = CreateDefaultProblem();
ISymbolicExpressionTree tree = CreateCustomExpressionTreeNoInfluenceX1();
var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
IClassificationSolution solution = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
Dictionary expectedImpacts = GetExpectedValuesForCustomProblemNoInfluence();
CheckDefaultAsserts(solution, expectedImpacts);
}
[TestMethod]
[TestCategory("Problems.DataAnalysis")]
[TestProperty("Time", "short")]
[ExpectedException(typeof(ArgumentException))]
public void WrongDataSetVariableImpactClassificationTest() {
IClassificationProblemData problemData = LoadIrisProblem();
IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
Dictionary expectedImpacts = GetExpectedValuesForIrisKNNModel();
solution.ProblemData = LoadMammographyProblem();
ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
}
[TestMethod]
[TestCategory("Problems.DataAnalysis")]
[TestProperty("Time", "medium")]
public void PerformanceVariableImpactClassificationTest() {
int rows = 1500;
int columns = 77;
IClassificationProblemData problemData = CreateDefaultProblem(rows, columns);
IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
Stopwatch watch = new Stopwatch();
watch.Start();
var results = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
watch.Stop();
TestContext.WriteLine("");
TestContext.WriteLine("Calculated cells per millisecond: {0}.", rows * columns / watch.ElapsedMilliseconds);
}
#region Load ClassificationProblemData
private IClassificationProblemData LoadIrisProblem() {
UCIInstanceProvider provider = new UCIInstanceProvider();
var instance = provider.GetDataDescriptors().Where(x => x.Name.Equals("Iris, M. Marshall, 1988")).Single();
return provider.LoadData(instance);
}
private IClassificationProblemData LoadMammographyProblem() {
UCIInstanceProvider provider = new UCIInstanceProvider();
var instance = provider.GetDataDescriptors().Where(x => x.Name.Equals("Mammography, M. Elter, 2007")).Single();
return provider.LoadData(instance);
}
private IClassificationProblemData CreateDefaultProblem() {
List allowedInputVariables = new List() { "x1", "x2", "x3", "x4", "x5" };
string targetVariable = "y";
var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable());
double[,] variableValues = new double[100, variableNames.Count()];
FastRandom random = new FastRandom(12345);
int len0 = variableValues.GetLength(0);
int len1 = variableValues.GetLength(1);
for (int i = 0; i < len0; i++) {
for (int j = 0; j < len1; j++) {
if (j == len1 - 1) {
variableValues[i, j] = (j + i) % 2;
} else {
variableValues[i, j] = random.Next(1, 100);
}
}
}
Dataset dataset = new Dataset(variableNames, variableValues);
var ret = new ClassificationProblemData(dataset, allowedInputVariables, targetVariable);
ret.SetClassName(0, "NOK");
ret.SetClassName(1, "OK");
return ret;
}
private IClassificationProblemData CreateDefaultProblem(int rows, int columns) {
List allowedInputVariables = Enumerable.Range(0, columns - 1).Select(x => "x" + x.ToString()).ToList();
string targetVariable = "y";
var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable());
double[,] variableValues = new double[rows, columns];
FastRandom random = new FastRandom(12345);
int len0 = variableValues.GetLength(0);
int len1 = variableValues.GetLength(1);
for (int i = 0; i < len0; i++) {
for (int j = 0; j < len1; j++) {
if (j == len1 - 1) {
variableValues[i, j] = (j + i) % 2;
} else {
variableValues[i, j] = random.Next(1, 100);
}
}
}
Dataset dataset = new Dataset(variableNames, variableValues);
var ret = new ClassificationProblemData(dataset, allowedInputVariables, targetVariable);
ret.SetClassName(0, "NOK");
ret.SetClassName(1, "OK");
return ret;
}
#endregion
#region Create SymbolicExpressionTree
private ISymbolicExpressionTree CreateCustomExpressionTree() {
return new InfixExpressionParser().Parse("x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
}
private ISymbolicExpressionTree CreateCustomExpressionTreeNoInfluenceX1() {
return new InfixExpressionParser().Parse("x1/x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
}
#endregion
#region Get Expected Values
private Dictionary GetExpectedValuesForConstantModel() {
Dictionary expectedImpacts = new Dictionary();
expectedImpacts.Add("petal_length", 0);
expectedImpacts.Add("petal_width", 0);
expectedImpacts.Add("sepal_length", 0);
expectedImpacts.Add("sepal_width", 0);
return expectedImpacts;
}
private Dictionary GetExpectedValuesForIrisKNNModel() {
Dictionary expectedImpacts = new Dictionary();
expectedImpacts.Add("petal_length", 0.22);
expectedImpacts.Add("petal_width", 0.35);
expectedImpacts.Add("sepal_length", 0.15);
expectedImpacts.Add("sepal_width", 0.05);
return expectedImpacts;
}
private Dictionary GetExpectedValuesForCustomProblem() {
Dictionary expectedImpacts = new Dictionary();
expectedImpacts.Add("x1", 0.04);
expectedImpacts.Add("x2", 0.22);
expectedImpacts.Add("x3", 0.26);
expectedImpacts.Add("x4", 0.24);
expectedImpacts.Add("x5", 0.2);
return expectedImpacts;
}
private Dictionary GetExpectedValuesForCustomProblemNoInfluence() {
Dictionary expectedImpacts = new Dictionary();
expectedImpacts.Add("x1", 0);
expectedImpacts.Add("x2", 0.22);
expectedImpacts.Add("x3", 0.14);
expectedImpacts.Add("x4", 0.3);
expectedImpacts.Add("x5", 0.44);
return expectedImpacts;
}
private Dictionary GetExpectedValuesForIrisLDAModel() {
Dictionary expectedImpacts = new Dictionary();
expectedImpacts.Add("sepal_width", 0.01);
expectedImpacts.Add("sepal_length", 0.03);
expectedImpacts.Add("petal_width", 0.2);
expectedImpacts.Add("petal_length", 0.5);
return expectedImpacts;
}
#endregion
private void CheckDefaultAsserts(IClassificationSolution solution, Dictionary expectedImpacts) {
IClassificationProblemData problemData = solution.ProblemData;
IEnumerable estimatedValues = solution.GetEstimatedClassValues(solution.ProblemData.TrainingIndices);
var solutionImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
var modelImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution.Model, problemData, estimatedValues, problemData.TrainingIndices);
//Both ways should return equal results
Assert.IsTrue(solutionImpacts.SequenceEqual(modelImpacts));
//Check if impacts are as expected
Assert.AreEqual(modelImpacts.Count(), expectedImpacts.Count);
Assert.IsTrue(modelImpacts.All(v => v.Item2.IsAlmost(expectedImpacts[v.Item1])));
}
}
}