Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs @ 15287

Last change on this file since 15287 was 15287, checked in by jkarder, 7 years ago

#2258: merged Async branch into trunk

File size: 11.2 KB
RevLine 
[12620]1using System;
[13895]2using System.Collections;
3using System.IO;
[12620]4using System.Linq;
[12632]5using HeuristicLab.Data;
[12620]6using HeuristicLab.Problems.DataAnalysis;
7using Microsoft.VisualStudio.TestTools.UnitTesting;
8
[12658]9namespace HeuristicLab.Algorithms.DataAnalysis {
[12620]10  [TestClass()]
[12710]11  public class GradientBoostingTest {
[12620]12    [TestMethod]
13    [TestCategory("Algorithms.DataAnalysis")]
14    [TestProperty("Time", "short")]
15    public void DecisionTreeTest() {
16      {
17        var xy = new double[,]
18        {
[12658]19          {-1, 20, 0},
20          {-1, 20, 0},
21          { 1, 10, 0},
22          { 1, 10, 0},
[12620]23        };
24        var allVariables = new string[] { "y", "x1", "x2" };
25
[12658]26        // x1 <= 15 -> 1
27        // x1 >  15 -> -1
[12620]28        BuildTree(xy, allVariables, 10);
29      }
30
31
32      {
33        var xy = new double[,]
34        {
[12658]35          {-1, 20,  1},
36          {-1, 20, -1},
37          { 1, 10, -1},
38          { 1, 10, 1},
[12620]39        };
40        var allVariables = new string[] { "y", "x1", "x2" };
41
42        // ignore irrelevant variables
[12658]43        // x1 <= 15 -> 1
44        // x1 >  15 -> -1
[12620]45        BuildTree(xy, allVariables, 10);
46      }
47
48      {
49        // split must be by x1 first
50        var xy = new double[,]
51        {
[12658]52          {-2, 20,  1},
53          {-1, 20, -1},
54          { 1, 10, -1},
55          { 2, 10, 1},
[12620]56        };
57
58        var allVariables = new string[] { "y", "x1", "x2" };
59
[12658]60        // x1 <= 15 AND x2 <= 0 -> 1
61        // x1 <= 15 AND x2 >  0 -> 2
62        // x1 >  15 AND x2 <= 0 -> -1
63        // x1 >  15 AND x2 >  0 -> -2
[12620]64        BuildTree(xy, allVariables, 10);
65      }
66
67      {
68        // averaging ys
69        var xy = new double[,]
70        {
[12658]71          {-2.5, 20,  1},
72          {-1.5, 20,  1},
73          {-1.5, 20, -1},
[12661]74          {-0.5, 20, -1},
[12658]75          {0.5, 10, -1},
76          {1.5, 10, -1},
77          {1.5, 10, 1},
78          {2.5, 10, 1},
[12620]79        };
80
81        var allVariables = new string[] { "y", "x1", "x2" };
82
[12658]83        // x1 <= 15 AND x2 <= 0 -> 1
84        // x1 <= 15 AND x2 >  0 -> 2
85        // x1 >  15 AND x2 <= 0 -> -1
86        // x1 >  15 AND x2 >  0 -> -2
[12620]87        BuildTree(xy, allVariables, 10);
88      }
89
90
91      {
92        // diagonal split (no split possible)
93        var xy = new double[,]
94        {
[12658]95          { 1, 1, 1},
96          {-1, 1, 2},
97          {-1, 2, 1},
98          { 1, 2, 2},
[12620]99        };
100
101        var allVariables = new string[] { "y", "x1", "x2" };
102
103        // split cannot be found
[12658]104        // -> 0.0
[12620]105        BuildTree(xy, allVariables, 3);
106      }
107      {
108        // almost diagonal split
109        var xy = new double[,]
110        {
[12658]111          { 1, 1, 1},
112          {-1, 1, 2},
113          {-1, 2, 1},
114          { 1.0001, 2, 2},
[12620]115        };
116
117        var allVariables = new string[] { "y", "x1", "x2" };
[12632]118        // (two possible solutions)
[12658]119        // x2 <= 1.5 -> 0
120        // x2 >  1.5 -> 0 (not quite)
[12632]121        BuildTree(xy, allVariables, 3);
122
[12658]123        // x1 <= 1.5 AND x2 <= 1.5 -> 1
124        // x1 <= 1.5 AND x2 >  1.5 -> -1
125        // x1 >  1.5 AND x2 <= 1.5 -> -1
126        // x1 >  1.5 AND x2 >  1.5 -> 1 (not quite)
[12632]127        BuildTree(xy, allVariables, 7);
[12620]128      }
129      {
130        // unbalanced split
131        var xy = new double[,]
132        {
133          {-1, 1, 1},
134          {-1, 1, 2},
135          {0.9, 2, 1},
136          {1.1, 2, 2},
137        };
138
139        var allVariables = new string[] { "y", "x1", "x2" };
140        // x1 <= 1.5 -> -1.0
141        // x1 >  1.5 AND x2 <= 1.5 -> 0.9
142        // x1 >  1.5 AND x2 >  1.5 -> 1.1
[12632]143        BuildTree(xy, allVariables, 10);
[12620]144      }
145
[12632]146      {
147        // unbalanced split
148        var xy = new double[,]
149        {
150          {-1, 1, 1},
151          {-1, 1, 2},
152          {-1, 2, 1},
[12658]153          { 3, 2, 2},
[12632]154        };
155
156        var allVariables = new string[] { "y", "x1", "x2" };
157        // (two possible solutions)
158        // x2 <= 1.5 -> -1.0
159        // x2 >  1.5 AND x1 <= 1.5 -> -1.0
[12658]160        // x2 >  1.5 AND x1 >  1.5 ->  3.0
[12632]161        BuildTree(xy, allVariables, 10);
162      }
[12620]163    }
164
[12632]165    [TestMethod]
166    [TestCategory("Algorithms.DataAnalysis")]
[13895]167    [TestProperty("Time", "short")]
168    public void TestDecisionTreePartialDependence() {
169      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
170      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
171      var regProblem = new RegressionProblem();
172      regProblem.Load(provider.LoadData(instance));
173      var problemData = regProblem.ProblemData;
174      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02);
175      for (int i = 0; i < 1000; i++)
176        GradientBoostedTreesAlgorithmStatic.MakeStep(state);
177
178
179      var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First();
180      Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value);
181      var model = ((IGradientBoostedTreesModel)state.GetModel());
182      var treeM = model.Models.Skip(1).First();
183      Console.WriteLine(treeM.ToString());
184      Console.WriteLine();
185
186      var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray();
187      var ds = new ModifiableDataset(new string[] { mostImportantVar.Key },
188        new IList[] { mostImportantVarValues.ToList<double>() });
189
190      var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray();
191
192      for (int i = 0; i < mostImportantVarValues.Length; i += 10) {
193        Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]);
194      }
195    }
196
197    [TestMethod]
198    [TestCategory("Algorithms.DataAnalysis")]
199    [TestProperty("Time", "short")]
200    public void TestDecisionTreePersistence() {
201      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
202      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
203      var regProblem = new RegressionProblem();
204      regProblem.Load(provider.LoadData(instance));
205      var problemData = regProblem.ProblemData;
206      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1);
207      GradientBoostedTreesAlgorithmStatic.MakeStep(state);
208
209      var model = ((IGradientBoostedTreesModel)state.GetModel());
210      var treeM = model.Models.Skip(1).First();
211      var origStr = treeM.ToString();
212      using (var memStream = new MemoryStream()) {
213        Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream);
214        var buf = memStream.GetBuffer();
215        using (var restoreStream = new MemoryStream(buf)) {
216          var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream);
217          var restoredStr = restoredTree.ToString();
218          Assert.AreEqual(origStr, restoredStr);
219        }
220      }
221    }
222
223    [TestMethod]
224    [TestCategory("Algorithms.DataAnalysis")]
[12632]225    [TestProperty("Time", "long")]
226    public void GradientBoostingTestTowerSquaredError() {
227      var gbt = new GradientBoostedTreesAlgorithm();
228      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
229      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
230      var regProblem = new RegressionProblem();
231      regProblem.Load(provider.LoadData(instance));
232
233      #region Algorithm Configuration
234      gbt.Problem = regProblem;
235      gbt.Seed = 0;
236      gbt.SetSeedRandomly = false;
237      gbt.Iterations = 5000;
238      gbt.MaxSize = 20;
[12699]239      gbt.CreateSolution = false;
[12632]240      #endregion
241
[15287]242      gbt.Start();
[12632]243
[12699]244      Console.WriteLine(gbt.ExecutionTime);
[12632]245      Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
246      Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
247    }
248
249    [TestMethod]
250    [TestCategory("Algorithms.DataAnalysis")]
251    [TestProperty("Time", "long")]
252    public void GradientBoostingTestTowerAbsoluteError() {
253      var gbt = new GradientBoostedTreesAlgorithm();
254      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
255      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
256      var regProblem = new RegressionProblem();
257      regProblem.Load(provider.LoadData(instance));
258
259      #region Algorithm Configuration
260      gbt.Problem = regProblem;
261      gbt.Seed = 0;
262      gbt.SetSeedRandomly = false;
263      gbt.Iterations = 1000;
264      gbt.MaxSize = 20;
265      gbt.Nu = 0.02;
266      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
[12699]267      gbt.CreateSolution = false;
[12632]268      #endregion
269
[15287]270      gbt.Start();
[12632]271
[12699]272      Console.WriteLine(gbt.ExecutionTime);
[12632]273      Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
274      Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
275    }
276
277    [TestMethod]
278    [TestCategory("Algorithms.DataAnalysis")]
279    [TestProperty("Time", "long")]
280    public void GradientBoostingTestTowerRelativeError() {
281      var gbt = new GradientBoostedTreesAlgorithm();
282      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
283      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
284      var regProblem = new RegressionProblem();
285      regProblem.Load(provider.LoadData(instance));
286
287      #region Algorithm Configuration
288      gbt.Problem = regProblem;
289      gbt.Seed = 0;
290      gbt.SetSeedRandomly = false;
291      gbt.Iterations = 3000;
292      gbt.MaxSize = 20;
293      gbt.Nu = 0.005;
294      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
[12699]295      gbt.CreateSolution = false;
[12632]296      #endregion
297
[15287]298      gbt.Start();
[12632]299
[12699]300      Console.WriteLine(gbt.ExecutionTime);
[12632]301      Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
302      Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
303    }
304
305    #region helper
[12661]306    private void BuildTree(double[,] xy, string[] allVariables, int maxSize) {
[12620]307      int nRows = xy.GetLength(0);
308      var allowedInputs = allVariables.Skip(1);
309      var dataset = new Dataset(allVariables, xy);
310      var problemData = new RegressionProblemData(dataset, allowedInputs, allVariables.First());
311      problemData.TrainingPartition.Start = 0;
312      problemData.TrainingPartition.End = nRows;
313      problemData.TestPartition.Start = nRows;
314      problemData.TestPartition.End = nRows;
[12661]315      var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
[13157]316      var model = solution.Model;
[12620]317      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
[12658]318
319      Console.WriteLine(treeM.ToString());
[12620]320      Console.WriteLine();
321    }
[12632]322    #endregion
[12620]323  }
324}
Note: See TracBrowser for help on using the repository browser.