Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs @ 12700

Last change on this file since 12700 was 12700, checked in by gkronber, 9 years ago

#2261: copied GBT implementation from branch to trunk

File size: 8.7 KB
Line 
1using System;
2using System.Linq;
3using System.Threading;
4using HeuristicLab.Data;
5using HeuristicLab.Optimization;
6using HeuristicLab.Problems.DataAnalysis;
7using Microsoft.VisualStudio.TestTools.UnitTesting;
8
9namespace HeuristicLab.Algorithms.DataAnalysis {
10  [TestClass()]
11  public class Test {
12    [TestMethod]
13    [TestCategory("Algorithms.DataAnalysis")]
14    [TestProperty("Time", "short")]
15    public void DecisionTreeTest() {
16      {
17        var xy = new double[,]
18        {
19          {-1, 20, 0},
20          {-1, 20, 0},
21          { 1, 10, 0},
22          { 1, 10, 0},
23        };
24        var allVariables = new string[] { "y", "x1", "x2" };
25
26        // x1 <= 15 -> 1
27        // x1 >  15 -> -1
28        BuildTree(xy, allVariables, 10);
29      }
30
31
32      {
33        var xy = new double[,]
34        {
35          {-1, 20,  1},
36          {-1, 20, -1},
37          { 1, 10, -1},
38          { 1, 10, 1},
39        };
40        var allVariables = new string[] { "y", "x1", "x2" };
41
42        // ignore irrelevant variables
43        // x1 <= 15 -> 1
44        // x1 >  15 -> -1
45        BuildTree(xy, allVariables, 10);
46      }
47
48      {
49        // split must be by x1 first
50        var xy = new double[,]
51        {
52          {-2, 20,  1},
53          {-1, 20, -1},
54          { 1, 10, -1},
55          { 2, 10, 1},
56        };
57
58        var allVariables = new string[] { "y", "x1", "x2" };
59
60        // x1 <= 15 AND x2 <= 0 -> 1
61        // x1 <= 15 AND x2 >  0 -> 2
62        // x1 >  15 AND x2 <= 0 -> -1
63        // x1 >  15 AND x2 >  0 -> -2
64        BuildTree(xy, allVariables, 10);
65      }
66
67      {
68        // averaging ys
69        var xy = new double[,]
70        {
71          {-2.5, 20,  1},
72          {-1.5, 20,  1},
73          {-1.5, 20, -1},
74          {-0.5, 20, -1},
75          {0.5, 10, -1},
76          {1.5, 10, -1},
77          {1.5, 10, 1},
78          {2.5, 10, 1},
79        };
80
81        var allVariables = new string[] { "y", "x1", "x2" };
82
83        // x1 <= 15 AND x2 <= 0 -> 1
84        // x1 <= 15 AND x2 >  0 -> 2
85        // x1 >  15 AND x2 <= 0 -> -1
86        // x1 >  15 AND x2 >  0 -> -2
87        BuildTree(xy, allVariables, 10);
88      }
89
90
91      {
92        // diagonal split (no split possible)
93        var xy = new double[,]
94        {
95          { 1, 1, 1},
96          {-1, 1, 2},
97          {-1, 2, 1},
98          { 1, 2, 2},
99        };
100
101        var allVariables = new string[] { "y", "x1", "x2" };
102
103        // split cannot be found
104        // -> 0.0
105        BuildTree(xy, allVariables, 3);
106      }
107      {
108        // almost diagonal split
109        var xy = new double[,]
110        {
111          { 1, 1, 1},
112          {-1, 1, 2},
113          {-1, 2, 1},
114          { 1.0001, 2, 2},
115        };
116
117        var allVariables = new string[] { "y", "x1", "x2" };
118        // (two possible solutions)
119        // x2 <= 1.5 -> 0
120        // x2 >  1.5 -> 0 (not quite)
121        BuildTree(xy, allVariables, 3);
122
123        // x1 <= 1.5 AND x2 <= 1.5 -> 1
124        // x1 <= 1.5 AND x2 >  1.5 -> -1
125        // x1 >  1.5 AND x2 <= 1.5 -> -1
126        // x1 >  1.5 AND x2 >  1.5 -> 1 (not quite)
127        BuildTree(xy, allVariables, 7);
128      }
129      {
130        // unbalanced split
131        var xy = new double[,]
132        {
133          {-1, 1, 1},
134          {-1, 1, 2},
135          {0.9, 2, 1},
136          {1.1, 2, 2},
137        };
138
139        var allVariables = new string[] { "y", "x1", "x2" };
140        // x1 <= 1.5 -> -1.0
141        // x1 >  1.5 AND x2 <= 1.5 -> 0.9
142        // x1 >  1.5 AND x2 >  1.5 -> 1.1
143        BuildTree(xy, allVariables, 10);
144      }
145
146      {
147        // unbalanced split
148        var xy = new double[,]
149        {
150          {-1, 1, 1},
151          {-1, 1, 2},
152          {-1, 2, 1},
153          { 3, 2, 2},
154        };
155
156        var allVariables = new string[] { "y", "x1", "x2" };
157        // (two possible solutions)
158        // x2 <= 1.5 -> -1.0
159        // x2 >  1.5 AND x1 <= 1.5 -> -1.0
160        // x2 >  1.5 AND x1 >  1.5 ->  3.0
161        BuildTree(xy, allVariables, 10);
162      }
163    }
164
165    [TestMethod]
166    [TestCategory("Algorithms.DataAnalysis")]
167    [TestProperty("Time", "long")]
168    public void GradientBoostingTestTowerSquaredError() {
169      var gbt = new GradientBoostedTreesAlgorithm();
170      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
171      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
172      var regProblem = new RegressionProblem();
173      regProblem.Load(provider.LoadData(instance));
174
175      #region Algorithm Configuration
176      gbt.Problem = regProblem;
177      gbt.Seed = 0;
178      gbt.SetSeedRandomly = false;
179      gbt.Iterations = 5000;
180      gbt.MaxSize = 20;
181      gbt.CreateSolution = false;
182      #endregion
183
184      RunAlgorithm(gbt);
185
186      Console.WriteLine(gbt.ExecutionTime);
187      Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
188      Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
189    }
190
191    [TestMethod]
192    [TestCategory("Algorithms.DataAnalysis")]
193    [TestProperty("Time", "long")]
194    public void GradientBoostingTestTowerAbsoluteError() {
195      var gbt = new GradientBoostedTreesAlgorithm();
196      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
197      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
198      var regProblem = new RegressionProblem();
199      regProblem.Load(provider.LoadData(instance));
200
201      #region Algorithm Configuration
202      gbt.Problem = regProblem;
203      gbt.Seed = 0;
204      gbt.SetSeedRandomly = false;
205      gbt.Iterations = 1000;
206      gbt.MaxSize = 20;
207      gbt.Nu = 0.02;
208      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
209      gbt.CreateSolution = false;
210      #endregion
211
212      RunAlgorithm(gbt);
213
214      Console.WriteLine(gbt.ExecutionTime);
215      Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
216      Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
217    }
218
219    [TestMethod]
220    [TestCategory("Algorithms.DataAnalysis")]
221    [TestProperty("Time", "long")]
222    public void GradientBoostingTestTowerRelativeError() {
223      var gbt = new GradientBoostedTreesAlgorithm();
224      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
225      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
226      var regProblem = new RegressionProblem();
227      regProblem.Load(provider.LoadData(instance));
228
229      #region Algorithm Configuration
230      gbt.Problem = regProblem;
231      gbt.Seed = 0;
232      gbt.SetSeedRandomly = false;
233      gbt.Iterations = 3000;
234      gbt.MaxSize = 20;
235      gbt.Nu = 0.005;
236      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
237      gbt.CreateSolution = false;
238      #endregion
239
240      RunAlgorithm(gbt);
241
242      Console.WriteLine(gbt.ExecutionTime);
243      Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
244      Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
245    }
246
247    // same as in SamplesUtil
248    private void RunAlgorithm(IAlgorithm a) {
249      var trigger = new EventWaitHandle(false, EventResetMode.ManualReset);
250      Exception ex = null;
251      a.Stopped += (src, e) => { trigger.Set(); };
252      a.ExceptionOccurred += (src, e) => { ex = e.Value; trigger.Set(); };
253      a.Prepare();
254      a.Start();
255      trigger.WaitOne();
256
257      Assert.AreEqual(ex, null);
258    }
259
260    #region helper
261    private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
262      int nRows = xy.GetLength(0);
263      var allowedInputs = allVariables.Skip(1);
264      var dataset = new Dataset(allVariables, xy);
265      var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
266      problemData.TrainingPartition.Start = 0;
267      problemData.TrainingPartition.End = nRows;
268      problemData.TestPartition.Start = nRows;
269      problemData.TestPartition.End = nRows;
270      var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
271      var model = (GradientBoostedTreesModel)solution.Model;
272      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
273
274      Console.WriteLine(treeM.ToString());
275      Console.WriteLine();
276    }
277    #endregion
278  }
279}
Note: See TracBrowser for help on using the repository browser.