Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2904_CalculateImpacts/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/ClassificationVariableImpactCalculationTest.cs @ 16067

Last change on this file since 16067 was 16067, checked in by fholzing, 6 years ago

#2904: Added additional Unit-Test for performance measurement

File size: 10.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using HeuristicLab.Algorithms.DataAnalysis;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis.Symbolic;
9using HeuristicLab.Problems.DataAnalysis.Symbolic.Classification;
10using HeuristicLab.Problems.Instances.DataAnalysis;
11using HeuristicLab.Random;
12using Microsoft.VisualStudio.TestTools.UnitTesting;
13
14namespace HeuristicLab.Problems.DataAnalysis.Tests {
15
16  [TestClass()]
17  public class ClassificationVariableImpactCalculationTest {
18    private TestContext testContextInstance;
19    /// <summary>
20    ///Gets or sets the test context which provides
21    ///information about and functionality for the current test run.
22    ///</summary>
23    public TestContext TestContext {
24      get { return testContextInstance; }
25      set { testContextInstance = value; }
26    }
27
28    private static readonly double epsilon = 0.00001;
29
30    [TestMethod]
31    [TestCategory("Problems.DataAnalysis")]
32    [TestProperty("Time", "short")]
33    public void ConstantModelVariableImpactTest() {
34      IClassificationProblemData problemData = LoadIrisProblem();
35      IClassificationModel model = new ConstantModel(5, "y");
36      IClassificationSolution solution = new ClassificationSolution(model, problemData);
37      Dictionary<string, double> expectedImpacts = GetExpectedValuesForConstantModel();
38
39      CheckDefaultAsserts(solution, expectedImpacts);
40    }
41
42    [TestMethod]
43    [TestCategory("Problems.DataAnalysis")]
44    [TestProperty("Time", "short")]
45    public void KNNIrisVariableImpactTest() {
46      IClassificationProblemData problemData = LoadIrisProblem();
47      IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
48      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
49      Dictionary<string, double> expectedImpacts = GetExpectedValuesForIrisKNNModel();
50
51      CheckDefaultAsserts(solution, expectedImpacts);
52    }
53
54
55    [TestMethod]
56    [TestCategory("Problems.DataAnalysis")]
57    [TestProperty("Time", "short")]
58    public void CustomModelVariableImpactTest() {
59      IClassificationProblemData problemData = CreateDefaultProblem();
60      ISymbolicExpressionTree tree = CreateCustomExpressionTree();
61      var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
62      model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
63      IClassificationSolution solution = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
64      Dictionary<string, double> expectedImpacts = GetExpectedValuesForCustomProblem();
65
66      CheckDefaultAsserts(solution, expectedImpacts);
67    }
68
69    [TestMethod]
70    [TestCategory("Problems.DataAnalysis")]
71    [TestProperty("Time", "short")]
72    public void CustomModelVariableImpactNoInfluenceTest() {
73      IClassificationProblemData problemData = CreateDefaultProblem();
74      ISymbolicExpressionTree tree = CreateCustomExpressionTreeNoInfluenceX1();
75      var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
76      model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
77      IClassificationSolution solution = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
78      Dictionary<string, double> expectedImpacts = GetExpectedValuesForCustomProblemNoInfluence();
79
80      CheckDefaultAsserts(solution, expectedImpacts);
81    }
82
83    [TestMethod]
84    [TestCategory("Problems.DataAnalysis")]
85    [TestProperty("Time", "short")]
86    [ExpectedException(typeof(ArgumentException))]
87    public void WrongDataSetTest() {
88      IClassificationProblemData problemData = LoadIrisProblem();
89      IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
90      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
91      Dictionary<string, double> expectedImpacts = GetExpectedValuesForIrisKNNModel();
92
93      solution.ProblemData = LoadMammographyProblem();
94      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
95    }
96
97
98    [TestMethod]
99    [TestCategory("Problems.DataAnalysis")]
100    [TestProperty("Time", "medium")]
101    public void PerformanceTest() {
102      int rows = 1500;
103      int columns = 77;
104      IClassificationProblemData problemData = CreateDefaultProblem(rows, columns);
105      IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
106
107      Stopwatch watch = new Stopwatch();
108      watch.Start();
109      var results = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
110      watch.Stop();
111
112      TestContext.WriteLine("");
113      TestContext.WriteLine("Calculated cells per millisecond: {0}.", rows * columns / watch.ElapsedMilliseconds);
114    }
115
116    #region Load ClassificationProblemData
117    private IClassificationProblemData LoadIrisProblem() {
118      UCIInstanceProvider provider = new UCIInstanceProvider();
119      var instance = provider.GetDataDescriptors().Where(x => x.Name.Equals("Iris, M. Marshall, 1988")).Single();
120      return provider.LoadData(instance);
121    }
122    private IClassificationProblemData LoadMammographyProblem() {
123      UCIInstanceProvider provider = new UCIInstanceProvider();
124      var instance = provider.GetDataDescriptors().Where(x => x.Name.Equals("Mammography, M. Elter, 2007")).Single();
125      return provider.LoadData(instance);
126    }
127    private IClassificationProblemData CreateDefaultProblem() {
128      List<string> allowedInputVariables = new List<string>() { "x1", "x2", "x3", "x4", "x5" };
129      string targetVariable = "y";
130      var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable());
131      double[,] variableValues = new double[100, variableNames.Count()];
132
133      FastRandom random = new FastRandom(12345);
134      int len0 = variableValues.GetLength(0);
135      int len1 = variableValues.GetLength(1);
136      for (int i = 0; i < len0; i++) {
137        for (int j = 0; j < len1; j++) {
138          if (j == len1 - 1) {
139            variableValues[i, j] = (j + i) % 2;
140          } else {
141            variableValues[i, j] = random.Next(1, 100);
142          }
143        }
144      }
145
146      Dataset dataset = new Dataset(variableNames, variableValues);
147      var ret = new ClassificationProblemData(dataset, allowedInputVariables, targetVariable);
148
149      ret.SetClassName(0, "NOK");
150      ret.SetClassName(1, "OK");
151      return ret;
152    }
153
154    private IClassificationProblemData CreateDefaultProblem(int rows, int columns) {
155      List<string> allowedInputVariables = Enumerable.Range(0, columns - 1).Select(x => "x" + x.ToString()).ToList();
156      string targetVariable = "y";
157      var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable());
158      double[,] variableValues = new double[rows, columns];
159
160      FastRandom random = new FastRandom(12345);
161      int len0 = variableValues.GetLength(0);
162      int len1 = variableValues.GetLength(1);
163      for (int i = 0; i < len0; i++) {
164        for (int j = 0; j < len1; j++) {
165          if (j == len1 - 1) {
166            variableValues[i, j] = (j + i) % 2;
167          } else {
168            variableValues[i, j] = random.Next(1, 100);
169          }
170        }
171      }
172
173      Dataset dataset = new Dataset(variableNames, variableValues);
174      var ret = new ClassificationProblemData(dataset, allowedInputVariables, targetVariable);
175
176      ret.SetClassName(0, "NOK");
177      ret.SetClassName(1, "OK");
178      return ret;
179    }
180
181    #endregion
182
183    #region Create SymbolicExpressionTree
184    private ISymbolicExpressionTree CreateCustomExpressionTree() {
185      return new InfixExpressionParser().Parse("x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
186    }
187    private ISymbolicExpressionTree CreateCustomExpressionTreeNoInfluenceX1() {
188      return new InfixExpressionParser().Parse("x1/x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
189    }
190    #endregion
191
192    #region Get Expected Values     
193    private Dictionary<string, double> GetExpectedValuesForConstantModel() {
194      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
195      expectedImpacts.Add("petal_length", 0);
196      expectedImpacts.Add("petal_width", 0);
197      expectedImpacts.Add("sepal_length", 0);
198      expectedImpacts.Add("sepal_width", 0);
199
200      return expectedImpacts;
201    }
202    private Dictionary<string, double> GetExpectedValuesForIrisKNNModel() {
203      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
204      expectedImpacts.Add("petal_length", 0.21);
205      expectedImpacts.Add("petal_width", 0.25);
206      expectedImpacts.Add("sepal_length", 0.05);
207      expectedImpacts.Add("sepal_width", 0.05);
208
209      return expectedImpacts;
210    }
211    private Dictionary<string, double> GetExpectedValuesForCustomProblem() {
212      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
213      expectedImpacts.Add("x1", 0.04);
214      expectedImpacts.Add("x2", 0.22);
215      expectedImpacts.Add("x3", 0.26);
216      expectedImpacts.Add("x4", 0.24);
217      expectedImpacts.Add("x5", 0.2);
218
219      return expectedImpacts;
220    }
221    private Dictionary<string, double> GetExpectedValuesForCustomProblemNoInfluence() {
222      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
223      expectedImpacts.Add("x1", 0);
224      expectedImpacts.Add("x2", 0.22);
225      expectedImpacts.Add("x3", 0.14);
226      expectedImpacts.Add("x4", 0.3);
227      expectedImpacts.Add("x5", 0.44);
228
229      return expectedImpacts;
230    }
231    #endregion
232
233    private void CheckDefaultAsserts(IClassificationSolution solution, Dictionary<string, double> expectedImpacts) {
234      IClassificationProblemData problemData = solution.ProblemData;
235      IEnumerable<double> estimatedValues = solution.GetEstimatedClassValues(solution.ProblemData.TrainingIndices);
236
237      var solutionImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
238      var modelImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution.Model, problemData, estimatedValues, problemData.TrainingIndices);
239
240      //Both ways should return equal results
241      Assert.IsTrue(solutionImpacts.SequenceEqual(modelImpacts));
242
243      //Check if impacts are as expected
244      Assert.AreEqual(modelImpacts.Count(), expectedImpacts.Count);
245      Assert.IsTrue(modelImpacts.All(v => Math.Abs(expectedImpacts[v.Item1] - v.Item2) < epsilon));
246    }
247  }
248}
Note: See TracBrowser for help on using the repository browser.