[12620] | 1 | using System;
|
---|
[14023] | 2 | using System.Collections;
|
---|
| 3 | using System.IO;
|
---|
[12620] | 4 | using System.Linq;
|
---|
[12632] | 5 | using System.Threading;
|
---|
| 6 | using HeuristicLab.Data;
|
---|
| 7 | using HeuristicLab.Optimization;
|
---|
[12620] | 8 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 9 | using Microsoft.VisualStudio.TestTools.UnitTesting;
|
---|
| 10 |
|
---|
[12658] | 11 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
[12620] | 12 | [TestClass()]
|
---|
[12711] | 13 | public class GradientBoostingTest {
|
---|
[12620] | 14 | [TestMethod]
|
---|
| 15 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 16 | [TestProperty("Time", "short")]
|
---|
| 17 | public void DecisionTreeTest() {
|
---|
| 18 | {
|
---|
| 19 | var xy = new double[,]
|
---|
| 20 | {
|
---|
[12658] | 21 | {-1, 20, 0},
|
---|
| 22 | {-1, 20, 0},
|
---|
| 23 | { 1, 10, 0},
|
---|
| 24 | { 1, 10, 0},
|
---|
[12620] | 25 | };
|
---|
| 26 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 27 |
|
---|
[12658] | 28 | // x1 <= 15 -> 1
|
---|
| 29 | // x1 > 15 -> -1
|
---|
[12620] | 30 | BuildTree(xy, allVariables, 10);
|
---|
| 31 | }
|
---|
| 32 |
|
---|
| 33 |
|
---|
| 34 | {
|
---|
| 35 | var xy = new double[,]
|
---|
| 36 | {
|
---|
[12658] | 37 | {-1, 20, 1},
|
---|
| 38 | {-1, 20, -1},
|
---|
| 39 | { 1, 10, -1},
|
---|
| 40 | { 1, 10, 1},
|
---|
[12620] | 41 | };
|
---|
| 42 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 43 |
|
---|
| 44 | // ignore irrelevant variables
|
---|
[12658] | 45 | // x1 <= 15 -> 1
|
---|
| 46 | // x1 > 15 -> -1
|
---|
[12620] | 47 | BuildTree(xy, allVariables, 10);
|
---|
| 48 | }
|
---|
| 49 |
|
---|
| 50 | {
|
---|
| 51 | // split must be by x1 first
|
---|
| 52 | var xy = new double[,]
|
---|
| 53 | {
|
---|
[12658] | 54 | {-2, 20, 1},
|
---|
| 55 | {-1, 20, -1},
|
---|
| 56 | { 1, 10, -1},
|
---|
| 57 | { 2, 10, 1},
|
---|
[12620] | 58 | };
|
---|
| 59 |
|
---|
| 60 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 61 |
|
---|
[12658] | 62 | // x1 <= 15 AND x2 <= 0 -> 1
|
---|
| 63 | // x1 <= 15 AND x2 > 0 -> 2
|
---|
| 64 | // x1 > 15 AND x2 <= 0 -> -1
|
---|
| 65 | // x1 > 15 AND x2 > 0 -> -2
|
---|
[12620] | 66 | BuildTree(xy, allVariables, 10);
|
---|
| 67 | }
|
---|
| 68 |
|
---|
| 69 | {
|
---|
| 70 | // averaging ys
|
---|
| 71 | var xy = new double[,]
|
---|
| 72 | {
|
---|
[12658] | 73 | {-2.5, 20, 1},
|
---|
| 74 | {-1.5, 20, 1},
|
---|
| 75 | {-1.5, 20, -1},
|
---|
[12661] | 76 | {-0.5, 20, -1},
|
---|
[12658] | 77 | {0.5, 10, -1},
|
---|
| 78 | {1.5, 10, -1},
|
---|
| 79 | {1.5, 10, 1},
|
---|
| 80 | {2.5, 10, 1},
|
---|
[12620] | 81 | };
|
---|
| 82 |
|
---|
| 83 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 84 |
|
---|
[12658] | 85 | // x1 <= 15 AND x2 <= 0 -> 1
|
---|
| 86 | // x1 <= 15 AND x2 > 0 -> 2
|
---|
| 87 | // x1 > 15 AND x2 <= 0 -> -1
|
---|
| 88 | // x1 > 15 AND x2 > 0 -> -2
|
---|
[12620] | 89 | BuildTree(xy, allVariables, 10);
|
---|
| 90 | }
|
---|
| 91 |
|
---|
| 92 |
|
---|
| 93 | {
|
---|
| 94 | // diagonal split (no split possible)
|
---|
| 95 | var xy = new double[,]
|
---|
| 96 | {
|
---|
[12658] | 97 | { 1, 1, 1},
|
---|
| 98 | {-1, 1, 2},
|
---|
| 99 | {-1, 2, 1},
|
---|
| 100 | { 1, 2, 2},
|
---|
[12620] | 101 | };
|
---|
| 102 |
|
---|
| 103 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 104 |
|
---|
| 105 | // split cannot be found
|
---|
[12658] | 106 | // -> 0.0
|
---|
[12620] | 107 | BuildTree(xy, allVariables, 3);
|
---|
| 108 | }
|
---|
| 109 | {
|
---|
| 110 | // almost diagonal split
|
---|
| 111 | var xy = new double[,]
|
---|
| 112 | {
|
---|
[12658] | 113 | { 1, 1, 1},
|
---|
| 114 | {-1, 1, 2},
|
---|
| 115 | {-1, 2, 1},
|
---|
| 116 | { 1.0001, 2, 2},
|
---|
[12620] | 117 | };
|
---|
| 118 |
|
---|
| 119 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
[12632] | 120 | // (two possible solutions)
|
---|
[12658] | 121 | // x2 <= 1.5 -> 0
|
---|
| 122 | // x2 > 1.5 -> 0 (not quite)
|
---|
[12632] | 123 | BuildTree(xy, allVariables, 3);
|
---|
| 124 |
|
---|
[12658] | 125 | // x1 <= 1.5 AND x2 <= 1.5 -> 1
|
---|
| 126 | // x1 <= 1.5 AND x2 > 1.5 -> -1
|
---|
| 127 | // x1 > 1.5 AND x2 <= 1.5 -> -1
|
---|
| 128 | // x1 > 1.5 AND x2 > 1.5 -> 1 (not quite)
|
---|
[12632] | 129 | BuildTree(xy, allVariables, 7);
|
---|
[12620] | 130 | }
|
---|
| 131 | {
|
---|
| 132 | // unbalanced split
|
---|
| 133 | var xy = new double[,]
|
---|
| 134 | {
|
---|
| 135 | {-1, 1, 1},
|
---|
| 136 | {-1, 1, 2},
|
---|
| 137 | {0.9, 2, 1},
|
---|
| 138 | {1.1, 2, 2},
|
---|
| 139 | };
|
---|
| 140 |
|
---|
| 141 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 142 | // x1 <= 1.5 -> -1.0
|
---|
| 143 | // x1 > 1.5 AND x2 <= 1.5 -> 0.9
|
---|
| 144 | // x1 > 1.5 AND x2 > 1.5 -> 1.1
|
---|
[12632] | 145 | BuildTree(xy, allVariables, 10);
|
---|
[12620] | 146 | }
|
---|
| 147 |
|
---|
[12632] | 148 | {
|
---|
| 149 | // unbalanced split
|
---|
| 150 | var xy = new double[,]
|
---|
| 151 | {
|
---|
| 152 | {-1, 1, 1},
|
---|
| 153 | {-1, 1, 2},
|
---|
| 154 | {-1, 2, 1},
|
---|
[12658] | 155 | { 3, 2, 2},
|
---|
[12632] | 156 | };
|
---|
| 157 |
|
---|
| 158 | var allVariables = new string[] { "y", "x1", "x2" };
|
---|
| 159 | // (two possible solutions)
|
---|
| 160 | // x2 <= 1.5 -> -1.0
|
---|
| 161 | // x2 > 1.5 AND x1 <= 1.5 -> -1.0
|
---|
[12658] | 162 | // x2 > 1.5 AND x1 > 1.5 -> 3.0
|
---|
[12632] | 163 | BuildTree(xy, allVariables, 10);
|
---|
| 164 | }
|
---|
[12620] | 165 | }
|
---|
| 166 |
|
---|
[12632] | 167 | [TestMethod]
|
---|
| 168 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
[14023] | 169 | [TestProperty("Time", "short")]
|
---|
| 170 | public void TestDecisionTreePartialDependence() {
|
---|
| 171 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 172 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 173 | var regProblem = new RegressionProblem();
|
---|
| 174 | regProblem.Load(provider.LoadData(instance));
|
---|
| 175 | var problemData = regProblem.ProblemData;
|
---|
| 176 | var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02);
|
---|
| 177 | for (int i = 0; i < 1000; i++)
|
---|
| 178 | GradientBoostedTreesAlgorithmStatic.MakeStep(state);
|
---|
| 179 |
|
---|
| 180 |
|
---|
| 181 | var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First();
|
---|
| 182 | Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value);
|
---|
| 183 | var model = ((IGradientBoostedTreesModel)state.GetModel());
|
---|
| 184 | var treeM = model.Models.Skip(1).First();
|
---|
| 185 | Console.WriteLine(treeM.ToString());
|
---|
| 186 | Console.WriteLine();
|
---|
| 187 |
|
---|
| 188 | var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray();
|
---|
| 189 | var ds = new ModifiableDataset(new string[] { mostImportantVar.Key },
|
---|
| 190 | new IList[] { mostImportantVarValues.ToList<double>() });
|
---|
| 191 |
|
---|
| 192 | var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray();
|
---|
| 193 |
|
---|
| 194 | for (int i = 0; i < mostImportantVarValues.Length; i += 10) {
|
---|
| 195 | Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]);
|
---|
| 196 | }
|
---|
| 197 | }
|
---|
| 198 |
|
---|
| 199 | [TestMethod]
|
---|
| 200 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 201 | [TestProperty("Time", "short")]
|
---|
| 202 | public void TestDecisionTreePersistence() {
|
---|
| 203 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 204 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 205 | var regProblem = new RegressionProblem();
|
---|
| 206 | regProblem.Load(provider.LoadData(instance));
|
---|
| 207 | var problemData = regProblem.ProblemData;
|
---|
| 208 | var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1);
|
---|
| 209 | GradientBoostedTreesAlgorithmStatic.MakeStep(state);
|
---|
| 210 |
|
---|
| 211 | var model = ((IGradientBoostedTreesModel)state.GetModel());
|
---|
| 212 | var treeM = model.Models.Skip(1).First();
|
---|
| 213 | var origStr = treeM.ToString();
|
---|
| 214 | using (var memStream = new MemoryStream()) {
|
---|
| 215 | Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream);
|
---|
| 216 | var buf = memStream.GetBuffer();
|
---|
| 217 | using (var restoreStream = new MemoryStream(buf)) {
|
---|
| 218 | var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream);
|
---|
| 219 | var restoredStr = restoredTree.ToString();
|
---|
| 220 | Assert.AreEqual(origStr, restoredStr);
|
---|
| 221 | }
|
---|
| 222 | }
|
---|
| 223 | }
|
---|
| 224 |
|
---|
| 225 | [TestMethod]
|
---|
| 226 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
[12632] | 227 | [TestProperty("Time", "long")]
|
---|
| 228 | public void GradientBoostingTestTowerSquaredError() {
|
---|
| 229 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 230 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 231 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 232 | var regProblem = new RegressionProblem();
|
---|
| 233 | regProblem.Load(provider.LoadData(instance));
|
---|
| 234 |
|
---|
| 235 | #region Algorithm Configuration
|
---|
| 236 | gbt.Problem = regProblem;
|
---|
| 237 | gbt.Seed = 0;
|
---|
| 238 | gbt.SetSeedRandomly = false;
|
---|
| 239 | gbt.Iterations = 5000;
|
---|
| 240 | gbt.MaxSize = 20;
|
---|
[12699] | 241 | gbt.CreateSolution = false;
|
---|
[12632] | 242 | #endregion
|
---|
| 243 |
|
---|
| 244 | RunAlgorithm(gbt);
|
---|
| 245 |
|
---|
[12699] | 246 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 247 | Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 248 | Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 249 | }
|
---|
| 250 |
|
---|
| 251 | [TestMethod]
|
---|
| 252 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 253 | [TestProperty("Time", "long")]
|
---|
| 254 | public void GradientBoostingTestTowerAbsoluteError() {
|
---|
| 255 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 256 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 257 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 258 | var regProblem = new RegressionProblem();
|
---|
| 259 | regProblem.Load(provider.LoadData(instance));
|
---|
| 260 |
|
---|
| 261 | #region Algorithm Configuration
|
---|
| 262 | gbt.Problem = regProblem;
|
---|
| 263 | gbt.Seed = 0;
|
---|
| 264 | gbt.SetSeedRandomly = false;
|
---|
| 265 | gbt.Iterations = 1000;
|
---|
| 266 | gbt.MaxSize = 20;
|
---|
| 267 | gbt.Nu = 0.02;
|
---|
| 268 | gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
|
---|
[12699] | 269 | gbt.CreateSolution = false;
|
---|
[12632] | 270 | #endregion
|
---|
| 271 |
|
---|
| 272 | RunAlgorithm(gbt);
|
---|
| 273 |
|
---|
[12699] | 274 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 275 | Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 276 | Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 277 | }
|
---|
| 278 |
|
---|
| 279 | [TestMethod]
|
---|
| 280 | [TestCategory("Algorithms.DataAnalysis")]
|
---|
| 281 | [TestProperty("Time", "long")]
|
---|
| 282 | public void GradientBoostingTestTowerRelativeError() {
|
---|
| 283 | var gbt = new GradientBoostedTreesAlgorithm();
|
---|
| 284 | var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
|
---|
| 285 | var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
|
---|
| 286 | var regProblem = new RegressionProblem();
|
---|
| 287 | regProblem.Load(provider.LoadData(instance));
|
---|
| 288 |
|
---|
| 289 | #region Algorithm Configuration
|
---|
| 290 | gbt.Problem = regProblem;
|
---|
| 291 | gbt.Seed = 0;
|
---|
| 292 | gbt.SetSeedRandomly = false;
|
---|
| 293 | gbt.Iterations = 3000;
|
---|
| 294 | gbt.MaxSize = 20;
|
---|
| 295 | gbt.Nu = 0.005;
|
---|
| 296 | gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
|
---|
[12699] | 297 | gbt.CreateSolution = false;
|
---|
[12632] | 298 | #endregion
|
---|
| 299 |
|
---|
| 300 | RunAlgorithm(gbt);
|
---|
| 301 |
|
---|
[12699] | 302 | Console.WriteLine(gbt.ExecutionTime);
|
---|
[12632] | 303 | Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
|
---|
| 304 | Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
|
---|
| 305 | }
|
---|
| 306 |
|
---|
| 307 | // same as in SamplesUtil
|
---|
| 308 | private void RunAlgorithm(IAlgorithm a) {
|
---|
| 309 | var trigger = new EventWaitHandle(false, EventResetMode.ManualReset);
|
---|
| 310 | Exception ex = null;
|
---|
| 311 | a.Stopped += (src, e) => { trigger.Set(); };
|
---|
| 312 | a.ExceptionOccurred += (src, e) => { ex = e.Value; trigger.Set(); };
|
---|
| 313 | a.Prepare();
|
---|
| 314 | a.Start();
|
---|
| 315 | trigger.WaitOne();
|
---|
| 316 |
|
---|
| 317 | Assert.AreEqual(ex, null);
|
---|
| 318 | }
|
---|
| 319 |
|
---|
| 320 | #region helper
|
---|
[12661] | 321 | private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
|
---|
[12620] | 322 | int nRows = xy.GetLength(0);
|
---|
| 323 | var allowedInputs = allVariables.Skip(1);
|
---|
| 324 | var dataset = new Dataset(allVariables, xy);
|
---|
| 325 | var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
|
---|
| 326 | problemData.TrainingPartition.Start = 0;
|
---|
| 327 | problemData.TrainingPartition.End = nRows;
|
---|
| 328 | problemData.TestPartition.Start = nRows;
|
---|
| 329 | problemData.TestPartition.End = nRows;
|
---|
[12661] | 330 | var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
|
---|
[13184] | 331 | var model = solution.Model;
|
---|
[12620] | 332 | var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
|
---|
[12658] | 333 |
|
---|
| 334 | Console.WriteLine(treeM.ToString());
|
---|
[12620] | 335 | Console.WriteLine();
|
---|
| 336 | }
|
---|
[12632] | 337 | #endregion
|
---|
[12620] | 338 | }
|
---|
| 339 | }
|
---|