Changeset 13065
- Timestamp:
- 10/26/15 20:44:41 (9 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
r12873 r13065 255 255 // produce solution 256 256 if (CreateSolution) { 257 var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction, 258 Iterations, MaxSize, R, M, Nu, state.GetModel()); 257 var model = state.GetModel(); 259 258 260 259 // for logistic regression we produce a classification solution 261 260 if (lossFunction is LogisticRegressionLoss) { 262 var model = new DiscriminantFunctionClassificationModel(surrogateModel,261 var classificationModel = new DiscriminantFunctionClassificationModel(model, 263 262 new AccuracyMaximizationThresholdCalculator()); 264 263 var classificationProblemData = new ClassificationProblemData(problemData.Dataset, 265 264 problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations); 266 model.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);267 268 var classificationSolution = new DiscriminantFunctionClassificationSolution( model, classificationProblemData);265 classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices); 266 267 var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData); 269 268 Results.Add(new Result("Solution", classificationSolution)); 270 269 } else { 271 270 // otherwise we produce a regression solution 272 Results.Add(new Result("Solution", new RegressionSolution( surrogateModel, problemData)));271 Results.Add(new Result("Solution", new RegressionSolution(model, problemData))); 273 272 } 274 273 } -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs
r12710 r13065 52 52 internal RegressionTreeBuilder treeBuilder { get; private set; } 53 53 54 private readonly uint randSeed; 54 55 private MersenneTwister random { get; set; } 55 56 … … 71 72 this.m = m; 72 73 74 this.randSeed = randSeed; 73 75 random = new MersenneTwister(randSeed); 74 76 this.problemData = problemData; … … 99 101 100 102 public IRegressionModel GetModel() { 101 return new GradientBoostedTreesModel(models, weights); 103 #pragma warning disable 618 104 var model = new GradientBoostedTreesModel(models, weights); 105 #pragma warning restore 618 106 // we don't know the number of iterations here but the number of weights is equal 107 // to the number of iterations + 1 (for the constant model) 108 // wrap the actual model in a surrogate that enables persistence and lazy recalculation of the model if necessary 109 return new GradientBoostedTreesModelSurrogate(problemData, randSeed, lossFunction, weights.Count - 1, maxSize, r, m, nu, model); 102 110 } 103 111 public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() { -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs
r12868 r13065 76 76 this.isCompatibilityLoaded = original.isCompatibilityLoaded; 77 77 } 78 [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")] 78 79 public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) 79 80 : base("Gradient boosted tree model", string.Empty) { -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
r12873 r13065 21 21 #endregion 22 22 23 using System;24 23 using System.Collections.Generic; 25 using System.Linq;26 24 using HeuristicLab.Common; 27 25 using HeuristicLab.Core; 28 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 using HeuristicLab.PluginInfrastructure;30 27 using HeuristicLab.Problems.DataAnalysis; 31 28 -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12700 r13065 119 119 } 120 120 121 // simple API produces a single regression tree optimizing sum of squared errors122 // this can be used if only a simple regression tree should be produced123 // for a set of trees use the method CreateRegressionTreeForGradientBoosting below124 //125 // r and m work in the same way as for alglib random forest126 // r is fraction of rows to use for training127 // m is fraction of variables to use for training128 public IRegressionModel CreateRegressionTree(int maxSize, double r = 0.5, double m = 0.5) {129 // subtract mean of y first130 var yAvg = y.Average();131 for (int i = 0; i < y.Length; i++) y[i] -= yAvg;132 133 var seLoss = new SquaredErrorLoss();134 135 var model = CreateRegressionTreeForGradientBoosting(y, curPred, maxSize, problemData.TrainingIndices.ToArray(), seLoss, r, m);136 137 return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });138 }139 140 121 // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees 141 122 public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, double[] curPred, int maxSize, int[] idx, ILossFunction lossFunction, double r = 0.5, double m = 0.5) {
Note: See TracChangeset
for help on using the changeset viewer.