Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/17/15 18:38:17 (9 years ago)
Author:
gkronber
Message:

#2434: merged r12873 from trunk to branch

Location:
branches/crossvalidation-2434
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/crossvalidation-2434

  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis

  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r12872 r12874  
    8282      get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    8383    }
    84     public IConstrainedValueParameter<StringValue> LossFunctionParameter {
    85       get { return (IConstrainedValueParameter<StringValue>)Parameters[LossFunctionParameterName]; }
     84    public IConstrainedValueParameter<ILossFunction> LossFunctionParameter {
     85      get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; }
    8686    }
    8787    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
     
    164164      Parameters[CreateSolutionParameterName].Hidden = true;
    165165
    166       var lossFunctionNames = ApplicationManager.Manager.GetInstances<ILossFunction>().Select(l => new StringValue(l.ToString()).AsReadOnly());
    167       Parameters.Add(new ConstrainedValueParameter<StringValue>(LossFunctionParameterName, "The loss function", new ItemSet<StringValue>(lossFunctionNames)));
    168       LossFunctionParameter.ActualValue = LossFunctionParameter.ValidValues.First(l => l.Value.Contains("Squared")); // squared error loss is the default
    169     }
    170 
     166      var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
     167      Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
     168      LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default
     169    }
     170
     171    [StorableHook(HookType.AfterDeserialization)]
     172    private void AfterDeserialization() {
     173      // BackwardsCompatibility3.4
     174      #region Backwards compatible code, remove with 3.5
     175      // parameter type has been changed
     176      var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>;
     177      if (lossFunctionParam != null) {
     178        Parameters.Remove(LossFunctionParameterName);
     179        var selectedValue = lossFunctionParam.Value; // to be restored below
     180
     181        var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
     182        Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
     183        // try to restore selected value
     184        var selectedLossFunction =
     185          LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value);
     186        if (selectedLossFunction != null) {
     187          LossFunctionParameter.Value = selectedLossFunction;
     188        } else {
     189          LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE
     190        }
     191      }
     192      #endregion
     193    }
    171194
    172195    protected override void Run(CancellationToken cancellationToken) {
     
    187210      // init
    188211      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    189       var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>()
    190         .Single(l => l.ToString() == LossFunctionParameter.Value.Value);
     212      var lossFunction = LossFunctionParameter.Value;
    191213      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);
    192214
     
    233255      // produce solution
    234256      if (CreateSolution) {
    235         var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction.ToString(),
     257        var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction,
    236258          Iterations, MaxSize, R, M, Nu, state.GetModel());
    237259
Note: See TracChangeset for help on using the changeset viewer.