using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using HeuristicLab.Algorithms.DataAnalysis; using HeuristicLab.Common; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; using HeuristicLab.Problems.Instances.DataAnalysis; using HeuristicLab.Random; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace HeuristicLab.Problems.DataAnalysis.Tests { [TestClass()] public class RegressionVariableImpactCalculationTest { private TestContext testContextInstance; /// ///Gets or sets the test context which provides ///information about and functionality for the current test run. /// public TestContext TestContext { get { return testContextInstance; } set { testContextInstance = value; } } [TestMethod] [TestCategory("Problems.DataAnalysis")] [TestProperty("Time", "short")] public void ConstantModelVariableImpactTest() { IRegressionProblemData problemData = LoadDefaultTowerProblem(); IRegressionModel model = new ConstantModel(5, "y"); IRegressionSolution solution = new RegressionSolution(model, problemData); Dictionary expectedImpacts = GetExpectedValuesForConstantModel(); CheckDefaultAsserts(solution, expectedImpacts); } [TestMethod] [TestCategory("Problems.DataAnalysis")] [TestProperty("Time", "short")] public void LinearRegressionModelVariableImpactTowerTest() { IRegressionProblemData problemData = LoadDefaultTowerProblem(); double rmsError; double cvRmsError; var solution = LinearRegression.CreateSolution(problemData, out rmsError, out cvRmsError); Dictionary expectedImpacts = GetExpectedValuesForLRTower(); CheckDefaultAsserts(solution, expectedImpacts); } [TestMethod] [TestCategory("Problems.DataAnalysis")] [TestProperty("Time", "short")] public void LinearRegressionModelVariableImpactMibaTest() { IRegressionProblemData problemData = LoadDefaultMibaProblem(); double rmsError; double cvRmsError; var solution = LinearRegression.CreateSolution(problemData, out rmsError, out cvRmsError); Dictionary expectedImpacts = GetExpectedValuesForLRMiba(); CheckDefaultAsserts(solution, expectedImpacts); } [TestMethod] [TestCategory("Problems.DataAnalysis")] [TestProperty("Time", "short")] public void RandomForestModelVariableImpactTowerTest() { IRegressionProblemData problemData = LoadDefaultTowerProblem(); double rmsError; double avgRelError; double outOfBagRmsError; double outofBagAvgRelError; var solution = RandomForestRegression.CreateRandomForestRegressionSolution(problemData, 50, 0.2, 0.5, 1234, out rmsError, out avgRelError, out outOfBagRmsError, out outofBagAvgRelError); Dictionary expectedImpacts = GetExpectedValuesForRFTower(); CheckDefaultAsserts(solution, expectedImpacts); } [TestMethod] [TestCategory("Problems.DataAnalysis")] [TestProperty("Time", "short")] public void CustomModelVariableImpactTest() { IRegressionProblemData problemData = CreateDefaultProblem(); ISymbolicExpressionTree tree = CreateCustomExpressionTree(); IRegressionModel model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter()); IRegressionSolution solution = new RegressionSolution(model, (IRegressionProblemData)problemData.Clone()); Dictionary expectedImpacts = GetExpectedValuesForCustomProblem(); CheckDefaultAsserts(solution, expectedImpacts); } [TestMethod] [TestCategory("Problems.DataAnalysis")] [TestProperty("Time", "short")] public void CustomModelVariableImpactNoInfluenceTest() { IRegressionProblemData problemData = CreateDefaultProblem(); ISymbolicExpressionTree tree = CreateCustomExpressionTreeNoInfluenceX1(); IRegressionModel model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter()); IRegressionSolution solution = new RegressionSolution(model, (IRegressionProblemData)problemData.Clone()); Dictionary expectedImpacts = GetExpectedValuesForCustomProblemNoInfluence(); CheckDefaultAsserts(solution, expectedImpacts); } [TestMethod] [TestCategory("Problems.DataAnalysis")] [TestProperty("Time", "short")] [ExpectedException(typeof(ArgumentException))] public void WrongDataSetVariableImpactRegressionTest() { IRegressionProblemData problemData = LoadDefaultTowerProblem(); double rmsError; double cvRmsError; var solution = LinearRegression.CreateSolution(problemData, out rmsError, out cvRmsError); solution.ProblemData = LoadDefaultMibaProblem(); RegressionSolutionVariableImpactsCalculator.CalculateImpacts(solution); } [TestMethod] [TestCategory("Problems.DataAnalysis")] [TestProperty("Time", "medium")] public void PerformanceVariableImpactRegressionTest() { int rows = 20000; int columns = 77; var dataSet = OnlineCalculatorPerformanceTest.CreateRandomDataset(new MersenneTwister(1234), rows, columns); IRegressionProblemData problemData = new RegressionProblemData(dataSet, dataSet.VariableNames.Except("y".ToEnumerable()), "y"); double rmsError; double cvRmsError; var solution = LinearRegression.CreateSolution(problemData, out rmsError, out cvRmsError); Stopwatch watch = new Stopwatch(); watch.Start(); var results = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(solution); watch.Stop(); TestContext.WriteLine(""); TestContext.WriteLine("Calculated cells per millisecond: {0}.", rows * columns / watch.ElapsedMilliseconds); } #region Load RegressionProblemData private IRegressionProblemData LoadDefaultTowerProblem() { RegressionRealWorldInstanceProvider provider = new RegressionRealWorldInstanceProvider(); var tower = new HeuristicLab.Problems.Instances.DataAnalysis.Tower(); return provider.LoadData(tower); } private IRegressionProblemData LoadDefaultMibaProblem() { MibaFrictionRegressionInstanceProvider provider = new MibaFrictionRegressionInstanceProvider(); var cf1 = new HeuristicLab.Problems.Instances.DataAnalysis.CF1(); return provider.LoadData(cf1); } private IRegressionProblemData CreateDefaultProblem() { List allowedInputVariables = new List() { "x1", "x2", "x3", "x4", "x5" }; string targetVariable = "y"; var variableNames = allowedInputVariables.Union(targetVariable.ToEnumerable()); double[,] variableValues = new double[100, variableNames.Count()]; FastRandom random = new FastRandom(12345); for (int i = 0; i < variableValues.GetLength(0); i++) { for (int j = 0; j < variableValues.GetLength(1); j++) { variableValues[i, j] = random.Next(1, 100); } } Dataset dataset = new Dataset(variableNames, variableValues); return new RegressionProblemData(dataset, allowedInputVariables, targetVariable); } #endregion #region Create SymbolicExpressionTree private ISymbolicExpressionTree CreateCustomExpressionTree() { return new InfixExpressionParser().Parse("x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12"); } private ISymbolicExpressionTree CreateCustomExpressionTreeNoInfluenceX1() { return new InfixExpressionParser().Parse("x1/x1*x2 - x2*x2 + x3*x3 + x4*x4 - x5*x5 + 14/12"); } #endregion #region Get Expected Values private Dictionary GetExpectedValuesForConstantModel() { Dictionary expectedImpacts = new Dictionary(); expectedImpacts.Add("x1", 0); expectedImpacts.Add("x10", 0); expectedImpacts.Add("x11", 0); expectedImpacts.Add("x12", 0); expectedImpacts.Add("x13", 0); expectedImpacts.Add("x14", 0); expectedImpacts.Add("x15", 0); expectedImpacts.Add("x16", 0); expectedImpacts.Add("x17", 0); expectedImpacts.Add("x18", 0); expectedImpacts.Add("x19", 0); expectedImpacts.Add("x2", 0); expectedImpacts.Add("x20", 0); expectedImpacts.Add("x21", 0); expectedImpacts.Add("x22", 0); expectedImpacts.Add("x23", 0); expectedImpacts.Add("x24", 0); expectedImpacts.Add("x25", 0); expectedImpacts.Add("x3", 0); expectedImpacts.Add("x4", 0); expectedImpacts.Add("x5", 0); expectedImpacts.Add("x6", 0); expectedImpacts.Add("x7", 0); expectedImpacts.Add("x8", 0); expectedImpacts.Add("x9", 0); return expectedImpacts; } private Dictionary GetExpectedValuesForLRTower() { Dictionary expectedImpacts = new Dictionary(); expectedImpacts.Add("x1", 0.639933657675427); expectedImpacts.Add("x10", 0.0127006885259798); expectedImpacts.Add("x11", 0.648236047877475); expectedImpacts.Add("x12", 0.248350173524562); expectedImpacts.Add("x13", 0.550889987109547); expectedImpacts.Add("x14", 0.0882824237877192); expectedImpacts.Add("x15", 0.0391276799061169); expectedImpacts.Add("x16", 0.743632451088798); expectedImpacts.Add("x17", 0.00254276857715308); expectedImpacts.Add("x18", 0.0021548147614302); expectedImpacts.Add("x19", 0.00513473927463037); expectedImpacts.Add("x2", 0.0107583487931443); expectedImpacts.Add("x20", 0.18085069746933); expectedImpacts.Add("x21", 0.138053600700762); expectedImpacts.Add("x22", 0.000339539790460086); expectedImpacts.Add("x23", 0.362111965467117); expectedImpacts.Add("x24", 0.0320167935572304); expectedImpacts.Add("x25", 0.57460423230969); expectedImpacts.Add("x3", 0.688142635515862); expectedImpacts.Add("x4", 0.000176632348454664); expectedImpacts.Add("x5", 0.0213915503114581); expectedImpacts.Add("x6", 0.807976486909701); expectedImpacts.Add("x7", 0.716217843319252); expectedImpacts.Add("x8", 0.772701841392564); expectedImpacts.Add("x9", 0.178418730050997); return expectedImpacts; } private Dictionary GetExpectedValuesForLRMiba() { Dictionary expectedImpacts = new Dictionary(); expectedImpacts.Add("Grooving", 0.0380558091030508); expectedImpacts.Add("Material", 0.02195836766156); expectedImpacts.Add("Material_Cat", 0.000338687689067418); expectedImpacts.Add("Oil", 0.363464994447857); expectedImpacts.Add("x10", 0.0015309669014415); expectedImpacts.Add("x11", -3.60432578908609E-05); expectedImpacts.Add("x12", 0.00118953859087612); expectedImpacts.Add("x13", 0.00164240977191832); expectedImpacts.Add("x14", 0.000688363685380056); expectedImpacts.Add("x15", -4.75067203969948E-05); expectedImpacts.Add("x16", 0.00130388206125076); expectedImpacts.Add("x17", 0.132351838646134); expectedImpacts.Add("x2", -2.47981401556574E-05); expectedImpacts.Add("x20", 0.716541716605016); expectedImpacts.Add("x22", 0.174959377282835); expectedImpacts.Add("x3", -2.65979754026091E-05); expectedImpacts.Add("x4", -1.24764212947603E-05); expectedImpacts.Add("x5", 0.001184959455798); expectedImpacts.Add("x6", 0.000743336665237626); expectedImpacts.Add("x7", 0.00188965927889773); expectedImpacts.Add("x8", 0.00415201581536351); expectedImpacts.Add("x9", 0.00365653880518491); return expectedImpacts; } private Dictionary GetExpectedValuesForRFTower() { Dictionary expectedImpacts = new Dictionary(); expectedImpacts.Add("x5", 0.00138095702433039); expectedImpacts.Add("x19", 0.00220739387855795); expectedImpacts.Add("x14", 0.00225120540266954); expectedImpacts.Add("x18", 0.00311857736968479); expectedImpacts.Add("x9", 0.00313474690023097); expectedImpacts.Add("x20", 0.00321781251408282); expectedImpacts.Add("x21", 0.00397483365571383); expectedImpacts.Add("x16", 0.00433280262892111); expectedImpacts.Add("x15", 0.00529918809786456); expectedImpacts.Add("x3", 0.00658791244929757); expectedImpacts.Add("x24", 0.0078645281886035); expectedImpacts.Add("x4", 0.00907314110749047); expectedImpacts.Add("x13", 0.0102943761648944); expectedImpacts.Add("x22", 0.0107132858548163); expectedImpacts.Add("x12", 0.0157078677788507); expectedImpacts.Add("x23", 0.0235857534562318); expectedImpacts.Add("x7", 0.0304143401617055); expectedImpacts.Add("x11", 0.0310773441767309); expectedImpacts.Add("x25", 0.0328308945873665); expectedImpacts.Add("x17", 0.0428771226844575); expectedImpacts.Add("x10", 0.0456335367972532); expectedImpacts.Add("x8", 0.049849257881126); expectedImpacts.Add("x1", 0.0663686086323108); expectedImpacts.Add("x2", 0.0799083890750926); expectedImpacts.Add("x6", 0.196557814244287); return expectedImpacts; } private Dictionary GetExpectedValuesForCustomProblem() { Dictionary expectedImpacts = new Dictionary(); expectedImpacts.Add("x1", -0.000573340275115796); expectedImpacts.Add("x2", 0.000781819784095592); expectedImpacts.Add("x3", -0.000390473234921058); expectedImpacts.Add("x4", -0.00116083274627995); expectedImpacts.Add("x5", -0.00036161186207545); return expectedImpacts; } private Dictionary GetExpectedValuesForCustomProblemNoInfluence() { Dictionary expectedImpacts = new Dictionary(); expectedImpacts.Add("x1", 0); expectedImpacts.Add("x2", 0.00263393690342982); expectedImpacts.Add("x3", -0.00053248037514929); expectedImpacts.Add("x4", 0.00450365819257568); expectedImpacts.Add("x5", -0.000550911612888904); return expectedImpacts; } #endregion private void CheckDefaultAsserts(IRegressionSolution solution, Dictionary expectedImpacts) { IRegressionProblemData problemData = solution.ProblemData; IEnumerable estimatedValues = solution.GetEstimatedValues(solution.ProblemData.TrainingIndices); var solutionImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(solution); var modelImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(solution.Model, problemData, estimatedValues, problemData.TrainingIndices); //Both ways should return equal results Assert.IsTrue(solutionImpacts.SequenceEqual(modelImpacts)); //Check if impacts are as expected Assert.AreEqual(modelImpacts.Count(), expectedImpacts.Count); Assert.IsTrue(modelImpacts.All(v => v.Item2.IsAlmost(expectedImpacts[v.Item1]))); } } }