Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs @ 14345

Last change on this file since 14345 was 14345, checked in by gkronber, 7 years ago

#2690: implemented methods to generate symbolic expression tree solutions for decision tree models (random forest and gradient boosted) as well as views which make it possible to inspect each of the individual trees in a GBT and RF solution

File size: 13.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Analysis;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.PluginInfrastructure;
34using HeuristicLab.Problems.DataAnalysis;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  [Item("Gradient Boosted Trees (GBT)", "Gradient boosted trees algorithm. Specific implementation of gradient boosting for regression trees. Friedman, J. \"Greedy Function Approximation: A Gradient Boosting Machine\", IMS 1999 Reitz Lecture.")]
38  [StorableClass]
39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 125)]
40  public class GradientBoostedTreesAlgorithm : BasicAlgorithm {
41    public override Type ProblemType {
42      get { return typeof(IRegressionProblem); }
43    }
44    public new IRegressionProblem Problem {
45      get { return (IRegressionProblem)base.Problem; }
46      set { base.Problem = value; }
47    }
48
49    #region ParameterNames
50    private const string IterationsParameterName = "Iterations";
51    private const string MaxSizeParameterName = "Maximum Tree Size";
52    private const string NuParameterName = "Nu";
53    private const string RParameterName = "R";
54    private const string MParameterName = "M";
55    private const string SeedParameterName = "Seed";
56    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
57    private const string LossFunctionParameterName = "LossFunction";
58    private const string UpdateIntervalParameterName = "UpdateInterval";
59    private const string CreateSolutionParameterName = "CreateSolution";
60    #endregion
61
62    #region ParameterProperties
63    public IFixedValueParameter<IntValue> IterationsParameter {
64      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
65    }
66    public IFixedValueParameter<IntValue> MaxSizeParameter {
67      get { return (IFixedValueParameter<IntValue>)Parameters[MaxSizeParameterName]; }
68    }
69    public IFixedValueParameter<DoubleValue> NuParameter {
70      get { return (IFixedValueParameter<DoubleValue>)Parameters[NuParameterName]; }
71    }
72    public IFixedValueParameter<DoubleValue> RParameter {
73      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
74    }
75    public IFixedValueParameter<DoubleValue> MParameter {
76      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
77    }
78    public IFixedValueParameter<IntValue> SeedParameter {
79      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
80    }
81    public FixedValueParameter<BoolValue> SetSeedRandomlyParameter {
82      get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
83    }
84    public IConstrainedValueParameter<ILossFunction> LossFunctionParameter {
85      get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; }
86    }
87    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
88      get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
89    }
90    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
91      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
92    }
93    #endregion
94
95    #region Properties
96    public int Iterations {
97      get { return IterationsParameter.Value.Value; }
98      set { IterationsParameter.Value.Value = value; }
99    }
100    public int Seed {
101      get { return SeedParameter.Value.Value; }
102      set { SeedParameter.Value.Value = value; }
103    }
104    public bool SetSeedRandomly {
105      get { return SetSeedRandomlyParameter.Value.Value; }
106      set { SetSeedRandomlyParameter.Value.Value = value; }
107    }
108    public int MaxSize {
109      get { return MaxSizeParameter.Value.Value; }
110      set { MaxSizeParameter.Value.Value = value; }
111    }
112    public double Nu {
113      get { return NuParameter.Value.Value; }
114      set { NuParameter.Value.Value = value; }
115    }
116    public double R {
117      get { return RParameter.Value.Value; }
118      set { RParameter.Value.Value = value; }
119    }
120    public double M {
121      get { return MParameter.Value.Value; }
122      set { MParameter.Value.Value = value; }
123    }
124    public bool CreateSolution {
125      get { return CreateSolutionParameter.Value.Value; }
126      set { CreateSolutionParameter.Value.Value = value; }
127    }
128    #endregion
129
130    #region ResultsProperties
131    private double ResultsBestQuality {
132      get { return ((DoubleValue)Results["Best Quality"].Value).Value; }
133      set { ((DoubleValue)Results["Best Quality"].Value).Value = value; }
134    }
135    private DataTable ResultsQualities {
136      get { return ((DataTable)Results["Qualities"].Value); }
137    }
138    #endregion
139
140    [StorableConstructor]
141    protected GradientBoostedTreesAlgorithm(bool deserializing) : base(deserializing) { }
142
143    protected GradientBoostedTreesAlgorithm(GradientBoostedTreesAlgorithm original, Cloner cloner)
144      : base(original, cloner) {
145    }
146
147    public override IDeepCloneable Clone(Cloner cloner) {
148      return new GradientBoostedTreesAlgorithm(this, cloner);
149    }
150
151    public GradientBoostedTreesAlgorithm() {
152      Problem = new RegressionProblem(); // default problem
153
154      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "Number of iterations (set as high as possible, adjust in combination with nu, when increasing iterations also decrease nu)", new IntValue(1000)));
155      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
156      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
157      Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, "Maximal size of the tree learned in each step (prefer smaller sizes if possible)", new IntValue(10)));
158      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "Ratio of training rows selected randomly in each step (0 < R <= 1)", new DoubleValue(0.5)));
159      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "Ratio of variables selected randomly in each step (0 < M <= 1)", new DoubleValue(0.5)));
160      Parameters.Add(new FixedValueParameter<DoubleValue>(NuParameterName, "Learning rate nu (step size for the gradient update, should be small 0 < nu < 0.1)", new DoubleValue(0.002)));
161      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "", new IntValue(100)));
162      Parameters[UpdateIntervalParameterName].Hidden = true;
163      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
164      Parameters[CreateSolutionParameterName].Hidden = true;
165
166      var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
167      Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
168      LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default
169    }
170
171    [StorableHook(HookType.AfterDeserialization)]
172    private void AfterDeserialization() {
173      // BackwardsCompatibility3.4
174      #region Backwards compatible code, remove with 3.5
175      // parameter type has been changed
176      var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>;
177      if (lossFunctionParam != null) {
178        Parameters.Remove(LossFunctionParameterName);
179        var selectedValue = lossFunctionParam.Value; // to be restored below
180
181        var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
182        Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
183        // try to restore selected value
184        var selectedLossFunction =
185          LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value);
186        if (selectedLossFunction != null) {
187          LossFunctionParameter.Value = selectedLossFunction;
188        } else {
189          LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE
190        }
191      }
192      #endregion
193    }
194
195    protected override void Run(CancellationToken cancellationToken) {
196      // Set up the algorithm
197      if (SetSeedRandomly) Seed = new System.Random().Next();
198
199      // Set up the results display
200      var iterations = new IntValue(0);
201      Results.Add(new Result("Iterations", iterations));
202
203      var table = new DataTable("Qualities");
204      table.Rows.Add(new DataRow("Loss (train)"));
205      table.Rows.Add(new DataRow("Loss (test)"));
206      Results.Add(new Result("Qualities", table));
207      var curLoss = new DoubleValue();
208      Results.Add(new Result("Loss (train)", curLoss));
209
210      // init
211      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
212      var lossFunction = LossFunctionParameter.Value;
213      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);
214
215      var updateInterval = UpdateIntervalParameter.Value.Value;
216      // Loop until iteration limit reached or canceled.
217      for (int i = 0; i < Iterations; i++) {
218        cancellationToken.ThrowIfCancellationRequested();
219
220        GradientBoostedTreesAlgorithmStatic.MakeStep(state);
221
222        // iteration results
223        if (i % updateInterval == 0) {
224          curLoss.Value = state.GetTrainLoss();
225          table.Rows["Loss (train)"].Values.Add(curLoss.Value);
226          table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());
227          iterations.Value = i;
228        }
229      }
230
231      // final results
232      iterations.Value = Iterations;
233      curLoss.Value = state.GetTrainLoss();
234      table.Rows["Loss (train)"].Values.Add(curLoss.Value);
235      table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());
236
237      // produce variable relevance
238      var orderedImpacts = state.GetVariableRelevance().Select(t => new { name = t.Key, impact = t.Value }).ToList();
239
240      var impacts = new DoubleMatrix();
241      var matrix = impacts as IStringConvertibleMatrix;
242      matrix.Rows = orderedImpacts.Count;
243      matrix.RowNames = orderedImpacts.Select(x => x.name);
244      matrix.Columns = 1;
245      matrix.ColumnNames = new string[] { "Relative variable relevance" };
246
247      int rowIdx = 0;
248      foreach (var p in orderedImpacts) {
249        matrix.SetValue(string.Format("{0:N2}", p.impact), rowIdx++, 0);
250      }
251
252      Results.Add(new Result("Variable relevance", impacts));
253      Results.Add(new Result("Loss (test)", new DoubleValue(state.GetTestLoss())));
254
255      // produce solution
256      if (CreateSolution) {
257        var model = state.GetModel();
258
259        // for logistic regression we produce a classification solution
260        if (lossFunction is LogisticRegressionLoss) {
261          var classificationModel = new DiscriminantFunctionClassificationModel(model,
262            new AccuracyMaximizationThresholdCalculator());
263          var classificationProblemData = new ClassificationProblemData(problemData.Dataset,
264            problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations);
265          classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);
266
267          var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData);
268          Results.Add(new Result("Solution", classificationSolution));
269        } else {
270          // otherwise we produce a regression solution
271          Results.Add(new Result("Solution", new GradientBoostedTreesSolution(model, problemData)));
272        }
273      }
274    }
275  }
276}
Note: See TracBrowser for help on using the repository browser.