Changeset 13157


Ignore:
Timestamp:
11/14/15 08:10:18 (3 years ago)
Author:
gkronber
Message:

#2450 made the changes suggested by mkommend in the review. This is definitely a big improvement, thx!

Location:
trunk/sources
Files:
2 added
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r13100 r13157  
    130130
    131131    // simple interface
    132     public static IRegressionSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxSize, double nu, double r, double m, int maxIterations, uint randSeed = 31415) {
     132    public static GradientBoostedTreesSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxSize, double nu, double r, double m, int maxIterations, uint randSeed = 31415) {
    133133      Contract.Assert(r > 0);
    134134      Contract.Assert(r <= 1.0);
     
    143143
    144144      var model = state.GetModel();
    145       return new RegressionSolution(model, (IRegressionProblemData)problemData.Clone());
     145      return new GradientBoostedTreesSolution(model, (IRegressionProblemData)problemData.Clone());
    146146    }
    147147
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs

    r13065 r13157  
    3333  [Item("Gradient boosted tree model", "")]
    3434  // this is essentially a collection of weighted regression models
    35   public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {
     35  public sealed class GradientBoostedTreesModel : NamedItem, IGradientBoostedTreesModel {
    3636    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
    3737    #region Backwards compatible code, remove with 3.5
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs

    r13066 r13157  
    3333  // recalculate the actual GBT model on demand
    3434  [Item("Gradient boosted tree model", "")]
    35   public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IRegressionModel {
     35  public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IGradientBoostedTreesModel {
    3636    // don't store the actual model!
    37     private IRegressionModel actualModel; // the actual model is only recalculated when necessary
    38     public IRegressionModel Model { get { return actualModel; } }
     37    private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary
    3938
    4039    [Storable]
     
    8786
    8887    // wrap an actual model in a surrograte
    89     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)
     88    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IGradientBoostedTreesModel model)
    9089      : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) {
    9190      this.actualModel = model;
     
    107106
    108107
    109     private IRegressionModel RecalculateModel() {
     108    private IGradientBoostedTreesModel RecalculateModel() {
    110109      return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model;
     110    }
     111
     112    public IEnumerable<IRegressionModel> Models {
     113      get {
     114        if (actualModel == null) actualModel = RecalculateModel();
     115        return actualModel.Models;
     116      }
     117    }
     118
     119    public IEnumerable<double> Weights {
     120      get {
     121        if (actualModel == null) actualModel = RecalculateModel();
     122        return actualModel.Weights;
     123      }
    111124    }
    112125  }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r13120 r13157  
    201201    <Compile Include="GaussianProcess\GaussianProcessRegressionSolution.cs" />
    202202    <Compile Include="GaussianProcess\ICovarianceFunction.cs" />
     203    <Compile Include="GradientBoostedTrees\IGradientBoostedTreesModel.cs" />
    203204    <Compile Include="GradientBoostedTrees\GradientBoostedTreesModelSurrogate.cs" />
    204205    <Compile Include="GradientBoostedTrees\GradientBoostedTreesAlgorithm.cs" />
     
    211212    <Compile Include="GradientBoostedTrees\LossFunctions\RelativeErrorLoss.cs" />
    212213    <Compile Include="GradientBoostedTrees\LossFunctions\SquaredErrorLoss.cs" />
     214    <Compile Include="GradientBoostedTrees\GradientBoostedTreesSolution.cs" />
    213215    <Compile Include="GradientBoostedTrees\RegressionTreeBuilder.cs" />
    214216    <Compile Include="GradientBoostedTrees\RegressionTreeModel.cs" />
  • trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs

    r13066 r13157  
    269269      problemData.TestPartition.End = nRows;
    270270      var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
    271       var model = (GradientBoostedTreesModel)((GradientBoostedTreesModelSurrogate)solution.Model).Model;
     271      var model = solution.Model;
    272272      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
    273273
Note: See TracChangeset for help on using the changeset viewer.