1 | using System;
2 | using System.Collections;
3 | using System.Globalization;
4 | using System.Linq;
5 | using System.Runtime.CompilerServices;
6 | using HeuristicLab.Problems.DataAnalysis;
7 | using HeuristicLab.Random;
8 | using Microsoft.VisualStudio.TestTools.UnitTesting;
9 |
10 | namespace HeuristicLab.Algorithms.DataAnalysis.GradientBoostedTrees {
11 | [TestClass()]
12 | public class Test {
13 | [TestMethod]
14 | [TestCategory("Algorithms.DataAnalysis")]
15 | [TestProperty("Time", "short")]
16 | public void DecisionTreeTest() {
17 | {
18 | var xy = new double[,]
19 | {
20 | {1, 20, 0},
21 | {1, 20, 0},
22 | {2, 10, 0},
23 | {2, 10, 0},
24 | };
25 | var allVariables = new string[] { "y", "x1", "x2" };
26 |
27 | // x1 <= 15 -> 1
28 | // x1 > 15 -> 2
29 | BuildTree(xy, allVariables, 10);
30 | }
31 |
32 |
33 | {
34 | var xy = new double[,]
35 | {
36 | {1, 20, 1},
37 | {1, 20, -1},
38 | {2, 10, -1},
39 | {2, 10, 1},
40 | };
41 | var allVariables = new string[] { "y", "x1", "x2" };
42 |
43 | // ignore irrelevant variables
44 | // x1 <= 15 -> 1
45 | // x1 > 15 -> 2
46 | BuildTree(xy, allVariables, 10);
47 | }
48 |
49 | {
50 | // split must be by x1 first
51 | var xy = new double[,]
52 | {
53 | {1, 20, 1},
54 | {2, 20, -1},
55 | {3, 10, -1},
56 | {4, 10, 1},
57 | };
58 |
59 | var allVariables = new string[] { "y", "x1", "x2" };
60 |
61 | // x1 <= 15 AND x2 <= 0 -> 3
62 | // x1 <= 15 AND x2 > 0 -> 4
63 | // x1 > 15 AND x2 <= 0 -> 1
64 | // x1 > 15 AND x2 > 0 -> 2
65 | BuildTree(xy, allVariables, 10);
66 | }
67 |
68 | {
69 | // averaging ys
70 | var xy = new double[,]
71 | {
72 | {0.5, 20, 1},
73 | {1.5, 20, 1},
74 | {1.5, 20, -1},
75 | {2.5, 20, -1},
76 | {2.5, 10, -1},
77 | {3.5, 10, -1},
78 | {3.5, 10, 1},
79 | {4.5, 10, 1},
80 | };
81 |
82 | var allVariables = new string[] { "y", "x1", "x2" };
83 |
84 | // x1 <= 15 AND x2 <= 0 -> 3
85 | // x1 <= 15 AND x2 > 0 -> 4
86 | // x1 > 15 AND x2 <= 0 -> 1
87 | // x1 > 15 AND x2 > 0 -> 2
88 | BuildTree(xy, allVariables, 10);
89 | }
90 |
91 |
92 | {
93 | // diagonal split (no split possible)
94 | var xy = new double[,]
95 | {
96 | {10, 1, 1},
97 | {1, 1, 2},
98 | {1, 2, 1},
99 | {10, 2, 2},
100 | };
101 |
102 | var allVariables = new string[] { "y", "x1", "x2" };
103 |
104 | // split cannot be found
105 | BuildTree(xy, allVariables, 3);
106 | }
107 | {
108 | // almost diagonal split
109 | var xy = new double[,]
110 | {
111 | {10, 1, 1},
112 | {1, 1, 2},
113 | {1, 2, 1},
114 | {10.1, 2, 2},
115 | };
116 |
117 | var allVariables = new string[] { "y", "x1", "x2" };
118 | // x1 <= 1.5 AND x2 <= 1.5 -> 10
119 | // x1 <= 1.5 AND x2 > 1.5 -> 1
120 | // x1 > 1.5 AND x2 <= 1.5 -> 1
121 | // x1 > 1.5 AND x2 > 1.5 -> 10.1
122 | BuildTree(xy, allVariables, 3);
123 | }
124 | {
125 | // unbalanced split
126 | var xy = new double[,]
127 | {
128 | {-1, 1, 1},
129 | {-1, 1, 2},
130 | {0.9, 2, 1},
131 | {1.1, 2, 2},
132 | };
133 |
134 | var allVariables = new string[] { "y", "x1", "x2" };
135 | // x1 <= 1.5 -> -1.0
136 | // x1 > 1.5 AND x2 <= 1.5 -> 0.9
137 | // x1 > 1.5 AND x2 > 1.5 -> 1.1
138 | BuildTree(xy, allVariables, 3);
139 | }
140 |
141 | }
142 |
143 | private void BuildTree(double[,] xy, string[] allVariables, int maxDepth) {
144 | int nRows = xy.GetLength(0);
145 | var allowedInputs = allVariables.Skip(1);
146 | var dataset = new Dataset(allVariables, xy);
147 | var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
148 | problemData.TrainingPartition.Start = 0;
149 | problemData.TrainingPartition.End = nRows;
150 | problemData.TestPartition.Start = nRows;
151 | problemData.TestPartition.End = nRows;
152 | var rand = new MersenneTwister(31415);
153 | var builder = new RegressionTreeBuilder(problemData, rand);
154 | var model = (GradientBoostedTreesModel)builder.CreateRegressionTree(maxDepth, 1, 1); // maximal depth and use all rows and cols
155 | var constM = model.Models.First() as ConstantRegressionModel;
156 | var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
157 | WriteTree(treeM.tree, 0, "", constM.Constant);
158 | Console.WriteLine();
159 | }
160 |
161 | private void WriteTree(RegressionTreeModel.TreeNode[] tree, int idx, string partialRule, double offset) {
162 | var n = tree[idx];
163 | if (n.varName == RegressionTreeModel.TreeNode.NO_VARIABLE) {
164 | Console.WriteLine("{0} -> {1:F}", partialRule, n.val + offset);
165 | } else {
166 | WriteTree(tree, n.leftIdx,
167 | string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}",
168 | partialRule,
169 | string.IsNullOrEmpty(partialRule) ? "" : " and ",
170 | n.varName,
171 | n.val), offset);
172 | WriteTree(tree, n.rightIdx,
173 | string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F}",
174 | partialRule,
175 | string.IsNullOrEmpty(partialRule) ? "" : " and ",
176 | n.varName,
177 | n.val), offset);
178 | }
179 | }
180 | }
181 | }