Free cookie consent management tool by TermsFeed Policy Generator

source: branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs @ 14713

Last change on this file since 14713 was 14558, checked in by bwerth, 8 years ago

#2700 made TSNE compatible with the new pausible BasicAlgs, removed rescaling of scatterplots during alg to give it a more movie-esque feel

File size: 13.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Analysis;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.PluginInfrastructure;
34using HeuristicLab.Problems.DataAnalysis;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  [Item("Gradient Boosted Trees (GBT)", "Gradient boosted trees algorithm. Specific implementation of gradient boosting for regression trees. Friedman, J. \"Greedy Function Approximation: A Gradient Boosting Machine\", IMS 1999 Reitz Lecture.")]
38  [StorableClass]
39  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 125)]
40  public class GradientBoostedTreesAlgorithm : BasicAlgorithm, IDataAnalysisAlgorithm<IRegressionProblem> {
41    public override Type ProblemType
42    {
43      get { return typeof(IRegressionProblem); }
44    }
45    public new IRegressionProblem Problem
46    {
47      get { return (IRegressionProblem)base.Problem; }
48      set { base.Problem = value; }
49    }
50    public override bool SupportsPause
51    {
52      get { return false; }
53    }
54
55    #region ParameterNames
56    private const string IterationsParameterName = "Iterations";
57    private const string MaxSizeParameterName = "Maximum Tree Size";
58    private const string NuParameterName = "Nu";
59    private const string RParameterName = "R";
60    private const string MParameterName = "M";
61    private const string SeedParameterName = "Seed";
62    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
63    private const string LossFunctionParameterName = "LossFunction";
64    private const string UpdateIntervalParameterName = "UpdateInterval";
65    private const string CreateSolutionParameterName = "CreateSolution";
66    #endregion
67
68    #region ParameterProperties
69    public IFixedValueParameter<IntValue> IterationsParameter
70    {
71      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
72    }
73    public IFixedValueParameter<IntValue> MaxSizeParameter
74    {
75      get { return (IFixedValueParameter<IntValue>)Parameters[MaxSizeParameterName]; }
76    }
77    public IFixedValueParameter<DoubleValue> NuParameter
78    {
79      get { return (IFixedValueParameter<DoubleValue>)Parameters[NuParameterName]; }
80    }
81    public IFixedValueParameter<DoubleValue> RParameter
82    {
83      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
84    }
85    public IFixedValueParameter<DoubleValue> MParameter
86    {
87      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
88    }
89    public IFixedValueParameter<IntValue> SeedParameter
90    {
91      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
92    }
93    public FixedValueParameter<BoolValue> SetSeedRandomlyParameter
94    {
95      get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
96    }
97    public IConstrainedValueParameter<ILossFunction> LossFunctionParameter
98    {
99      get { return (IConstrainedValueParameter<ILossFunction>)Parameters[LossFunctionParameterName]; }
100    }
101    public IFixedValueParameter<IntValue> UpdateIntervalParameter
102    {
103      get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
104    }
105    public IFixedValueParameter<BoolValue> CreateSolutionParameter
106    {
107      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
108    }
109    #endregion
110
111    #region Properties
112    public int Iterations
113    {
114      get { return IterationsParameter.Value.Value; }
115      set { IterationsParameter.Value.Value = value; }
116    }
117    public int Seed
118    {
119      get { return SeedParameter.Value.Value; }
120      set { SeedParameter.Value.Value = value; }
121    }
122    public bool SetSeedRandomly
123    {
124      get { return SetSeedRandomlyParameter.Value.Value; }
125      set { SetSeedRandomlyParameter.Value.Value = value; }
126    }
127    public int MaxSize
128    {
129      get { return MaxSizeParameter.Value.Value; }
130      set { MaxSizeParameter.Value.Value = value; }
131    }
132    public double Nu
133    {
134      get { return NuParameter.Value.Value; }
135      set { NuParameter.Value.Value = value; }
136    }
137    public double R
138    {
139      get { return RParameter.Value.Value; }
140      set { RParameter.Value.Value = value; }
141    }
142    public double M
143    {
144      get { return MParameter.Value.Value; }
145      set { MParameter.Value.Value = value; }
146    }
147    public bool CreateSolution
148    {
149      get { return CreateSolutionParameter.Value.Value; }
150      set { CreateSolutionParameter.Value.Value = value; }
151    }
152    #endregion
153
154    #region ResultsProperties
155    private double ResultsBestQuality
156    {
157      get { return ((DoubleValue)Results["Best Quality"].Value).Value; }
158      set { ((DoubleValue)Results["Best Quality"].Value).Value = value; }
159    }
160    private DataTable ResultsQualities
161    {
162      get { return ((DataTable)Results["Qualities"].Value); }
163    }
164    #endregion
165
166    [StorableConstructor]
167    protected GradientBoostedTreesAlgorithm(bool deserializing) : base(deserializing) { }
168
169    protected GradientBoostedTreesAlgorithm(GradientBoostedTreesAlgorithm original, Cloner cloner)
170      : base(original, cloner) {
171    }
172
173    public override IDeepCloneable Clone(Cloner cloner) {
174      return new GradientBoostedTreesAlgorithm(this, cloner);
175    }
176
177    public GradientBoostedTreesAlgorithm() {
178      Problem = new RegressionProblem(); // default problem
179
180      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "Number of iterations (set as high as possible, adjust in combination with nu, when increasing iterations also decrease nu)", new IntValue(1000)));
181      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
182      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
183      Parameters.Add(new FixedValueParameter<IntValue>(MaxSizeParameterName, "Maximal size of the tree learned in each step (prefer smaller sizes if possible)", new IntValue(10)));
184      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName, "Ratio of training rows selected randomly in each step (0 < R <= 1)", new DoubleValue(0.5)));
185      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName, "Ratio of variables selected randomly in each step (0 < M <= 1)", new DoubleValue(0.5)));
186      Parameters.Add(new FixedValueParameter<DoubleValue>(NuParameterName, "Learning rate nu (step size for the gradient update, should be small 0 < nu < 0.1)", new DoubleValue(0.002)));
187      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "", new IntValue(100)));
188      Parameters[UpdateIntervalParameterName].Hidden = true;
189      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
190      Parameters[CreateSolutionParameterName].Hidden = true;
191
192      var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
193      Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
194      LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // squared error loss is the default
195    }
196
197    [StorableHook(HookType.AfterDeserialization)]
198    private void AfterDeserialization() {
199      // BackwardsCompatibility3.4
200      #region Backwards compatible code, remove with 3.5
201      // parameter type has been changed
202      var lossFunctionParam = Parameters[LossFunctionParameterName] as ConstrainedValueParameter<StringValue>;
203      if (lossFunctionParam != null) {
204        Parameters.Remove(LossFunctionParameterName);
205        var selectedValue = lossFunctionParam.Value; // to be restored below
206
207        var lossFunctions = ApplicationManager.Manager.GetInstances<ILossFunction>();
208        Parameters.Add(new ConstrainedValueParameter<ILossFunction>(LossFunctionParameterName, "The loss function", new ItemSet<ILossFunction>(lossFunctions)));
209        // try to restore selected value
210        var selectedLossFunction =
211          LossFunctionParameter.ValidValues.FirstOrDefault(f => f.ToString() == selectedValue.Value);
212        if (selectedLossFunction != null) {
213          LossFunctionParameter.Value = selectedLossFunction;
214        } else {
215          LossFunctionParameter.Value = LossFunctionParameter.ValidValues.First(f => f.ToString().Contains("Squared")); // default: SE
216        }
217      }
218      #endregion
219    }
220
221    protected override void Run(CancellationToken cancellationToken) {
222      // Set up the algorithm
223      if (SetSeedRandomly) Seed = new System.Random().Next();
224
225      // Set up the results display
226      var iterations = new IntValue(0);
227      Results.Add(new Result("Iterations", iterations));
228
229      var table = new DataTable("Qualities");
230      table.Rows.Add(new DataRow("Loss (train)"));
231      table.Rows.Add(new DataRow("Loss (test)"));
232      Results.Add(new Result("Qualities", table));
233      var curLoss = new DoubleValue();
234      Results.Add(new Result("Loss (train)", curLoss));
235
236      // init
237      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
238      var lossFunction = LossFunctionParameter.Value;
239      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, lossFunction, (uint)Seed, MaxSize, R, M, Nu);
240
241      var updateInterval = UpdateIntervalParameter.Value.Value;
242      // Loop until iteration limit reached or canceled.
243      for (int i = 0; i < Iterations; i++) {
244        cancellationToken.ThrowIfCancellationRequested();
245
246        GradientBoostedTreesAlgorithmStatic.MakeStep(state);
247
248        // iteration results
249        if (i % updateInterval == 0) {
250          curLoss.Value = state.GetTrainLoss();
251          table.Rows["Loss (train)"].Values.Add(curLoss.Value);
252          table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());
253          iterations.Value = i;
254        }
255      }
256
257      // final results
258      iterations.Value = Iterations;
259      curLoss.Value = state.GetTrainLoss();
260      table.Rows["Loss (train)"].Values.Add(curLoss.Value);
261      table.Rows["Loss (test)"].Values.Add(state.GetTestLoss());
262
263      // produce variable relevance
264      var orderedImpacts = state.GetVariableRelevance().Select(t => new { name = t.Key, impact = t.Value }).ToList();
265
266      var impacts = new DoubleMatrix();
267      var matrix = impacts as IStringConvertibleMatrix;
268      matrix.Rows = orderedImpacts.Count;
269      matrix.RowNames = orderedImpacts.Select(x => x.name);
270      matrix.Columns = 1;
271      matrix.ColumnNames = new string[] { "Relative variable relevance" };
272
273      int rowIdx = 0;
274      foreach (var p in orderedImpacts) {
275        matrix.SetValue(string.Format("{0:N2}", p.impact), rowIdx++, 0);
276      }
277
278      Results.Add(new Result("Variable relevance", impacts));
279      Results.Add(new Result("Loss (test)", new DoubleValue(state.GetTestLoss())));
280
281      // produce solution
282      if (CreateSolution) {
283        var model = state.GetModel();
284
285        // for logistic regression we produce a classification solution
286        if (lossFunction is LogisticRegressionLoss) {
287          var classificationModel = new DiscriminantFunctionClassificationModel(model,
288            new AccuracyMaximizationThresholdCalculator());
289          var classificationProblemData = new ClassificationProblemData(problemData.Dataset,
290            problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations);
291          classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);
292
293          var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData);
294          Results.Add(new Result("Solution", classificationSolution));
295        } else {
296          // otherwise we produce a regression solution
297          Results.Add(new Result("Solution", new GradientBoostedTreesSolution(model, problemData)));
298        }
299      }
300    }
301  }
302}
Note: See TracBrowser for help on using the repository browser.