Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs @ 12868

Last change on this file since 12868 was 12868, checked in by gkronber, 9 years ago

#2450: introduced surrogate for GBT-models which recalculates the actual model on demand to improve persistence of GBT solutions

File size: 4.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  [StorableClass]
33  [Item("Gradient boosted tree model", "")]
34  // this is essentially a collection of weighted regression models
35  public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {
36    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
37    #region Backwards compatible code, remove with 3.5
38    private bool isCompatibilityLoaded = false; // only set to true if the model is deserialized from the old format, needed to make sure that information is serialized again if it was loaded from the old format
39
40    [Storable(Name = "models")]
41    private IList<IRegressionModel> __persistedModels {
42      set {
43        this.isCompatibilityLoaded = true;
44        this.models.Clear();
45        foreach (var m in value) this.models.Add(m);
46      }
47      get { if (this.isCompatibilityLoaded) return models; else return null; }
48    }
49    [Storable(Name = "weights")]
50    private IList<double> __persistedWeights {
51      set {
52        this.isCompatibilityLoaded = true;
53        this.weights.Clear();
54        foreach (var w in value) this.weights.Add(w);
55      }
56      get { if (this.isCompatibilityLoaded) return weights; else return null; }
57    }
58    #endregion
59
60    private readonly IList<IRegressionModel> models;
61    public IEnumerable<IRegressionModel> Models { get { return models; } }
62
63    private readonly IList<double> weights;
64    public IEnumerable<double> Weights { get { return weights; } }
65
66    [StorableConstructor]
67    private GradientBoostedTreesModel(bool deserializing)
68      : base(deserializing) {
69      models = new List<IRegressionModel>();
70      weights = new List<double>();
71    }
72    private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
73      : base(original, cloner) {
74      this.weights = new List<double>(original.weights);
75      this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
76      this.isCompatibilityLoaded = original.isCompatibilityLoaded;
77    }
78    public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
79      : base("Gradient boosted tree model", string.Empty) {
80      this.models = new List<IRegressionModel>(models);
81      this.weights = new List<double>(weights);
82
83      if (this.models.Count != this.weights.Count) throw new ArgumentException();
84    }
85
86    public override IDeepCloneable Clone(Cloner cloner) {
87      return new GradientBoostedTreesModel(this, cloner);
88    }
89
90    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
91      // allocate target array go over all models and add up weighted estimation for each row
92      if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable.
93      // (which essentially looks up indexes in a dictionary)
94      var res = new double[rows.Count()];
95      for (int i = 0; i < models.Count; i++) {
96        var w = weights[i];
97        var m = models[i];
98        int r = 0;
99        foreach (var est in m.GetEstimatedValues(dataset, rows)) {
100          res[r++] += w * est;
101        }
102      }
103      return res;
104    }
105
106    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
107      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
108    }
109  }
110}
Note: See TracBrowser for help on using the repository browser.