Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2904_CalculateImpacts/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/ClassificationVariableImpactCalculationTest.cs @ 16065

Last change on this file since 16065 was 16065, checked in by fholzing, 6 years ago

#2904: Added more Unit-Tests for Classification

File size: 8.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Algorithms.DataAnalysis;
5using HeuristicLab.Common;
6using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
7using HeuristicLab.Problems.DataAnalysis.Symbolic;
8using HeuristicLab.Problems.DataAnalysis.Symbolic.Classification;
9using HeuristicLab.Problems.Instances.DataAnalysis;
10using HeuristicLab.Random;
11using Microsoft.VisualStudio.TestTools.UnitTesting;
12
13namespace HeuristicLab.Problems.DataAnalysis.Tests {
14
15  [TestClass()]
16  public class ClassificationVariableImpactCalculationTest {
17    private static readonly double epsilon = 0.00001;
18
19    [TestMethod]
20    [TestCategory("Problems.DataAnalysis")]
21    [TestProperty("Time", "short")]
22    public void ConstantModelVariableImpactTest() {
23      IClassificationProblemData problemData = LoadIrisProblem();
24      IClassificationModel model = new ConstantModel(5, "y");
25      IClassificationSolution solution = new ClassificationSolution(model, problemData);
26      Dictionary<string, double> expectedImpacts = GetExpectedValuesForConstantModel();
27
28      CheckDefaultAsserts(solution, expectedImpacts);
29    }
30
31    [TestMethod]
32    [TestCategory("Problems.DataAnalysis")]
33    [TestProperty("Time", "short")]
34    public void KNNIrisVariableImpactTest() {
35      IClassificationProblemData problemData = LoadIrisProblem();
36      IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
37      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
38      Dictionary<string, double> expectedImpacts = GetExpectedValuesForIrisKNNModel();
39
40      CheckDefaultAsserts(solution, expectedImpacts);
41    }
42
43
44    [TestMethod]
45    [TestCategory("Problems.DataAnalysis")]
46    [TestProperty("Time", "short")]
47    public void CustomModelVariableImpactTest() {
48      IClassificationProblemData problemData = CreateDefaultProblem();
49      ISymbolicExpressionTree tree = CreateCustomExpressionTree();
50      var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
51      model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
52      IClassificationSolution solution = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
53      Dictionary<string, double> expectedImpacts = GetExpectedValuesForCustomProblem();
54
55      CheckDefaultAsserts(solution, expectedImpacts);
56    }
57
58    [TestMethod]
59    [TestCategory("Problems.DataAnalysis")]
60    [TestProperty("Time", "short")]
61    public void CustomModelVariableImpactNoInfluenceTest() {
62      IClassificationProblemData problemData = CreateDefaultProblem();
63      ISymbolicExpressionTree tree = CreateCustomExpressionTreeNoInfluenceX1();
64      var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
65      model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
66      IClassificationSolution solution = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
67      Dictionary<string, double> expectedImpacts = GetExpectedValuesForCustomProblemNoInfluence();
68
69      CheckDefaultAsserts(solution, expectedImpacts);
70    }
71
72    [TestMethod]
73    [TestCategory("Problems.DataAnalysis")]
74    [TestProperty("Time", "short")]
75    [ExpectedException(typeof(ArgumentException))]
76    public void WrongDataSetTest() {
77      IClassificationProblemData problemData = LoadIrisProblem();
78      IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
79      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
80      Dictionary<string, double> expectedImpacts = GetExpectedValuesForIrisKNNModel();
81
82      solution.ProblemData = LoadMammographyProblem();
83      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
84    }
85
86    #region Load ClassificationProblemData
87    private IClassificationProblemData LoadIrisProblem() {
88      UCIInstanceProvider provider = new UCIInstanceProvider();
89      var instance = provider.GetDataDescriptors().Where(x => x.Name.Equals("Iris, M. Marshall, 1988")).Single();
90      return provider.LoadData(instance);
91    }
92    private IClassificationProblemData LoadMammographyProblem() {
93      UCIInstanceProvider provider = new UCIInstanceProvider();
94      var instance = provider.GetDataDescriptors().Where(x => x.Name.Equals("Mammography, M. Elter, 2007")).Single();
95      return provider.LoadData(instance);
96    }
97    private IClassificationProblemData CreateDefaultProblem() {
98      List<string> allowedInputVariables = new List<string>() { "x1", "x2", "x3", "x4", "x5" };
99      string targetVariable = "y";
100      var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable());
101      double[,] variableValues = new double[100, variableNames.Count()];
102
103      FastRandom random = new FastRandom(12345);
104      int len0 = variableValues.GetLength(0);
105      int len1 = variableValues.GetLength(1);
106      for (int i = 0; i < len0; i++) {
107        for (int j = 0; j < len1; j++) {
108          if (j == len1 - 1) {
109            variableValues[i, j] = (j + i) % 2;
110          } else {
111            variableValues[i, j] = random.Next(1, 100);
112          }
113        }
114      }
115
116      Dataset dataset = new Dataset(variableNames, variableValues);
117      var ret = new ClassificationProblemData(dataset, allowedInputVariables, targetVariable);
118
119      ret.SetClassName(0, "NOK");
120      ret.SetClassName(1, "OK");
121      return ret;
122    }
123    #endregion
124
125    #region Create SymbolicExpressionTree
126    private ISymbolicExpressionTree CreateCustomExpressionTree() {
127      return new InfixExpressionParser().Parse("x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
128    }
129    private ISymbolicExpressionTree CreateCustomExpressionTreeNoInfluenceX1() {
130      return new InfixExpressionParser().Parse("x1/x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
131    }
132    #endregion
133
134    #region Get Expected Values     
135    private Dictionary<string, double> GetExpectedValuesForConstantModel() {
136      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
137      expectedImpacts.Add("petal_length", 0);
138      expectedImpacts.Add("petal_width", 0);
139      expectedImpacts.Add("sepal_length", 0);
140      expectedImpacts.Add("sepal_width", 0);
141
142      return expectedImpacts;
143    }
144    private Dictionary<string, double> GetExpectedValuesForIrisKNNModel() {
145      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
146      expectedImpacts.Add("petal_length", 0.21);
147      expectedImpacts.Add("petal_width", 0.25);
148      expectedImpacts.Add("sepal_length", 0.05);
149      expectedImpacts.Add("sepal_width", 0.05);
150
151      return expectedImpacts;
152    }
153    private Dictionary<string, double> GetExpectedValuesForCustomProblem() {
154      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
155      expectedImpacts.Add("x1", 0.04);
156      expectedImpacts.Add("x2", 0.22);
157      expectedImpacts.Add("x3", 0.26);
158      expectedImpacts.Add("x4", 0.24);
159      expectedImpacts.Add("x5", 0.2);
160
161      return expectedImpacts;
162    }
163    private Dictionary<string, double> GetExpectedValuesForCustomProblemNoInfluence() {
164      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
165      expectedImpacts.Add("x1", 0);
166      expectedImpacts.Add("x2", 0.22);
167      expectedImpacts.Add("x3", 0.14);
168      expectedImpacts.Add("x4", 0.3);
169      expectedImpacts.Add("x5", 0.44);
170
171      return expectedImpacts;
172    }
173    #endregion
174
175    private void CheckDefaultAsserts(IClassificationSolution solution, Dictionary<string, double> expectedImpacts) {
176      IClassificationProblemData problemData = solution.ProblemData;
177      IEnumerable<double> estimatedValues = solution.GetEstimatedClassValues(solution.ProblemData.TrainingIndices);
178
179      var solutionImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
180      var modelImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution.Model, problemData, estimatedValues, problemData.TrainingIndices);
181
182      //Both ways should return equal results
183      Assert.IsTrue(solutionImpacts.SequenceEqual(modelImpacts));
184
185      //Check if impacts are as expected
186      Assert.AreEqual(modelImpacts.Count(), expectedImpacts.Count);
187      Assert.IsTrue(modelImpacts.All(v => Math.Abs(expectedImpacts[v.Item1] - v.Item2) < epsilon));
188    }
189  }
190}
Note: See TracBrowser for help on using the repository browser.