Changeset 17030
- Timestamp:
- 06/25/19 17:36:07 (5 years ago)
- Location:
- trunk/HeuristicLab.Algorithms.DataAnalysis/3.4
- Files:
-
- 1 added
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
r16565 r17030 21 21 #endregion 22 22 23 using System; 23 24 using System.Linq; 24 25 using System.Threading; 26 using HeuristicLab.Algorithms.DataAnalysis.GradientBoostedTrees; 25 27 using HeuristicLab.Analysis; 26 28 using HeuristicLab.Common; … … 48 50 private const string LossFunctionParameterName = "LossFunction"; 49 51 private const string UpdateIntervalParameterName = "UpdateInterval"; 50 private const string CreateSolutionParameterName = "CreateSolution";52 private const string ModelCreationParameterName = "ModelCreation"; 51 53 #endregion 52 54 … … 79 81 get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; } 80 82 } 81 p ublic IFixedValueParameter<BoolValue> CreateSolutionParameter {82 get { return (IFixedValueParameter< BoolValue>)Parameters[CreateSolutionParameterName]; }83 private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter { 84 get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; } 83 85 } 84 86 #endregion … … 113 115 set { MParameter.Value.Value = value; } 114 116 } 115 public bool CreateSolution {116 get { return CreateSolutionParameter.Value.Value; }117 set { CreateSolutionParameter.Value.Value = value; }117 public ModelCreation ModelCreation { 118 get { return ModelCreationParameter.Value.Value; } 119 set { ModelCreationParameter.Value.Value = value; } 118 120 } 119 121 #endregion … … 146 148 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); 147 149 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); 148 Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, "Maximal size of the tree learned in each step (prefer smaller sizes if possible)", new IntValue(10)));150 Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, "Maximal size of the tree learned in each step (prefer smaller sizes (3 to 10) if possible)", new IntValue(10))); 149 151 Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "Ratio of training rows selected randomly in each step (0 < R <= 1)", new DoubleValue(0.5))); 150 152 Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "Ratio of variables selected randomly in each step (0 < M <= 1)", new DoubleValue(0.5))); … … 152 154 Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "", new IntValue(100))); 153 155 Parameters[UpdateIntervalParameterName].Hidden = true; 154 Parameters.Add(new FixedValueParameter< BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));155 Parameters[ CreateSolutionParameterName].Hidden = true;156 Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model))); 157 Parameters[ModelCreationParameterName].Hidden = true; 156 158 157 159 var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>(); … … 164 166 // BackwardsCompatibility3.4 165 167 #region Backwards compatible code, remove with 3.5 168 169 #region LossFunction 166 170 // parameter type has been changed 167 171 var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>; … … 182 186 } 183 187 #endregion 188 189 #region CreateSolution 190 // parameter type has been changed 191 if (Parameters.ContainsKey("CreateSolution")) { 192 var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>; 193 Parameters.Remove(createSolutionParam); 194 195 ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly; 196 Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value))); 197 Parameters[ModelCreationParameterName].Hidden = true; 198 } 199 #endregion 200 #endregion 184 201 } 185 202 … … 248 265 249 266 // produce solution 250 if (CreateSolution) { 251 var model = state.GetModel(); 267 if (ModelCreation == ModelCreation.SurrogateModel || ModelCreation == ModelCreation.Model) { 268 IRegressionModel model = state.GetModel(); 269 270 if (ModelCreation == ModelCreation.SurrogateModel) { 271 model = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction, Iterations, MaxSize, R, M, Nu, (GradientBoostedTreesModel)model); 272 } 252 273 253 274 // for logistic regression we produce a classification solution … … 271 292 Results.Add(new Result("Solution", new GradientBoostedTreesSolution(model, problemData))); 272 293 } 294 } else if (ModelCreation == ModelCreation.QualityOnly) { 295 //Do nothing 296 } else { 297 throw new NotImplementedException("Selected parameter for CreateSolution isn't implemented yet"); 273 298 } 274 299 } -
trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs
r16565 r17030 101 101 102 102 public IRegressionModel GetModel() { 103 #pragma warning disable 618 104 var model = new GradientBoostedTreesModel(models, weights); 105 #pragma warning restore 618 106 // we don't know the number of iterations here but the number of weights is equal 107 // to the number of iterations + 1 (for the constant model) 108 // wrap the actual model in a surrogate that enables persistence and lazy recalculation of the model if necessary 109 return new GradientBoostedTreesModelSurrogate(problemData, randSeed, lossFunction, weights.Count - 1, maxSize, r, m, nu, model); 103 return new GradientBoostedTreesModel(models, weights); 110 104 } 111 105 public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() { -
trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs
r16565 r17030 36 36 // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models 37 37 #region Backwards compatible code, remove with 3.5 38 private bool isCompatibilityLoaded = false; // only set to true if the model is deserialized from the old format, needed to make sure that information is serialized again if it was loaded from the old format39 38 40 39 [Storable(Name = "models")] 41 40 private IList<IRegressionModel> __persistedModels { 42 41 set { 43 this.isCompatibilityLoaded = true;44 42 this.models.Clear(); 45 43 foreach (var m in value) this.models.Add(m); 46 44 } 47 get { if (this.isCompatibilityLoaded) return models; else return null; }45 get { return models; } 48 46 } 49 47 [Storable(Name = "weights")] 50 48 private IList<double> __persistedWeights { 51 49 set { 52 this.isCompatibilityLoaded = true;53 50 this.weights.Clear(); 54 51 foreach (var w in value) this.weights.Add(w); 55 52 } 56 get { if (this.isCompatibilityLoaded) return weights; else return null; }53 get { return weights; } 57 54 } 58 55 #endregion … … 77 74 this.weights = new List<double>(original.weights); 78 75 this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m))); 79 this.isCompatibilityLoaded = original.isCompatibilityLoaded;80 76 } 81 [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")] 77 82 78 internal GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) 83 79 : base(string.Empty, "Gradient boosted tree model", string.Empty) { -
trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj
r16658 r17030 217 217 <Compile Include="GradientBoostedTrees\LossFunctions\SquaredErrorLoss.cs" /> 218 218 <Compile Include="GradientBoostedTrees\GradientBoostedTreesSolution.cs" /> 219 <Compile Include="GradientBoostedTrees\ModelCreation.cs" /> 219 220 <Compile Include="GradientBoostedTrees\RegressionTreeBuilder.cs" /> 220 221 <Compile Include="GradientBoostedTrees\RegressionTreeModel.cs" />
Note: See TracChangeset
for help on using the changeset viewer.