Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2434_crossvalidation/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs @ 15728

Last change on this file since 15728 was 14029, checked in by gkronber, 8 years ago

#2434: merged trunk changes r12934:14026 from trunk to branch

File size: 4.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  [StorableClass]
33  [Item("Gradient boosted tree model", "")]
34  // this is essentially a collection of weighted regression models
35  public sealed class GradientBoostedTreesModel : RegressionModel, IGradientBoostedTreesModel {
36    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
37    #region Backwards compatible code, remove with 3.5
38    private bool isCompatibilityLoaded = false; // only set to true if the model is deserialized from the old format, needed to make sure that information is serialized again if it was loaded from the old format
39
40    [Storable(Name = "models")]
41    private IList<IRegressionModel> __persistedModels {
42      set {
43        this.isCompatibilityLoaded = true;
44        this.models.Clear();
45        foreach (var m in value) this.models.Add(m);
46      }
47      get { if (this.isCompatibilityLoaded) return models; else return null; }
48    }
49    [Storable(Name = "weights")]
50    private IList<double> __persistedWeights {
51      set {
52        this.isCompatibilityLoaded = true;
53        this.weights.Clear();
54        foreach (var w in value) this.weights.Add(w);
55      }
56      get { if (this.isCompatibilityLoaded) return weights; else return null; }
57    }
58    #endregion
59
60    public override IEnumerable<string> VariablesUsedForPrediction {
61      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
62    }
63
64    private readonly IList<IRegressionModel> models;
65    public IEnumerable<IRegressionModel> Models { get { return models; } }
66
67    private readonly IList<double> weights;
68    public IEnumerable<double> Weights { get { return weights; } }
69
70    [StorableConstructor]
71    private GradientBoostedTreesModel(bool deserializing)
72      : base(deserializing) {
73      models = new List<IRegressionModel>();
74      weights = new List<double>();
75    }
76    private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
77      : base(original, cloner) {
78      this.weights = new List<double>(original.weights);
79      this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
80      this.isCompatibilityLoaded = original.isCompatibilityLoaded;
81    }
82    [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")]
83    internal GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
84      : base(string.Empty, "Gradient boosted tree model", string.Empty) {
85      this.models = new List<IRegressionModel>(models);
86      this.weights = new List<double>(weights);
87
88      if (this.models.Count != this.weights.Count) throw new ArgumentException();
89    }
90
91    public override IDeepCloneable Clone(Cloner cloner) {
92      return new GradientBoostedTreesModel(this, cloner);
93    }
94
95    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
96      // allocate target array go over all models and add up weighted estimation for each row
97      if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable.
98      // (which essentially looks up indexes in a dictionary)
99      var res = new double[rows.Count()];
100      for (int i = 0; i < models.Count; i++) {
101        var w = weights[i];
102        var m = models[i];
103        int r = 0;
104        foreach (var est in m.GetEstimatedValues(dataset, rows)) {
105          res[r++] += w * est;
106        }
107      }
108      return res;
109    }
110
111    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
112      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
113    }
114
115  }
116}
Note: See TracBrowser for help on using the repository browser.