Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs @ 16045

Last change on this file since 16045 was 15287, checked in by jkarder, 7 years ago

#2258: merged Async branch into trunk

File size: 11.2 KB
Line 
1using System;
2using System.Collections;
3using System.IO;
4using System.Linq;
5using HeuristicLab.Data;
6using HeuristicLab.Problems.DataAnalysis;
7using Microsoft.VisualStudio.TestTools.UnitTesting;
8
9namespace HeuristicLab.Algorithms.DataAnalysis {
10  [TestClass()]
11  public class GradientBoostingTest {
12    [TestMethod]
13    [TestCategory("Algorithms.DataAnalysis")]
14    [TestProperty("Time", "short")]
15    public void DecisionTreeTest() {
16      {
17        var xy = new double[,]
18        {
19          {-1, 20, 0},
20          {-1, 20, 0},
21          { 1, 10, 0},
22          { 1, 10, 0},
23        };
24        var allVariables = new string[] { "y", "x1", "x2" };
25
26        // x1 <= 15 -> 1
27        // x1 >  15 -> -1
28        BuildTree(xy, allVariables, 10);
29      }
30
31
32      {
33        var xy = new double[,]
34        {
35          {-1, 20,  1},
36          {-1, 20, -1},
37          { 1, 10, -1},
38          { 1, 10, 1},
39        };
40        var allVariables = new string[] { "y", "x1", "x2" };
41
42        // ignore irrelevant variables
43        // x1 <= 15 -> 1
44        // x1 >  15 -> -1
45        BuildTree(xy, allVariables, 10);
46      }
47
48      {
49        // split must be by x1 first
50        var xy = new double[,]
51        {
52          {-2, 20,  1},
53          {-1, 20, -1},
54          { 1, 10, -1},
55          { 2, 10, 1},
56        };
57
58        var allVariables = new string[] { "y", "x1", "x2" };
59
60        // x1 <= 15 AND x2 <= 0 -> 1
61        // x1 <= 15 AND x2 >  0 -> 2
62        // x1 >  15 AND x2 <= 0 -> -1
63        // x1 >  15 AND x2 >  0 -> -2
64        BuildTree(xy, allVariables, 10);
65      }
66
67      {
68        // averaging ys
69        var xy = new double[,]
70        {
71          {-2.5, 20,  1},
72          {-1.5, 20,  1},
73          {-1.5, 20, -1},
74          {-0.5, 20, -1},
75          {0.5, 10, -1},
76          {1.5, 10, -1},
77          {1.5, 10, 1},
78          {2.5, 10, 1},
79        };
80
81        var allVariables = new string[] { "y", "x1", "x2" };
82
83        // x1 <= 15 AND x2 <= 0 -> 1
84        // x1 <= 15 AND x2 >  0 -> 2
85        // x1 >  15 AND x2 <= 0 -> -1
86        // x1 >  15 AND x2 >  0 -> -2
87        BuildTree(xy, allVariables, 10);
88      }
89
90
91      {
92        // diagonal split (no split possible)
93        var xy = new double[,]
94        {
95          { 1, 1, 1},
96          {-1, 1, 2},
97          {-1, 2, 1},
98          { 1, 2, 2},
99        };
100
101        var allVariables = new string[] { "y", "x1", "x2" };
102
103        // split cannot be found
104        // -> 0.0
105        BuildTree(xy, allVariables, 3);
106      }
107      {
108        // almost diagonal split
109        var xy = new double[,]
110        {
111          { 1, 1, 1},
112          {-1, 1, 2},
113          {-1, 2, 1},
114          { 1.0001, 2, 2},
115        };
116
117        var allVariables = new string[] { "y", "x1", "x2" };
118        // (two possible solutions)
119        // x2 <= 1.5 -> 0
120        // x2 >  1.5 -> 0 (not quite)
121        BuildTree(xy, allVariables, 3);
122
123        // x1 <= 1.5 AND x2 <= 1.5 -> 1
124        // x1 <= 1.5 AND x2 >  1.5 -> -1
125        // x1 >  1.5 AND x2 <= 1.5 -> -1
126        // x1 >  1.5 AND x2 >  1.5 -> 1 (not quite)
127        BuildTree(xy, allVariables, 7);
128      }
129      {
130        // unbalanced split
131        var xy = new double[,]
132        {
133          {-1, 1, 1},
134          {-1, 1, 2},
135          {0.9, 2, 1},
136          {1.1, 2, 2},
137        };
138
139        var allVariables = new string[] { "y", "x1", "x2" };
140        // x1 <= 1.5 -> -1.0
141        // x1 >  1.5 AND x2 <= 1.5 -> 0.9
142        // x1 >  1.5 AND x2 >  1.5 -> 1.1
143        BuildTree(xy, allVariables, 10);
144      }
145
146      {
147        // unbalanced split
148        var xy = new double[,]
149        {
150          {-1, 1, 1},
151          {-1, 1, 2},
152          {-1, 2, 1},
153          { 3, 2, 2},
154        };
155
156        var allVariables = new string[] { "y", "x1", "x2" };
157        // (two possible solutions)
158        // x2 <= 1.5 -> -1.0
159        // x2 >  1.5 AND x1 <= 1.5 -> -1.0
160        // x2 >  1.5 AND x1 >  1.5 ->  3.0
161        BuildTree(xy, allVariables, 10);
162      }
163    }
164
165    [TestMethod]
166    [TestCategory("Algorithms.DataAnalysis")]
167    [TestProperty("Time", "short")]
168    public void TestDecisionTreePartialDependence() {
169      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
170      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
171      var regProblem = new RegressionProblem();
172      regProblem.Load(provider.LoadData(instance));
173      var problemData = regProblem.ProblemData;
174      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02);
175      for (int i = 0; i < 1000; i++)
176        GradientBoostedTreesAlgorithmStatic.MakeStep(state);
177
178
179      var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First();
180      Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value);
181      var model = ((IGradientBoostedTreesModel)state.GetModel());
182      var treeM = model.Models.Skip(1).First();
183      Console.WriteLine(treeM.ToString());
184      Console.WriteLine();
185
186      var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray();
187      var ds = new ModifiableDataset(new string[] { mostImportantVar.Key },
188        new IList[] { mostImportantVarValues.ToList<double>() });
189
190      var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray();
191
192      for (int i = 0; i < mostImportantVarValues.Length; i += 10) {
193        Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]);
194      }
195    }
196
197    [TestMethod]
198    [TestCategory("Algorithms.DataAnalysis")]
199    [TestProperty("Time", "short")]
200    public void TestDecisionTreePersistence() {
201      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
202      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
203      var regProblem = new RegressionProblem();
204      regProblem.Load(provider.LoadData(instance));
205      var problemData = regProblem.ProblemData;
206      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1);
207      GradientBoostedTreesAlgorithmStatic.MakeStep(state);
208
209      var model = ((IGradientBoostedTreesModel)state.GetModel());
210      var treeM = model.Models.Skip(1).First();
211      var origStr = treeM.ToString();
212      using (var memStream = new MemoryStream()) {
213        Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream);
214        var buf = memStream.GetBuffer();
215        using (var restoreStream = new MemoryStream(buf)) {
216          var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream);
217          var restoredStr = restoredTree.ToString();
218          Assert.AreEqual(origStr, restoredStr);
219        }
220      }
221    }
222
223    [TestMethod]
224    [TestCategory("Algorithms.DataAnalysis")]
225    [TestProperty("Time", "long")]
226    public void GradientBoostingTestTowerSquaredError() {
227      var gbt = new GradientBoostedTreesAlgorithm();
228      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
229      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
230      var regProblem = new RegressionProblem();
231      regProblem.Load(provider.LoadData(instance));
232
233      #region Algorithm Configuration
234      gbt.Problem = regProblem;
235      gbt.Seed = 0;
236      gbt.SetSeedRandomly = false;
237      gbt.Iterations = 5000;
238      gbt.MaxSize = 20;
239      gbt.CreateSolution = false;
240      #endregion
241
242      gbt.Start();
243
244      Console.WriteLine(gbt.ExecutionTime);
245      Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
246      Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
247    }
248
249    [TestMethod]
250    [TestCategory("Algorithms.DataAnalysis")]
251    [TestProperty("Time", "long")]
252    public void GradientBoostingTestTowerAbsoluteError() {
253      var gbt = new GradientBoostedTreesAlgorithm();
254      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
255      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
256      var regProblem = new RegressionProblem();
257      regProblem.Load(provider.LoadData(instance));
258
259      #region Algorithm Configuration
260      gbt.Problem = regProblem;
261      gbt.Seed = 0;
262      gbt.SetSeedRandomly = false;
263      gbt.Iterations = 1000;
264      gbt.MaxSize = 20;
265      gbt.Nu = 0.02;
266      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
267      gbt.CreateSolution = false;
268      #endregion
269
270      gbt.Start();
271
272      Console.WriteLine(gbt.ExecutionTime);
273      Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
274      Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
275    }
276
277    [TestMethod]
278    [TestCategory("Algorithms.DataAnalysis")]
279    [TestProperty("Time", "long")]
280    public void GradientBoostingTestTowerRelativeError() {
281      var gbt = new GradientBoostedTreesAlgorithm();
282      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
283      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
284      var regProblem = new RegressionProblem();
285      regProblem.Load(provider.LoadData(instance));
286
287      #region Algorithm Configuration
288      gbt.Problem = regProblem;
289      gbt.Seed = 0;
290      gbt.SetSeedRandomly = false;
291      gbt.Iterations = 3000;
292      gbt.MaxSize = 20;
293      gbt.Nu = 0.005;
294      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
295      gbt.CreateSolution = false;
296      #endregion
297
298      gbt.Start();
299
300      Console.WriteLine(gbt.ExecutionTime);
301      Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
302      Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
303    }
304
305    #region helper
306    private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
307      int nRows = xy.GetLength(0);
308      var allowedInputs = allVariables.Skip(1);
309      var dataset = new Dataset(allVariables, xy);
310      var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
311      problemData.TrainingPartition.Start = 0;
312      problemData.TrainingPartition.End = nRows;
313      problemData.TestPartition.Start = nRows;
314      problemData.TestPartition.End = nRows;
315      var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
316      var model = solution.Model;
317      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
318
319      Console.WriteLine(treeM.ToString());
320      Console.WriteLine();
321    }
322    #endregion
323  }
324}
Note: See TracBrowser for help on using the repository browser.