Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GBM/GradientBoostingRegressionAlgorithm.cs @ 13646

Last change on this file since 13646 was 13646, checked in by gkronber, 8 years ago

#1795: added a data analysis algorithm for gradient boosting for regression which uses another regression algorithm as a base learner. Currently, only squared error loss is supported.

File size: 19.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections.Generic;
25using System.Linq;
26using System.Threading;
27using HeuristicLab.Analysis;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Optimization;
33using HeuristicLab.Parameters;
34using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Problems.DataAnalysis.Symbolic;
37using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
38using HeuristicLab.Random;
39
40namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
41  [Item("Gradient Boosting Machine Regression (GBM)",
42    "Gradient boosting for any regression base learner (e.g. MCTS symbolic regression)")]
43  [StorableClass]
44  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 350)]
45  public class GradientBoostingRegressionAlgorithm : BasicAlgorithm {
46    public override Type ProblemType {
47      get { return typeof(IRegressionProblem); }
48    }
49
50    public new IRegressionProblem Problem {
51      get { return (IRegressionProblem)base.Problem; }
52      set { base.Problem = value; }
53    }
54
55    #region ParameterNames
56
57    private const string IterationsParameterName = "Iterations";
58    private const string NuParameterName = "Nu";
59    private const string MParameterName = "M";
60    private const string RParameterName = "R";
61    private const string RegressionAlgorithmParameterName = "RegressionAlgorithm";
62    private const string SeedParameterName = "Seed";
63    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
64    private const string CreateSolutionParameterName = "CreateSolution";
65    private const string RegressionAlgorithmSolutionResultParameterName = "RegressionAlgorithmResult";
66
67    #endregion
68
69    #region ParameterProperties
70
71    public IFixedValueParameter<IntValue> IterationsParameter {
72      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
73    }
74
75    public IFixedValueParameter<DoubleValue> NuParameter {
76      get { return (IFixedValueParameter<DoubleValue>)Parameters[NuParameterName]; }
77    }
78
79    public IFixedValueParameter<DoubleValue> RParameter {
80      get { return (IFixedValueParameter<DoubleValue>)Parameters[RParameterName]; }
81    }
82
83    public IFixedValueParameter<DoubleValue> MParameter {
84      get { return (IFixedValueParameter<DoubleValue>)Parameters[MParameterName]; }
85    }
86
87    // regression algorithms are currently: DataAnalysisAlgorithms, BasicAlgorithms and engine algorithms with no common interface
88    public IConstrainedValueParameter<IAlgorithm> RegressionAlgorithmParameter {
89      get { return (IConstrainedValueParameter<IAlgorithm>)Parameters[RegressionAlgorithmParameterName]; }
90    }
91
92    public IFixedValueParameter<StringValue> RegressionAlgorithmSolutionResultParameter {
93      get { return (IFixedValueParameter<StringValue>)Parameters[RegressionAlgorithmSolutionResultParameterName]; }
94    }
95
96    public IFixedValueParameter<IntValue> SeedParameter {
97      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
98    }
99
100    public FixedValueParameter<BoolValue> SetSeedRandomlyParameter {
101      get { return (FixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
102    }
103
104    public IFixedValueParameter<BoolValue> CreateSolutionParameter {
105      get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; }
106    }
107
108    #endregion
109
110    #region Properties
111
112    public int Iterations {
113      get { return IterationsParameter.Value.Value; }
114      set { IterationsParameter.Value.Value = value; }
115    }
116
117    public int Seed {
118      get { return SeedParameter.Value.Value; }
119      set { SeedParameter.Value.Value = value; }
120    }
121
122    public bool SetSeedRandomly {
123      get { return SetSeedRandomlyParameter.Value.Value; }
124      set { SetSeedRandomlyParameter.Value.Value = value; }
125    }
126
127    public double Nu {
128      get { return NuParameter.Value.Value; }
129      set { NuParameter.Value.Value = value; }
130    }
131
132    public double R {
133      get { return RParameter.Value.Value; }
134      set { RParameter.Value.Value = value; }
135    }
136
137    public double M {
138      get { return MParameter.Value.Value; }
139      set { MParameter.Value.Value = value; }
140    }
141
142    public bool CreateSolution {
143      get { return CreateSolutionParameter.Value.Value; }
144      set { CreateSolutionParameter.Value.Value = value; }
145    }
146
147    public IAlgorithm RegressionAlgorithm {
148      get { return RegressionAlgorithmParameter.Value; }
149    }
150
151    public string RegressionAlgorithmResult {
152      get { return RegressionAlgorithmSolutionResultParameter.Value.Value; }
153      set { RegressionAlgorithmSolutionResultParameter.Value.Value = value; }
154    }
155
156    #endregion
157
158    [StorableConstructor]
159    protected GradientBoostingRegressionAlgorithm(bool deserializing)
160      : base(deserializing) {
161    }
162
163    protected GradientBoostingRegressionAlgorithm(GradientBoostingRegressionAlgorithm original, Cloner cloner)
164      : base(original, cloner) {
165    }
166
167    public override IDeepCloneable Clone(Cloner cloner) {
168      return new GradientBoostingRegressionAlgorithm(this, cloner);
169    }
170
171    public GradientBoostingRegressionAlgorithm() {
172      Problem = new RegressionProblem(); // default problem
173      var mctsSymbReg = new MctsSymbolicRegressionAlgorithm();
174      // var sgp = CreateSGP();
175      var regressionAlgs = new ItemSet<IAlgorithm>(new IAlgorithm[] {
176        new LinearRegression(), new RandomForestRegression(), new NearestNeighbourRegression(),
177        // sgp,
178        mctsSymbReg
179      });
180      foreach (var alg in regressionAlgs) alg.Prepare();
181
182
183      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName,
184        "Number of iterations", new IntValue(100)));
185      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName,
186        "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
187      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName,
188        "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
189      Parameters.Add(new FixedValueParameter<DoubleValue>(NuParameterName,
190        "The learning rate nu when updating predictions in GBM (0 < nu <= 1)", new DoubleValue(0.5)));
191      Parameters.Add(new FixedValueParameter<DoubleValue>(RParameterName,
192        "The fraction of rows that are sampled randomly for the base learner in each iteration (0 < r <= 1)",
193        new DoubleValue(1)));
194      Parameters.Add(new FixedValueParameter<DoubleValue>(MParameterName,
195        "The fraction of variables that are sampled randomly for the base learner in each iteration (0 < m <= 1)",
196        new DoubleValue(0.5)));
197      Parameters.Add(new ConstrainedValueParameter<IAlgorithm>(RegressionAlgorithmParameterName,
198        "The regression algorithm to use as a base learner", regressionAlgs, mctsSymbReg));
199      Parameters.Add(new FixedValueParameter<StringValue>(RegressionAlgorithmSolutionResultParameterName,
200        "The name of the solution produced by the regression algorithm", new StringValue("Solution")));
201      Parameters[RegressionAlgorithmSolutionResultParameterName].Hidden = true;
202      Parameters.Add(new FixedValueParameter<BoolValue>(CreateSolutionParameterName,
203        "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true)));
204      Parameters[CreateSolutionParameterName].Hidden = true;
205    }
206
207    protected override void Run(CancellationToken cancellationToken) {
208      // Set up the algorithm
209      if (SetSeedRandomly) Seed = new System.Random().Next();
210      var rand = new MersenneTwister((uint)Seed);
211
212      // Set up the results display
213      var iterations = new IntValue(0);
214      Results.Add(new Result("Iterations", iterations));
215
216      var table = new DataTable("Qualities");
217      table.Rows.Add(new DataRow("Loss (train)"));
218      table.Rows.Add(new DataRow("Loss (test)"));
219      Results.Add(new Result("Qualities", table));
220      var curLoss = new DoubleValue();
221      var curTestLoss = new DoubleValue();
222      Results.Add(new Result("Loss (train)", curLoss));
223      Results.Add(new Result("Loss (test)", curTestLoss));
224      var runCollection = new RunCollection();
225      Results.Add(new Result("Runs", runCollection));
226
227      // init
228      var problemData = Problem.ProblemData;
229      var targetVarName = Problem.ProblemData.TargetVariable;
230      var modifiableDataset = new ModifiableDataset(
231        problemData.Dataset.VariableNames,
232        problemData.Dataset.VariableNames.Select(v => problemData.Dataset.GetDoubleValues(v).ToList()));
233
234      var trainingRows = problemData.TrainingIndices;
235      var testRows = problemData.TestIndices;
236      var yPred = new double[trainingRows.Count()];
237      var yPredTest = new double[testRows.Count()];
238      var y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
239      var curY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
240
241      var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToArray();
242      var curYTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices).ToArray();
243      var nu = Nu;
244      var mVars = (int)Math.Ceiling(M * problemData.AllowedInputVariables.Count());
245      var rRows = (int)Math.Ceiling(R * problemData.TrainingIndices.Count());
246      var alg = RegressionAlgorithm;
247      List<IRegressionModel> models = new List<IRegressionModel>();
248      try {
249
250        // Loop until iteration limit reached or canceled.
251        for (int i = 0; i < Iterations; i++) {
252          cancellationToken.ThrowIfCancellationRequested();
253
254          modifiableDataset.RemoveVariable(targetVarName);
255          modifiableDataset.AddVariable(targetVarName, curY.Concat(curYTest));
256
257          SampleTrainingData(rand, modifiableDataset, rRows, problemData.Dataset, curY, problemData.TargetVariable, problemData.TrainingIndices); // all training indices from the original problem data are allowed
258          var modifiableProblemData = new RegressionProblemData(modifiableDataset,
259            problemData.AllowedInputVariables.SampleRandomWithoutRepetition(rand, mVars),
260            problemData.TargetVariable);
261          modifiableProblemData.TrainingPartition.Start = 0;
262          modifiableProblemData.TrainingPartition.End = rRows;
263          modifiableProblemData.TestPartition.Start = problemData.TestPartition.Start;
264          modifiableProblemData.TestPartition.End = problemData.TestPartition.End;
265
266          if (!TrySetProblemData(alg, modifiableProblemData))
267            throw new NotSupportedException("The algorithm cannot be used with GBM.");
268
269          IRegressionModel model;
270          IRun run;
271          // try to find a model. The algorithm might fail to produce a model. In this case we just retry until the iterations are exhausted
272          if (TryExecute(alg, RegressionAlgorithmResult, out model, out run)) {
273            int row = 0;
274            // update predictions for training and test
275            // update new targets (in the case of squared error loss we simply use negative residuals)
276            foreach (var pred in model.GetEstimatedValues(problemData.Dataset, trainingRows)) {
277              yPred[row] = yPred[row] + nu * pred;
278              curY[row] = y[row] - yPred[row];
279              row++;
280            }
281            row = 0;
282            foreach (var pred in model.GetEstimatedValues(problemData.Dataset, testRows)) {
283              yPredTest[row] = yPredTest[row] + nu * pred;
284              curYTest[row] = yTest[row] - yPredTest[row];
285              row++;
286            }
287            // determine quality
288            OnlineCalculatorError error;
289            var trainR = OnlinePearsonsRCalculator.Calculate(yPred, y, out error);
290            var testR = OnlinePearsonsRCalculator.Calculate(yPredTest, yTest, out error);
291
292            // iteration results
293            curLoss.Value = error == OnlineCalculatorError.None ? trainR * trainR : 0.0;
294            curTestLoss.Value = error == OnlineCalculatorError.None ? testR * testR : 0.0;
295
296            models.Add(model);
297
298
299          }
300
301          runCollection.Add(run);
302          table.Rows["Loss (train)"].Values.Add(curLoss.Value);
303          table.Rows["Loss (test)"].Values.Add(curTestLoss.Value);
304          iterations.Value = i + 1;
305        }
306
307        // produce solution
308        if (CreateSolution) {
309          // when all our models are symbolic models we can easily combine them to a single model
310          if (models.All(m => m is ISymbolicRegressionModel)) {
311            Results.Add(new Result("Solution", CreateSymbolicSolution(models, Nu, (IRegressionProblemData)problemData.Clone())));
312          }
313          // just produce an ensemble solution for now (TODO: correct scaling or linear regression for ensemble model weights)
314          Results.Add(new Result("EnsembleSolution", new RegressionEnsembleSolution(models, (IRegressionProblemData)problemData.Clone())));
315        }
316      } finally {
317        // reset everything
318        alg.Prepare(true);
319      }
320    }
321
322    // this is probably slow as hell
323    private void SampleTrainingData(MersenneTwister rand, ModifiableDataset ds, int rRows,
324      IDataset sourceDs, double[] curTarget, string targetVarName, IEnumerable<int> trainingIndices) {
325      var selectedRows = trainingIndices.SampleRandomWithoutRepetition(rand, rRows).ToArray();
326      int t = 0;
327      object[] srcRow = new object[ds.Columns];
328      var varNames = ds.DoubleVariables.ToArray();
329      foreach (var r in selectedRows) {
330        // take all values from the original dataset
331        for (int c = 0; c < srcRow.Length; c++) {
332          var col = sourceDs.GetReadOnlyDoubleValues(varNames[c]);
333          srcRow[c] = col[r];
334        }
335        ds.ReplaceRow(t, srcRow);
336        // but use the updated target values
337        ds.SetVariableValue(curTarget[r], targetVarName, t);
338        t++;
339      }
340    }
341
342    private static ISymbolicRegressionSolution CreateSymbolicSolution(List<IRegressionModel> models, double nu, IRegressionProblemData problemData) {
343      var symbModels = models.OfType<ISymbolicRegressionModel>();
344      var lowerLimit = symbModels.Min(m => m.LowerEstimationLimit);
345      var upperLimit = symbModels.Max(m => m.UpperEstimationLimit);
346      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
347      var progRootNode = new ProgramRootSymbol().CreateTreeNode();
348      var startNode = new StartSymbol().CreateTreeNode();
349
350      var addNode = new Addition().CreateTreeNode();
351      var mulNode = new Multiplication().CreateTreeNode();
352      var scaleNode = (ConstantTreeNode)new Constant().CreateTreeNode(); // all models are scaled using the same nu
353      scaleNode.Value = nu;
354
355      foreach (var m in symbModels) {
356        var relevantPart = m.SymbolicExpressionTree.Root.GetSubtree(0).GetSubtree(0); // skip root and start
357        addNode.AddSubtree((ISymbolicExpressionTreeNode)relevantPart.Clone());
358      }
359
360      mulNode.AddSubtree(addNode);
361      mulNode.AddSubtree(scaleNode);
362      startNode.AddSubtree(mulNode);
363      progRootNode.AddSubtree(startNode);
364      var t = new SymbolicExpressionTree(progRootNode);
365      var combinedModel = new SymbolicRegressionModel(t, interpreter, lowerLimit, upperLimit);
366      var sol = new SymbolicRegressionSolution(combinedModel, problemData);
367      return sol;
368    }
369
370    private static bool TrySetProblemData(IAlgorithm alg, IRegressionProblemData problemData) {
371      var prob = alg.Problem as IRegressionProblem;
372      // there is already a problem and it is compatible -> just set problem data
373      if (prob != null) {
374        prob.ProblemDataParameter.Value = problemData;
375        return true;
376      } else if (alg.Problem != null) {
377        // a problem is set and it is not compatible
378        return false;
379      } else {
380        try {
381          // we try to use a symbolic regression problem (works for simple regression algs and GP)
382          alg.Problem = new SymbolicRegressionSingleObjectiveProblem();
383        } catch (Exception) {
384          return false;
385        }
386        return true;
387      }
388    }
389
390    private static bool TryExecute(IAlgorithm alg, string regressionAlgorithmResultName, out IRegressionModel model, out IRun run) {
391      model = null;
392      using (var wh = new AutoResetEvent(false)) {
393        EventHandler<EventArgs<Exception>> handler = (sender, args) => wh.Set();
394        EventHandler handler2 = (sender, args) => wh.Set();
395        alg.ExceptionOccurred += handler;
396        alg.Stopped += handler2;
397        try {
398          alg.Prepare();
399          alg.Start();
400          wh.WaitOne();
401
402          run = alg.Runs.Last();
403          var sols = alg.Results.Select(r => r.Value).OfType<IRegressionSolution>();
404          if (!sols.Any()) return false;
405          var sol = sols.First();
406          if (sols.Skip(1).Any()) {
407            // more than one solution => use regressionAlgorithmResult
408            if (alg.Results.ContainsKey(regressionAlgorithmResultName)) {
409              sol = (IRegressionSolution)alg.Results[regressionAlgorithmResultName].Value;
410            }
411          }
412          var symbRegSol = sol as SymbolicRegressionSolution;
413          // only accept symb reg solutions that do not hit the estimation limits
414          // NaN evaluations would not be critical but are problematic if we want to combine all symbolic models into a single symbolic model
415          if (symbRegSol == null ||
416            (symbRegSol.TrainingLowerEstimationLimitHits == 0 && symbRegSol.TrainingUpperEstimationLimitHits == 0 &&
417             symbRegSol.TestLowerEstimationLimitHits == 0 && symbRegSol.TestUpperEstimationLimitHits == 0) &&
418            symbRegSol.TrainingNaNEvaluations == 0 && symbRegSol.TestNaNEvaluations == 0) {
419            model = sol.Model;
420          }
421        } finally {
422          alg.ExceptionOccurred -= handler;
423          alg.Stopped -= handler2;
424        }
425      }
426      return model != null;
427    }
428  }
429}
Note: See TracBrowser for help on using the repository browser.