- Timestamp:
- 08/17/15 18:33:31 (9 years ago)
- Location:
- branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Files:
-
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
r12871 r12872 82 82 get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 83 83 } 84 public IConstrainedValueParameter< ILossFunction> LossFunctionParameter {85 get { return (IConstrainedValueParameter< ILossFunction>)Parameters[LossFunctionParameterName]; }84 public IConstrainedValueParameter<StringValue> LossFunctionParameter { 85 get { return (IConstrainedValueParameter<StringValue>)Parameters[LossFunctionParameterName]; } 86 86 } 87 87 public IFixedValueParameter<IntValue> UpdateIntervalParameter { … … 164 164 Parameters[CreateSolutionParameterName].Hidden = true; 165 165 166 var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>(); 167 Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions))); 168 LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default 169 } 170 171 [StorableHook(HookType.AfterDeserialization)] 172 private void AfterDeserialization() { 173 // BackwardsCompatibility3.4 174 #region Backwards compatible code, remove with 3.5 175 // parameter type has been changed 176 var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>; 177 if (lossFunctionParam != null) { 178 Parameters.Remove(LossFunctionParameterName); 179 var selectedValue = lossFunctionParam.Value; // to be restored below 180 181 var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>(); 182 Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions))); 183 // try to restore selected value 184 var selectedLossFunction = 185 LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value); 186 if (selectedLossFunction != null) { 187 LossFunctionParameter.Value = selectedLossFunction; 188 } else { 189 LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE 190 } 191 } 192 #endregion 193 } 166 var lossFunctionNames = ApplicationManager.Manager.GetInstances<ILossFunction>().Select(l => new StringValue(l.ToString()).AsReadOnly()); 167 Parameters.Add(new ConstrainedValueParameter<StringValue>(LossFunctionParameterName, "The loss function", new ItemSet<StringValue>(lossFunctionNames))); 168 LossFunctionParameter.ActualValue = LossFunctionParameter.ValidValues.First(l => l.Value.Contains("Squared")); // squared error loss is the default 169 } 170 194 171 195 172 protected override void Run(CancellationToken cancellationToken) { … … 210 187 // init 211 188 var problemData = (IRegressionProblemData)Problem.ProblemData.Clone(); 212 var lossFunction = LossFunctionParameter.Value; 189 var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>() 190 .Single(l => l.ToString() == LossFunctionParameter.Value.Value); 213 191 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu); 214 192 … … 255 233 // produce solution 256 234 if (CreateSolution) { 257 var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction ,235 var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction.ToString(), 258 236 Iterations, MaxSize, R, M, Nu, state.GetModel()); 259 237 -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
r12871 r12872 45 45 private readonly uint seed; 46 46 [Storable] 47 private ILossFunction lossFunction;47 private string lossFunctionName; 48 48 [Storable] 49 49 private double r; … … 66 66 67 67 this.trainingProblemData = cloner.Clone(original.trainingProblemData); 68 this.lossFunction = cloner.Clone(original.lossFunction);69 68 this.seed = original.seed; 69 this.lossFunctionName = original.lossFunctionName; 70 70 this.iterations = original.iterations; 71 71 this.maxSize = original.maxSize; … … 76 76 77 77 // create only the surrogate model without an actual model 78 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)78 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu) 79 79 : base("Gradient boosted tree model", string.Empty) { 80 80 this.trainingProblemData = trainingProblemData; 81 81 this.seed = seed; 82 this.lossFunction = lossFunction;82 this.lossFunctionName = lossFunctionName; 83 83 this.iterations = iterations; 84 84 this.maxSize = maxSize; … … 89 89 90 90 // wrap an actual model in a surrograte 91 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)92 : this(trainingProblemData, seed, lossFunction , iterations, maxSize, r, m, nu) {91 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model) 92 : this(trainingProblemData, seed, lossFunctionName, iterations, maxSize, r, m, nu) { 93 93 this.actualModel = model; 94 94 } … … 110 110 111 111 private IRegressionModel RecalculateModel() { 112 var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>().Single(l => l.ToString() == lossFunctionName); 112 113 return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model; 113 114 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs
r12871 r12872 23 23 using System; 24 24 using System.Collections.Generic; 25 using System.Diagnostics; 26 using System.Linq; 25 27 using HeuristicLab.Common; 26 using HeuristicLab.Core;27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;28 28 29 29 namespace HeuristicLab.Algorithms.DataAnalysis { 30 30 // loss function for the weighted absolute error 31 [StorableClass] 32 [Item("Absolute error loss", "")] 33 public class AbsoluteErrorLoss : Item, ILossFunction { 34 public AbsoluteErrorLoss() { } 35 31 public class AbsoluteErrorLoss : ILossFunction { 36 32 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) { 37 33 var targetEnum = target.GetEnumerator(); … … 81 77 } 82 78 83 #region item implementation 84 private AbsoluteErrorLoss(AbsoluteErrorLoss original, Cloner cloner) : base(original, cloner) { } 85 86 public override IDeepCloneable Clone(Cloner cloner) { 87 return new AbsoluteErrorLoss(this, cloner); 79 public override string ToString() { 80 return "Absolute error loss"; 88 81 } 89 #endregion90 82 } 91 83 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs
r12871 r12872 22 22 23 23 using System.Collections.Generic; 24 using HeuristicLab.Core;25 24 26 25 namespace HeuristicLab.Algorithms.DataAnalysis { … … 28 27 // target represents the target vector (original targets from the problem data, never changed) 29 28 // pred represents the current vector of predictions (a weighted combination of models learned so far, this vector is updated after each step) 30 public interface ILossFunction : IItem{29 public interface ILossFunction { 31 30 // returns the loss of the current prediction vector 32 31 double GetLoss(IEnumerable<double> target, IEnumerable<double> pred); -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs
r12871 r12872 26 26 using System.Linq; 27 27 using HeuristicLab.Common; 28 using HeuristicLab.Core;29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;30 28 31 29 namespace HeuristicLab.Algorithms.DataAnalysis { 32 30 // Greedy Function Approximation: A Gradient Boosting Machine (page 9) 33 [StorableClass] 34 [Item("Logistic regression loss", "")] 35 public class LogisticRegressionLoss : Item, ILossFunction { 36 public LogisticRegressionLoss() { } 37 31 public class LogisticRegressionLoss : ILossFunction { 38 32 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) { 39 33 var targetEnum = target.GetEnumerator(); … … 89 83 } 90 84 91 #region item implementation 92 private LogisticRegressionLoss(LogisticRegressionLoss original, Cloner cloner) : base(original, cloner) { } 93 94 public override IDeepCloneable Clone(Cloner cloner) { 95 return new LogisticRegressionLoss(this, cloner); 85 public override string ToString() { 86 return "Logistic regression loss"; 96 87 } 97 #endregion98 99 88 } 100 89 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs
r12871 r12872 26 26 using System.Linq; 27 27 using HeuristicLab.Common; 28 using HeuristicLab.Core;29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;30 28 31 29 namespace HeuristicLab.Algorithms.DataAnalysis { 32 30 // relative error loss is a special case of weighted absolute error loss with weights = (1/target) 33 [StorableClass] 34 [Item("Relative error loss", "")] 35 public class RelativeErrorLoss : Item, ILossFunction { 36 public RelativeErrorLoss() { } 37 31 public class RelativeErrorLoss : ILossFunction { 38 32 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) { 39 33 var targetEnum = target.GetEnumerator(); … … 111 105 } 112 106 113 #region item implementation 114 private RelativeErrorLoss(RelativeErrorLoss original, Cloner cloner) : base(original, cloner) { } 115 116 public override IDeepCloneable Clone(Cloner cloner) { 117 return new RelativeErrorLoss(this, cloner); 107 public override string ToString() { 108 return "Relative error loss"; 118 109 } 119 #endregion120 110 } 121 111 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs
r12871 r12872 24 24 using System.Collections.Generic; 25 25 using System.Linq; 26 using HeuristicLab.Common;27 using HeuristicLab.Core;28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;29 26 30 27 namespace HeuristicLab.Algorithms.DataAnalysis { 31 [StorableClass] 32 [Item("Squared error loss", "")] 33 public class SquaredErrorLoss : Item, ILossFunction { 34 public SquaredErrorLoss() { } 35 28 public class SquaredErrorLoss : ILossFunction { 36 29 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) { 37 30 var targetEnum = target.GetEnumerator(); … … 77 70 } 78 71 79 #region item implementation 80 private SquaredErrorLoss(SquaredErrorLoss original, Cloner cloner) : base(original, cloner) { } 81 82 public override IDeepCloneable Clone(Cloner cloner) { 83 return new SquaredErrorLoss(this, cloner); 72 public override string ToString() { 73 return "Squared error loss"; 84 74 } 85 #endregion86 75 } 87 76 }
Note: See TracChangeset
for help on using the changeset viewer.