Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/ClassificationVariableImpactCalculationTest.cs @ 18079

Last change on this file since 18079 was 17948, checked in by gkronber, 4 years ago

#3117: updated reference results in unit tests to match new results caused by updated alglib version.

File size: 11.7 KB
RevLine 
[16061]1using System;
2using System.Collections.Generic;
[16067]3using System.Diagnostics;
[16061]4using System.Linq;
5using HeuristicLab.Algorithms.DataAnalysis;
[16065]6using HeuristicLab.Common;
[16061]7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis.Symbolic;
[16065]9using HeuristicLab.Problems.DataAnalysis.Symbolic.Classification;
[16061]10using HeuristicLab.Problems.Instances.DataAnalysis;
[16065]11using HeuristicLab.Random;
[16061]12using Microsoft.VisualStudio.TestTools.UnitTesting;
13
14namespace HeuristicLab.Problems.DataAnalysis.Tests {
15
16  [TestClass()]
17  public class ClassificationVariableImpactCalculationTest {
[16067]18    private TestContext testContextInstance;
19    /// <summary>
20    ///Gets or sets the test context which provides
21    ///information about and functionality for the current test run.
22    ///</summary>
23    public TestContext TestContext {
24      get { return testContextInstance; }
25      set { testContextInstance = value; }
26    }
27
[16061]28
29    [TestMethod]
30    [TestCategory("Problems.DataAnalysis")]
31    [TestProperty("Time", "short")]
32    public void ConstantModelVariableImpactTest() {
33      IClassificationProblemData problemData = LoadIrisProblem();
34      IClassificationModel model = new ConstantModel(5, "y");
35      IClassificationSolution solution = new ClassificationSolution(model, problemData);
36      Dictionary<string, double> expectedImpacts = GetExpectedValuesForConstantModel();
37
38      CheckDefaultAsserts(solution, expectedImpacts);
39    }
40
41    [TestMethod]
42    [TestCategory("Problems.DataAnalysis")]
43    [TestProperty("Time", "short")]
44    public void KNNIrisVariableImpactTest() {
45      IClassificationProblemData problemData = LoadIrisProblem();
46      IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
47      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
48      Dictionary<string, double> expectedImpacts = GetExpectedValuesForIrisKNNModel();
49
50      CheckDefaultAsserts(solution, expectedImpacts);
51    }
52
53
54    [TestMethod]
55    [TestCategory("Problems.DataAnalysis")]
56    [TestProperty("Time", "short")]
[16416]57    public void LDAIrisVariableImpactTest() {
58      IClassificationProblemData problemData = LoadIrisProblem();
59      IClassificationSolution solution = LinearDiscriminantAnalysis.CreateLinearDiscriminantAnalysisSolution(problemData);
60      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
61      Dictionary<string, double> expectedImpacts = GetExpectedValuesForIrisLDAModel();
62
63      CheckDefaultAsserts(solution, expectedImpacts);
64    }
65
66
67    [TestMethod]
68    [TestCategory("Problems.DataAnalysis")]
69    [TestProperty("Time", "short")]
[16065]70    public void CustomModelVariableImpactTest() {
71      IClassificationProblemData problemData = CreateDefaultProblem();
72      ISymbolicExpressionTree tree = CreateCustomExpressionTree();
73      var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
74      model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
75      IClassificationSolution solution = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
76      Dictionary<string, double> expectedImpacts = GetExpectedValuesForCustomProblem();
77
78      CheckDefaultAsserts(solution, expectedImpacts);
79    }
80
81    [TestMethod]
82    [TestCategory("Problems.DataAnalysis")]
83    [TestProperty("Time", "short")]
84    public void CustomModelVariableImpactNoInfluenceTest() {
85      IClassificationProblemData problemData = CreateDefaultProblem();
86      ISymbolicExpressionTree tree = CreateCustomExpressionTreeNoInfluenceX1();
87      var model = new SymbolicNearestNeighbourClassificationModel(problemData.TargetVariable, 3, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
88      model.RecalculateModelParameters(problemData, problemData.TrainingIndices);
89      IClassificationSolution solution = new ClassificationSolution(model, (IClassificationProblemData)problemData.Clone());
90      Dictionary<string, double> expectedImpacts = GetExpectedValuesForCustomProblemNoInfluence();
91
92      CheckDefaultAsserts(solution, expectedImpacts);
93    }
94
95    [TestMethod]
96    [TestCategory("Problems.DataAnalysis")]
97    [TestProperty("Time", "short")]
[16061]98    [ExpectedException(typeof(ArgumentException))]
[16416]99    public void WrongDataSetVariableImpactClassificationTest() {
[16061]100      IClassificationProblemData problemData = LoadIrisProblem();
101      IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
102      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
103      Dictionary<string, double> expectedImpacts = GetExpectedValuesForIrisKNNModel();
104
105      solution.ProblemData = LoadMammographyProblem();
106      ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
107    }
108
[16067]109
110    [TestMethod]
111    [TestCategory("Problems.DataAnalysis")]
112    [TestProperty("Time", "medium")]
[16416]113    public void PerformanceVariableImpactClassificationTest() {
[16067]114      int rows = 1500;
115      int columns = 77;
116      IClassificationProblemData problemData = CreateDefaultProblem(rows, columns);
117      IClassificationSolution solution = NearestNeighbourClassification.CreateNearestNeighbourClassificationSolution(problemData, 3);
118
119      Stopwatch watch = new Stopwatch();
120      watch.Start();
121      var results = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
122      watch.Stop();
123
124      TestContext.WriteLine("");
125      TestContext.WriteLine("Calculated cells per millisecond: {0}.", rows * columns / watch.ElapsedMilliseconds);
126    }
127
[16061]128    #region Load ClassificationProblemData
129    private IClassificationProblemData LoadIrisProblem() {
130      UCIInstanceProvider provider = new UCIInstanceProvider();
131      var instance = provider.GetDataDescriptors().Where(x => x.Name.Equals("Iris, M. Marshall, 1988")).Single();
132      return provider.LoadData(instance);
133    }
134    private IClassificationProblemData LoadMammographyProblem() {
135      UCIInstanceProvider provider = new UCIInstanceProvider();
136      var instance = provider.GetDataDescriptors().Where(x => x.Name.Equals("Mammography, M. Elter, 2007")).Single();
137      return provider.LoadData(instance);
138    }
[16065]139    private IClassificationProblemData CreateDefaultProblem() {
140      List<string> allowedInputVariables = new List<string>() { "x1", "x2", "x3", "x4", "x5" };
141      string targetVariable = "y";
142      var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable());
143      double[,] variableValues = new double[100, variableNames.Count()];
144
145      FastRandom random = new FastRandom(12345);
146      int len0 = variableValues.GetLength(0);
147      int len1 = variableValues.GetLength(1);
148      for (int i = 0; i < len0; i++) {
149        for (int j = 0; j < len1; j++) {
150          if (j == len1 - 1) {
151            variableValues[i, j] = (j + i) % 2;
152          } else {
153            variableValues[i, j] = random.Next(1, 100);
154          }
155        }
156      }
157
158      Dataset dataset = new Dataset(variableNames, variableValues);
159      var ret = new ClassificationProblemData(dataset, allowedInputVariables, targetVariable);
160
161      ret.SetClassName(0, "NOK");
162      ret.SetClassName(1, "OK");
163      return ret;
164    }
[16067]165
166    private IClassificationProblemData CreateDefaultProblem(int rows, int columns) {
167      List<string> allowedInputVariables = Enumerable.Range(0, columns - 1).Select(x => "x" + x.ToString()).ToList();
168      string targetVariable = "y";
169      var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable());
170      double[,] variableValues = new double[rows, columns];
171
172      FastRandom random = new FastRandom(12345);
173      int len0 = variableValues.GetLength(0);
174      int len1 = variableValues.GetLength(1);
175      for (int i = 0; i < len0; i++) {
176        for (int j = 0; j < len1; j++) {
177          if (j == len1 - 1) {
178            variableValues[i, j] = (j + i) % 2;
179          } else {
180            variableValues[i, j] = random.Next(1, 100);
181          }
182        }
183      }
184
185      Dataset dataset = new Dataset(variableNames, variableValues);
186      var ret = new ClassificationProblemData(dataset, allowedInputVariables, targetVariable);
187
188      ret.SetClassName(0, "NOK");
189      ret.SetClassName(1, "OK");
190      return ret;
191    }
192
[16061]193    #endregion
194
195    #region Create SymbolicExpressionTree
[16065]196    private ISymbolicExpressionTree CreateCustomExpressionTree() {
197      return new InfixExpressionParser().Parse("x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
198    }
[16061]199    private ISymbolicExpressionTree CreateCustomExpressionTreeNoInfluenceX1() {
200      return new InfixExpressionParser().Parse("x1/x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
201    }
202    #endregion
203
204    #region Get Expected Values     
205    private Dictionary<string, double> GetExpectedValuesForConstantModel() {
206      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
207      expectedImpacts.Add("petal_length", 0);
208      expectedImpacts.Add("petal_width", 0);
209      expectedImpacts.Add("sepal_length", 0);
210      expectedImpacts.Add("sepal_width", 0);
211
212      return expectedImpacts;
213    }
214    private Dictionary<string, double> GetExpectedValuesForIrisKNNModel() {
215      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
[17948]216      expectedImpacts.Add("petal_length", 0.22);
217      expectedImpacts.Add("petal_width", 0.35);
218      expectedImpacts.Add("sepal_length", 0.15);
[16061]219      expectedImpacts.Add("sepal_width", 0.05);
220
221      return expectedImpacts;
222    }
[16065]223    private Dictionary<string, double> GetExpectedValuesForCustomProblem() {
224      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
225      expectedImpacts.Add("x1", 0.04);
226      expectedImpacts.Add("x2", 0.22);
227      expectedImpacts.Add("x3", 0.26);
228      expectedImpacts.Add("x4", 0.24);
229      expectedImpacts.Add("x5", 0.2);
230
231      return expectedImpacts;
232    }
233    private Dictionary<string, double> GetExpectedValuesForCustomProblemNoInfluence() {
234      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
235      expectedImpacts.Add("x1", 0);
236      expectedImpacts.Add("x2", 0.22);
237      expectedImpacts.Add("x3", 0.14);
238      expectedImpacts.Add("x4", 0.3);
239      expectedImpacts.Add("x5", 0.44);
240
241      return expectedImpacts;
242    }
[16416]243    private Dictionary<string, double> GetExpectedValuesForIrisLDAModel() {
244      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
245      expectedImpacts.Add("sepal_width", 0.01);
246      expectedImpacts.Add("sepal_length", 0.03);
247      expectedImpacts.Add("petal_width", 0.2);
248      expectedImpacts.Add("petal_length", 0.5);
249
250      return expectedImpacts;
251    }
[16061]252    #endregion
253
254    private void CheckDefaultAsserts(IClassificationSolution solution, Dictionary<string, double> expectedImpacts) {
255      IClassificationProblemData problemData = solution.ProblemData;
256      IEnumerable<double> estimatedValues = solution.GetEstimatedClassValues(solution.ProblemData.TrainingIndices);
257
258      var solutionImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution);
259      var modelImpacts = ClassificationSolutionVariableImpactsCalculator.CalculateImpacts(solution.Model, problemData, estimatedValues, problemData.TrainingIndices);
260
261      //Both ways should return equal results
262      Assert.IsTrue(solutionImpacts.SequenceEqual(modelImpacts));
263
264      //Check if impacts are as expected
265      Assert.AreEqual(modelImpacts.Count(), expectedImpacts.Count);
[16416]266      Assert.IsTrue(modelImpacts.All(v => v.Item2.IsAlmost(expectedImpacts[v.Item1])));
[16061]267    }
268  }
269}
Note: See TracBrowser for help on using the repository browser.