Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/Test.cs @ 12620

Last change on this file since 12620 was 12620, checked in by gkronber, 9 years ago

#2261: corrected check if a split is useful, added a unit test class and added an elaborate comment on split quality calculation

File size: 5.2 KB
Line 
1using System;
2using System.Collections;
3using System.Globalization;
4using System.Linq;
5using System.Runtime.CompilerServices;
6using HeuristicLab.Problems.DataAnalysis;
7using HeuristicLab.Random;
8using Microsoft.VisualStudio.TestTools.UnitTesting;
9
10namespace HeuristicLab.Algorithms.DataAnalysis.GradientBoostedTrees {
11  [TestClass()]
12  public class Test {
13    [TestMethod]
14    [TestCategory("Algorithms.DataAnalysis")]
15    [TestProperty("Time", "short")]
16    public void DecisionTreeTest() {
17      {
18        var xy = new double[,]
19        {
20          {1, 20, 0},
21          {1, 20, 0},
22          {2, 10, 0},
23          {2, 10, 0},
24        };
25        var allVariables = new string[] { "y", "x1", "x2" };
26
27        // x1 <= 15 -> 1
28        // x1 >  15 -> 2
29        BuildTree(xy, allVariables, 10);
30      }
31
32
33      {
34        var xy = new double[,]
35        {
36          {1, 20,  1},
37          {1, 20, -1},
38          {2, 10, -1},
39          {2, 10, 1},
40        };
41        var allVariables = new string[] { "y", "x1", "x2" };
42
43        // ignore irrelevant variables
44        // x1 <= 15 -> 1
45        // x1 >  15 -> 2
46        BuildTree(xy, allVariables, 10);
47      }
48
49      {
50        // split must be by x1 first
51        var xy = new double[,]
52        {
53          {1, 20,  1},
54          {2, 20, -1},
55          {3, 10, -1},
56          {4, 10, 1},
57        };
58
59        var allVariables = new string[] { "y", "x1", "x2" };
60
61        // x1 <= 15 AND x2 <= 0 -> 3
62        // x1 <= 15 AND x2 >  0 -> 4
63        // x1 >  15 AND x2 <= 0 -> 1
64        // x1 >  15 AND x2 >  0 -> 2
65        BuildTree(xy, allVariables, 10);
66      }
67
68      {
69        // averaging ys
70        var xy = new double[,]
71        {
72          {0.5, 20,  1},
73          {1.5, 20,  1},
74          {1.5, 20, -1},
75          {2.5, 20, -1},
76          {2.5, 10, -1},
77          {3.5, 10, -1},
78          {3.5, 10, 1},
79          {4.5, 10, 1},
80        };
81
82        var allVariables = new string[] { "y", "x1", "x2" };
83
84        // x1 <= 15 AND x2 <= 0 -> 3
85        // x1 <= 15 AND x2 >  0 -> 4
86        // x1 >  15 AND x2 <= 0 -> 1
87        // x1 >  15 AND x2 >  0 -> 2
88        BuildTree(xy, allVariables, 10);
89      }
90
91
92      {
93        // diagonal split (no split possible)
94        var xy = new double[,]
95        {
96          {10, 1, 1},
97          {1, 1, 2},
98          {1, 2, 1},
99          {10, 2, 2},
100        };
101
102        var allVariables = new string[] { "y", "x1", "x2" };
103
104        // split cannot be found
105        BuildTree(xy, allVariables, 3);
106      }
107      {
108        // almost diagonal split
109        var xy = new double[,]
110        {
111          {10, 1, 1},
112          {1, 1, 2},
113          {1, 2, 1},
114          {10.1, 2, 2},
115        };
116
117        var allVariables = new string[] { "y", "x1", "x2" };
118        // x1 <= 1.5 AND x2 <= 1.5 -> 10
119        // x1 <= 1.5 AND x2 >  1.5 -> 1
120        // x1 >  1.5 AND x2 <= 1.5 -> 1
121        // x1 >  1.5 AND x2 >  1.5 -> 10.1
122        BuildTree(xy, allVariables, 3);
123      }
124      {
125        // unbalanced split
126        var xy = new double[,]
127        {
128          {-1, 1, 1},
129          {-1, 1, 2},
130          {0.9, 2, 1},
131          {1.1, 2, 2},
132        };
133
134        var allVariables = new string[] { "y", "x1", "x2" };
135        // x1 <= 1.5 -> -1.0
136        // x1 >  1.5 AND x2 <= 1.5 -> 0.9
137        // x1 >  1.5 AND x2 >  1.5 -> 1.1
138        BuildTree(xy, allVariables, 3);
139      }
140
141    }
142
143    private void BuildTree(double[,] xy, string[] allVariables, int maxDepth) {
144      int nRows = xy.GetLength(0);
145      var allowedInputs = allVariables.Skip(1);
146      var dataset = new Dataset(allVariables, xy);
147      var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
148      problemData.TrainingPartition.Start = 0;
149      problemData.TrainingPartition.End = nRows;
150      problemData.TestPartition.Start = nRows;
151      problemData.TestPartition.End = nRows;
152      var rand = new MersenneTwister(31415);
153      var builder = new RegressionTreeBuilder(problemData, rand);
154      var model = (GradientBoostedTreesModel)builder.CreateRegressionTree(maxDepth, 1, 1); // maximal depth and use all rows and cols
155      var constM = model.Models.First() as ConstantRegressionModel;
156      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
157      WriteTree(treeM.tree, 0, "", constM.Constant);
158      Console.WriteLine();
159    }
160
161    private void WriteTree(RegressionTreeModel.TreeNode[] tree, int idx, string partialRule, double offset) {
162      var n = tree[idx];
163      if (n.varName == RegressionTreeModel.TreeNode.NO_VARIABLE) {
164        Console.WriteLine("{0} -> {1:F}", partialRule, n.val + offset);
165      } else {
166        WriteTree(tree, n.leftIdx,
167          string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}",
168          partialRule,
169          string.IsNullOrEmpty(partialRule) ? "" : " and ",
170          n.varName,
171          n.val), offset);
172        WriteTree(tree, n.rightIdx,
173          string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}",
174          partialRule,
175          string.IsNullOrEmpty(partialRule) ? "" : " and ",
176          n.varName,
177          n.val), offset);
178      }
179    }
180  }
181}
Note: See TracBrowser for help on using the repository browser.