Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2883_GBTModelStorage/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs @ 15678

Last change on this file since 15678 was 15678, checked in by fholzing, 6 years ago

#2883: Implemented third option for complete storage

File size: 5.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  [StorableClass]
33  [Item("Gradient boosted trees model", "")]
34  // this is essentially a collection of weighted regression models
35  public sealed class GradientBoostedTreesModel : RegressionModel, IGradientBoostedTreesModel {
36    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
37    #region Backwards compatible code, remove with 3.5
38    private bool isCompatibilityLoaded = false; // only set to true if the model is deserialized from the old format, needed to make sure that information is serialized again if it was loaded from the old format
39    internal bool IsCompatibilityLoaded { get { return this.isCompatibilityLoaded; } set { this.isCompatibilityLoaded = value; } }
40
41    [Storable(Name = "models")]
42    private IList<IRegressionModel> __persistedModels {
43      set {
44        this.isCompatibilityLoaded = true;
45        this.models.Clear();
46        foreach (var m in value) this.models.Add(m);
47      }
48      get { if (this.isCompatibilityLoaded) return models; else return null; }
49    }
50    [Storable(Name = "weights")]
51    private IList<double> __persistedWeights {
52      set {
53        this.isCompatibilityLoaded = true;
54        this.weights.Clear();
55        foreach (var w in value) this.weights.Add(w);
56      }
57      get { if (this.isCompatibilityLoaded) return weights; else return null; }
58    }
59    #endregion
60
61    public override IEnumerable<string> VariablesUsedForPrediction {
62      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
63    }
64
65    private readonly IList<IRegressionModel> models;
66    public IEnumerable<IRegressionModel> Models { get { return models; } }
67
68    private readonly IList<double> weights;
69    public IEnumerable<double> Weights { get { return weights; } }
70
71    [StorableConstructor]
72    private GradientBoostedTreesModel(bool deserializing)
73      : base(deserializing) {
74      models = new List<IRegressionModel>();
75      weights = new List<double>();
76    }
77    private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
78      : base(original, cloner) {
79      this.weights = new List<double>(original.weights);
80      this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
81      this.isCompatibilityLoaded = original.isCompatibilityLoaded;
82    }
83    [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")]
84    internal GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
85      : base(string.Empty, "Gradient boosted tree model", string.Empty) {
86      this.models = new List<IRegressionModel>(models);
87      this.weights = new List<double>(weights);
88
89      if (this.models.Count != this.weights.Count) throw new ArgumentException();
90    }
91
92    public override IDeepCloneable Clone(Cloner cloner) {
93      return new GradientBoostedTreesModel(this, cloner);
94    }
95
96    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
97      // allocate target array go over all models and add up weighted estimation for each row
98      if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable.
99      // (which essentially looks up indexes in a dictionary)
100      var res = new double[rows.Count()];
101      for (int i = 0; i < models.Count; i++) {
102        var w = weights[i];
103        var m = models[i];
104        int r = 0;
105        foreach (var est in m.GetEstimatedValues(dataset, rows)) {
106          res[r++] += w * est;
107        }
108      }
109      return res;
110    }
111
112    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
113      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
114    }
115
116  }
117}
Note: See TracBrowser for help on using the repository browser.