Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/09/15 16:49:00 (9 years ago)
Author:
gkronber
Message:

#2261 hiding internals of GbmState

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r12697 r12698  
    4242    // GbmState details are private API users can only use methods from IGbmState
    4343    private class GbmState : IGbmState {
    44       internal IRegressionProblemData problemData { get; set; }
    45       internal MersenneTwister random { get; private set; }
    46       internal ILossFunction lossFunction { get; set; }
    47       internal int maxSize { get; set; }
    48       internal double nu { get; set; }
    49       internal double r { get; set; }
    50       internal double m { get; set; }
    51       internal readonly RegressionTreeBuilder treeBuilder;
     44      internal IRegressionProblemData problemData { get; private set; }
     45      internal ILossFunction lossFunction { get; private set; }
     46      internal int maxSize { get; private set; }
     47      internal double nu { get; private set; }
     48      internal double r { get; private set; }
     49      internal double m { get; private set; }
     50      internal RegressionTreeBuilder treeBuilder { get; private set; }
    5251
     52      private MersenneTwister random { get; set; }
    5353
    5454      // array members (allocate only once)
     
    5959      internal double[] pseudoRes;
    6060
    61       internal IList<IRegressionModel> models;
    62       internal IList<double> weights;
     61      private readonly IList<IRegressionModel> models;
     62      private readonly IList<double> weights;
    6363
    6464      public GbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxSize, double r, double m, double nu) {
     
    110110        return lossFunction.GetLoss(yTest, predTest) / nRows;
    111111      }
     112
     113      internal void AddModel(IRegressionModel m, double weight) {
     114        models.Add(m);
     115        weights.Add(weight);
     116      }
    112117    }
    113118
     
    119124      Contract.Assert(nu <= 1.0);
    120125
    121       var state = (GbmState)CreateGbmState(problemData, lossFunction, randSeed);
    122       state.maxSize = maxSize;
    123       state.r = r;
    124       state.m = m;
    125       state.nu = nu;
     126      var state = (GbmState)CreateGbmState(problemData, lossFunction, randSeed, maxSize, r, m, nu);
    126127
    127128      for (int iter = 0; iter < maxIterations; iter++) {
     
    129130      }
    130131
    131       var model = new GradientBoostedTreesModel(state.models, state.weights);
     132      var model = state.GetModel();
    132133      return new RegressionSolution(model, (IRegressionProblemData)problemData.Clone());
    133134    }
     
    138139    }
    139140
    140     // use default settings for maxDepth, nu, r from state
     141    // use default settings for maxSize, nu, r from state
    141142    public static void MakeStep(IGbmState state) {
    142143      var gbmState = state as GbmState;
     
    146147    }
    147148
    148     // allow dynamic adaptation of maxDepth, nu and r (even though this is not used)
     149    // allow dynamic adaptation of maxSize, nu and r (even though this is not used)
    149150    public static void MakeStep(IGbmState state, int maxSize, double nu, double r, double m) {
    150151      var gbmState = state as GbmState;
     
    160161      var pseudoRes = gbmState.pseudoRes;
    161162
    162       // copy output of gradient function to pre-allocated rim array (pseudo-residuals)
     163      // copy output of gradient function to pre-allocated rim array (pseudo-residual per row and model)
    163164      int rimIdx = 0;
    164165      foreach (var g in lossFunction.GetLossGradient(y, yPred)) {
     
    181182      }
    182183
    183       gbmState.weights.Add(nu);
    184       gbmState.models.Add(tree);
     184      gbmState.AddModel(tree, nu);
    185185    }
    186186    #endregion
Note: See TracChangeset for help on using the changeset viewer.