Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MathNetNumerics-Exploration-2789/Test/UnitTest1.cs @ 15450

Last change on this file since 15450 was 15450, checked in by gkronber, 6 years ago

#2789 more tests with CV and automatic determination of smoothing parameter

File size: 5.4 KB
RevLine 
[15313]1using System;
[15442]2using System.Diagnostics;
3using System.Linq;
[15313]4using HeuristicLab.Algorithms.DataAnalysis.Experimental;
5using HeuristicLab.Algorithms.GeneticAlgorithm;
[15442]6using HeuristicLab.Common;
7using HeuristicLab.Problems.DataAnalysis;
[15313]8using HeuristicLab.SequentialEngine;
9using Microsoft.VisualStudio.TestTools.UnitTesting;
10
11namespace Test {
12  [TestClass]
13  public class UnitTest1 {
14    [TestMethod]
15    public void TestMethod1() {
16      var ga = new GeneticAlgorithm();
17      var symbReg = new HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.SymbolicRegressionSingleObjectiveProblem();
18      var eval = new SymbolicRegressionConstantOptimizationEvaluator();
19      symbReg.EvaluatorParameter.Value = eval;
20      ga.Engine = new SequentialEngine();
21      ga.Seed.Value = 1234;
22      ga.Problem = symbReg;
23
24      ga.Start();
25
26    }
27  }
[15442]28
29  [TestClass]
30  public class PenalizedRegressionSplineTests {
31    [TestMethod]
32    public void TestPenalizedRegressionSplinesCrossValidation() {
33      var xs = HeuristicLab.Common.SequenceGenerator.GenerateSteps(-3.5, 3.5, 0.02, includeEnd: true).ToArray();
34      var ys = xs.Select(xi => alglib.normaldistr.normaldistribution(xi)); // 1.0 / (Math.Sqrt(2 * Math.PI) * Math.Exp(-0.5 * xi * xi))).ToArray();
35
36      alglib.hqrndstate state;
37      alglib.hqrndseed(1234, 5678, out state);
38      var ys_noise = ys.Select(yi => yi + alglib.hqrndnormal(state) * 0.01).ToArray();
39
40      double bestRho = -15;
41      double best_loo_rmse = double.PositiveInfinity;
42      var clock = new Stopwatch();
43      clock.Start();
44      int iters = 0;
45      for (int rho = -15; rho <= 15; rho++) {
46        double loo_rmse;
47        double avgTrainRmse;
48        Splines.CalculatePenalizedRegressionSpline(xs, ys_noise, rho, "y", new string[] { "x" }, out avgTrainRmse, out loo_rmse);
49        iters++;
50        Console.WriteLine("{0} {1} {2}", rho, avgTrainRmse, loo_rmse);
51        if (loo_rmse < best_loo_rmse) {
52          best_loo_rmse = loo_rmse;
53          bestRho = rho;
54        }
55      }
56      clock.Stop();
57      Console.WriteLine("Best rho {0}, RMSE (LOO): {1}, ms/run {2}", bestRho, best_loo_rmse, clock.ElapsedMilliseconds / (double)iters);
58    }
59
60    [TestMethod]
61    public void TestReinschSplineCrossValidation() {
62      var xs = HeuristicLab.Common.SequenceGenerator.GenerateSteps(-3.5, 3.5, 0.02, includeEnd: true).ToList();
63      var ys = xs.Select(xi => alglib.normaldistr.normaldistribution(xi)); // 1.0 / (Math.Sqrt(2 * Math.PI) * Math.Exp(-0.5 * xi * xi))).ToArray();
64
65      alglib.hqrndstate state;
66      alglib.hqrndseed(1234, 5678, out state);
67      var ys_noise = ys.Select(yi => yi + alglib.hqrndnormal(state) * 0.01).ToList();
68
69      var tol = ys_noise.StandardDeviation();
70
71      var ds = new Dataset(new string[] { "x", "y" }, new[] { xs, ys_noise });
72      var rows = Enumerable.Range(0, xs.Count);
73
74      var bestTol = double.PositiveInfinity;
75      var bestLooRmse = double.PositiveInfinity;
76      var clock = new Stopwatch();
77      clock.Start();
78      var iters = 0;
79      while (tol > 0.0001) {
80        double loo_rmse;
81        double avgTrainRmse;
82        var model = Splines.CalculateSmoothingSplineReinsch(
83          xs.ToArray(), ys_noise.ToArray(), tol, "y", new string[] { "x" }, out avgTrainRmse, out loo_rmse);
84        var y_pred = model.GetEstimatedValues(ds, rows);
85
86        Console.WriteLine("{0} {1} {2}", tol, avgTrainRmse, loo_rmse);
87        if (loo_rmse < bestLooRmse) {
88          bestLooRmse = loo_rmse;
89          bestTol = tol;
90        }
91        tol *= 0.8;
92        iters++;
93      }
94      clock.Stop();
95      Console.WriteLine("Best tolerance {0}, RMSE (LOO): {1}, ms/run {2}", bestTol, bestLooRmse, clock.ElapsedMilliseconds / (double)iters);
96    }
97
98    [TestMethod]
99    public void TestReinschSplineAutomaticTolerance() {
100      var xs = HeuristicLab.Common.SequenceGenerator.GenerateSteps(-3.5, 3.5, 0.02, includeEnd: true).ToList();
101      var ys = xs.Select(xi => alglib.normaldistr.normaldistribution(xi)); // 1.0 / (Math.Sqrt(2 * Math.PI) * Math.Exp(-0.5 * xi * xi))).ToArray();
102
103      alglib.hqrndstate state;
104      alglib.hqrndseed(1234, 5678, out state);
105      var ys_noise = ys.Select(yi => yi + alglib.hqrndnormal(state) * 0.01).ToList();
106
107      double optTol;
108      double looRMSE;
109      Splines.CalculateSmoothingSplineReinsch(xs.ToArray(), ys_noise.ToArray(), new string[] { "x" }, "y", out optTol, out looRMSE);
110
111      Console.WriteLine("Best tolerance {0}, RMSE (LOO): {1}", optTol, looRMSE);
[15450]112    }
[15442]113
[15450]114    [TestMethod]
115    public void TestGAM() {
116      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider();
117      var problemData = provider.LoadData(provider.GetDataDescriptors().First(dd => dd.Name.Contains("Poly")));
118      // var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
119      // var problemData = provider.LoadData(provider.GetDataDescriptors().First(dd => dd.Name.Contains("Chem")));
120
121      var gam = new GAM();
122      gam.MaxIterations = 10;
123      gam.MaxInteractions = 3;
124      gam.Problem.ProblemData = problemData;
125      gam.Start();
126
127      var solution = (IRegressionSolution)gam.Results["Ensemble solution"].Value;
128
129      Console.WriteLine("RMSE (train) {0}", solution.TrainingRootMeanSquaredError);
130      Console.WriteLine("RMSE (test) {0}", solution.TestRootMeanSquaredError);
[15442]131    }
[15450]132
[15442]133  }
[15313]134}
Note: See TracBrowser for help on using the repository browser.