Changeset 12698 for branches/GBT-trunkintegration
- Timestamp:
- 07/09/15 16:49:00 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs
r12697 r12698 42 42 // GbmState details are private API users can only use methods from IGbmState 43 43 private class GbmState : IGbmState { 44 internal IRegressionProblemData problemData { get; set; } 45 internal MersenneTwister random { get; private set; } 46 internal ILossFunction lossFunction { get; set; } 47 internal int maxSize { get; set; } 48 internal double nu { get; set; } 49 internal double r { get; set; } 50 internal double m { get; set; } 51 internal readonly RegressionTreeBuilder treeBuilder; 44 internal IRegressionProblemData problemData { get; private set; } 45 internal ILossFunction lossFunction { get; private set; } 46 internal int maxSize { get; private set; } 47 internal double nu { get; private set; } 48 internal double r { get; private set; } 49 internal double m { get; private set; } 50 internal RegressionTreeBuilder treeBuilder { get; private set; } 52 51 52 private MersenneTwister random { get; set; } 53 53 54 54 // array members (allocate only once) … … 59 59 internal double[] pseudoRes; 60 60 61 internalIList<IRegressionModel> models;62 internalIList<double> weights;61 private readonly IList<IRegressionModel> models; 62 private readonly IList<double> weights; 63 63 64 64 public GbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxSize, double r, double m, double nu) { … … 110 110 return lossFunction.GetLoss(yTest, predTest) / nRows; 111 111 } 112 113 internal void AddModel(IRegressionModel m, double weight) { 114 models.Add(m); 115 weights.Add(weight); 116 } 112 117 } 113 118 … … 119 124 Contract.Assert(nu <= 1.0); 120 125 121 var state = (GbmState)CreateGbmState(problemData, lossFunction, randSeed); 122 state.maxSize = maxSize; 123 state.r = r; 124 state.m = m; 125 state.nu = nu; 126 var state = (GbmState)CreateGbmState(problemData, lossFunction, randSeed, maxSize, r, m, nu); 126 127 127 128 for (int iter = 0; iter < maxIterations; iter++) { … … 129 130 } 130 131 131 var model = new GradientBoostedTreesModel(state.models, state.weights);132 var model = state.GetModel(); 132 133 return new RegressionSolution(model, (IRegressionProblemData)problemData.Clone()); 133 134 } … … 138 139 } 139 140 140 // use default settings for max Depth, nu, r from state141 // use default settings for maxSize, nu, r from state 141 142 public static void MakeStep(IGbmState state) { 142 143 var gbmState = state as GbmState; … … 146 147 } 147 148 148 // allow dynamic adaptation of max Depth, nu and r (even though this is not used)149 // allow dynamic adaptation of maxSize, nu and r (even though this is not used) 149 150 public static void MakeStep(IGbmState state, int maxSize, double nu, double r, double m) { 150 151 var gbmState = state as GbmState; … … 160 161 var pseudoRes = gbmState.pseudoRes; 161 162 162 // copy output of gradient function to pre-allocated rim array (pseudo-residual s)163 // copy output of gradient function to pre-allocated rim array (pseudo-residual per row and model) 163 164 int rimIdx = 0; 164 165 foreach (var g in lossFunction.GetLossGradient(y, yPred)) { … … 181 182 } 182 183 183 gbmState.weights.Add(nu); 184 gbmState.models.Add(tree); 184 gbmState.AddModel(tree, nu); 185 185 } 186 186 #endregion
Note: See TracChangeset
for help on using the changeset viewer.