Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT-trunkintegration/Tests/Test.cs @ 12632

Last change on this file since 12632 was 12632, checked in by gkronber, 9 years ago

#2261 implemented node expansion using a priority queue (and changed parameter MaxDepth to MaxSize). Moved unit tests to a separate project.

File size: 9.5 KB
Line 
1using System;
2using System.Collections;
3using System.Globalization;
4using System.Linq;
5using System.Runtime.CompilerServices;
6using System.Threading;
7using HeuristicLab.Data;
8using HeuristicLab.Optimization;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Random;
11using Microsoft.VisualStudio.TestTools.UnitTesting;
12
13namespace HeuristicLab.Algorithms.DataAnalysis.GradientBoostedTrees {
14  [TestClass()]
15  public class Test {
16    [TestMethod]
17    [TestCategory("Algorithms.DataAnalysis")]
18    [TestProperty("Time", "short")]
19    public void DecisionTreeTest() {
20      {
21        var xy = new double[,]
22        {
23          {1, 20, 0},
24          {1, 20, 0},
25          {2, 10, 0},
26          {2, 10, 0},
27        };
28        var allVariables = new string[] { "y", "x1", "x2" };
29
30        // x1 <= 15 -> 2
31        // x1 >  15 -> 1
32        BuildTree(xy, allVariables, 10);
33      }
34
35
36      {
37        var xy = new double[,]
38        {
39          {1, 20,  1},
40          {1, 20, -1},
41          {2, 10, -1},
42          {2, 10, 1},
43        };
44        var allVariables = new string[] { "y", "x1", "x2" };
45
46        // ignore irrelevant variables
47        // x1 <= 15 -> 2
48        // x1 >  15 -> 1
49        BuildTree(xy, allVariables, 10);
50      }
51
52      {
53        // split must be by x1 first
54        var xy = new double[,]
55        {
56          {1, 20,  1},
57          {2, 20, -1},
58          {3, 10, -1},
59          {4, 10, 1},
60        };
61
62        var allVariables = new string[] { "y", "x1", "x2" };
63
64        // x1 <= 15 AND x2 <= 0 -> 3
65        // x1 <= 15 AND x2 >  0 -> 4
66        // x1 >  15 AND x2 <= 0 -> 2
67        // x1 >  15 AND x2 >  0 -> 1
68        BuildTree(xy, allVariables, 10);
69      }
70
71      {
72        // averaging ys
73        var xy = new double[,]
74        {
75          {0.5, 20,  1},
76          {1.5, 20,  1},
77          {1.5, 20, -1},
78          {2.5, 20, -1},
79          {2.5, 10, -1},
80          {3.5, 10, -1},
81          {3.5, 10, 1},
82          {4.5, 10, 1},
83        };
84
85        var allVariables = new string[] { "y", "x1", "x2" };
86
87        // x1 <= 15 AND x2 <= 0 -> 3
88        // x1 <= 15 AND x2 >  0 -> 4
89        // x1 >  15 AND x2 <= 0 -> 2
90        // x1 >  15 AND x2 >  0 -> 1
91        BuildTree(xy, allVariables, 10);
92      }
93
94
95      {
96        // diagonal split (no split possible)
97        var xy = new double[,]
98        {
99          {10, 1, 1},
100          {1, 1, 2},
101          {1, 2, 1},
102          {10, 2, 2},
103        };
104
105        var allVariables = new string[] { "y", "x1", "x2" };
106
107        // split cannot be found
108        // -> 5.50
109        BuildTree(xy, allVariables, 3);
110      }
111      {
112        // almost diagonal split
113        var xy = new double[,]
114        {
115          {10, 1, 1},
116          {1, 1, 2},
117          {1, 2, 1},
118          {10.1, 2, 2},
119        };
120
121        var allVariables = new string[] { "y", "x1", "x2" };
122        // (two possible solutions)
123        // x2 <= 1.5 -> 5.50
124        // x2 >  1.5 -> 5.55
125        BuildTree(xy, allVariables, 3);
126
127        // x1 <= 1.5 AND x2 <= 1.5 -> 10
128        // x1 <= 1.5 AND x2 >  1.5 -> 1
129        // x1 >  1.5 AND x2 <= 1.5 -> 1
130        // x1 >  1.5 AND x2 >  1.5 -> 10.1
131        BuildTree(xy, allVariables, 7);
132      }
133      {
134        // unbalanced split
135        var xy = new double[,]
136        {
137          {-1, 1, 1},
138          {-1, 1, 2},
139          {0.9, 2, 1},
140          {1.1, 2, 2},
141        };
142
143        var allVariables = new string[] { "y", "x1", "x2" };
144        // x1 <= 1.5 -> -1.0
145        // x1 >  1.5 AND x2 <= 1.5 -> 0.9
146        // x1 >  1.5 AND x2 >  1.5 -> 1.1
147        BuildTree(xy, allVariables, 10);
148      }
149
150      {
151        // unbalanced split
152        var xy = new double[,]
153        {
154          {-1, 1, 1},
155          {-1, 1, 2},
156          {-1, 2, 1},
157          { 1, 2, 2},
158        };
159
160        var allVariables = new string[] { "y", "x1", "x2" };
161        // (two possible solutions)
162        // x2 <= 1.5 -> -1.0
163        // x2 >  1.5 AND x1 <= 1.5 -> -1.0
164        // x2 >  1.5 AND x1 >  1.5 ->  1.0
165        BuildTree(xy, allVariables, 10);
166      }
167    }
168
169    [TestMethod]
170    [TestCategory("Algorithms.DataAnalysis")]
171    [TestProperty("Time", "long")]
172    public void GradientBoostingTestTowerSquaredError() {
173      var gbt = new GradientBoostedTreesAlgorithm();
174      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
175      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
176      var regProblem = new RegressionProblem();
177      regProblem.Load(provider.LoadData(instance));
178
179      #region Algorithm Configuration
180      gbt.Problem = regProblem;
181      gbt.Seed = 0;
182      gbt.SetSeedRandomly = false;
183      gbt.Iterations = 5000;
184      gbt.MaxSize = 20;
185      #endregion
186
187      RunAlgorithm(gbt);
188
189      Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
190      Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
191    }
192
193    [TestMethod]
194    [TestCategory("Algorithms.DataAnalysis")]
195    [TestProperty("Time", "long")]
196    public void GradientBoostingTestTowerAbsoluteError() {
197      var gbt = new GradientBoostedTreesAlgorithm();
198      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
199      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
200      var regProblem = new RegressionProblem();
201      regProblem.Load(provider.LoadData(instance));
202
203      #region Algorithm Configuration
204      gbt.Problem = regProblem;
205      gbt.Seed = 0;
206      gbt.SetSeedRandomly = false;
207      gbt.Iterations = 1000;
208      gbt.MaxSize = 20;
209      gbt.Nu = 0.02;
210      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
211      #endregion
212
213      RunAlgorithm(gbt);
214
215      Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
216      Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
217    }
218
219    [TestMethod]
220    [TestCategory("Algorithms.DataAnalysis")]
221    [TestProperty("Time", "long")]
222    public void GradientBoostingTestTowerRelativeError() {
223      var gbt = new GradientBoostedTreesAlgorithm();
224      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
225      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
226      var regProblem = new RegressionProblem();
227      regProblem.Load(provider.LoadData(instance));
228
229      #region Algorithm Configuration
230      gbt.Problem = regProblem;
231      gbt.Seed = 0;
232      gbt.SetSeedRandomly = false;
233      gbt.Iterations = 3000;
234      gbt.MaxSize = 20;
235      gbt.Nu = 0.005;
236      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
237      #endregion
238
239      RunAlgorithm(gbt);
240
241      Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
242      Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
243    }
244
245    // same as in SamplesUtil
246    private void RunAlgorithm(IAlgorithm a) {
247      var trigger = new EventWaitHandle(false, EventResetMode.ManualReset);
248      Exception ex = null;
249      a.Stopped += (src, e) => { trigger.Set(); };
250      a.ExceptionOccurred += (src, e) => { ex = e.Value; trigger.Set(); };
251      a.Prepare();
252      a.Start();
253      trigger.WaitOne();
254
255      Assert.AreEqual(ex, null);
256    }
257
258    #region helper
259    private void BuildTree(double[,] xy, string[] allVariables, int maxDepth) {
260      int nRows = xy.GetLength(0);
261      var allowedInputs = allVariables.Skip(1);
262      var dataset = new Dataset(allVariables, xy);
263      var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
264      problemData.TrainingPartition.Start = 0;
265      problemData.TrainingPartition.End = nRows;
266      problemData.TestPartition.Start = nRows;
267      problemData.TestPartition.End = nRows;
268      var rand = new MersenneTwister(31415);
269      var builder = new RegressionTreeBuilder(problemData, rand);
270      var model = (GradientBoostedTreesModel)builder.CreateRegressionTree(maxDepth, 1, 1); // maximal depth and use all rows and cols
271      var constM = model.Models.First() as ConstantRegressionModel;
272      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
273      WriteTree(treeM.tree, 0, "", constM.Constant);
274      Console.WriteLine();
275    }
276
277    private void WriteTree(RegressionTreeModel.TreeNode[] tree, int idx, string partialRule, double offset) {
278      var n = tree[idx];
279      if (n.VarName == RegressionTreeModel.TreeNode.NO_VARIABLE) {
280        Console.WriteLine("{0} -> {1:F}", partialRule, n.Val + offset);
281      } else {
282        WriteTree(tree, n.LeftIdx,
283          string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}",
284          partialRule,
285          string.IsNullOrEmpty(partialRule) ? "" : " and ",
286          n.VarName,
287          n.Val), offset);
288        WriteTree(tree, n.RightIdx,
289          string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}",
290          partialRule,
291          string.IsNullOrEmpty(partialRule) ? "" : " and ",
292          n.VarName,
293          n.Val), offset);
294      }
295    }
296    #endregion
297  }
298}
Note: See TracBrowser for help on using the repository browser.