Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs @ 15199

Last change on this file since 15199 was 13895, checked in by gkronber, 8 years ago

#2612: extended GBT to support calculation of partial dependence (as described in the greedy function approximation paper), changed persistence of regression tree models and added two unit tests.

File size: 11.7 KB
RevLine 
[12620]1using System;
[13895]2using System.Collections;
3using System.IO;
[12620]4using System.Linq;
[12632]5using System.Threading;
6using HeuristicLab.Data;
7using HeuristicLab.Optimization;
[12620]8using HeuristicLab.Problems.DataAnalysis;
9using Microsoft.VisualStudio.TestTools.UnitTesting;
10
[12658]11namespace HeuristicLab.Algorithms.DataAnalysis {
[12620]12  [TestClass()]
[12710]13  public class GradientBoostingTest {
[12620]14    [TestMethod]
15    [TestCategory("Algorithms.DataAnalysis")]
16    [TestProperty("Time", "short")]
17    public void DecisionTreeTest() {
18      {
19        var xy = new double[,]
20        {
[12658]21          {-1, 20, 0},
22          {-1, 20, 0},
23          { 1, 10, 0},
24          { 1, 10, 0},
[12620]25        };
26        var allVariables = new string[] { "y", "x1", "x2" };
27
[12658]28        // x1 <= 15 -> 1
29        // x1 >  15 -> -1
[12620]30        BuildTree(xy, allVariables, 10);
31      }
32
33
34      {
35        var xy = new double[,]
36        {
[12658]37          {-1, 20,  1},
38          {-1, 20, -1},
39          { 1, 10, -1},
40          { 1, 10, 1},
[12620]41        };
42        var allVariables = new string[] { "y", "x1", "x2" };
43
44        // ignore irrelevant variables
[12658]45        // x1 <= 15 -> 1
46        // x1 >  15 -> -1
[12620]47        BuildTree(xy, allVariables, 10);
48      }
49
50      {
51        // split must be by x1 first
52        var xy = new double[,]
53        {
[12658]54          {-2, 20,  1},
55          {-1, 20, -1},
56          { 1, 10, -1},
57          { 2, 10, 1},
[12620]58        };
59
60        var allVariables = new string[] { "y", "x1", "x2" };
61
[12658]62        // x1 <= 15 AND x2 <= 0 -> 1
63        // x1 <= 15 AND x2 >  0 -> 2
64        // x1 >  15 AND x2 <= 0 -> -1
65        // x1 >  15 AND x2 >  0 -> -2
[12620]66        BuildTree(xy, allVariables, 10);
67      }
68
69      {
70        // averaging ys
71        var xy = new double[,]
72        {
[12658]73          {-2.5, 20,  1},
74          {-1.5, 20,  1},
75          {-1.5, 20, -1},
[12661]76          {-0.5, 20, -1},
[12658]77          {0.5, 10, -1},
78          {1.5, 10, -1},
79          {1.5, 10, 1},
80          {2.5, 10, 1},
[12620]81        };
82
83        var allVariables = new string[] { "y", "x1", "x2" };
84
[12658]85        // x1 <= 15 AND x2 <= 0 -> 1
86        // x1 <= 15 AND x2 >  0 -> 2
87        // x1 >  15 AND x2 <= 0 -> -1
88        // x1 >  15 AND x2 >  0 -> -2
[12620]89        BuildTree(xy, allVariables, 10);
90      }
91
92
93      {
94        // diagonal split (no split possible)
95        var xy = new double[,]
96        {
[12658]97          { 1, 1, 1},
98          {-1, 1, 2},
99          {-1, 2, 1},
100          { 1, 2, 2},
[12620]101        };
102
103        var allVariables = new string[] { "y", "x1", "x2" };
104
105        // split cannot be found
[12658]106        // -> 0.0
[12620]107        BuildTree(xy, allVariables, 3);
108      }
109      {
110        // almost diagonal split
111        var xy = new double[,]
112        {
[12658]113          { 1, 1, 1},
114          {-1, 1, 2},
115          {-1, 2, 1},
116          { 1.0001, 2, 2},
[12620]117        };
118
119        var allVariables = new string[] { "y", "x1", "x2" };
[12632]120        // (two possible solutions)
[12658]121        // x2 <= 1.5 -> 0
122        // x2 >  1.5 -> 0 (not quite)
[12632]123        BuildTree(xy, allVariables, 3);
124
[12658]125        // x1 <= 1.5 AND x2 <= 1.5 -> 1
126        // x1 <= 1.5 AND x2 >  1.5 -> -1
127        // x1 >  1.5 AND x2 <= 1.5 -> -1
128        // x1 >  1.5 AND x2 >  1.5 -> 1 (not quite)
[12632]129        BuildTree(xy, allVariables, 7);
[12620]130      }
131      {
132        // unbalanced split
133        var xy = new double[,]
134        {
135          {-1, 1, 1},
136          {-1, 1, 2},
137          {0.9, 2, 1},
138          {1.1, 2, 2},
139        };
140
141        var allVariables = new string[] { "y", "x1", "x2" };
142        // x1 <= 1.5 -> -1.0
143        // x1 >  1.5 AND x2 <= 1.5 -> 0.9
144        // x1 >  1.5 AND x2 >  1.5 -> 1.1
[12632]145        BuildTree(xy, allVariables, 10);
[12620]146      }
147
[12632]148      {
149        // unbalanced split
150        var xy = new double[,]
151        {
152          {-1, 1, 1},
153          {-1, 1, 2},
154          {-1, 2, 1},
[12658]155          { 3, 2, 2},
[12632]156        };
157
158        var allVariables = new string[] { "y", "x1", "x2" };
159        // (two possible solutions)
160        // x2 <= 1.5 -> -1.0
161        // x2 >  1.5 AND x1 <= 1.5 -> -1.0
[12658]162        // x2 >  1.5 AND x1 >  1.5 ->  3.0
[12632]163        BuildTree(xy, allVariables, 10);
164      }
[12620]165    }
166
[12632]167    [TestMethod]
168    [TestCategory("Algorithms.DataAnalysis")]
[13895]169    [TestProperty("Time", "short")]
170    public void TestDecisionTreePartialDependence() {
171      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
172      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
173      var regProblem = new RegressionProblem();
174      regProblem.Load(provider.LoadData(instance));
175      var problemData = regProblem.ProblemData;
176      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02);
177      for (int i = 0; i < 1000; i++)
178        GradientBoostedTreesAlgorithmStatic.MakeStep(state);
179
180
181      var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First();
182      Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value);
183      var model = ((IGradientBoostedTreesModel)state.GetModel());
184      var treeM = model.Models.Skip(1).First();
185      Console.WriteLine(treeM.ToString());
186      Console.WriteLine();
187
188      var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray();
189      var ds = new ModifiableDataset(new string[] { mostImportantVar.Key },
190        new IList[] { mostImportantVarValues.ToList<double>() });
191
192      var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray();
193
194      for (int i = 0; i < mostImportantVarValues.Length; i += 10) {
195        Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]);
196      }
197    }
198
199    [TestMethod]
200    [TestCategory("Algorithms.DataAnalysis")]
201    [TestProperty("Time", "short")]
202    public void TestDecisionTreePersistence() {
203      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
204      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
205      var regProblem = new RegressionProblem();
206      regProblem.Load(provider.LoadData(instance));
207      var problemData = regProblem.ProblemData;
208      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1);
209      GradientBoostedTreesAlgorithmStatic.MakeStep(state);
210
211      var model = ((IGradientBoostedTreesModel)state.GetModel());
212      var treeM = model.Models.Skip(1).First();
213      var origStr = treeM.ToString();
214      using (var memStream = new MemoryStream()) {
215        Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream);
216        var buf = memStream.GetBuffer();
217        using (var restoreStream = new MemoryStream(buf)) {
218          var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream);
219          var restoredStr = restoredTree.ToString();
220          Assert.AreEqual(origStr, restoredStr);
221        }
222      }
223    }
224
225    [TestMethod]
226    [TestCategory("Algorithms.DataAnalysis")]
[12632]227    [TestProperty("Time", "long")]
228    public void GradientBoostingTestTowerSquaredError() {
229      var gbt = new GradientBoostedTreesAlgorithm();
230      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
231      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
232      var regProblem = new RegressionProblem();
233      regProblem.Load(provider.LoadData(instance));
234
235      #region Algorithm Configuration
236      gbt.Problem = regProblem;
237      gbt.Seed = 0;
238      gbt.SetSeedRandomly = false;
239      gbt.Iterations = 5000;
240      gbt.MaxSize = 20;
[12699]241      gbt.CreateSolution = false;
[12632]242      #endregion
243
244      RunAlgorithm(gbt);
245
[12699]246      Console.WriteLine(gbt.ExecutionTime);
[12632]247      Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
248      Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
249    }
250
251    [TestMethod]
252    [TestCategory("Algorithms.DataAnalysis")]
253    [TestProperty("Time", "long")]
254    public void GradientBoostingTestTowerAbsoluteError() {
255      var gbt = new GradientBoostedTreesAlgorithm();
256      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
257      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
258      var regProblem = new RegressionProblem();
259      regProblem.Load(provider.LoadData(instance));
260
261      #region Algorithm Configuration
262      gbt.Problem = regProblem;
263      gbt.Seed = 0;
264      gbt.SetSeedRandomly = false;
265      gbt.Iterations = 1000;
266      gbt.MaxSize = 20;
267      gbt.Nu = 0.02;
268      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
[12699]269      gbt.CreateSolution = false;
[12632]270      #endregion
271
272      RunAlgorithm(gbt);
273
[12699]274      Console.WriteLine(gbt.ExecutionTime);
[12632]275      Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
276      Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
277    }
278
279    [TestMethod]
280    [TestCategory("Algorithms.DataAnalysis")]
281    [TestProperty("Time", "long")]
282    public void GradientBoostingTestTowerRelativeError() {
283      var gbt = new GradientBoostedTreesAlgorithm();
284      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
285      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
286      var regProblem = new RegressionProblem();
287      regProblem.Load(provider.LoadData(instance));
288
289      #region Algorithm Configuration
290      gbt.Problem = regProblem;
291      gbt.Seed = 0;
292      gbt.SetSeedRandomly = false;
293      gbt.Iterations = 3000;
294      gbt.MaxSize = 20;
295      gbt.Nu = 0.005;
296      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
[12699]297      gbt.CreateSolution = false;
[12632]298      #endregion
299
300      RunAlgorithm(gbt);
301
[12699]302      Console.WriteLine(gbt.ExecutionTime);
[12632]303      Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
304      Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
305    }
306
307    // same as in SamplesUtil
308    private void RunAlgorithm(IAlgorithm a) {
309      var trigger = new EventWaitHandle(false, EventResetMode.ManualReset);
310      Exception ex = null;
311      a.Stopped += (src, e) => { trigger.Set(); };
312      a.ExceptionOccurred += (src, e) => { ex = e.Value; trigger.Set(); };
313      a.Prepare();
314      a.Start();
315      trigger.WaitOne();
316
317      Assert.AreEqual(ex, null);
318    }
319
320    #region helper
[12661]321    private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
[12620]322      int nRows = xy.GetLength(0);
323      var allowedInputs = allVariables.Skip(1);
324      var dataset = new Dataset(allVariables, xy);
325      var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
326      problemData.TrainingPartition.Start = 0;
327      problemData.TrainingPartition.End = nRows;
328      problemData.TestPartition.Start = nRows;
329      problemData.TestPartition.End = nRows;
[12661]330      var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
[13157]331      var model = solution.Model;
[12620]332      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
[12658]333
334      Console.WriteLine(treeM.ToString());
[12620]335      Console.WriteLine();
336    }
[12632]337    #endregion
[12620]338  }
339}
Note: See TracBrowser for help on using the repository browser.