Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs @ 13398

Last change on this file since 13398 was 13157, checked in by gkronber, 9 years ago

#2450 made the changes suggested by mkommend in the review. This is definitely a big improvement, thx!

File size: 8.7 KB
RevLine 
[12620]1using System;
2using System.Linq;
[12632]3using System.Threading;
4using HeuristicLab.Data;
5using HeuristicLab.Optimization;
[12620]6using HeuristicLab.Problems.DataAnalysis;
7using Microsoft.VisualStudio.TestTools.UnitTesting;
8
[12658]9namespace HeuristicLab.Algorithms.DataAnalysis {
[12620]10  [TestClass()]
[12710]11  public class GradientBoostingTest {
[12620]12    [TestMethod]
13    [TestCategory("Algorithms.DataAnalysis")]
14    [TestProperty("Time", "short")]
15    public void DecisionTreeTest() {
16      {
17        var xy = new double[,]
18        {
[12658]19          {-1, 20, 0},
20          {-1, 20, 0},
21          { 1, 10, 0},
22          { 1, 10, 0},
[12620]23        };
24        var allVariables = new string[] { "y", "x1", "x2" };
25
[12658]26        // x1 <= 15 -> 1
27        // x1 >  15 -> -1
[12620]28        BuildTree(xy, allVariables, 10);
29      }
30
31
32      {
33        var xy = new double[,]
34        {
[12658]35          {-1, 20,  1},
36          {-1, 20, -1},
37          { 1, 10, -1},
38          { 1, 10, 1},
[12620]39        };
40        var allVariables = new string[] { "y", "x1", "x2" };
41
42        // ignore irrelevant variables
[12658]43        // x1 <= 15 -> 1
44        // x1 >  15 -> -1
[12620]45        BuildTree(xy, allVariables, 10);
46      }
47
48      {
49        // split must be by x1 first
50        var xy = new double[,]
51        {
[12658]52          {-2, 20,  1},
53          {-1, 20, -1},
54          { 1, 10, -1},
55          { 2, 10, 1},
[12620]56        };
57
58        var allVariables = new string[] { "y", "x1", "x2" };
59
[12658]60        // x1 <= 15 AND x2 <= 0 -> 1
61        // x1 <= 15 AND x2 >  0 -> 2
62        // x1 >  15 AND x2 <= 0 -> -1
63        // x1 >  15 AND x2 >  0 -> -2
[12620]64        BuildTree(xy, allVariables, 10);
65      }
66
67      {
68        // averaging ys
69        var xy = new double[,]
70        {
[12658]71          {-2.5, 20,  1},
72          {-1.5, 20,  1},
73          {-1.5, 20, -1},
[12661]74          {-0.5, 20, -1},
[12658]75          {0.5, 10, -1},
76          {1.5, 10, -1},
77          {1.5, 10, 1},
78          {2.5, 10, 1},
[12620]79        };
80
81        var allVariables = new string[] { "y", "x1", "x2" };
82
[12658]83        // x1 <= 15 AND x2 <= 0 -> 1
84        // x1 <= 15 AND x2 >  0 -> 2
85        // x1 >  15 AND x2 <= 0 -> -1
86        // x1 >  15 AND x2 >  0 -> -2
[12620]87        BuildTree(xy, allVariables, 10);
88      }
89
90
91      {
92        // diagonal split (no split possible)
93        var xy = new double[,]
94        {
[12658]95          { 1, 1, 1},
96          {-1, 1, 2},
97          {-1, 2, 1},
98          { 1, 2, 2},
[12620]99        };
100
101        var allVariables = new string[] { "y", "x1", "x2" };
102
103        // split cannot be found
[12658]104        // -> 0.0
[12620]105        BuildTree(xy, allVariables, 3);
106      }
107      {
108        // almost diagonal split
109        var xy = new double[,]
110        {
[12658]111          { 1, 1, 1},
112          {-1, 1, 2},
113          {-1, 2, 1},
114          { 1.0001, 2, 2},
[12620]115        };
116
117        var allVariables = new string[] { "y", "x1", "x2" };
[12632]118        // (two possible solutions)
[12658]119        // x2 <= 1.5 -> 0
120        // x2 >  1.5 -> 0 (not quite)
[12632]121        BuildTree(xy, allVariables, 3);
122
[12658]123        // x1 <= 1.5 AND x2 <= 1.5 -> 1
124        // x1 <= 1.5 AND x2 >  1.5 -> -1
125        // x1 >  1.5 AND x2 <= 1.5 -> -1
126        // x1 >  1.5 AND x2 >  1.5 -> 1 (not quite)
[12632]127        BuildTree(xy, allVariables, 7);
[12620]128      }
129      {
130        // unbalanced split
131        var xy = new double[,]
132        {
133          {-1, 1, 1},
134          {-1, 1, 2},
135          {0.9, 2, 1},
136          {1.1, 2, 2},
137        };
138
139        var allVariables = new string[] { "y", "x1", "x2" };
140        // x1 <= 1.5 -> -1.0
141        // x1 >  1.5 AND x2 <= 1.5 -> 0.9
142        // x1 >  1.5 AND x2 >  1.5 -> 1.1
[12632]143        BuildTree(xy, allVariables, 10);
[12620]144      }
145
[12632]146      {
147        // unbalanced split
148        var xy = new double[,]
149        {
150          {-1, 1, 1},
151          {-1, 1, 2},
152          {-1, 2, 1},
[12658]153          { 3, 2, 2},
[12632]154        };
155
156        var allVariables = new string[] { "y", "x1", "x2" };
157        // (two possible solutions)
158        // x2 <= 1.5 -> -1.0
159        // x2 >  1.5 AND x1 <= 1.5 -> -1.0
[12658]160        // x2 >  1.5 AND x1 >  1.5 ->  3.0
[12632]161        BuildTree(xy, allVariables, 10);
162      }
[12620]163    }
164
[12632]165    [TestMethod]
166    [TestCategory("Algorithms.DataAnalysis")]
167    [TestProperty("Time", "long")]
168    public void GradientBoostingTestTowerSquaredError() {
169      var gbt = new GradientBoostedTreesAlgorithm();
170      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
171      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
172      var regProblem = new RegressionProblem();
173      regProblem.Load(provider.LoadData(instance));
174
175      #region Algorithm Configuration
176      gbt.Problem = regProblem;
177      gbt.Seed = 0;
178      gbt.SetSeedRandomly = false;
179      gbt.Iterations = 5000;
180      gbt.MaxSize = 20;
[12699]181      gbt.CreateSolution = false;
[12632]182      #endregion
183
184      RunAlgorithm(gbt);
185
[12699]186      Console.WriteLine(gbt.ExecutionTime);
[12632]187      Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
188      Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
189    }
190
191    [TestMethod]
192    [TestCategory("Algorithms.DataAnalysis")]
193    [TestProperty("Time", "long")]
194    public void GradientBoostingTestTowerAbsoluteError() {
195      var gbt = new GradientBoostedTreesAlgorithm();
196      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
197      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
198      var regProblem = new RegressionProblem();
199      regProblem.Load(provider.LoadData(instance));
200
201      #region Algorithm Configuration
202      gbt.Problem = regProblem;
203      gbt.Seed = 0;
204      gbt.SetSeedRandomly = false;
205      gbt.Iterations = 1000;
206      gbt.MaxSize = 20;
207      gbt.Nu = 0.02;
208      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
[12699]209      gbt.CreateSolution = false;
[12632]210      #endregion
211
212      RunAlgorithm(gbt);
213
[12699]214      Console.WriteLine(gbt.ExecutionTime);
[12632]215      Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
216      Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
217    }
218
219    [TestMethod]
220    [TestCategory("Algorithms.DataAnalysis")]
221    [TestProperty("Time", "long")]
222    public void GradientBoostingTestTowerRelativeError() {
223      var gbt = new GradientBoostedTreesAlgorithm();
224      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
225      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
226      var regProblem = new RegressionProblem();
227      regProblem.Load(provider.LoadData(instance));
228
229      #region Algorithm Configuration
230      gbt.Problem = regProblem;
231      gbt.Seed = 0;
232      gbt.SetSeedRandomly = false;
233      gbt.Iterations = 3000;
234      gbt.MaxSize = 20;
235      gbt.Nu = 0.005;
236      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
[12699]237      gbt.CreateSolution = false;
[12632]238      #endregion
239
240      RunAlgorithm(gbt);
241
[12699]242      Console.WriteLine(gbt.ExecutionTime);
[12632]243      Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
244      Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
245    }
246
247    // same as in SamplesUtil
248    private void RunAlgorithm(IAlgorithm a) {
249      var trigger = new EventWaitHandle(false, EventResetMode.ManualReset);
250      Exception ex = null;
251      a.Stopped += (src, e) => { trigger.Set(); };
252      a.ExceptionOccurred += (src, e) => { ex = e.Value; trigger.Set(); };
253      a.Prepare();
254      a.Start();
255      trigger.WaitOne();
256
257      Assert.AreEqual(ex, null);
258    }
259
260    #region helper
[12661]261    private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
[12620]262      int nRows = xy.GetLength(0);
263      var allowedInputs = allVariables.Skip(1);
264      var dataset = new Dataset(allVariables, xy);
265      var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
266      problemData.TrainingPartition.Start = 0;
267      problemData.TrainingPartition.End = nRows;
268      problemData.TestPartition.Start = nRows;
269      problemData.TestPartition.End = nRows;
[12661]270      var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
[13157]271      var model = solution.Model;
[12620]272      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
[12658]273
274      Console.WriteLine(treeM.ToString());
[12620]275      Console.WriteLine();
276    }
[12632]277    #endregion
[12620]278  }
279}
Note: See TracBrowser for help on using the repository browser.