Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/17/15 18:31:39 (9 years ago)
Author:
gkronber
Message:

#2434 derived ILossFunction from IItem to allow execution on hive without privileged flag (made an "after deserialization"-hook necessary to convert the parameter type)

Location:
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r12869 r12871  
    8282      get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    8383    }
    84     public IConstrainedValueParameter<StringValue> LossFunctionParameter {
    85       get { return (IConstrainedValueParameter<StringValue>)Parameters[LossFunctionParameterName]; }
     84    public IConstrainedValueParameter<ILossFunction> LossFunctionParameter {
     85      get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; }
    8686    }
    8787    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
     
    164164      Parameters[CreateSolutionParameterName].Hidden = true;
    165165
    166       var lossFunctionNames = ApplicationManager.Manager.GetInstances<ILossFunction>().Select(l => new StringValue(l.ToString()).AsReadOnly());
    167       Parameters.Add(new ConstrainedValueParameter<StringValue>(LossFunctionParameterName, "The loss function", new ItemSet<StringValue>(lossFunctionNames)));
    168       LossFunctionParameter.ActualValue = LossFunctionParameter.ValidValues.First(l => l.Value.Contains("Squared")); // squared error loss is the default
    169     }
    170 
     166      var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
     167      Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
     168      LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default
     169    }
     170
     171    [StorableHook(HookType.AfterDeserialization)]
     172    private void AfterDeserialization() {
     173      // BackwardsCompatibility3.4
     174      #region Backwards compatible code, remove with 3.5
     175      // parameter type has been changed
     176      var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>;
     177      if (lossFunctionParam != null) {
     178        Parameters.Remove(LossFunctionParameterName);
     179        var selectedValue = lossFunctionParam.Value; // to be restored below
     180
     181        var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
     182        Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
     183        // try to restore selected value
     184        var selectedLossFunction =
     185          LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value);
     186        if (selectedLossFunction != null) {
     187          LossFunctionParameter.Value = selectedLossFunction;
     188        } else {
     189          LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE
     190        }
     191      }
     192      #endregion
     193    }
    171194
    172195    protected override void Run(CancellationToken cancellationToken) {
     
    187210      // init
    188211      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    189       var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>()
    190         .Single(l => l.ToString() == LossFunctionParameter.Value.Value);
     212      var lossFunction = LossFunctionParameter.Value;
    191213      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);
    192214
     
    233255      // produce solution
    234256      if (CreateSolution) {
    235         var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction.ToString(),
     257        var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction,
    236258          Iterations, MaxSize, R, M, Nu, state.GetModel());
    237259
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs

    r12869 r12871  
    4545    private readonly uint seed;
    4646    [Storable]
    47     private string lossFunctionName;
     47    private ILossFunction lossFunction;
    4848    [Storable]
    4949    private double r;
     
    6666
    6767      this.trainingProblemData = cloner.Clone(original.trainingProblemData);
     68      this.lossFunction = cloner.Clone(original.lossFunction);
    6869      this.seed = original.seed;
    69       this.lossFunctionName = original.lossFunctionName;
    7070      this.iterations = original.iterations;
    7171      this.maxSize = original.maxSize;
     
    7676
    7777    // create only the surrogate model without an actual model
    78     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu)
     78    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)
    7979      : base("Gradient boosted tree model", string.Empty) {
    8080      this.trainingProblemData = trainingProblemData;
    8181      this.seed = seed;
    82       this.lossFunctionName = lossFunctionName;
     82      this.lossFunction = lossFunction;
    8383      this.iterations = iterations;
    8484      this.maxSize = maxSize;
     
    8989
    9090    // wrap an actual model in a surrograte
    91     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)
    92       : this(trainingProblemData, seed, lossFunctionName, iterations, maxSize, r, m, nu) {
     91    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)
     92      : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) {
    9393      this.actualModel = model;
    9494    }
     
    110110
    111111    private IRegressionModel RecalculateModel() {
    112       var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>().Single(l => l.ToString() == lossFunctionName);
    113112      return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model;
    114113    }
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs

    r12700 r12871  
    2323using System;
    2424using System.Collections.Generic;
    25 using System.Diagnostics;
    26 using System.Linq;
    2725using HeuristicLab.Common;
     26using HeuristicLab.Core;
     27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2828
    2929namespace HeuristicLab.Algorithms.DataAnalysis {
    3030  // loss function for the weighted absolute error
    31   public class AbsoluteErrorLoss : ILossFunction {
     31  [StorableClass]
     32  [Item("Absolute error loss", "")]
     33  public class AbsoluteErrorLoss : Item, ILossFunction {
     34    public AbsoluteErrorLoss() { }
     35
    3236    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3337      var targetEnum = target.GetEnumerator();
     
    7781    }
    7882
    79     public override string ToString() {
    80       return "Absolute error loss";
     83    #region item implementation
     84    private AbsoluteErrorLoss(AbsoluteErrorLoss original, Cloner cloner) : base(original, cloner) { }
     85
     86    public override IDeepCloneable Clone(Cloner cloner) {
     87      return new AbsoluteErrorLoss(this, cloner);
    8188    }
     89    #endregion
    8290  }
    8391}
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs

    r12700 r12871  
    2222
    2323using System.Collections.Generic;
     24using HeuristicLab.Core;
    2425
    2526namespace HeuristicLab.Algorithms.DataAnalysis {
     
    2728  // target represents the target vector  (original targets from the problem data, never changed)
    2829  // pred   represents the current vector of predictions (a weighted combination of models learned so far, this vector is updated after each step)
    29   public interface ILossFunction {
     30  public interface ILossFunction : IItem {
    3031    // returns the loss of the current prediction vector
    3132    double GetLoss(IEnumerable<double> target, IEnumerable<double> pred);
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs

    r12700 r12871  
    2626using System.Linq;
    2727using HeuristicLab.Common;
     28using HeuristicLab.Core;
     29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2830
    2931namespace HeuristicLab.Algorithms.DataAnalysis {
    3032  // Greedy Function Approximation: A Gradient Boosting Machine (page 9)
    31   public class LogisticRegressionLoss : ILossFunction {
     33  [StorableClass]
     34  [Item("Logistic regression loss", "")]
     35  public class LogisticRegressionLoss : Item, ILossFunction {
     36    public LogisticRegressionLoss() { }
     37
    3238    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3339      var targetEnum = target.GetEnumerator();
     
    8389    }
    8490
    85     public override string ToString() {
    86       return "Logistic regression loss";
     91    #region item implementation
     92    private LogisticRegressionLoss(LogisticRegressionLoss original, Cloner cloner) : base(original, cloner) { }
     93
     94    public override IDeepCloneable Clone(Cloner cloner) {
     95      return new LogisticRegressionLoss(this, cloner);
    8796    }
     97    #endregion
     98
    8899  }
    89100}
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs

    r12700 r12871  
    2626using System.Linq;
    2727using HeuristicLab.Common;
     28using HeuristicLab.Core;
     29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2830
    2931namespace HeuristicLab.Algorithms.DataAnalysis {
    3032  // relative error loss is a special case of weighted absolute error loss with weights = (1/target)
    31   public class RelativeErrorLoss : ILossFunction {
     33  [StorableClass]
     34  [Item("Relative error loss", "")]
     35  public class RelativeErrorLoss : Item, ILossFunction {
     36    public RelativeErrorLoss() { }
     37
    3238    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3339      var targetEnum = target.GetEnumerator();
     
    105111    }
    106112
    107     public override string ToString() {
    108       return "Relative error loss";
     113    #region item implementation
     114    private RelativeErrorLoss(RelativeErrorLoss original, Cloner cloner) : base(original, cloner) { }
     115
     116    public override IDeepCloneable Clone(Cloner cloner) {
     117      return new RelativeErrorLoss(this, cloner);
    109118    }
     119    #endregion
    110120  }
    111121}
  • branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs

    r12700 r12871  
    2424using System.Collections.Generic;
    2525using System.Linq;
     26using HeuristicLab.Common;
     27using HeuristicLab.Core;
     28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2629
    2730namespace HeuristicLab.Algorithms.DataAnalysis {
    28   public class SquaredErrorLoss : ILossFunction {
     31  [StorableClass]
     32  [Item("Squared error loss", "")]
     33  public class SquaredErrorLoss : Item, ILossFunction {
     34    public SquaredErrorLoss() { }
     35
    2936    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3037      var targetEnum = target.GetEnumerator();
     
    7077    }
    7178
    72     public override string ToString() {
    73       return "Squared error loss";
     79    #region item implementation
     80    private SquaredErrorLoss(SquaredErrorLoss original, Cloner cloner) : base(original, cloner) { }
     81
     82    public override IDeepCloneable Clone(Cloner cloner) {
     83      return new SquaredErrorLoss(this, cloner);
    7484    }
     85    #endregion
    7586  }
    7687}
Note: See TracChangeset for help on using the changeset viewer.