Changeset 12871 for branches/crossvalidation-2434
- Timestamp:
- 08/17/15 18:31:39 (9 years ago)
- Location:
- branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Files:
-
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
r12869 r12871 82 82 get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 83 83 } 84 public IConstrainedValueParameter< StringValue> LossFunctionParameter {85 get { return (IConstrainedValueParameter< StringValue>)Parameters[LossFunctionParameterName]; }84 public IConstrainedValueParameter<ILossFunction> LossFunctionParameter { 85 get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; } 86 86 } 87 87 public IFixedValueParameter<IntValue> UpdateIntervalParameter { … … 164 164 Parameters[CreateSolutionParameterName].Hidden = true; 165 165 166 var lossFunctionNames = ApplicationManager.Manager.GetInstances<ILossFunction>().Select(l => new StringValue(l.ToString()).AsReadOnly()); 167 Parameters.Add(new ConstrainedValueParameter<StringValue>(LossFunctionParameterName, "The loss function", new ItemSet<StringValue>(lossFunctionNames))); 168 LossFunctionParameter.ActualValue = LossFunctionParameter.ValidValues.First(l => l.Value.Contains("Squared")); // squared error loss is the default 169 } 170 166 var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>(); 167 Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions))); 168 LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default 169 } 170 171 [StorableHook(HookType.AfterDeserialization)] 172 private void AfterDeserialization() { 173 // BackwardsCompatibility3.4 174 #region Backwards compatible code, remove with 3.5 175 // parameter type has been changed 176 var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>; 177 if (lossFunctionParam != null) { 178 Parameters.Remove(LossFunctionParameterName); 179 var selectedValue = lossFunctionParam.Value; // to be restored below 180 181 var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>(); 182 Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions))); 183 // try to restore selected value 184 var selectedLossFunction = 185 LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value); 186 if (selectedLossFunction != null) { 187 LossFunctionParameter.Value = selectedLossFunction; 188 } else { 189 LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE 190 } 191 } 192 #endregion 193 } 171 194 172 195 protected override void Run(CancellationToken cancellationToken) { … … 187 210 // init 188 211 var problemData = (IRegressionProblemData)Problem.ProblemData.Clone(); 189 var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>() 190 .Single(l => l.ToString() == LossFunctionParameter.Value.Value); 212 var lossFunction = LossFunctionParameter.Value; 191 213 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu); 192 214 … … 233 255 // produce solution 234 256 if (CreateSolution) { 235 var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction .ToString(),257 var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction, 236 258 Iterations, MaxSize, R, M, Nu, state.GetModel()); 237 259 -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
r12869 r12871 45 45 private readonly uint seed; 46 46 [Storable] 47 private string lossFunctionName;47 private ILossFunction lossFunction; 48 48 [Storable] 49 49 private double r; … … 66 66 67 67 this.trainingProblemData = cloner.Clone(original.trainingProblemData); 68 this.lossFunction = cloner.Clone(original.lossFunction); 68 69 this.seed = original.seed; 69 this.lossFunctionName = original.lossFunctionName;70 70 this.iterations = original.iterations; 71 71 this.maxSize = original.maxSize; … … 76 76 77 77 // create only the surrogate model without an actual model 78 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu)78 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 79 79 : base("Gradient boosted tree model", string.Empty) { 80 80 this.trainingProblemData = trainingProblemData; 81 81 this.seed = seed; 82 this.lossFunction Name = lossFunctionName;82 this.lossFunction = lossFunction; 83 83 this.iterations = iterations; 84 84 this.maxSize = maxSize; … … 89 89 90 90 // wrap an actual model in a surrograte 91 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)92 : this(trainingProblemData, seed, lossFunction Name, iterations, maxSize, r, m, nu) {91 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model) 92 : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) { 93 93 this.actualModel = model; 94 94 } … … 110 110 111 111 private IRegressionModel RecalculateModel() { 112 var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>().Single(l => l.ToString() == lossFunctionName);113 112 return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model; 114 113 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs
r12700 r12871 23 23 using System; 24 24 using System.Collections.Generic; 25 using System.Diagnostics;26 using System.Linq;27 25 using HeuristicLab.Common; 26 using HeuristicLab.Core; 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 28 29 29 namespace HeuristicLab.Algorithms.DataAnalysis { 30 30 // loss function for the weighted absolute error 31 public class AbsoluteErrorLoss : ILossFunction { 31 [StorableClass] 32 [Item("Absolute error loss", "")] 33 public class AbsoluteErrorLoss : Item, ILossFunction { 34 public AbsoluteErrorLoss() { } 35 32 36 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) { 33 37 var targetEnum = target.GetEnumerator(); … … 77 81 } 78 82 79 public override string ToString() { 80 return "Absolute error loss"; 83 #region item implementation 84 private AbsoluteErrorLoss(AbsoluteErrorLoss original, Cloner cloner) : base(original, cloner) { } 85 86 public override IDeepCloneable Clone(Cloner cloner) { 87 return new AbsoluteErrorLoss(this, cloner); 81 88 } 89 #endregion 82 90 } 83 91 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs
r12700 r12871 22 22 23 23 using System.Collections.Generic; 24 using HeuristicLab.Core; 24 25 25 26 namespace HeuristicLab.Algorithms.DataAnalysis { … … 27 28 // target represents the target vector (original targets from the problem data, never changed) 28 29 // pred represents the current vector of predictions (a weighted combination of models learned so far, this vector is updated after each step) 29 public interface ILossFunction {30 public interface ILossFunction : IItem { 30 31 // returns the loss of the current prediction vector 31 32 double GetLoss(IEnumerable<double> target, IEnumerable<double> pred); -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs
r12700 r12871 26 26 using System.Linq; 27 27 using HeuristicLab.Common; 28 using HeuristicLab.Core; 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 30 29 31 namespace HeuristicLab.Algorithms.DataAnalysis { 30 32 // Greedy Function Approximation: A Gradient Boosting Machine (page 9) 31 public class LogisticRegressionLoss : ILossFunction { 33 [StorableClass] 34 [Item("Logistic regression loss", "")] 35 public class LogisticRegressionLoss : Item, ILossFunction { 36 public LogisticRegressionLoss() { } 37 32 38 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) { 33 39 var targetEnum = target.GetEnumerator(); … … 83 89 } 84 90 85 public override string ToString() { 86 return "Logistic regression loss"; 91 #region item implementation 92 private LogisticRegressionLoss(LogisticRegressionLoss original, Cloner cloner) : base(original, cloner) { } 93 94 public override IDeepCloneable Clone(Cloner cloner) { 95 return new LogisticRegressionLoss(this, cloner); 87 96 } 97 #endregion 98 88 99 } 89 100 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs
r12700 r12871 26 26 using System.Linq; 27 27 using HeuristicLab.Common; 28 using HeuristicLab.Core; 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 30 29 31 namespace HeuristicLab.Algorithms.DataAnalysis { 30 32 // relative error loss is a special case of weighted absolute error loss with weights = (1/target) 31 public class RelativeErrorLoss : ILossFunction { 33 [StorableClass] 34 [Item("Relative error loss", "")] 35 public class RelativeErrorLoss : Item, ILossFunction { 36 public RelativeErrorLoss() { } 37 32 38 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) { 33 39 var targetEnum = target.GetEnumerator(); … … 105 111 } 106 112 107 public override string ToString() { 108 return "Relative error loss"; 113 #region item implementation 114 private RelativeErrorLoss(RelativeErrorLoss original, Cloner cloner) : base(original, cloner) { } 115 116 public override IDeepCloneable Clone(Cloner cloner) { 117 return new RelativeErrorLoss(this, cloner); 109 118 } 119 #endregion 110 120 } 111 121 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs
r12700 r12871 24 24 using System.Collections.Generic; 25 25 using System.Linq; 26 using HeuristicLab.Common; 27 using HeuristicLab.Core; 28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 26 29 27 30 namespace HeuristicLab.Algorithms.DataAnalysis { 28 public class SquaredErrorLoss : ILossFunction { 31 [StorableClass] 32 [Item("Squared error loss", "")] 33 public class SquaredErrorLoss : Item, ILossFunction { 34 public SquaredErrorLoss() { } 35 29 36 public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) { 30 37 var targetEnum = target.GetEnumerator(); … … 70 77 } 71 78 72 public override string ToString() { 73 return "Squared error loss"; 79 #region item implementation 80 private SquaredErrorLoss(SquaredErrorLoss original, Cloner cloner) : base(original, cloner) { } 81 82 public override IDeepCloneable Clone(Cloner cloner) { 83 return new SquaredErrorLoss(this, cloner); 74 84 } 85 #endregion 75 86 } 76 87 }
Note: See TracChangeset
for help on using the changeset viewer.