Free cookie consent management tool by TermsFeed Policy Generator

source: branches/PausableBasicAlgorithm/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs @ 15428

Last change on this file since 15428 was 13157, checked in by gkronber, 9 years ago

#2450 made the changes suggested by mkommend in the review. This is definitely a big improvement, thx!

File size: 4.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  [StorableClass]
33  [Item("Gradient boosted tree model", "")]
34  // this is essentially a collection of weighted regression models
35  public sealed class GradientBoostedTreesModel : NamedItem, IGradientBoostedTreesModel {
36    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
37    #region Backwards compatible code, remove with 3.5
38    private bool isCompatibilityLoaded = false; // only set to true if the model is deserialized from the old format, needed to make sure that information is serialized again if it was loaded from the old format
39
40    [Storable(Name = "models")]
41    private IList<IRegressionModel> __persistedModels {
42      set {
43        this.isCompatibilityLoaded = true;
44        this.models.Clear();
45        foreach (var m in value) this.models.Add(m);
46      }
47      get { if (this.isCompatibilityLoaded) return models; else return null; }
48    }
49    [Storable(Name = "weights")]
50    private IList<double> __persistedWeights {
51      set {
52        this.isCompatibilityLoaded = true;
53        this.weights.Clear();
54        foreach (var w in value) this.weights.Add(w);
55      }
56      get { if (this.isCompatibilityLoaded) return weights; else return null; }
57    }
58    #endregion
59
60    private readonly IList<IRegressionModel> models;
61    public IEnumerable<IRegressionModel> Models { get { return models; } }
62
63    private readonly IList<double> weights;
64    public IEnumerable<double> Weights { get { return weights; } }
65
66    [StorableConstructor]
67    private GradientBoostedTreesModel(bool deserializing)
68      : base(deserializing) {
69      models = new List<IRegressionModel>();
70      weights = new List<double>();
71    }
72    private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
73      : base(original, cloner) {
74      this.weights = new List<double>(original.weights);
75      this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
76      this.isCompatibilityLoaded = original.isCompatibilityLoaded;
77    }
78    [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")]
79    public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
80      : base("Gradient boosted tree model", string.Empty) {
81      this.models = new List<IRegressionModel>(models);
82      this.weights = new List<double>(weights);
83
84      if (this.models.Count != this.weights.Count) throw new ArgumentException();
85    }
86
87    public override IDeepCloneable Clone(Cloner cloner) {
88      return new GradientBoostedTreesModel(this, cloner);
89    }
90
91    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
92      // allocate target array go over all models and add up weighted estimation for each row
93      if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable.
94      // (which essentially looks up indexes in a dictionary)
95      var res = new double[rows.Count()];
96      for (int i = 0; i < models.Count; i++) {
97        var w = weights[i];
98        var m = models[i];
99        int r = 0;
100        foreach (var est in m.GetEstimatedValues(dataset, rows)) {
101          res[r++] += w * est;
102        }
103      }
104      return res;
105    }
106
107    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
108      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
109    }
110  }
111}
Note: See TracBrowser for help on using the repository browser.