Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT-trunkintegration/Tests/Test.cs @ 12661

Last change on this file since 12661 was 12661, checked in by gkronber, 9 years ago

#2261: regression tree builder should not be used from outside (made internal)

File size: 8.6 KB
Line 
1using System;
2using System.Collections;
3using System.Globalization;
4using System.Linq;
5using System.Runtime.CompilerServices;
6using System.Threading;
7using HeuristicLab.Data;
8using HeuristicLab.Optimization;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Random;
11using Microsoft.VisualStudio.TestTools.UnitTesting;
12
13namespace HeuristicLab.Algorithms.DataAnalysis {
14  [TestClass()]
15  public class Test {
16    [TestMethod]
17    [TestCategory("Algorithms.DataAnalysis")]
18    [TestProperty("Time", "short")]
19    public void DecisionTreeTest() {
20      {
21        var xy = new double[,]
22        {
23          {-1, 20, 0},
24          {-1, 20, 0},
25          { 1, 10, 0},
26          { 1, 10, 0},
27        };
28        var allVariables = new string[] { "y", "x1", "x2" };
29
30        // x1 <= 15 -> 1
31        // x1 >  15 -> -1
32        BuildTree(xy, allVariables, 10);
33      }
34
35
36      {
37        var xy = new double[,]
38        {
39          {-1, 20,  1},
40          {-1, 20, -1},
41          { 1, 10, -1},
42          { 1, 10, 1},
43        };
44        var allVariables = new string[] { "y", "x1", "x2" };
45
46        // ignore irrelevant variables
47        // x1 <= 15 -> 1
48        // x1 >  15 -> -1
49        BuildTree(xy, allVariables, 10);
50      }
51
52      {
53        // split must be by x1 first
54        var xy = new double[,]
55        {
56          {-2, 20,  1},
57          {-1, 20, -1},
58          { 1, 10, -1},
59          { 2, 10, 1},
60        };
61
62        var allVariables = new string[] { "y", "x1", "x2" };
63
64        // x1 <= 15 AND x2 <= 0 -> 1
65        // x1 <= 15 AND x2 >  0 -> 2
66        // x1 >  15 AND x2 <= 0 -> -1
67        // x1 >  15 AND x2 >  0 -> -2
68        BuildTree(xy, allVariables, 10);
69      }
70
71      {
72        // averaging ys
73        var xy = new double[,]
74        {
75          {-2.5, 20,  1},
76          {-1.5, 20,  1},
77          {-1.5, 20, -1},
78          {-0.5, 20, -1},
79          {0.5, 10, -1},
80          {1.5, 10, -1},
81          {1.5, 10, 1},
82          {2.5, 10, 1},
83        };
84
85        var allVariables = new string[] { "y", "x1", "x2" };
86
87        // x1 <= 15 AND x2 <= 0 -> 1
88        // x1 <= 15 AND x2 >  0 -> 2
89        // x1 >  15 AND x2 <= 0 -> -1
90        // x1 >  15 AND x2 >  0 -> -2
91        BuildTree(xy, allVariables, 10);
92      }
93
94
95      {
96        // diagonal split (no split possible)
97        var xy = new double[,]
98        {
99          { 1, 1, 1},
100          {-1, 1, 2},
101          {-1, 2, 1},
102          { 1, 2, 2},
103        };
104
105        var allVariables = new string[] { "y", "x1", "x2" };
106
107        // split cannot be found
108        // -> 0.0
109        BuildTree(xy, allVariables, 3);
110      }
111      {
112        // almost diagonal split
113        var xy = new double[,]
114        {
115          { 1, 1, 1},
116          {-1, 1, 2},
117          {-1, 2, 1},
118          { 1.0001, 2, 2},
119        };
120
121        var allVariables = new string[] { "y", "x1", "x2" };
122        // (two possible solutions)
123        // x2 <= 1.5 -> 0
124        // x2 >  1.5 -> 0 (not quite)
125        BuildTree(xy, allVariables, 3);
126
127        // x1 <= 1.5 AND x2 <= 1.5 -> 1
128        // x1 <= 1.5 AND x2 >  1.5 -> -1
129        // x1 >  1.5 AND x2 <= 1.5 -> -1
130        // x1 >  1.5 AND x2 >  1.5 -> 1 (not quite)
131        BuildTree(xy, allVariables, 7);
132      }
133      {
134        // unbalanced split
135        var xy = new double[,]
136        {
137          {-1, 1, 1},
138          {-1, 1, 2},
139          {0.9, 2, 1},
140          {1.1, 2, 2},
141        };
142
143        var allVariables = new string[] { "y", "x1", "x2" };
144        // x1 <= 1.5 -> -1.0
145        // x1 >  1.5 AND x2 <= 1.5 -> 0.9
146        // x1 >  1.5 AND x2 >  1.5 -> 1.1
147        BuildTree(xy, allVariables, 10);
148      }
149
150      {
151        // unbalanced split
152        var xy = new double[,]
153        {
154          {-1, 1, 1},
155          {-1, 1, 2},
156          {-1, 2, 1},
157          { 3, 2, 2},
158        };
159
160        var allVariables = new string[] { "y", "x1", "x2" };
161        // (two possible solutions)
162        // x2 <= 1.5 -> -1.0
163        // x2 >  1.5 AND x1 <= 1.5 -> -1.0
164        // x2 >  1.5 AND x1 >  1.5 ->  3.0
165        BuildTree(xy, allVariables, 10);
166      }
167    }
168
169    [TestMethod]
170    [TestCategory("Algorithms.DataAnalysis")]
171    [TestProperty("Time", "long")]
172    public void GradientBoostingTestTowerSquaredError() {
173      var gbt = new GradientBoostedTreesAlgorithm();
174      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
175      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
176      var regProblem = new RegressionProblem();
177      regProblem.Load(provider.LoadData(instance));
178
179      #region Algorithm Configuration
180      gbt.Problem = regProblem;
181      gbt.Seed = 0;
182      gbt.SetSeedRandomly = false;
183      gbt.Iterations = 5000;
184      gbt.MaxSize = 20;
185      #endregion
186
187      RunAlgorithm(gbt);
188
189      Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
190      Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
191    }
192
193    [TestMethod]
194    [TestCategory("Algorithms.DataAnalysis")]
195    [TestProperty("Time", "long")]
196    public void GradientBoostingTestTowerAbsoluteError() {
197      var gbt = new GradientBoostedTreesAlgorithm();
198      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
199      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
200      var regProblem = new RegressionProblem();
201      regProblem.Load(provider.LoadData(instance));
202
203      #region Algorithm Configuration
204      gbt.Problem = regProblem;
205      gbt.Seed = 0;
206      gbt.SetSeedRandomly = false;
207      gbt.Iterations = 1000;
208      gbt.MaxSize = 20;
209      gbt.Nu = 0.02;
210      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
211      #endregion
212
213      RunAlgorithm(gbt);
214
215      Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
216      Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
217    }
218
219    [TestMethod]
220    [TestCategory("Algorithms.DataAnalysis")]
221    [TestProperty("Time", "long")]
222    public void GradientBoostingTestTowerRelativeError() {
223      var gbt = new GradientBoostedTreesAlgorithm();
224      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
225      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
226      var regProblem = new RegressionProblem();
227      regProblem.Load(provider.LoadData(instance));
228
229      #region Algorithm Configuration
230      gbt.Problem = regProblem;
231      gbt.Seed = 0;
232      gbt.SetSeedRandomly = false;
233      gbt.Iterations = 3000;
234      gbt.MaxSize = 20;
235      gbt.Nu = 0.005;
236      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
237      #endregion
238
239      RunAlgorithm(gbt);
240
241      Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
242      Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
243    }
244
245    // same as in SamplesUtil
246    private void RunAlgorithm(IAlgorithm a) {
247      var trigger = new EventWaitHandle(false, EventResetMode.ManualReset);
248      Exception ex = null;
249      a.Stopped += (src, e) => { trigger.Set(); };
250      a.ExceptionOccurred += (src, e) => { ex = e.Value; trigger.Set(); };
251      a.Prepare();
252      a.Start();
253      trigger.WaitOne();
254
255      Assert.AreEqual(ex, null);
256    }
257
258    #region helper
259    private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
260      int nRows = xy.GetLength(0);
261      var allowedInputs = allVariables.Skip(1);
262      var dataset = new Dataset(allVariables, xy);
263      var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
264      problemData.TrainingPartition.Start = 0;
265      problemData.TrainingPartition.End = nRows;
266      problemData.TestPartition.Start = nRows;
267      problemData.TestPartition.End = nRows;
268      var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
269      var model = (GradientBoostedTreesModel)solution.Model;
270      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
271
272      Console.WriteLine(treeM.ToString());
273      Console.WriteLine();
274    }
275    #endregion
276  }
277}
Note: See TracBrowser for help on using the repository browser.