using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace GradientBoostedTrees { [Item("GradientBoostedTreesSolution", "")] [StorableClass] public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel { [Storable] private readonly IList models; [Storable] private readonly IList weights; [StorableConstructor] private GradientBoostedTreesModel(bool deserializing) : base(deserializing) { } private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner) : base(original, cloner) { this.weights = new List(original.weights); this.models = new List(original.models.Select(m => cloner.Clone(m))); } public GradientBoostedTreesModel(IEnumerable models, IEnumerable weights) : base() { this.models = new List(models); this.weights = new List(weights); if (this.models.Count != this.weights.Count) throw new ArgumentException(); } public override IDeepCloneable Clone(Cloner cloner) { return new GradientBoostedTreesModel(this, cloner); } public IEnumerable GetEstimatedValues(Dataset dataset, IEnumerable rows) { var tuples = (from idx in Enumerable.Range(0, models.Count) let model = models[idx] let weight = weights[idx] select new { weight, enumerator = model.GetEstimatedValues(dataset, rows).GetEnumerator() }).ToArray(); while (tuples.All(t => t.enumerator.MoveNext())) { yield return tuples.Sum(t => t.weight * t.enumerator.Current); } } public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); } } }