[12620] | 1 | using System;
|
---|
| 2 | using System.Linq;
|
---|
[12632] | 3 | using System.Threading;
|
---|
| 4 | using HeuristicLab.Data;
|
---|
| 5 | using HeuristicLab.Optimization;
|
---|
[12620] | 6 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 7 | using Microsoft.VisualStudio.TestTools.UnitTesting;
|
---|
| 8 |
|
---|
[12658] | 9 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
[12620] | 10 | [TestClass()]
|
---|
[12711] | 11 | public class GradientBoostingTest {
|
---|
[12620] | 12 | [TestMethod]
|
---|
| 13 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 14 | [TestProperty("Time", "short")]
|
---|
| 15 | public void DecisionTreeTest() {
|
---|
| 16 | {
|
---|
| 17 | var xy = new double[,]
|
---|
| 18 | {
|
---|
[12658] | 19 | {-1, 20, 0},
|
---|
| 20 | {-1, 20, 0},
|
---|
| 21 | { 1, 10, 0},
|
---|
| 22 | { 1, 10, 0},
|
---|
[12620] | 23 | };
|
---|
| 24 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 25 |
|
---|
[12658] | 26 | // x1 <= 15 -> 1
|
---|
| 27 | // x1 > 15 -> -1
|
---|
[12620] | 28 | BuildTree(xy, allVariables, 10);
|
---|
| 29 | }
|
---|
| 30 |
|
---|
| 31 |
|
---|
| 32 | {
|
---|
| 33 | var xy = new double[,]
|
---|
| 34 | {
|
---|
[12658] | 35 | {-1, 20, 1},
|
---|
| 36 | {-1, 20, -1},
|
---|
| 37 | { 1, 10, -1},
|
---|
| 38 | { 1, 10, 1},
|
---|
[12620] | 39 | };
|
---|
| 40 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 41 |
|
---|
| 42 | // ignore irrelevant variables
|
---|
[12658] | 43 | // x1 <= 15 -> 1
|
---|
| 44 | // x1 > 15 -> -1
|
---|
[12620] | 45 | BuildTree(xy, allVariables, 10);
|
---|
| 46 | }
|
---|
| 47 |
|
---|
| 48 | {
|
---|
| 49 | // split must be by x1 first
|
---|
| 50 | var xy = new double[,]
|
---|
| 51 | {
|
---|
[12658] | 52 | {-2, 20, 1},
|
---|
| 53 | {-1, 20, -1},
|
---|
| 54 | { 1, 10, -1},
|
---|
| 55 | { 2, 10, 1},
|
---|
[12620] | 56 | };
|
---|
| 57 |
|
---|
| 58 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 59 |
|
---|
[12658] | 60 | // x1 <= 15 AND x2 <= 0 -> 1
|
---|
| 61 | // x1 <= 15 AND x2 > 0 -> 2
|
---|
| 62 | // x1 > 15 AND x2 <= 0 -> -1
|
---|
| 63 | // x1 > 15 AND x2 > 0 -> -2
|
---|
[12620] | 64 | BuildTree(xy, allVariables, 10);
|
---|
| 65 | }
|
---|
| 66 |
|
---|
| 67 | {
|
---|
| 68 | // averaging ys
|
---|
| 69 | var xy = new double[,]
|
---|
| 70 | {
|
---|
[12658] | 71 | {-2.5, 20, 1},
|
---|
| 72 | {-1.5, 20, 1},
|
---|
| 73 | {-1.5, 20, -1},
|
---|
[12661] | 74 | {-0.5, 20, -1},
|
---|
[12658] | 75 | {0.5, 10, -1},
|
---|
| 76 | {1.5, 10, -1},
|
---|
| 77 | {1.5, 10, 1},
|
---|
| 78 | {2.5, 10, 1},
|
---|
[12620] | 79 | };
|
---|
| 80 |
|
---|
| 81 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 82 |
|
---|
[12658] | 83 | // x1 <= 15 AND x2 <= 0 -> 1
|
---|
| 84 | // x1 <= 15 AND x2 > 0 -> 2
|
---|
| 85 | // x1 > 15 AND x2 <= 0 -> -1
|
---|
| 86 | // x1 > 15 AND x2 > 0 -> -2
|
---|
[12620] | 87 | BuildTree(xy, allVariables, 10);
|
---|
| 88 | }
|
---|
| 89 |
|
---|
| 90 |
|
---|
| 91 | {
|
---|
| 92 | // diagonal split (no split possible)
|
---|
| 93 | var xy = new double[,]
|
---|
| 94 | {
|
---|
[12658] | 95 | { 1, 1, 1},
|
---|
| 96 | {-1, 1, 2},
|
---|
| 97 | {-1, 2, 1},
|
---|
| 98 | { 1, 2, 2},
|
---|
[12620] | 99 | };
|
---|
| 100 |
|
---|
| 101 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 102 |
|
---|
| 103 | // split cannot be found
|
---|
[12658] | 104 | // -> 0.0
|
---|
[12620] | 105 | BuildTree(xy, allVariables, 3);
|
---|
| 106 | }
|
---|
| 107 | {
|
---|
| 108 | // almost diagonal split
|
---|
| 109 | var xy = new double[,]
|
---|
| 110 | {
|
---|
[12658] | 111 | { 1, 1, 1},
|
---|
| 112 | {-1, 1, 2},
|
---|
| 113 | {-1, 2, 1},
|
---|
| 114 | { 1.0001, 2, 2},
|
---|
[12620] | 115 | };
|
---|
| 116 |
|
---|
| 117 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
[12632] | 118 | // (two possible solutions)
|
---|
[12658] | 119 | // x2 <= 1.5 -> 0
|
---|
| 120 | // x2 > 1.5 -> 0 (not quite)
|
---|
[12632] | 121 | BuildTree(xy, allVariables, 3);
|
---|
| 122 |
|
---|
[12658] | 123 | // x1 <= 1.5 AND x2 <= 1.5 -> 1
|
---|
| 124 | // x1 <= 1.5 AND x2 > 1.5 -> -1
|
---|
| 125 | // x1 > 1.5 AND x2 <= 1.5 -> -1
|
---|
| 126 | // x1 > 1.5 AND x2 > 1.5 -> 1 (not quite)
|
---|
[12632] | 127 | BuildTree(xy, allVariables, 7);
|
---|
[12620] | 128 | }
|
---|
| 129 | {
|
---|
| 130 | // unbalanced split
|
---|
| 131 | var xy = new double[,]
|
---|
| 132 | {
|
---|
| 133 | {-1, 1, 1},
|
---|
| 134 | {-1, 1, 2},
|
---|
| 135 | {0.9, 2, 1},
|
---|
| 136 | {1.1, 2, 2},
|
---|
| 137 | };
|
---|
| 138 |
|
---|
| 139 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 140 | // x1 <= 1.5 -> -1.0
|
---|
| 141 | // x1 > 1.5 AND x2 <= 1.5 -> 0.9
|
---|
| 142 | // x1 > 1.5 AND x2 > 1.5 -> 1.1
|
---|
[12632] | 143 | BuildTree(xy, allVariables, 10);
|
---|
[12620] | 144 | }
|
---|
| 145 |
|
---|
[12632] | 146 | {
|
---|
| 147 | // unbalanced split
|
---|
| 148 | var xy = new double[,]
|
---|
| 149 | {
|
---|
| 150 | {-1, 1, 1},
|
---|
| 151 | {-1, 1, 2},
|
---|
| 152 | {-1, 2, 1},
|
---|
[12658] | 153 | { 3, 2, 2},
|
---|
[12632] | 154 | };
|
---|
| 155 |
|
---|
| 156 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 157 | // (two possible solutions)
|
---|
| 158 | // x2 <= 1.5 -> -1.0
|
---|
| 159 | // x2 > 1.5 AND x1 <= 1.5 -> -1.0
|
---|
[12658] | 160 | // x2 > 1.5 AND x1 > 1.5 -> 3.0
|
---|
[12632] | 161 | BuildTree(xy, allVariables, 10);
|
---|
| 162 | }
|
---|
[12620] | 163 | }
|
---|
| 164 |
|
---|
[12632] | 165 | [TestMethod]
|
---|
| 166 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 167 | [TestProperty("Time", "long")]
|
---|
| 168 | public void GradientBoostingTestTowerSquaredError() {
|
---|
| 169 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 170 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 171 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 172 | var regProblem = new RegressionProblem();
|
---|
| 173 | regProblem.Load(provider.LoadData(instance));
|
---|
| 174 |
|
---|
| 175 | #region Algorithm Configuration
|
---|
| 176 | gbt.Problem = regProblem;
|
---|
| 177 | gbt.Seed = 0;
|
---|
| 178 | gbt.SetSeedRandomly = false;
|
---|
| 179 | gbt.Iterations = 5000;
|
---|
| 180 | gbt.MaxSize = 20;
|
---|
[12699] | 181 | gbt.CreateSolution = false;
|
---|
[12632] | 182 | #endregion
|
---|
| 183 |
|
---|
| 184 | RunAlgorithm(gbt);
|
---|
| 185 |
|
---|
[12699] | 186 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 187 | Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 188 | Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 189 | }
|
---|
| 190 |
|
---|
| 191 | [TestMethod]
|
---|
| 192 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 193 | [TestProperty("Time", "long")]
|
---|
| 194 | public void GradientBoostingTestTowerAbsoluteError() {
|
---|
| 195 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 196 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 197 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 198 | var regProblem = new RegressionProblem();
|
---|
| 199 | regProblem.Load(provider.LoadData(instance));
|
---|
| 200 |
|
---|
| 201 | #region Algorithm Configuration
|
---|
| 202 | gbt.Problem = regProblem;
|
---|
| 203 | gbt.Seed = 0;
|
---|
| 204 | gbt.SetSeedRandomly = false;
|
---|
| 205 | gbt.Iterations = 1000;
|
---|
| 206 | gbt.MaxSize = 20;
|
---|
| 207 | gbt.Nu = 0.02;
|
---|
| 208 | gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
|
---|
[12699] | 209 | gbt.CreateSolution = false;
|
---|
[12632] | 210 | #endregion
|
---|
| 211 |
|
---|
| 212 | RunAlgorithm(gbt);
|
---|
| 213 |
|
---|
[12699] | 214 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 215 | Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 216 | Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 217 | }
|
---|
| 218 |
|
---|
| 219 | [TestMethod]
|
---|
| 220 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 221 | [TestProperty("Time", "long")]
|
---|
| 222 | public void GradientBoostingTestTowerRelativeError() {
|
---|
| 223 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 224 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 225 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 226 | var regProblem = new RegressionProblem();
|
---|
| 227 | regProblem.Load(provider.LoadData(instance));
|
---|
| 228 |
|
---|
| 229 | #region Algorithm Configuration
|
---|
| 230 | gbt.Problem = regProblem;
|
---|
| 231 | gbt.Seed = 0;
|
---|
| 232 | gbt.SetSeedRandomly = false;
|
---|
| 233 | gbt.Iterations = 3000;
|
---|
| 234 | gbt.MaxSize = 20;
|
---|
| 235 | gbt.Nu = 0.005;
|
---|
| 236 | gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
|
---|
[12699] | 237 | gbt.CreateSolution = false;
|
---|
[12632] | 238 | #endregion
|
---|
| 239 |
|
---|
| 240 | RunAlgorithm(gbt);
|
---|
| 241 |
|
---|
[12699] | 242 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 243 | Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 244 | Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 245 | }
|
---|
| 246 |
|
---|
| 247 | // same as in SamplesUtil
|
---|
| 248 | private void RunAlgorithm(IAlgorithm a) {
|
---|
| 249 | var trigger = new EventWaitHandle(false, EventResetMode.ManualReset);
|
---|
| 250 | Exception ex = null;
|
---|
| 251 | a.Stopped += (src, e) => { trigger.Set(); };
|
---|
| 252 | a.ExceptionOccurred += (src, e) => { ex = e.Value; trigger.Set(); };
|
---|
| 253 | a.Prepare();
|
---|
| 254 | a.Start();
|
---|
| 255 | trigger.WaitOne();
|
---|
| 256 |
|
---|
| 257 | Assert.AreEqual(ex, null);
|
---|
| 258 | }
|
---|
| 259 |
|
---|
| 260 | #region helper
|
---|
[12661] | 261 | private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
|
---|
[12620] | 262 | int nRows = xy.GetLength(0);
|
---|
| 263 | var allowedInputs = allVariables.Skip(1);
|
---|
| 264 | var dataset = new Dataset(allVariables, xy);
|
---|
| 265 | var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
|
---|
| 266 | problemData.TrainingPartition.Start = 0;
|
---|
| 267 | problemData.TrainingPartition.End = nRows;
|
---|
| 268 | problemData.TestPartition.Start = nRows;
|
---|
| 269 | problemData.TestPartition.End = nRows;
|
---|
[12661] | 270 | var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
|
---|
[13184] | 271 | var model = solution.Model;
|
---|
[12620] | 272 | var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
|
---|
[12658] | 273 |
|
---|
| 274 | Console.WriteLine(treeM.ToString());
|
---|
[12620] | 275 | Console.WriteLine();
|
---|
| 276 | }
|
---|
[12632] | 277 | #endregion
|
---|
[12620] | 278 | }
|
---|
| 279 | }
|
---|