Changeset 13184 for stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
- Timestamp:
- 11/16/15 19:49:40 (8 years ago)
- Location:
- stable
- Files:
-
- 2 edited
- 1 copied
Legend:
- Unmodified
- Added
- Removed
-
stable
- Property svn:mergeinfo changed
/trunk/sources merged: 12868,12873,12875,13065-13066,13157-13158
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Algorithms.DataAnalysis merged: 12868,12873,12875,13065-13066,13157-13158
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
r12868 r13184 21 21 #endregion 22 22 23 using System;24 23 using System.Collections.Generic; 25 using System.Linq;26 24 using HeuristicLab.Common; 27 25 using HeuristicLab.Core; 28 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 using HeuristicLab.PluginInfrastructure;30 27 using HeuristicLab.Problems.DataAnalysis; 31 28 … … 36 33 // recalculate the actual GBT model on demand 37 34 [Item("Gradient boosted tree model", "")] 38 public sealed class GradientBoostedTreesModelSurrogate : NamedItem, I RegressionModel {35 public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IGradientBoostedTreesModel { 39 36 // don't store the actual model! 40 private I RegressionModel actualModel; // the actual model is only recalculated when necessary37 private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary 41 38 42 39 [Storable] … … 45 42 private readonly uint seed; 46 43 [Storable] 47 private string lossFunctionName;44 private ILossFunction lossFunction; 48 45 [Storable] 49 46 private double r; … … 66 63 67 64 this.trainingProblemData = cloner.Clone(original.trainingProblemData); 65 this.lossFunction = cloner.Clone(original.lossFunction); 68 66 this.seed = original.seed; 69 this.lossFunctionName = original.lossFunctionName;70 67 this.iterations = original.iterations; 71 68 this.maxSize = original.maxSize; … … 76 73 77 74 // create only the surrogate model without an actual model 78 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu)75 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 79 76 : base("Gradient boosted tree model", string.Empty) { 80 77 this.trainingProblemData = trainingProblemData; 81 78 this.seed = seed; 82 this.lossFunction Name = lossFunctionName;79 this.lossFunction = lossFunction; 83 80 this.iterations = iterations; 84 81 this.maxSize = maxSize; … … 89 86 90 87 // wrap an actual model in a surrograte 91 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)92 : this(trainingProblemData, seed, lossFunction Name, iterations, maxSize, r, m, nu) {88 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IGradientBoostedTreesModel model) 89 : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) { 93 90 this.actualModel = model; 94 91 } … … 109 106 110 107 111 private IRegressionModel RecalculateModel() { 112 var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>().Single(l => l.ToString() == lossFunctionName); 108 private IGradientBoostedTreesModel RecalculateModel() { 113 109 return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model; 110 } 111 112 public IEnumerable<IRegressionModel> Models { 113 get { 114 if (actualModel == null) actualModel = RecalculateModel(); 115 return actualModel.Models; 116 } 117 } 118 119 public IEnumerable<double> Weights { 120 get { 121 if (actualModel == null) actualModel = RecalculateModel(); 122 return actualModel.Weights; 123 } 114 124 } 115 125 }
Note: See TracChangeset
for help on using the changeset viewer.