source: branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs @ 12332

Last change on this file since 12332 was 12332, checked in by gkronber, 5 years ago

#2261: initial import of gradient boosted trees for regression

File size: 2.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
7using HeuristicLab.Problems.DataAnalysis;
8
9namespace GradientBoostedTrees {
10  [Item("GradientBoostedTreesSolution", "")]
11  [StorableClass]
12  public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {
13
14    [Storable]
15    private readonly IList<IRegressionModel> models;
16    [Storable]
17    private readonly IList<double> weights;
18
19    [StorableConstructor]
20    private GradientBoostedTreesModel(bool deserializing) : base(deserializing) { }
21    private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
22      : base(original, cloner) {
23      this.weights = new List<double>(original.weights);
24      this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
25    }
26    public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
27      : base() {
28      this.models = new List<IRegressionModel>(models);
29      this.weights = new List<double>(weights);
30
31      if (this.models.Count != this.weights.Count) throw new ArgumentException();
32    }
33
34    public override IDeepCloneable Clone(Cloner cloner) {
35      return new GradientBoostedTreesModel(this, cloner);
36    }
37
38    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
39      var tuples = (from idx in Enumerable.Range(0, models.Count)
40                    let model = models[idx]
41                    let weight = weights[idx]
42                    select new { weight, enumerator = model.GetEstimatedValues(dataset, rows).GetEnumerator() }).ToArray();
43
44
45      while (tuples.All(t => t.enumerator.MoveNext())) {
46        yield return tuples.Sum(t => t.weight * t.enumerator.Current);
47      }
48    }
49
50    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
51      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
52    }
53  }
54}
Note: See TracBrowser for help on using the repository browser.