Changeset 12868


Ignore:
Timestamp:
08/17/15 16:11:47 (4 years ago)
Author:
gkronber
Message:

#2450: introduced surrogate for GBT-models which recalculates the actual model on demand to improve persistence of GBT solutions

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
1 added
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r12632 r12868  
    233233      // produce solution
    234234      if (CreateSolution) {
     235        var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction.ToString(),
     236          Iterations, MaxSize, R, M, Nu, state.GetModel());
     237
    235238        // for logistic regression we produce a classification solution
    236239        if (lossFunction is LogisticRegressionLoss) {
    237           var model = new DiscriminantFunctionClassificationModel(state.GetModel(),
     240          var model = new DiscriminantFunctionClassificationModel(surrogateModel,
    238241            new AccuracyMaximizationThresholdCalculator());
    239242          var classificationProblemData = new ClassificationProblemData(problemData.Dataset,
     
    245248        } else {
    246249          // otherwise we produce a regression solution
    247           Results.Add(new Result("Solution", new RegressionSolution(state.GetModel(), problemData)));
     250          Results.Add(new Result("Solution", new RegressionSolution(surrogateModel, problemData)));
    248251        }
    249252      }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs

    r12660 r12868  
    3434  // this is essentially a collection of weighted regression models
    3535  public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {
    36     [Storable]
     36    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
     37    #region Backwards compatible code, remove with 3.5
     38    private bool isCompatibilityLoaded = false; // only set to true if the model is deserialized from the old format, needed to make sure that information is serialized again if it was loaded from the old format
     39
     40    [Storable(Name = "models")]
     41    private IList<IRegressionModel> __persistedModels {
     42      set {
     43        this.isCompatibilityLoaded = true;
     44        this.models.Clear();
     45        foreach (var m in value) this.models.Add(m);
     46      }
     47      get { if (this.isCompatibilityLoaded) return models; else return null; }
     48    }
     49    [Storable(Name = "weights")]
     50    private IList<double> __persistedWeights {
     51      set {
     52        this.isCompatibilityLoaded = true;
     53        this.weights.Clear();
     54        foreach (var w in value) this.weights.Add(w);
     55      }
     56      get { if (this.isCompatibilityLoaded) return weights; else return null; }
     57    }
     58    #endregion
     59
    3760    private readonly IList<IRegressionModel> models;
    3861    public IEnumerable<IRegressionModel> Models { get { return models; } }
    3962
    40     [Storable]
    4163    private readonly IList<double> weights;
    4264    public IEnumerable<double> Weights { get { return weights; } }
    4365
    4466    [StorableConstructor]
    45     private GradientBoostedTreesModel(bool deserializing) : base(deserializing) { }
     67    private GradientBoostedTreesModel(bool deserializing)
     68      : base(deserializing) {
     69      models = new List<IRegressionModel>();
     70      weights = new List<double>();
     71    }
    4672    private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
    4773      : base(original, cloner) {
    4874      this.weights = new List<double>(original.weights);
    4975      this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
     76      this.isCompatibilityLoaded = original.isCompatibilityLoaded;
    5077    }
    5178    public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
     
    6491      // allocate target array go over all models and add up weighted estimation for each row
    6592      if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable.
    66                                                           // (which essentially looks up indexes in a dictionary)
     93      // (which essentially looks up indexes in a dictionary)
    6794      var res = new double[rows.Count()];
    6895      for (int i = 0; i < models.Count; i++) {
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r12700 r12868  
    8282    [Storable]
    8383    // to prevent storing the references to data caches in nodes
     84    // seemingly it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) TODO
    8485    private Tuple<string, double, int, int>[] SerializedTree {
    8586      get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r12817 r12868  
    195195    <Compile Include="GaussianProcess\GaussianProcessRegressionSolution.cs" />
    196196    <Compile Include="GaussianProcess\ICovarianceFunction.cs" />
     197    <Compile Include="GradientBoostedTrees\GradientBoostedTreesModelSurrogate.cs" />
    197198    <Compile Include="GradientBoostedTrees\GradientBoostedTreesAlgorithm.cs" />
    198199    <Compile Include="GradientBoostedTrees\GradientBoostedTreesAlgorithmStatic.cs" />
Note: See TracChangeset for help on using the changeset viewer.