Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/16/15 19:49:40 (8 years ago)
Author:
gkronber
Message:

#2450: merged r12868,r12873,r12875,r13065:13066,r13157:13158 from trunk to stable

Location:
stable
Files:
2 edited
1 copied

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs

    r12868 r13184  
    2121#endregion
    2222
    23 using System;
    2423using System.Collections.Generic;
    25 using System.Linq;
    2624using HeuristicLab.Common;
    2725using HeuristicLab.Core;
    2826using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    29 using HeuristicLab.PluginInfrastructure;
    3027using HeuristicLab.Problems.DataAnalysis;
    3128
     
    3633  // recalculate the actual GBT model on demand
    3734  [Item("Gradient boosted tree model", "")]
    38   public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IRegressionModel {
     35  public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IGradientBoostedTreesModel {
    3936    // don't store the actual model!
    40     private IRegressionModel actualModel; // the actual model is only recalculated when necessary
     37    private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary
    4138
    4239    [Storable]
     
    4542    private readonly uint seed;
    4643    [Storable]
    47     private string lossFunctionName;
     44    private ILossFunction lossFunction;
    4845    [Storable]
    4946    private double r;
     
    6663
    6764      this.trainingProblemData = cloner.Clone(original.trainingProblemData);
     65      this.lossFunction = cloner.Clone(original.lossFunction);
    6866      this.seed = original.seed;
    69       this.lossFunctionName = original.lossFunctionName;
    7067      this.iterations = original.iterations;
    7168      this.maxSize = original.maxSize;
     
    7673
    7774    // create only the surrogate model without an actual model
    78     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu)
     75    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)
    7976      : base("Gradient boosted tree model", string.Empty) {
    8077      this.trainingProblemData = trainingProblemData;
    8178      this.seed = seed;
    82       this.lossFunctionName = lossFunctionName;
     79      this.lossFunction = lossFunction;
    8380      this.iterations = iterations;
    8481      this.maxSize = maxSize;
     
    8986
    9087    // wrap an actual model in a surrograte
    91     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)
    92       : this(trainingProblemData, seed, lossFunctionName, iterations, maxSize, r, m, nu) {
     88    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IGradientBoostedTreesModel model)
     89      : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) {
    9390      this.actualModel = model;
    9491    }
     
    109106
    110107
    111     private IRegressionModel RecalculateModel() {
    112       var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>().Single(l => l.ToString() == lossFunctionName);
     108    private IGradientBoostedTreesModel RecalculateModel() {
    113109      return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model;
     110    }
     111
     112    public IEnumerable<IRegressionModel> Models {
     113      get {
     114        if (actualModel == null) actualModel = RecalculateModel();
     115        return actualModel.Models;
     116      }
     117    }
     118
     119    public IEnumerable<double> Weights {
     120      get {
     121        if (actualModel == null) actualModel = RecalculateModel();
     122        return actualModel.Weights;
     123      }
    114124    }
    115125  }
Note: See TracChangeset for help on using the changeset viewer.