Free cookie consent management tool by TermsFeed Policy Generator

Changeset 13184


Ignore:
Timestamp:
11/16/15 19:49:40 (9 years ago)
Author:
gkronber
Message:

#2450: merged r12868,r12873,r12875,r13065:13066,r13157:13158 from trunk to stable

Location:
stable
Files:
15 edited
3 copied

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs

    r12632 r13184  
    8282      get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    8383    }
    84     public IConstrainedValueParameter<StringValue> LossFunctionParameter {
    85       get { return (IConstrainedValueParameter<StringValue>)Parameters[LossFunctionParameterName]; }
     84    public IConstrainedValueParameter<ILossFunction> LossFunctionParameter {
     85      get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; }
    8686    }
    8787    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
     
    164164      Parameters[CreateSolutionParameterName].Hidden = true;
    165165
    166       var lossFunctionNames = ApplicationManager.Manager.GetInstances<ILossFunction>().Select(l => new StringValue(l.ToString()).AsReadOnly());
    167       Parameters.Add(new ConstrainedValueParameter<StringValue>(LossFunctionParameterName, "The loss function", new ItemSet<StringValue>(lossFunctionNames)));
    168       LossFunctionParameter.ActualValue = LossFunctionParameter.ValidValues.First(l => l.Value.Contains("Squared")); // squared error loss is the default
    169     }
    170 
     166      var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
     167      Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
     168      LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default
     169    }
     170
     171    [StorableHook(HookType.AfterDeserialization)]
     172    private void AfterDeserialization() {
     173      // BackwardsCompatibility3.4
     174      #region Backwards compatible code, remove with 3.5
     175      // parameter type has been changed
     176      var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>;
     177      if (lossFunctionParam != null) {
     178        Parameters.Remove(LossFunctionParameterName);
     179        var selectedValue = lossFunctionParam.Value; // to be restored below
     180
     181        var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
     182        Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
     183        // try to restore selected value
     184        var selectedLossFunction =
     185          LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value);
     186        if (selectedLossFunction != null) {
     187          LossFunctionParameter.Value = selectedLossFunction;
     188        } else {
     189          LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE
     190        }
     191      }
     192      #endregion
     193    }
    171194
    172195    protected override void Run(CancellationToken cancellationToken) {
     
    187210      // init
    188211      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    189       var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>()
    190         .Single(l => l.ToString() == LossFunctionParameter.Value.Value);
     212      var lossFunction = LossFunctionParameter.Value;
    191213      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);
    192214
     
    233255      // produce solution
    234256      if (CreateSolution) {
     257        var model = state.GetModel();
     258
    235259        // for logistic regression we produce a classification solution
    236260        if (lossFunction is LogisticRegressionLoss) {
    237           var model = new DiscriminantFunctionClassificationModel(state.GetModel(),
     261          var classificationModel = new DiscriminantFunctionClassificationModel(model,
    238262            new AccuracyMaximizationThresholdCalculator());
    239263          var classificationProblemData = new ClassificationProblemData(problemData.Dataset,
    240264            problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations);
    241           model.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);
    242 
    243           var classificationSolution = new DiscriminantFunctionClassificationSolution(model, classificationProblemData);
     265          classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);
     266
     267          var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData);
    244268          Results.Add(new Result("Solution", classificationSolution));
    245269        } else {
    246270          // otherwise we produce a regression solution
    247           Results.Add(new Result("Solution", new RegressionSolution(state.GetModel(), problemData)));
     271          Results.Add(new Result("Solution", new RegressionSolution(model, problemData)));
    248272        }
    249273      }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r13156 r13184  
    5252      internal RegressionTreeBuilder treeBuilder { get; private set; }
    5353
     54      private readonly uint randSeed;
    5455      private MersenneTwister random { get; set; }
    5556
     
    7172        this.m = m;
    7273
     74        this.randSeed = randSeed;
    7375        random = new MersenneTwister(randSeed);
    7476        this.problemData = problemData;
     
    99101
    100102      public IRegressionModel GetModel() {
    101         return new GradientBoostedTreesModel(models, weights);
     103#pragma warning disable 618
     104        var model = new GradientBoostedTreesModel(models, weights);
     105#pragma warning restore 618
     106        // we don't know the number of iterations here but the number of weights is equal
     107        // to the number of iterations + 1 (for the constant model)
     108        // wrap the actual model in a surrogate that enables persistence and lazy recalculation of the model if necessary
     109        return new GradientBoostedTreesModelSurrogate(problemData, randSeed, lossFunction, weights.Count - 1, maxSize, r, m, nu, model);
    102110      }
    103111      public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() {
     
    122130
    123131    // simple interface
    124     public static IRegressionSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxSize, double nu, double r, double m, int maxIterations, uint randSeed = 31415) {
     132    public static GradientBoostedTreesSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxSize, double nu, double r, double m, int maxIterations, uint randSeed = 31415) {
    125133      Contract.Assert(r > 0);
    126134      Contract.Assert(r <= 1.0);
     
    135143
    136144      var model = state.GetModel();
    137       return new RegressionSolution(model, (IRegressionProblemData)problemData.Clone());
     145      return new GradientBoostedTreesSolution(model, (IRegressionProblemData)problemData.Clone());
    138146    }
    139147
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs

    r12660 r13184  
    3333  [Item("Gradient boosted tree model", "")]
    3434  // this is essentially a collection of weighted regression models
    35   public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {
    36     [Storable]
     35  public sealed class GradientBoostedTreesModel : NamedItem, IGradientBoostedTreesModel {
     36    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
     37    #region Backwards compatible code, remove with 3.5
     38    private bool isCompatibilityLoaded = false; // only set to true if the model is deserialized from the old format, needed to make sure that information is serialized again if it was loaded from the old format
     39
     40    [Storable(Name = "models")]
     41    private IList<IRegressionModel> __persistedModels {
     42      set {
     43        this.isCompatibilityLoaded = true;
     44        this.models.Clear();
     45        foreach (var m in value) this.models.Add(m);
     46      }
     47      get { if (this.isCompatibilityLoaded) return models; else return null; }
     48    }
     49    [Storable(Name = "weights")]
     50    private IList<double> __persistedWeights {
     51      set {
     52        this.isCompatibilityLoaded = true;
     53        this.weights.Clear();
     54        foreach (var w in value) this.weights.Add(w);
     55      }
     56      get { if (this.isCompatibilityLoaded) return weights; else return null; }
     57    }
     58    #endregion
     59
    3760    private readonly IList<IRegressionModel> models;
    3861    public IEnumerable<IRegressionModel> Models { get { return models; } }
    3962
    40     [Storable]
    4163    private readonly IList<double> weights;
    4264    public IEnumerable<double> Weights { get { return weights; } }
    4365
    4466    [StorableConstructor]
    45     private GradientBoostedTreesModel(bool deserializing) : base(deserializing) { }
     67    private GradientBoostedTreesModel(bool deserializing)
     68      : base(deserializing) {
     69      models = new List<IRegressionModel>();
     70      weights = new List<double>();
     71    }
    4672    private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
    4773      : base(original, cloner) {
    4874      this.weights = new List<double>(original.weights);
    4975      this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
     76      this.isCompatibilityLoaded = original.isCompatibilityLoaded;
    5077    }
     78    [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")]
    5179    public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
    5280      : base("Gradient boosted tree model", string.Empty) {
     
    6492      // allocate target array go over all models and add up weighted estimation for each row
    6593      if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable.
    66                                                           // (which essentially looks up indexes in a dictionary)
     94      // (which essentially looks up indexes in a dictionary)
    6795      var res = new double[rows.Count()];
    6896      for (int i = 0; i < models.Count; i++) {
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs

    r12868 r13184  
    2121#endregion
    2222
    23 using System;
    2423using System.Collections.Generic;
    25 using System.Linq;
    2624using HeuristicLab.Common;
    2725using HeuristicLab.Core;
    2826using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    29 using HeuristicLab.PluginInfrastructure;
    3027using HeuristicLab.Problems.DataAnalysis;
    3128
     
    3633  // recalculate the actual GBT model on demand
    3734  [Item("Gradient boosted tree model", "")]
    38   public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IRegressionModel {
     35  public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IGradientBoostedTreesModel {
    3936    // don't store the actual model!
    40     private IRegressionModel actualModel; // the actual model is only recalculated when necessary
     37    private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary
    4138
    4239    [Storable]
     
    4542    private readonly uint seed;
    4643    [Storable]
    47     private string lossFunctionName;
     44    private ILossFunction lossFunction;
    4845    [Storable]
    4946    private double r;
     
    6663
    6764      this.trainingProblemData = cloner.Clone(original.trainingProblemData);
     65      this.lossFunction = cloner.Clone(original.lossFunction);
    6866      this.seed = original.seed;
    69       this.lossFunctionName = original.lossFunctionName;
    7067      this.iterations = original.iterations;
    7168      this.maxSize = original.maxSize;
     
    7673
    7774    // create only the surrogate model without an actual model
    78     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu)
     75    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu)
    7976      : base("Gradient boosted tree model", string.Empty) {
    8077      this.trainingProblemData = trainingProblemData;
    8178      this.seed = seed;
    82       this.lossFunctionName = lossFunctionName;
     79      this.lossFunction = lossFunction;
    8380      this.iterations = iterations;
    8481      this.maxSize = maxSize;
     
    8986
    9087    // wrap an actual model in a surrograte
    91     public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, string lossFunctionName, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model)
    92       : this(trainingProblemData, seed, lossFunctionName, iterations, maxSize, r, m, nu) {
     88    public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IGradientBoostedTreesModel model)
     89      : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) {
    9390      this.actualModel = model;
    9491    }
     
    109106
    110107
    111     private IRegressionModel RecalculateModel() {
    112       var lossFunction = ApplicationManager.Manager.GetInstances<ILossFunction>().Single(l => l.ToString() == lossFunctionName);
     108    private IGradientBoostedTreesModel RecalculateModel() {
    113109      return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model;
     110    }
     111
     112    public IEnumerable<IRegressionModel> Models {
     113      get {
     114        if (actualModel == null) actualModel = RecalculateModel();
     115        return actualModel.Models;
     116      }
     117    }
     118
     119    public IEnumerable<double> Weights {
     120      get {
     121        if (actualModel == null) actualModel = RecalculateModel();
     122        return actualModel.Weights;
     123      }
    114124    }
    115125  }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs

    r12700 r13184  
    2323using System;
    2424using System.Collections.Generic;
    25 using System.Diagnostics;
    26 using System.Linq;
    2725using HeuristicLab.Common;
     26using HeuristicLab.Core;
     27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2828
    2929namespace HeuristicLab.Algorithms.DataAnalysis {
    3030  // loss function for the weighted absolute error
    31   public class AbsoluteErrorLoss : ILossFunction {
     31  [StorableClass]
     32  [Item("Absolute error loss", "")]
     33  public sealed class AbsoluteErrorLoss : Item, ILossFunction {
     34    public AbsoluteErrorLoss() { }
     35
    3236    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3337      var targetEnum = target.GetEnumerator();
     
    7781    }
    7882
    79     public override string ToString() {
    80       return "Absolute error loss";
     83    #region item implementation
     84    [StorableConstructor]
     85    private AbsoluteErrorLoss(bool deserializing) : base(deserializing) { }
     86
     87    private AbsoluteErrorLoss(AbsoluteErrorLoss original, Cloner cloner) : base(original, cloner) { }
     88
     89    public override IDeepCloneable Clone(Cloner cloner) {
     90      return new AbsoluteErrorLoss(this, cloner);
    8191    }
     92    #endregion
    8293  }
    8394}
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs

    r12700 r13184  
    2222
    2323using System.Collections.Generic;
     24using HeuristicLab.Core;
    2425
    2526namespace HeuristicLab.Algorithms.DataAnalysis {
     
    2728  // target represents the target vector  (original targets from the problem data, never changed)
    2829  // pred   represents the current vector of predictions (a weighted combination of models learned so far, this vector is updated after each step)
    29   public interface ILossFunction {
     30  public interface ILossFunction : IItem {
    3031    // returns the loss of the current prediction vector
    3132    double GetLoss(IEnumerable<double> target, IEnumerable<double> pred);
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs

    r12700 r13184  
    2626using System.Linq;
    2727using HeuristicLab.Common;
     28using HeuristicLab.Core;
     29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2830
    2931namespace HeuristicLab.Algorithms.DataAnalysis {
    3032  // Greedy Function Approximation: A Gradient Boosting Machine (page 9)
    31   public class LogisticRegressionLoss : ILossFunction {
     33  [StorableClass]
     34  [Item("Logistic regression loss", "")]
     35  public sealed class LogisticRegressionLoss : Item, ILossFunction {
     36    public LogisticRegressionLoss() { }
     37
    3238    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3339      var targetEnum = target.GetEnumerator();
     
    8389    }
    8490
    85     public override string ToString() {
    86       return "Logistic regression loss";
     91    #region item implementation
     92    [StorableConstructor]
     93    private LogisticRegressionLoss(bool deserializing) : base(deserializing) { }
     94
     95    private LogisticRegressionLoss(LogisticRegressionLoss original, Cloner cloner) : base(original, cloner) { }
     96
     97    public override IDeepCloneable Clone(Cloner cloner) {
     98      return new LogisticRegressionLoss(this, cloner);
    8799    }
     100    #endregion
     101
    88102  }
    89103}
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs

    r12700 r13184  
    2626using System.Linq;
    2727using HeuristicLab.Common;
     28using HeuristicLab.Core;
     29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2830
    2931namespace HeuristicLab.Algorithms.DataAnalysis {
    3032  // relative error loss is a special case of weighted absolute error loss with weights = (1/target)
    31   public class RelativeErrorLoss : ILossFunction {
     33  [StorableClass]
     34  [Item("Relative error loss", "")]
     35  public sealed class RelativeErrorLoss : Item, ILossFunction {
     36    public RelativeErrorLoss() { }
     37
    3238    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3339      var targetEnum = target.GetEnumerator();
     
    105111    }
    106112
    107     public override string ToString() {
    108       return "Relative error loss";
     113    #region item implementation
     114    [StorableConstructor]
     115    private RelativeErrorLoss(bool deserializing) : base(deserializing) { }
     116
     117    private RelativeErrorLoss(RelativeErrorLoss original, Cloner cloner) : base(original, cloner) { }
     118
     119    public override IDeepCloneable Clone(Cloner cloner) {
     120      return new RelativeErrorLoss(this, cloner);
    109121    }
     122    #endregion
    110123  }
    111124}
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs

    r12700 r13184  
    2424using System.Collections.Generic;
    2525using System.Linq;
     26using HeuristicLab.Common;
     27using HeuristicLab.Core;
     28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2629
    2730namespace HeuristicLab.Algorithms.DataAnalysis {
    28   public class SquaredErrorLoss : ILossFunction {
     31  [StorableClass]
     32  [Item("Squared error loss", "")]
     33  public sealed class SquaredErrorLoss : Item, ILossFunction {
     34    public SquaredErrorLoss() { }
     35
    2936    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3037      var targetEnum = target.GetEnumerator();
     
    7077    }
    7178
    72     public override string ToString() {
    73       return "Squared error loss";
     79    #region item implementation
     80    [StorableConstructor]
     81    private SquaredErrorLoss(bool deserializing) : base(deserializing) { }
     82
     83    private SquaredErrorLoss(SquaredErrorLoss original, Cloner cloner) : base(original, cloner) { }
     84
     85    public override IDeepCloneable Clone(Cloner cloner) {
     86      return new SquaredErrorLoss(this, cloner);
    7487    }
     88    #endregion
    7589  }
    7690}
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12700 r13184  
    119119    }
    120120
    121     // simple API produces a single regression tree optimizing sum of squared errors
    122     // this can be used if only a simple regression tree should be produced
    123     // for a set of trees use the method CreateRegressionTreeForGradientBoosting below
    124     //
    125     // r and m work in the same way as for alglib random forest
    126     // r is fraction of rows to use for training
    127     // m is fraction of variables to use for training
    128     public IRegressionModel CreateRegressionTree(int maxSize, double r = 0.5, double m = 0.5) {
    129       // subtract mean of y first
    130       var yAvg = y.Average();
    131       for (int i = 0; i < y.Length; i++) y[i] -= yAvg;
    132 
    133       var seLoss = new SquaredErrorLoss();
    134 
    135       var model = CreateRegressionTreeForGradientBoosting(y, curPred, maxSize, problemData.TrainingIndices.ToArray(), seLoss, r, m);
    136 
    137       return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });
    138     }
    139 
    140121    // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees
    141122    public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, double[] curPred, int maxSize, int[] idx, ILossFunction lossFunction, double r = 0.5, double m = 0.5) {
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r12700 r13184  
    8282    [Storable]
    8383    // to prevent storing the references to data caches in nodes
     84    // seemingly it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) TODO
    8485    private Tuple<string, double, int, int>[] SerializedTree {
    8586      get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r13156 r13184  
    200200    <Compile Include="GaussianProcess\GaussianProcessRegressionSolution.cs" />
    201201    <Compile Include="GaussianProcess\ICovarianceFunction.cs" />
     202    <Compile Include="GradientBoostedTrees\IGradientBoostedTreesModel.cs" />
     203    <Compile Include="GradientBoostedTrees\GradientBoostedTreesModelSurrogate.cs" />
    202204    <Compile Include="GradientBoostedTrees\GradientBoostedTreesAlgorithm.cs" />
    203205    <Compile Include="GradientBoostedTrees\GradientBoostedTreesAlgorithmStatic.cs" />
     
    209211    <Compile Include="GradientBoostedTrees\LossFunctions\RelativeErrorLoss.cs" />
    210212    <Compile Include="GradientBoostedTrees\LossFunctions\SquaredErrorLoss.cs" />
     213    <Compile Include="GradientBoostedTrees\GradientBoostedTreesSolution.cs" />
    211214    <Compile Include="GradientBoostedTrees\RegressionTreeBuilder.cs" />
    212215    <Compile Include="GradientBoostedTrees\RegressionTreeModel.cs" />
  • stable/HeuristicLab.Tests

  • stable/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs

    r12711 r13184  
    269269      problemData.TestPartition.End = nRows;
    270270      var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), maxSize, nu: 1, r: 1, m: 1, maxIterations: 1, randSeed: 31415);
    271       var model = (GradientBoostedTreesModel)solution.Model;
     271      var model = solution.Model;
    272272      var treeM = model.Models.Skip(1).First() as RegressionTreeModel;
    273273
Note: See TracChangeset for help on using the changeset viewer.