source: stable/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/RegressionVariableImpactCalculationTest.cs @ 17163

Last change on this file since 17163 was 17163, checked in by gkronber, 2 months ago

#2892: merged r16443 from trunk to stable

File size: 15.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using HeuristicLab.Algorithms.DataAnalysis;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis.Symbolic;
9using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
10using HeuristicLab.Problems.Instances.DataAnalysis;
11using HeuristicLab.Random;
12using Microsoft.VisualStudio.TestTools.UnitTesting;
13
14namespace HeuristicLab.Problems.DataAnalysis.Tests {
15
16  [TestClass()]
17  public class RegressionVariableImpactCalculationTest {
18    private TestContext testContextInstance;
19    /// <summary>
20    ///Gets or sets the test context which provides
21    ///information about and functionality for the current test run.
22    ///</summary>
23    public TestContext TestContext {
24      get { return testContextInstance; }
25      set { testContextInstance = value; }
26    }
27
28
29    [TestMethod]
30    [TestCategory("Problems.DataAnalysis")]
31    [TestProperty("Time", "short")]
32    public void ConstantModelVariableImpactTest() {
33      IRegressionProblemData problemData = LoadDefaultTowerProblem();
34      IRegressionModel model = new ConstantModel(5, "y");
35      IRegressionSolution solution = new RegressionSolution(model, problemData);
36      Dictionary<string, double> expectedImpacts = GetExpectedValuesForConstantModel();
37
38      CheckDefaultAsserts(solution, expectedImpacts);
39    }
40
41    [TestMethod]
42    [TestCategory("Problems.DataAnalysis")]
43    [TestProperty("Time", "short")]
44    public void LinearRegressionModelVariableImpactTowerTest() {
45      IRegressionProblemData problemData = LoadDefaultTowerProblem();
46      double rmsError;
47      double cvRmsError;
48      var solution = LinearRegression.CreateSolution(problemData, out rmsError, out cvRmsError);
49      Dictionary<string, double> expectedImpacts = GetExpectedValuesForLRTower();
50
51      CheckDefaultAsserts(solution, expectedImpacts);
52    }
53
54    [TestMethod]
55    [TestCategory("Problems.DataAnalysis")]
56    [TestProperty("Time", "short")]
57    public void LinearRegressionModelVariableImpactMibaTest() {
58      IRegressionProblemData problemData = LoadDefaultMibaProblem();
59      double rmsError;
60      double cvRmsError;
61      var solution = LinearRegression.CreateSolution(problemData, out rmsError, out cvRmsError);
62      Dictionary<string, double> expectedImpacts = GetExpectedValuesForLRMiba();
63
64      CheckDefaultAsserts(solution, expectedImpacts);
65    }
66
67    [TestMethod]
68    [TestCategory("Problems.DataAnalysis")]
69    [TestProperty("Time", "short")]
70    public void RandomForestModelVariableImpactTowerTest() {
71      IRegressionProblemData problemData = LoadDefaultTowerProblem();
72      double rmsError;
73      double avgRelError;
74      double outOfBagRmsError;
75      double outofBagAvgRelError;
76      var solution = RandomForestRegression.CreateRandomForestRegressionSolution(problemData, 50, 0.2, 0.5, 1234, out rmsError, out avgRelError, out outOfBagRmsError, out outofBagAvgRelError);
77      Dictionary<string, double> expectedImpacts = GetExpectedValuesForRFTower();
78
79      CheckDefaultAsserts(solution, expectedImpacts);
80    }
81
82    [TestMethod]
83    [TestCategory("Problems.DataAnalysis")]
84    [TestProperty("Time", "short")]
85    public void CustomModelVariableImpactTest() {
86      IRegressionProblemData problemData = CreateDefaultProblem();
87      ISymbolicExpressionTree tree = CreateCustomExpressionTree();
88      IRegressionModel model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
89      IRegressionSolution solution = new RegressionSolution(model, (IRegressionProblemData)problemData.Clone());
90      Dictionary<string, double> expectedImpacts = GetExpectedValuesForCustomProblem();
91
92      CheckDefaultAsserts(solution, expectedImpacts);
93    }
94
95    [TestMethod]
96    [TestCategory("Problems.DataAnalysis")]
97    [TestProperty("Time", "short")]
98    public void CustomModelVariableImpactNoInfluenceTest() {
99      IRegressionProblemData problemData = CreateDefaultProblem();
100      ISymbolicExpressionTree tree = CreateCustomExpressionTreeNoInfluenceX1();
101      IRegressionModel model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter());
102      IRegressionSolution solution = new RegressionSolution(model, (IRegressionProblemData)problemData.Clone());
103      Dictionary<string, double> expectedImpacts = GetExpectedValuesForCustomProblemNoInfluence();
104
105      CheckDefaultAsserts(solution, expectedImpacts);
106    }
107
108    [TestMethod]
109    [TestCategory("Problems.DataAnalysis")]
110    [TestProperty("Time", "short")]
111    [ExpectedException(typeof(ArgumentException))]
112    public void WrongDataSetVariableImpactRegressionTest() {
113      IRegressionProblemData problemData = LoadDefaultTowerProblem();
114      double rmsError;
115      double cvRmsError;
116      var solution = LinearRegression.CreateSolution(problemData, out rmsError, out cvRmsError);
117      solution.ProblemData = LoadDefaultMibaProblem();
118      RegressionSolutionVariableImpactsCalculator.CalculateImpacts(solution);
119
120    }
121
122    [TestMethod]
123    [TestCategory("Problems.DataAnalysis")]
124    [TestProperty("Time", "medium")]
125    public void PerformanceVariableImpactRegressionTest() {
126      int rows = 20000;
127      int columns = 77;
128      var dataSet = OnlineCalculatorPerformanceTest.CreateRandomDataset(new MersenneTwister(1234), rows, columns);
129      IRegressionProblemData problemData = new RegressionProblemData(dataSet, dataSet.VariableNames.Except("y".ToEnumerable()), "y");
130      double rmsError;
131      double cvRmsError;
132      var solution = LinearRegression.CreateSolution(problemData, out rmsError, out cvRmsError);
133
134      Stopwatch watch = new Stopwatch();
135      watch.Start();
136      var results = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(solution);
137      watch.Stop();
138
139      TestContext.WriteLine("");
140      TestContext.WriteLine("Calculated cells per millisecond: {0}.", rows * columns / watch.ElapsedMilliseconds);
141
142    }
143
144    #region Load RegressionProblemData
145    private IRegressionProblemData LoadDefaultTowerProblem() {
146      RegressionRealWorldInstanceProvider provider = new RegressionRealWorldInstanceProvider();
147      var tower = new HeuristicLab.Problems.Instances.DataAnalysis.Tower();
148      return provider.LoadData(tower);
149    }
150    private IRegressionProblemData LoadDefaultMibaProblem() {
151      MibaFrictionRegressionInstanceProvider provider = new MibaFrictionRegressionInstanceProvider();
152      var cf1 = new HeuristicLab.Problems.Instances.DataAnalysis.CF1();
153      return provider.LoadData(cf1);
154    }
155    private IRegressionProblemData CreateDefaultProblem() {
156      List<string> allowedInputVariables = new List<string>() { "x1", "x2", "x3", "x4", "x5" };
157      string targetVariable = "y";
158      var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable());
159      double[,] variableValues = new double[100, variableNames.Count()];
160
161      FastRandom random = new FastRandom(12345);
162      for (int i = 0; i < variableValues.GetLength(0); i++) {
163        for (int j = 0; j < variableValues.GetLength(1); j++) {
164          variableValues[i, j] = random.Next(1, 100);
165        }
166      }
167
168      Dataset dataset = new Dataset(variableNames, variableValues);
169      return new RegressionProblemData(dataset, allowedInputVariables, targetVariable);
170    }
171    #endregion
172
173    #region Create SymbolicExpressionTree
174
175    private ISymbolicExpressionTree CreateCustomExpressionTree() {
176      return new InfixExpressionParser().Parse("x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
177    }
178    private ISymbolicExpressionTree CreateCustomExpressionTreeNoInfluenceX1() {
179      return new InfixExpressionParser().Parse("x1/x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12");
180    }
181    #endregion
182
183    #region Get Expected Values
184    private Dictionary<string, double> GetExpectedValuesForConstantModel() {
185      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
186      expectedImpacts.Add("x1", 0);
187      expectedImpacts.Add("x10", 0);
188      expectedImpacts.Add("x11", 0);
189      expectedImpacts.Add("x12", 0);
190      expectedImpacts.Add("x13", 0);
191      expectedImpacts.Add("x14", 0);
192      expectedImpacts.Add("x15", 0);
193      expectedImpacts.Add("x16", 0);
194      expectedImpacts.Add("x17", 0);
195      expectedImpacts.Add("x18", 0);
196      expectedImpacts.Add("x19", 0);
197      expectedImpacts.Add("x2", 0);
198      expectedImpacts.Add("x20", 0);
199      expectedImpacts.Add("x21", 0);
200      expectedImpacts.Add("x22", 0);
201      expectedImpacts.Add("x23", 0);
202      expectedImpacts.Add("x24", 0);
203      expectedImpacts.Add("x25", 0);
204      expectedImpacts.Add("x3", 0);
205      expectedImpacts.Add("x4", 0);
206      expectedImpacts.Add("x5", 0);
207      expectedImpacts.Add("x6", 0);
208      expectedImpacts.Add("x7", 0);
209      expectedImpacts.Add("x8", 0);
210      expectedImpacts.Add("x9", 0);
211
212      return expectedImpacts;
213    }
214    private Dictionary<string, double> GetExpectedValuesForLRTower() {
215      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
216      expectedImpacts.Add("x1", 0.639933657675427);
217      expectedImpacts.Add("x10", 0.0127006885259798);
218      expectedImpacts.Add("x11", 0.648236047877475);
219      expectedImpacts.Add("x12", 0.248350173524562);
220      expectedImpacts.Add("x13", 0.550889987109547);
221      expectedImpacts.Add("x14", 0.0882824237877192);
222      expectedImpacts.Add("x15", 0.0391276799061169);
223      expectedImpacts.Add("x16", 0.743632451088798);
224      expectedImpacts.Add("x17", 0.00254276857715308);
225      expectedImpacts.Add("x18", 0.0021548147614302);
226      expectedImpacts.Add("x19", 0.00513473927463037);
227      expectedImpacts.Add("x2", 0.0107583487931443);
228      expectedImpacts.Add("x20", 0.18085069746933);
229      expectedImpacts.Add("x21", 0.138053600700762);
230      expectedImpacts.Add("x22", 0.000339539790460086);
231      expectedImpacts.Add("x23", 0.362111965467117);
232      expectedImpacts.Add("x24", 0.0320167935572304);
233      expectedImpacts.Add("x25", 0.57460423230969);
234      expectedImpacts.Add("x3", 0.688142635515862);
235      expectedImpacts.Add("x4", 0.000176632348454664);
236      expectedImpacts.Add("x5", 0.0213915503114581);
237      expectedImpacts.Add("x6", 0.807976486909701);
238      expectedImpacts.Add("x7", 0.716217843319252);
239      expectedImpacts.Add("x8", 0.772701841392564);
240      expectedImpacts.Add("x9", 0.178418730050997);
241
242      return expectedImpacts;
243    }
244    private Dictionary<string, double> GetExpectedValuesForLRMiba() {
245      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
246      expectedImpacts.Add("Grooving", 0.0380558091030508);
247      expectedImpacts.Add("Material", 0.02195836766156);
248      expectedImpacts.Add("Material_Cat", 0.000338687689067418);
249      expectedImpacts.Add("Oil", 0.363464994447857);
250      expectedImpacts.Add("x10", 0.0015309669014415);
251      expectedImpacts.Add("x11", -3.60432578908609E-05);
252      expectedImpacts.Add("x12", 0.00118953859087612);
253      expectedImpacts.Add("x13", 0.00164240977191832);
254      expectedImpacts.Add("x14", 0.000688363685380056);
255      expectedImpacts.Add("x15", -4.75067203969948E-05);
256      expectedImpacts.Add("x16", 0.00130388206125076);
257      expectedImpacts.Add("x17", 0.132351838646134);
258      expectedImpacts.Add("x2", -2.47981401556574E-05);
259      expectedImpacts.Add("x20", 0.716541716605016);
260      expectedImpacts.Add("x22", 0.174959377282835);
261      expectedImpacts.Add("x3", -2.65979754026091E-05);
262      expectedImpacts.Add("x4", -1.24764212947603E-05);
263      expectedImpacts.Add("x5", 0.001184959455798);
264      expectedImpacts.Add("x6", 0.000743336665237626);
265      expectedImpacts.Add("x7", 0.00188965927889773);
266      expectedImpacts.Add("x8", 0.00415201581536351);
267      expectedImpacts.Add("x9", 0.00365653880518491);
268
269      return expectedImpacts;
270    }
271    private Dictionary<string, double> GetExpectedValuesForRFTower() {
272      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
273      expectedImpacts.Add("x5", 0.00138095702433039);
274      expectedImpacts.Add("x19", 0.00220739387855795);
275      expectedImpacts.Add("x14", 0.00225120540266954);
276      expectedImpacts.Add("x18", 0.00311857736968479);
277      expectedImpacts.Add("x9", 0.00313474690023097);
278      expectedImpacts.Add("x20", 0.00321781251408282);
279      expectedImpacts.Add("x21", 0.00397483365571383);
280      expectedImpacts.Add("x16", 0.00433280262892111);
281      expectedImpacts.Add("x15", 0.00529918809786456);
282      expectedImpacts.Add("x3", 0.00658791244929757);
283      expectedImpacts.Add("x24", 0.0078645281886035);
284      expectedImpacts.Add("x4", 0.00907314110749047);
285      expectedImpacts.Add("x13", 0.0102943761648944);
286      expectedImpacts.Add("x22", 0.0107132858548163);
287      expectedImpacts.Add("x12", 0.0157078677788507);
288      expectedImpacts.Add("x23", 0.0235857534562318);
289      expectedImpacts.Add("x7", 0.0304143401617055);
290      expectedImpacts.Add("x11", 0.0310773441767309);
291      expectedImpacts.Add("x25", 0.0328308945873665);
292      expectedImpacts.Add("x17", 0.0428771226844575);
293      expectedImpacts.Add("x10", 0.0456335367972532);
294      expectedImpacts.Add("x8", 0.049849257881126);
295      expectedImpacts.Add("x1", 0.0663686086323108);
296      expectedImpacts.Add("x2", 0.0799083890750926);
297      expectedImpacts.Add("x6", 0.196557814244287);
298
299      return expectedImpacts;
300    }
301    private Dictionary<string, double> GetExpectedValuesForCustomProblem() {
302      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
303      expectedImpacts.Add("x1", -0.000573340275115796);
304      expectedImpacts.Add("x2", 0.000781819784095592);
305      expectedImpacts.Add("x3", -0.000390473234921058);
306      expectedImpacts.Add("x4", -0.00116083274627995);
307      expectedImpacts.Add("x5", -0.00036161186207545);
308
309      return expectedImpacts;
310    }
311    private Dictionary<string, double> GetExpectedValuesForCustomProblemNoInfluence() {
312      Dictionary<string, double> expectedImpacts = new Dictionary<string, double>();
313      expectedImpacts.Add("x1", 0);
314      expectedImpacts.Add("x2", 0.00263393690342982);
315      expectedImpacts.Add("x3", -0.00053248037514929);
316      expectedImpacts.Add("x4", 0.00450365819257568);
317      expectedImpacts.Add("x5", -0.000550911612888904);
318
319      return expectedImpacts;
320    }
321    #endregion
322
323    private void CheckDefaultAsserts(IRegressionSolution solution, Dictionary<string, double> expectedImpacts) {
324      IRegressionProblemData problemData = solution.ProblemData;
325      IEnumerable<double> estimatedValues = solution.GetEstimatedValues(solution.ProblemData.TrainingIndices);
326
327      var solutionImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(solution);
328      var modelImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(solution.Model, problemData, estimatedValues, problemData.TrainingIndices);
329
330      //Both ways should return equal results
331      Assert.IsTrue(solutionImpacts.SequenceEqual(modelImpacts));
332
333      //Check if impacts are as expected
334      Assert.AreEqual(modelImpacts.Count(), expectedImpacts.Count);
335      Assert.IsTrue(modelImpacts.All(v => v.Item2.IsAlmost(expectedImpacts[v.Item1])));
336    }
337  }
338}
Note: See TracBrowser for help on using the repository browser.