Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/17/15 18:33:31 (9 years ago)
Author:
gkronber
Message:

#2434 reverse merge of r12871 (changes should be applied directly to trunk)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r12871 r12872  
    8282      get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    8383    }
    84     public IConstrainedValueParameter<ILossFunction> LossFunctionParameter {
    85       get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; }
     84    public IConstrainedValueParameter<StringValue> LossFunctionParameter {
     85      get { return (IConstrainedValueParameter<StringValue>)Parameters[LossFunctionParameterName]; }
    8686    }
    8787    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
     
    164164      Parameters[CreateSolutionParameterName].Hidden = true;
    165165
    166       var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
    167       Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
    168       LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default
    169     }
    170 
    171     [StorableHook(HookType.AfterDeserialization)]
    172     private void AfterDeserialization() {
    173       // BackwardsCompatibility3.4
    174       #region Backwards compatible code, remove with 3.5
    175       // parameter type has been changed
    176       var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>;
    177       if (lossFunctionParam != null) {
    178         Parameters.Remove(LossFunctionParameterName);
    179         var selectedValue = lossFunctionParam.Value; // to be restored below
    180 
    181         var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
    182         Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
    183         // try to restore selected value
    184         var selectedLossFunction =
    185           LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value);
    186         if (selectedLossFunction != null) {
    187           LossFunctionParameter.Value = selectedLossFunction;
    188         } else {
    189           LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE
    190         }
    191       }
    192       #endregion
    193     }
     166      var lossFunctionNames = ApplicationManager.Manager.GetInstances<ILossFunction>().Select(l => new StringValue(l.ToString()).AsReadOnly());
     167      Parameters.Add(new ConstrainedValueParameter<StringValue>(LossFunctionParameterName, "The loss function", new ItemSet<StringValue>(lossFunctionNames)));
     168      LossFunctionParameter.ActualValue = LossFunctionParameter.ValidValues.First(l => l.Value.Contains("Squared")); // squared error loss is the default
     169    }
     170
    194171
    195172    protected override void Run(CancellationToken cancellationToken) {
     
    210187      // init
    211188      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    212       var lossFunction = LossFunctionParameter.Value;
     189      var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>()
     190        .Single(l => l.ToString() == LossFunctionParameter.Value.Value);
    213191      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);
    214192
     
    255233      // produce solution
    256234      if (CreateSolution) {
    257         var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction,
     235        var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction.ToString(),
    258236          Iterations, MaxSize, R, M, Nu, state.GetModel());
    259237
Note: See TracChangeset for help on using the changeset viewer.