1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using HeuristicLab.Common;
|
---|
5 | using HeuristicLab.Core;
|
---|
6 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
7 | using HeuristicLab.Problems.DataAnalysis;
|
---|
8 |
|
---|
9 | namespace GradientBoostedTrees {
|
---|
10 | [Item("GradientBoostedTreesSolution", "")]
|
---|
11 | [StorableClass]
|
---|
12 | public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {
|
---|
13 |
|
---|
14 | [Storable]
|
---|
15 | private readonly IList<IRegressionModel> models;
|
---|
16 | [Storable]
|
---|
17 | private readonly IList<double> weights;
|
---|
18 |
|
---|
19 | [StorableConstructor]
|
---|
20 | private GradientBoostedTreesModel(bool deserializing) : base(deserializing) { }
|
---|
21 | private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
|
---|
22 | : base(original, cloner) {
|
---|
23 | this.weights = new List<double>(original.weights);
|
---|
24 | this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
|
---|
25 | }
|
---|
26 | public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
|
---|
27 | : base() {
|
---|
28 | this.models = new List<IRegressionModel>(models);
|
---|
29 | this.weights = new List<double>(weights);
|
---|
30 |
|
---|
31 | if (this.models.Count != this.weights.Count) throw new ArgumentException();
|
---|
32 | }
|
---|
33 |
|
---|
34 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
35 | return new GradientBoostedTreesModel(this, cloner);
|
---|
36 | }
|
---|
37 |
|
---|
38 | public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
|
---|
39 | var tuples = (from idx in Enumerable.Range(0, models.Count)
|
---|
40 | let model = models[idx]
|
---|
41 | let weight = weights[idx]
|
---|
42 | select new { weight, enumerator = model.GetEstimatedValues(dataset, rows).GetEnumerator() }).ToArray();
|
---|
43 |
|
---|
44 |
|
---|
45 | while (tuples.All(t => t.enumerator.MoveNext())) {
|
---|
46 | yield return tuples.Sum(t => t.weight * t.enumerator.Current);
|
---|
47 | }
|
---|
48 | }
|
---|
49 |
|
---|
50 | public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
|
---|
51 | return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
|
---|
52 | }
|
---|
53 | }
|
---|
54 | }
|
---|