Changeset 14029 for branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
- Timestamp:
- 07/08/16 14:40:02 (8 years ago)
- Location:
- branches/crossvalidation-2434
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
r12874 r14029 21 21 #endregion 22 22 23 using System;24 23 using System.Collections.Generic; 25 24 using System.Linq; … … 27 26 using HeuristicLab.Core; 28 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 using HeuristicLab.PluginInfrastructure;30 28 using HeuristicLab.Problems.DataAnalysis; 31 29 … … 36 34 // recalculate the actual GBT model on demand 37 35 [Item("Gradient boosted tree model", "")] 38 public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IRegressionModel {36 public sealed class GradientBoostedTreesModelSurrogate : RegressionModel, IGradientBoostedTreesModel { 39 37 // don't store the actual model! 40 private I RegressionModel actualModel; // the actual model is only recalculated when necessary38 private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary 41 39 42 40 [Storable] … … 58 56 59 57 58 public override IEnumerable<string> VariablesUsedForPrediction { 59 get { return actualModel.Models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 60 } 61 60 62 [StorableConstructor] 61 63 private GradientBoostedTreesModelSurrogate(bool deserializing) : base(deserializing) { } … … 76 78 77 79 // create only the surrogate model without an actual model 78 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 79 : base("Gradient boosted tree model", string.Empty) { 80 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, 81 ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 82 : base(trainingProblemData.TargetVariable, "Gradient boosted tree model", string.Empty) { 80 83 this.trainingProblemData = trainingProblemData; 81 84 this.seed = seed; … … 89 92 90 93 // wrap an actual model in a surrograte 91 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model) 94 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, 95 ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, 96 IGradientBoostedTreesModel model) 92 97 : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) { 93 98 this.actualModel = model; … … 99 104 100 105 // forward message to actual model (recalculate model first if necessary) 101 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {106 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 102 107 if (actualModel == null) actualModel = RecalculateModel(); 103 108 return actualModel.GetEstimatedValues(dataset, rows); 104 109 } 105 110 106 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {111 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 107 112 return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); 108 113 } 109 114 115 private IGradientBoostedTreesModel RecalculateModel() { 116 return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model; 117 } 110 118 111 private IRegressionModel RecalculateModel() { 112 return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model; 119 public IEnumerable<IRegressionModel> Models { 120 get { 121 if (actualModel == null) actualModel = RecalculateModel(); 122 return actualModel.Models; 123 } 124 } 125 126 public IEnumerable<double> Weights { 127 get { 128 if (actualModel == null) actualModel = RecalculateModel(); 129 return actualModel.Weights; 130 } 113 131 } 114 132 }
Note: See TracChangeset
for help on using the changeset viewer.