Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/15/16 10:02:15 (8 years ago)
Author:
gkronber
Message:

#2612: extended GBT to support calculation of partial dependence (as described in the greedy function approximation paper), changed persistence of regression tree models and added two unit tests.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs

    r13157 r13895  
    11using System;
     2using System.Collections;
     3using System.IO;
    24using System.Linq;
    35using System.Threading;
     
    160162        // x2 >  1.5 AND x1 >  1.5 ->  3.0
    161163        BuildTree(xy, allVariables, 10);
     164      }
     165    }
     166
     167    [TestMethod]
     168    [TestCategory("Algorithms.DataAnalysis")]
     169    [TestProperty("Time", "short")]
     170    public void TestDecisionTreePartialDependence() {
     171      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
     172      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
     173      var regProblem = new RegressionProblem();
     174      regProblem.Load(provider.LoadData(instance));
     175      var problemData = regProblem.ProblemData;
     176      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02);
     177      for (int i = 0; i < 1000; i++)
     178        GradientBoostedTreesAlgorithmStatic.MakeStep(state);
     179
     180
     181      var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First();
     182      Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value);
     183      var model = ((IGradientBoostedTreesModel)state.GetModel());
     184      var treeM = model.Models.Skip(1).First();
     185      Console.WriteLine(treeM.ToString());
     186      Console.WriteLine();
     187
     188      var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray();
     189      var ds = new ModifiableDataset(new string[] { mostImportantVar.Key },
     190        new IList[] { mostImportantVarValues.ToList<double>() });
     191
     192      var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray();
     193
     194      for (int i = 0; i < mostImportantVarValues.Length; i += 10) {
     195        Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]);
     196      }
     197    }
     198
     199    [TestMethod]
     200    [TestCategory("Algorithms.DataAnalysis")]
     201    [TestProperty("Time", "short")]
     202    public void TestDecisionTreePersistence() {
     203      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
     204      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
     205      var regProblem = new RegressionProblem();
     206      regProblem.Load(provider.LoadData(instance));
     207      var problemData = regProblem.ProblemData;
     208      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1);
     209      GradientBoostedTreesAlgorithmStatic.MakeStep(state);
     210
     211      var model = ((IGradientBoostedTreesModel)state.GetModel());
     212      var treeM = model.Models.Skip(1).First();
     213      var origStr = treeM.ToString();
     214      using (var memStream = new MemoryStream()) {
     215        Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream);
     216        var buf = memStream.GetBuffer();
     217        using (var restoreStream = new MemoryStream(buf)) {
     218          var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream);
     219          var restoredStr = restoredTree.ToString();
     220          Assert.AreEqual(origStr, restoredStr);
     221        }
    162222      }
    163223    }
Note: See TracChangeset for help on using the changeset viewer.