source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs @ 16565

Last change on this file since 16565 was 16565, checked in by gkronber, 11 months ago

#2520: merged changes from PersistenceOverhaul branch (r16451:16564) into trunk

File size: 4.9 KB
RevLine 
[12590]1#region License Information
2/* HeuristicLab
[16565]3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[12590]4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
[12332]24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
[16565]28using HEAL.Attic;
[12332]29using HeuristicLab.Problems.DataAnalysis;
30
[12590]31namespace HeuristicLab.Algorithms.DataAnalysis {
[16565]32  [StorableType("4EC1B359-D145-434C-A373-3EDD764D2D63")]
[15105]33  [Item("Gradient boosted trees model", "")]
[12590]34  // this is essentially a collection of weighted regression models
[13941]35  public sealed class GradientBoostedTreesModel : RegressionModel, IGradientBoostedTreesModel {
[12868]36    // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models
37    #region Backwards compatible code, remove with 3.5
38    private bool isCompatibilityLoaded = false; // only set to true if the model is deserialized from the old format, needed to make sure that information is serialized again if it was loaded from the old format
39
40    [Storable(Name = "models")]
41    private IList<IRegressionModel> __persistedModels {
42      set {
43        this.isCompatibilityLoaded = true;
44        this.models.Clear();
45        foreach (var m in value) this.models.Add(m);
46      }
47      get { if (this.isCompatibilityLoaded) return models; else return null; }
48    }
49    [Storable(Name = "weights")]
50    private IList<double> __persistedWeights {
51      set {
52        this.isCompatibilityLoaded = true;
53        this.weights.Clear();
54        foreach (var w in value) this.weights.Add(w);
55      }
56      get { if (this.isCompatibilityLoaded) return weights; else return null; }
57    }
58    #endregion
59
[13941]60    public override IEnumerable<string> VariablesUsedForPrediction {
[13921]61      get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); }
62    }
63
[12332]64    private readonly IList<IRegressionModel> models;
[12372]65    public IEnumerable<IRegressionModel> Models { get { return models; } }
66
[12332]67    private readonly IList<double> weights;
[12372]68    public IEnumerable<double> Weights { get { return weights; } }
[12332]69
70    [StorableConstructor]
[16565]71    private GradientBoostedTreesModel(StorableConstructorFlag _) : base(_) {
[12868]72      models = new List<IRegressionModel>();
73      weights = new List<double>();
74    }
[12332]75    private GradientBoostedTreesModel(GradientBoostedTreesModel original, Cloner cloner)
76      : base(original, cloner) {
77      this.weights = new List<double>(original.weights);
78      this.models = new List<IRegressionModel>(original.models.Select(m => cloner.Clone(m)));
[12868]79      this.isCompatibilityLoaded = original.isCompatibilityLoaded;
[12332]80    }
[13065]81    [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")]
[13941]82    internal GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
83      : base(string.Empty, "Gradient boosted tree model", string.Empty) {
[12332]84      this.models = new List<IRegressionModel>(models);
85      this.weights = new List<double>(weights);
86
87      if (this.models.Count != this.weights.Count) throw new ArgumentException();
88    }
89
90    public override IDeepCloneable Clone(Cloner cloner) {
91      return new GradientBoostedTreesModel(this, cloner);
92    }
93
[13941]94    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
[12590]95      // allocate target array go over all models and add up weighted estimation for each row
[12660]96      if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable.
[12868]97      // (which essentially looks up indexes in a dictionary)
[12590]98      var res = new double[rows.Count()];
99      for (int i = 0; i < models.Count; i++) {
100        var w = weights[i];
101        var m = models[i];
102        int r = 0;
103        foreach (var est in m.GetEstimatedValues(dataset, rows)) {
104          res[r++] += w * est;
105        }
[12332]106      }
[12590]107      return res;
[12332]108    }
109
[13941]110    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
[12332]111      return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone());
112    }
[13921]113
[12332]114  }
115}
Note: See TracBrowser for help on using the repository browser.