Free cookie consent management tool by TermsFeed Policy Generator

Changeset 12872


Ignore:
Timestamp:
08/17/15 18:33:31 (9 years ago)
Author:
gkronber
Message:

#2434 reverse merge of r12871 (changes should be applied directly to trunk)

Location:
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r12871 r12872  
    8282      get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    8383    }
    84     public IConstrainedValueParameter<ILossFunction> LossFunctionParameter {
    85       get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; }
     84    public IConstrainedValueParameter<StringValue> LossFunctionParameter {
     85      get { return (IConstrainedValueParameter<StringValue>)Parameters[LossFunctionParameterName]; }
    8686    }
    8787    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
     
    164164      Parameters[CreateSolutionParameterName].Hidden = true;
    165165
    166       var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
    167       Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
    168       LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default
    169     }
    170 
    171     [StorableHook(HookType.AfterDeserialization)]
    172     private void AfterDeserialization() {
    173       // BackwardsCompatibility3.4
    174       #region Backwards compatible code, remove with 3.5
    175       // parameter type has been changed
    176       var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>;
    177       if (lossFunctionParam != null) {
    178         Parameters.Remove(LossFunctionParameterName);
    179         var selectedValue = lossFunctionParam.Value; // to be restored below
    180 
    181         var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
    182         Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
    183         // try to restore selected value
    184         var selectedLossFunction =
    185           LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value);
    186         if (selectedLossFunction != null) {
    187           LossFunctionParameter.Value = selectedLossFunction;
    188         } else {
    189           LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE
    190         }
    191       }
    192       #endregion
    193     }
     166      var lossFunctionNames = ApplicationManager.Manager.GetInstances<ILossFunction>().Select(l => new StringValue(l.ToString()).AsReadOnly());
     167      Parameters.Add(new ConstrainedValueParameter<StringValue>(LossFunctionParameterName, "The loss function", new ItemSet<StringValue>(lossFunctionNames)));
     168      LossFunctionParameter.ActualValue = LossFunctionParameter.ValidValues.First(l => l.Value.Contains("Squared")); // squared error loss is the default
     169    }
     170
    194171
    195172    protected override void Run(CancellationToken cancellationToken) {
     
    210187      // init
    211188      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    212       var lossFunction = LossFunctionParameter.Value;
     189      var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>()
     190        .Single(l => l.ToString() == LossFunctionParameter.Value.Value);
    213191      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);
    214192
     
    255233      // produce solution
    256234      if (CreateSolution) {
    257         var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction,
     235        var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction.ToString(),
    258236          Iterations, MaxSize, R, M, Nu, state.GetModel());
    259237
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs

    r12871 r12872  
    4545    private readonly uint seed;
    4646    [Storable]
    47     private ILossFunction lossFunction;
     47    private string lossFunctionName;
    4848    [Storable]
    4949    private double r;
     
    6666
    6767      this.trainingProblemData = cloner.Clone(original.trainingProblemData);
    68       this.lossFunction = cloner.Clone(original.lossFunction);
    6968      this.seed = original.seed;
     69      this.lossFunctionName = original.lossFunctionName;
    7070      this.iterations = original.iterations;
    7171      this.maxSize = original.maxSize;
     
    7676
    7777    // create only the surrogate model without an actual model
    78     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)
     78    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu)
    7979      : base("Gradient boosted tree model", string.Empty) {
    8080      this.trainingProblemData = trainingProblemData;
    8181      this.seed = seed;
    82       this.lossFunction = lossFunction;
     82      this.lossFunctionName = lossFunctionName;
    8383      this.iterations = iterations;
    8484      this.maxSize = maxSize;
     
    8989
    9090    // wrap an actual model in a surrograte
    91     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)
    92       : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) {
     91    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)
     92      : this(trainingProblemData, seed, lossFunctionName, iterations, maxSize, r, m, nu) {
    9393      this.actualModel = model;
    9494    }
     
    110110
    111111    private IRegressionModel RecalculateModel() {
     112      var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>().Single(l => l.ToString() == lossFunctionName);
    112113      return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model;
    113114    }
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs

    r12871 r12872  
    2323using System;
    2424using System.Collections.Generic;
     25using System.Diagnostics;
     26using System.Linq;
    2527using HeuristicLab.Common;
    26 using HeuristicLab.Core;
    27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2828
    2929namespace HeuristicLab.Algorithms.DataAnalysis {
    3030  // loss function for the weighted absolute error
    31   [StorableClass]
    32   [Item("Absolute error loss", "")]
    33   public class AbsoluteErrorLoss : Item, ILossFunction {
    34     public AbsoluteErrorLoss() { }
    35 
     31  public class AbsoluteErrorLoss : ILossFunction {
    3632    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3733      var targetEnum = target.GetEnumerator();
     
    8177    }
    8278
    83     #region item implementation
    84     private AbsoluteErrorLoss(AbsoluteErrorLoss original, Cloner cloner) : base(original, cloner) { }
    85 
    86     public override IDeepCloneable Clone(Cloner cloner) {
    87       return new AbsoluteErrorLoss(this, cloner);
     79    public override string ToString() {
     80      return "Absolute error loss";
    8881    }
    89     #endregion
    9082  }
    9183}
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs

    r12871 r12872  
    2222
    2323using System.Collections.Generic;
    24 using HeuristicLab.Core;
    2524
    2625namespace HeuristicLab.Algorithms.DataAnalysis {
     
    2827  // target represents the target vector  (original targets from the problem data, never changed)
    2928  // pred   represents the current vector of predictions (a weighted combination of models learned so far, this vector is updated after each step)
    30   public interface ILossFunction : IItem {
     29  public interface ILossFunction {
    3130    // returns the loss of the current prediction vector
    3231    double GetLoss(IEnumerable<double> target, IEnumerable<double> pred);
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs

    r12871 r12872  
    2626using System.Linq;
    2727using HeuristicLab.Common;
    28 using HeuristicLab.Core;
    29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3028
    3129namespace HeuristicLab.Algorithms.DataAnalysis {
    3230  // Greedy Function Approximation: A Gradient Boosting Machine (page 9)
    33   [StorableClass]
    34   [Item("Logistic regression loss", "")]
    35   public class LogisticRegressionLoss : Item, ILossFunction {
    36     public LogisticRegressionLoss() { }
    37 
     31  public class LogisticRegressionLoss : ILossFunction {
    3832    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3933      var targetEnum = target.GetEnumerator();
     
    8983    }
    9084
    91     #region item implementation
    92     private LogisticRegressionLoss(LogisticRegressionLoss original, Cloner cloner) : base(original, cloner) { }
    93 
    94     public override IDeepCloneable Clone(Cloner cloner) {
    95       return new LogisticRegressionLoss(this, cloner);
     85    public override string ToString() {
     86      return "Logistic regression loss";
    9687    }
    97     #endregion
    98 
    9988  }
    10089}
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs

    r12871 r12872  
    2626using System.Linq;
    2727using HeuristicLab.Common;
    28 using HeuristicLab.Core;
    29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3028
    3129namespace HeuristicLab.Algorithms.DataAnalysis {
    3230  // relative error loss is a special case of weighted absolute error loss with weights = (1/target)
    33   [StorableClass]
    34   [Item("Relative error loss", "")]
    35   public class RelativeErrorLoss : Item, ILossFunction {
    36     public RelativeErrorLoss() { }
    37 
     31  public class RelativeErrorLoss : ILossFunction {
    3832    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3933      var targetEnum = target.GetEnumerator();
     
    111105    }
    112106
    113     #region item implementation
    114     private RelativeErrorLoss(RelativeErrorLoss original, Cloner cloner) : base(original, cloner) { }
    115 
    116     public override IDeepCloneable Clone(Cloner cloner) {
    117       return new RelativeErrorLoss(this, cloner);
     107    public override string ToString() {
     108      return "Relative error loss";
    118109    }
    119     #endregion
    120110  }
    121111}
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs

    r12871 r12872  
    2424using System.Collections.Generic;
    2525using System.Linq;
    26 using HeuristicLab.Common;
    27 using HeuristicLab.Core;
    28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2926
    3027namespace HeuristicLab.Algorithms.DataAnalysis {
    31   [StorableClass]
    32   [Item("Squared error loss", "")]
    33   public class SquaredErrorLoss : Item, ILossFunction {
    34     public SquaredErrorLoss() { }
    35 
     28  public class SquaredErrorLoss : ILossFunction {
    3629    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3730      var targetEnum = target.GetEnumerator();
     
    7770    }
    7871
    79     #region item implementation
    80     private SquaredErrorLoss(SquaredErrorLoss original, Cloner cloner) : base(original, cloner) { }
    81 
    82     public override IDeepCloneable Clone(Cloner cloner) {
    83       return new SquaredErrorLoss(this, cloner);
     72    public override string ToString() {
     73      return "Squared error loss";
    8474    }
    85     #endregion
    8675  }
    8776}
Note: See TracChangeset for help on using the changeset viewer.