Changeset 13065


Ignore:
Timestamp:
10/26/15 20:44:41 (5 years ago)
Author:
gkronber
Message:

#2450: marked constructor of GBTModel obsolete and wrapped GBTModels in GBTModelSurrogates where necessary in the API. Removed an internal unused method from the API.

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r12873 r13065  
    255255      // produce solution
    256256      if (CreateSolution) {
    257         var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction,
    258           Iterations, MaxSize, R, M, Nu, state.GetModel());
     257        var model = state.GetModel();
    259258
    260259        // for logistic regression we produce a classification solution
    261260        if (lossFunction is LogisticRegressionLoss) {
    262           var model = new DiscriminantFunctionClassificationModel(surrogateModel,
     261          var classificationModel = new DiscriminantFunctionClassificationModel(model,
    263262            new AccuracyMaximizationThresholdCalculator());
    264263          var classificationProblemData = new ClassificationProblemData(problemData.Dataset,
    265264            problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations);
    266           model.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);
    267 
    268           var classificationSolution = new DiscriminantFunctionClassificationSolution(model, classificationProblemData);
     265          classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);
     266
     267          var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData);
    269268          Results.Add(new Result("Solution", classificationSolution));
    270269        } else {
    271270          // otherwise we produce a regression solution
    272           Results.Add(new Result("Solution", new RegressionSolution(surrogateModel, problemData)));
     271          Results.Add(new Result("Solution", new RegressionSolution(model, problemData)));
    273272        }
    274273      }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r12710 r13065  
    5252      internal RegressionTreeBuilder treeBuilder { get; private set; }
    5353
     54      private readonly uint randSeed;
    5455      private MersenneTwister random { get; set; }
    5556
     
    7172        this.m = m;
    7273
     74        this.randSeed = randSeed;
    7375        random = new MersenneTwister(randSeed);
    7476        this.problemData = problemData;
     
    99101
    100102      public IRegressionModel GetModel() {
    101         return new GradientBoostedTreesModel(models, weights);
     103#pragma warning disable 618
     104        var model = new GradientBoostedTreesModel(models, weights);
     105#pragma warning restore 618
     106        // we don't know the number of iterations here but the number of weights is equal
     107        // to the number of iterations + 1 (for the constant model)
     108        // wrap the actual model in a surrogate that enables persistence and lazy recalculation of the model if necessary
     109        return new GradientBoostedTreesModelSurrogate(problemData, randSeed, lossFunction, weights.Count - 1, maxSize, r, m, nu, model);
    102110      }
    103111      public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() {
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs

    r12868 r13065  
    7676      this.isCompatibilityLoaded = original.isCompatibilityLoaded;
    7777    }
     78    [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")]
    7879    public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
    7980      : base("Gradient boosted tree model", string.Empty) {
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs

    r12873 r13065  
    2121#endregion
    2222
    23 using System;
    2423using System.Collections.Generic;
    25 using System.Linq;
    2624using HeuristicLab.Common;
    2725using HeuristicLab.Core;
    2826using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    29 using HeuristicLab.PluginInfrastructure;
    3027using HeuristicLab.Problems.DataAnalysis;
    3128
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12700 r13065  
    119119    }
    120120
    121     // simple API produces a single regression tree optimizing sum of squared errors
    122     // this can be used if only a simple regression tree should be produced
    123     // for a set of trees use the method CreateRegressionTreeForGradientBoosting below
    124     //
    125     // r and m work in the same way as for alglib random forest
    126     // r is fraction of rows to use for training
    127     // m is fraction of variables to use for training
    128     public IRegressionModel CreateRegressionTree(int maxSize, double r = 0.5, double m = 0.5) {
    129       // subtract mean of y first
    130       var yAvg = y.Average();
    131       for (int i = 0; i < y.Length; i++) y[i] -= yAvg;
    132 
    133       var seLoss = new SquaredErrorLoss();
    134 
    135       var model = CreateRegressionTreeForGradientBoosting(y, curPred, maxSize, problemData.TrainingIndices.ToArray(), seLoss, r, m);
    136 
    137       return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });
    138     }
    139 
    140121    // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees
    141122    public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, double[] curPred, int maxSize, int[] idx, ILossFunction lossFunction, double r = 0.5, double m = 0.5) {
Note: See TracChangeset for help on using the changeset viewer.