[12620] | 1 | using System;
|
---|
[15973] | 2 | using System.Collections;
|
---|
| 3 | using System.IO;
|
---|
[12620] | 4 | using System.Linq;
|
---|
[12632] | 5 | using HeuristicLab.Data;
|
---|
[12620] | 6 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 7 | using Microsoft.VisualStudio.TestTools.UnitTesting;
|
---|
| 8 |
|
---|
[12658] | 9 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
[12620] | 10 | [TestClass()]
|
---|
[12710] | 11 | public class GradientBoostingTest {
|
---|
[12620] | 12 | [TestMethod]
|
---|
| 13 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 14 | [TestProperty("Time", "short")]
|
---|
| 15 | public void DecisionTreeTest() {
|
---|
| 16 | {
|
---|
| 17 | var xy = new double[,]
|
---|
| 18 | {
|
---|
[12658] | 19 | {-1, 20, 0},
|
---|
| 20 | {-1, 20, 0},
|
---|
| 21 | { 1, 10, 0},
|
---|
| 22 | { 1, 10, 0},
|
---|
[12620] | 23 | };
|
---|
| 24 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 25 |
|
---|
[12658] | 26 | // x1 <= 15 -> 1
|
---|
| 27 | // x1 > 15 -> -1
|
---|
[12620] | 28 | BuildTree(xy, allVariables, 10);
|
---|
| 29 | }
|
---|
| 30 |
|
---|
| 31 |
|
---|
| 32 | {
|
---|
| 33 | var xy = new double[,]
|
---|
| 34 | {
|
---|
[12658] | 35 | {-1, 20, 1},
|
---|
| 36 | {-1, 20, -1},
|
---|
| 37 | { 1, 10, -1},
|
---|
| 38 | { 1, 10, 1},
|
---|
[12620] | 39 | };
|
---|
| 40 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 41 |
|
---|
| 42 | // ignore irrelevant variables
|
---|
[12658] | 43 | // x1 <= 15 -> 1
|
---|
| 44 | // x1 > 15 -> -1
|
---|
[12620] | 45 | BuildTree(xy, allVariables, 10);
|
---|
| 46 | }
|
---|
| 47 |
|
---|
| 48 | {
|
---|
| 49 | // split must be by x1 first
|
---|
| 50 | var xy = new double[,]
|
---|
| 51 | {
|
---|
[12658] | 52 | {-2, 20, 1},
|
---|
| 53 | {-1, 20, -1},
|
---|
| 54 | { 1, 10, -1},
|
---|
| 55 | { 2, 10, 1},
|
---|
[12620] | 56 | };
|
---|
| 57 |
|
---|
| 58 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 59 |
|
---|
[12658] | 60 | // x1 <= 15 AND x2 <= 0 -> 1
|
---|
| 61 | // x1 <= 15 AND x2 > 0 -> 2
|
---|
| 62 | // x1 > 15 AND x2 <= 0 -> -1
|
---|
| 63 | // x1 > 15 AND x2 > 0 -> -2
|
---|
[12620] | 64 | BuildTree(xy, allVariables, 10);
|
---|
| 65 | }
|
---|
| 66 |
|
---|
| 67 | {
|
---|
| 68 | // averaging ys
|
---|
| 69 | var xy = new double[,]
|
---|
| 70 | {
|
---|
[12658] | 71 | {-2.5, 20, 1},
|
---|
| 72 | {-1.5, 20, 1},
|
---|
| 73 | {-1.5, 20, -1},
|
---|
[12661] | 74 | {-0.5, 20, -1},
|
---|
[12658] | 75 | {0.5, 10, -1},
|
---|
| 76 | {1.5, 10, -1},
|
---|
| 77 | {1.5, 10, 1},
|
---|
| 78 | {2.5, 10, 1},
|
---|
[12620] | 79 | };
|
---|
| 80 |
|
---|
| 81 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 82 |
|
---|
[12658] | 83 | // x1 <= 15 AND x2 <= 0 -> 1
|
---|
| 84 | // x1 <= 15 AND x2 > 0 -> 2
|
---|
| 85 | // x1 > 15 AND x2 <= 0 -> -1
|
---|
| 86 | // x1 > 15 AND x2 > 0 -> -2
|
---|
[12620] | 87 | BuildTree(xy, allVariables, 10);
|
---|
| 88 | }
|
---|
| 89 |
|
---|
| 90 |
|
---|
| 91 | {
|
---|
| 92 | // diagonal split (no split possible)
|
---|
| 93 | var xy = new double[,]
|
---|
| 94 | {
|
---|
[12658] | 95 | { 1, 1, 1},
|
---|
| 96 | {-1, 1, 2},
|
---|
| 97 | {-1, 2, 1},
|
---|
| 98 | { 1, 2, 2},
|
---|
[12620] | 99 | };
|
---|
| 100 |
|
---|
| 101 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 102 |
|
---|
| 103 | // split cannot be found
|
---|
[12658] | 104 | // -> 0.0
|
---|
[12620] | 105 | BuildTree(xy, allVariables, 3);
|
---|
| 106 | }
|
---|
| 107 | {
|
---|
| 108 | // almost diagonal split
|
---|
| 109 | var xy = new double[,]
|
---|
| 110 | {
|
---|
[12658] | 111 | { 1, 1, 1},
|
---|
| 112 | {-1, 1, 2},
|
---|
| 113 | {-1, 2, 1},
|
---|
| 114 | { 1.0001, 2, 2},
|
---|
[12620] | 115 | };
|
---|
| 116 |
|
---|
| 117 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
[12632] | 118 | // (two possible solutions)
|
---|
[12658] | 119 | // x2 <= 1.5 -> 0
|
---|
| 120 | // x2 > 1.5 -> 0 (not quite)
|
---|
[12632] | 121 | BuildTree(xy, allVariables, 3);
|
---|
| 122 |
|
---|
[12658] | 123 | // x1 <= 1.5 AND x2 <= 1.5 -> 1
|
---|
| 124 | // x1 <= 1.5 AND x2 > 1.5 -> -1
|
---|
| 125 | // x1 > 1.5 AND x2 <= 1.5 -> -1
|
---|
| 126 | // x1 > 1.5 AND x2 > 1.5 -> 1 (not quite)
|
---|
[12632] | 127 | BuildTree(xy, allVariables, 7);
|
---|
[12620] | 128 | }
|
---|
| 129 | {
|
---|
| 130 | // unbalanced split
|
---|
| 131 | var xy = new double[,]
|
---|
| 132 | {
|
---|
| 133 | {-1, 1, 1},
|
---|
| 134 | {-1, 1, 2},
|
---|
| 135 | {0.9, 2, 1},
|
---|
| 136 | {1.1, 2, 2},
|
---|
| 137 | };
|
---|
| 138 |
|
---|
| 139 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 140 | // x1 <= 1.5 -> -1.0
|
---|
| 141 | // x1 > 1.5 AND x2 <= 1.5 -> 0.9
|
---|
| 142 | // x1 > 1.5 AND x2 > 1.5 -> 1.1
|
---|
[12632] | 143 | BuildTree(xy, allVariables, 10);
|
---|
[12620] | 144 | }
|
---|
| 145 |
|
---|
[12632] | 146 | {
|
---|
| 147 | // unbalanced split
|
---|
| 148 | var xy = new double[,]
|
---|
| 149 | {
|
---|
| 150 | {-1, 1, 1},
|
---|
| 151 | {-1, 1, 2},
|
---|
| 152 | {-1, 2, 1},
|
---|
[12658] | 153 | { 3, 2, 2},
|
---|
[12632] | 154 | };
|
---|
| 155 |
|
---|
| 156 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 157 | // (two possible solutions)
|
---|
| 158 | // x2 <= 1.5 -> -1.0
|
---|
| 159 | // x2 > 1.5 AND x1 <= 1.5 -> -1.0
|
---|
[12658] | 160 | // x2 > 1.5 AND x1 > 1.5 -> 3.0
|
---|
[12632] | 161 | BuildTree(xy, allVariables, 10);
|
---|
| 162 | }
|
---|
[12620] | 163 | }
|
---|
| 164 |
|
---|
[12632] | 165 | [TestMethod]
|
---|
| 166 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
[15973] | 167 | [TestProperty("Time", "short")]
|
---|
| 168 | public void TestDecisionTreePartialDependence() {
|
---|
| 169 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 170 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 171 | var regProblem = new RegressionProblem();
|
---|
| 172 | regProblem.Load(provider.LoadData(instance));
|
---|
| 173 | var problemData = regProblem.ProblemData;
|
---|
| 174 | var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02);
|
---|
| 175 | for (int i = 0; i < 1000; i++)
|
---|
| 176 | GradientBoostedTreesAlgorithmStatic.MakeStep(state);
|
---|
| 177 |
|
---|
| 178 |
|
---|
| 179 | var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First();
|
---|
| 180 | Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value);
|
---|
| 181 | var model = ((IGradientBoostedTreesModel)state.GetModel());
|
---|
| 182 | var treeM = model.Models.Skip(1).First();
|
---|
| 183 | Console.WriteLine(treeM.ToString());
|
---|
| 184 | Console.WriteLine();
|
---|
| 185 |
|
---|
| 186 | var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray();
|
---|
| 187 | var ds = new ModifiableDataset(new string[] { mostImportantVar.Key },
|
---|
| 188 | new IList[] { mostImportantVarValues.ToList<double>() });
|
---|
| 189 |
|
---|
| 190 | var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray();
|
---|
| 191 |
|
---|
| 192 | for (int i = 0; i < mostImportantVarValues.Length; i += 10) {
|
---|
| 193 | Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]);
|
---|
| 194 | }
|
---|
| 195 | }
|
---|
| 196 |
|
---|
| 197 | [TestMethod]
|
---|
| 198 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 199 | [TestProperty("Time", "short")]
|
---|
| 200 | public void TestDecisionTreePersistence() {
|
---|
| 201 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 202 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 203 | var regProblem = new RegressionProblem();
|
---|
| 204 | regProblem.Load(provider.LoadData(instance));
|
---|
| 205 | var problemData = regProblem.ProblemData;
|
---|
| 206 | var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1);
|
---|
| 207 | GradientBoostedTreesAlgorithmStatic.MakeStep(state);
|
---|
| 208 |
|
---|
| 209 | var model = ((IGradientBoostedTreesModel)state.GetModel());
|
---|
| 210 | var treeM = model.Models.Skip(1).First();
|
---|
| 211 | var origStr = treeM.ToString();
|
---|
| 212 | using (var memStream = new MemoryStream()) {
|
---|
| 213 | Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream);
|
---|
| 214 | var buf = memStream.GetBuffer();
|
---|
| 215 | using (var restoreStream = new MemoryStream(buf)) {
|
---|
| 216 | var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream);
|
---|
| 217 | var restoredStr = restoredTree.ToString();
|
---|
| 218 | Assert.AreEqual(origStr, restoredStr);
|
---|
| 219 | }
|
---|
| 220 | }
|
---|
| 221 | }
|
---|
| 222 |
|
---|
| 223 | [TestMethod]
|
---|
| 224 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
[12632] | 225 | [TestProperty("Time", "long")]
|
---|
| 226 | public void GradientBoostingTestTowerSquaredError() {
|
---|
| 227 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 228 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 229 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 230 | var regProblem = new RegressionProblem();
|
---|
| 231 | regProblem.Load(provider.LoadData(instance));
|
---|
| 232 |
|
---|
| 233 | #region Algorithm Configuration
|
---|
| 234 | gbt.Problem = regProblem;
|
---|
| 235 | gbt.Seed = 0;
|
---|
| 236 | gbt.SetSeedRandomly = false;
|
---|
| 237 | gbt.Iterations = 5000;
|
---|
| 238 | gbt.MaxSize = 20;
|
---|
[12699] | 239 | gbt.CreateSolution = false;
|
---|
[12632] | 240 | #endregion
|
---|
| 241 |
|
---|
[15973] | 242 | gbt.Start();
|
---|
[12632] | 243 |
|
---|
[12699] | 244 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 245 | Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 246 | Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 247 | }
|
---|
| 248 |
|
---|
| 249 | [TestMethod]
|
---|
| 250 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 251 | [TestProperty("Time", "long")]
|
---|
| 252 | public void GradientBoostingTestTowerAbsoluteError() {
|
---|
| 253 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 254 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 255 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 256 | var regProblem = new RegressionProblem();
|
---|
| 257 | regProblem.Load(provider.LoadData(instance));
|
---|
| 258 |
|
---|
| 259 | #region Algorithm Configuration
|
---|
| 260 | gbt.Problem = regProblem;
|
---|
| 261 | gbt.Seed = 0;
|
---|
| 262 | gbt.SetSeedRandomly = false;
|
---|
| 263 | gbt.Iterations = 1000;
|
---|
| 264 | gbt.MaxSize = 20;
|
---|
| 265 | gbt.Nu = 0.02;
|
---|
| 266 | gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
|
---|
[12699] | 267 | gbt.CreateSolution = false;
|
---|
[12632] | 268 | #endregion
|
---|
| 269 |
|
---|
[15973] | 270 | gbt.Start();
|
---|
[12632] | 271 |
|
---|
[12699] | 272 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 273 | Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 274 | Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 275 | }
|
---|
| 276 |
|
---|
| 277 | [TestMethod]
|
---|
| 278 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 279 | [TestProperty("Time", "long")]
|
---|
| 280 | public void GradientBoostingTestTowerRelativeError() {
|
---|
| 281 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 282 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 283 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 284 | var regProblem = new RegressionProblem();
|
---|
| 285 | regProblem.Load(provider.LoadData(instance));
|
---|
| 286 |
|
---|
| 287 | #region Algorithm Configuration
|
---|
| 288 | gbt.Problem = regProblem;
|
---|
| 289 | gbt.Seed = 0;
|
---|
| 290 | gbt.SetSeedRandomly = false;
|
---|
| 291 | gbt.Iterations = 3000;
|
---|
| 292 | gbt.MaxSize = 20;
|
---|
| 293 | gbt.Nu = 0.005;
|
---|
| 294 | gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
|
---|
[12699] | 295 | gbt.CreateSolution = false;
|
---|
[12632] | 296 | #endregion
|
---|
| 297 |
|
---|
[15973] | 298 | gbt.Start();
|
---|
[12632] | 299 |
|
---|
[12699] | 300 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 301 | Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 302 | Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 303 | }
|
---|
| 304 |
|
---|
| 305 | #region helper
|
---|
[12661] | 306 | private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
|
---|
[12620] | 307 | int nRows = xy.GetLength(0);
|
---|
| 308 | var allowedInputs = allVariables.Skip(1);
|
---|
| 309 | var dataset = new Dataset(allVariables, xy);
|
---|
| 310 | var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
|
---|
| 311 | problemData.TrainingPartition.Start = 0;
|
---|
| 312 | problemData.TrainingPartition.End = nRows;
|
---|
| 313 | problemData.TestPartition.Start = nRows;
|
---|
| 314 | problemData.TestPartition.End = nRows;
|
---|
[12661] | 315 | var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
|
---|
[13157] | 316 | var model = solution.Model;
|
---|
[12620] | 317 | var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
|
---|
[12658] | 318 |
|
---|
| 319 | Console.WriteLine(treeM.ToString());
|
---|
[12620] | 320 | Console.WriteLine();
|
---|
| 321 | }
|
---|
[12632] | 322 | #endregion
|
---|
[12620] | 323 | }
|
---|
| 324 | }
|
---|