Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/10/19 17:39:38 (5 years ago)
Author:
gkronber
Message:

#2994: merged r17007:17118 from trunk to branch

Location:
branches/2994-AutoDiffForIntervals
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2994-AutoDiffForIntervals

  • branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis

  • branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis/3.4

  • branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r16565 r17120  
    2121#endregion
    2222
     23using System;
    2324using System.Linq;
    2425using System.Threading;
     26using HeuristicLab.Algorithms.DataAnalysis.GradientBoostedTrees;
    2527using HeuristicLab.Analysis;
    2628using HeuristicLab.Common;
     
    4850    private const string LossFunctionParameterName = "LossFunction";
    4951    private const string UpdateIntervalParameterName = "UpdateInterval";
    50     private const string CreateSolutionParameterName = "CreateSolution";
     52    private const string ModelCreationParameterName = "ModelCreation";
    5153    #endregion
    5254
     
    7981      get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
    8082    }
    81     public IFixedValueParameter<BoolValue> CreateSolutionParameter {
    82       get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
     83    private IFixedValueParameter<EnumValue<ModelCreation>> ModelCreationParameter {
     84      get { return (IFixedValueParameter<EnumValue<ModelCreation>>)Parameters[ModelCreationParameterName]; }
    8385    }
    8486    #endregion
     
    113115      set { MParameter.Value.Value = value; }
    114116    }
    115     public bool CreateSolution {
    116       get { return CreateSolutionParameter.Value.Value; }
    117       set { CreateSolutionParameter.Value.Value = value; }
     117    public ModelCreation ModelCreation {
     118      get { return ModelCreationParameter.Value.Value; }
     119      set { ModelCreationParameter.Value.Value = value; }
    118120    }
    119121    #endregion
     
    146148      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
    147149      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
    148       Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, "Maximal size of the tree learned in each step (prefer smaller sizes if possible)", new IntValue(10)));
     150      Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, "Maximal size of the tree learned in each step (prefer smaller sizes (3 to 10) if possible)", new IntValue(10)));
    149151      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "Ratio of training rows selected randomly in each step (0 < R <= 1)", new DoubleValue(0.5)));
    150152      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "Ratio of variables selected randomly in each step (0 < M <= 1)", new DoubleValue(0.5)));
     
    152154      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "", new IntValue(100)));
    153155      Parameters[UpdateIntervalParameterName].Hidden = true;
    154       Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
    155       Parameters[CreateSolutionParameterName].Hidden = true;
     156      Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(ModelCreation.Model)));
     157      Parameters[ModelCreationParameterName].Hidden = true;
    156158
    157159      var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
     
    164166      // BackwardsCompatibility3.4
    165167      #region Backwards compatible code, remove with 3.5
     168
     169      #region LossFunction
    166170      // parameter type has been changed
    167171      var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>;
     
    182186      }
    183187      #endregion
     188
     189      #region CreateSolution
     190      // parameter type has been changed
     191      if (Parameters.ContainsKey("CreateSolution")) {
     192        var createSolutionParam = Parameters["CreateSolution"] as FixedValueParameter<BoolValue>;
     193        Parameters.Remove(createSolutionParam);
     194
     195        ModelCreation value = createSolutionParam.Value.Value ? ModelCreation.Model : ModelCreation.QualityOnly;
     196        Parameters.Add(new FixedValueParameter<EnumValue<ModelCreation>>(ModelCreationParameterName, "Defines the results produced at the end of the run (Surrogate => Less disk space, lazy recalculation of model)", new EnumValue<ModelCreation>(value)));
     197        Parameters[ModelCreationParameterName].Hidden = true;
     198      }
     199      #endregion
     200      #endregion
    184201    }
    185202
     
    248265
    249266      // produce solution
    250       if (CreateSolution) {
    251         var model = state.GetModel();
     267      if (ModelCreation == ModelCreation.SurrogateModel || ModelCreation == ModelCreation.Model) {
     268        IRegressionModel model = state.GetModel();
     269
     270        if (ModelCreation == ModelCreation.SurrogateModel) {
     271          model = new GradientBoostedTreesModelSurrogate((GradientBoostedTreesModel)model, problemData, (uint)Seed, lossFunction, Iterations, MaxSize, R, M, Nu);
     272        }
    252273
    253274        // for logistic regression we produce a classification solution
     
    271292          Results.Add(new Result("Solution", new GradientBoostedTreesSolution(model, problemData)));
    272293        }
     294      } else if (ModelCreation == ModelCreation.QualityOnly) {
     295        //Do nothing
     296      } else {
     297        throw new NotImplementedException("Selected parameter for CreateSolution isn't implemented yet");
    273298      }
    274299    }
Note: See TracChangeset for help on using the changeset viewer.