#region License Information /* HeuristicLab * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * and the BEACON Center for the Study of Evolution in Action. * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Diagnostics.Contracts; using System.Linq; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.DataAnalysis { public static class GradientBoostedTreesAlgorithmStatic { #region static API public interface IGbmState { IRegressionModel GetModel(); double GetTrainLoss(); double GetTestLoss(); IEnumerable> GetVariableRelevance(); } // created through factory method // GbmState details are private API users can only use methods from IGbmState private class GbmState : IGbmState { internal IRegressionProblemData problemData { get; set; } internal MersenneTwister random { get; private set; } internal ILossFunction lossFunction { get; set; } internal int maxSize { get; set; } internal double nu { get; set; } internal double r { get; set; } internal double m { get; set; } internal readonly RegressionTreeBuilder treeBuilder; // array members (allocate only once) internal double[] pred; internal double[] predTest; internal double[] y; internal int[] activeIdx; internal double[] pseudoRes; internal IList models; internal IList weights; public GbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxSize, double r, double m, double nu) { // default settings for MaxSize, Nu and R this.maxSize = maxSize; this.nu = nu; this.r = r; this.m = m; random = new MersenneTwister(randSeed); this.problemData = problemData; this.lossFunction = lossFunction; int nRows = problemData.TrainingIndices.Count(); y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray(); treeBuilder = new RegressionTreeBuilder(problemData, random); activeIdx = Enumerable.Range(0, nRows).ToArray(); var zeros = Enumerable.Repeat(0.0, nRows).ToArray(); double f0 = lossFunction.LineSearch(y, zeros, activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors) pred = Enumerable.Repeat(f0, nRows).ToArray(); predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray(); pseudoRes = new double[nRows]; models = new List(); weights = new List(); // add constant model models.Add(new ConstantRegressionModel(f0)); weights.Add(1.0); } public IRegressionModel GetModel() { return new GradientBoostedTreesModel(models, weights); } public IEnumerable> GetVariableRelevance() { return treeBuilder.GetVariableRelevance(); } public double GetTrainLoss() { int nRows = y.Length; return lossFunction.GetLoss(y, pred) / nRows; } public double GetTestLoss() { var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices); var nRows = problemData.TestIndices.Count(); return lossFunction.GetLoss(yTest, predTest) / nRows; } } // simple interface public static IRegressionSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxSize, double nu, double r, double m, int maxIterations, uint randSeed = 31415) { Contract.Assert(r > 0); Contract.Assert(r <= 1.0); Contract.Assert(nu > 0); Contract.Assert(nu <= 1.0); var state = (GbmState)CreateGbmState(problemData, lossFunction, randSeed); state.maxSize = maxSize; state.r = r; state.m = m; state.nu = nu; for (int iter = 0; iter < maxIterations; iter++) { MakeStep(state); } var model = new GradientBoostedTreesModel(state.models, state.weights); return new RegressionSolution(model, (IRegressionProblemData)problemData.Clone()); } // for custom stepping & termination public static IGbmState CreateGbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxSize = 3, double r = 0.66, double m = 0.5, double nu = 0.01) { return new GbmState(problemData, lossFunction, randSeed, maxSize, r, m, nu); } // use default settings for maxDepth, nu, r from state public static void MakeStep(IGbmState state) { var gbmState = state as GbmState; if (gbmState == null) throw new ArgumentException("state"); MakeStep(gbmState, gbmState.maxSize, gbmState.nu, gbmState.r, gbmState.m); } // allow dynamic adaptation of maxDepth, nu and r (even though this is not used) public static void MakeStep(IGbmState state, int maxSize, double nu, double r, double m) { var gbmState = state as GbmState; if (gbmState == null) throw new ArgumentException("state"); var problemData = gbmState.problemData; var lossFunction = gbmState.lossFunction; var yPred = gbmState.pred; var yPredTest = gbmState.predTest; var treeBuilder = gbmState.treeBuilder; var y = gbmState.y; var activeIdx = gbmState.activeIdx; var pseudoRes = gbmState.pseudoRes; // copy output of gradient function to pre-allocated rim array (pseudo-residuals) int rimIdx = 0; foreach (var g in lossFunction.GetLossGradient(y, yPred)) { pseudoRes[rimIdx++] = g; } var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(pseudoRes, yPred, maxSize, activeIdx, lossFunction, r, m); int i = 0; // TODO: slow because of multiple calls to GetDoubleValue for each row index foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices)) { yPred[i] = yPred[i] + nu * pred; i++; } // update predictions for validation set i = 0; foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TestIndices)) { yPredTest[i] = yPredTest[i] + nu * pred; i++; } gbmState.weights.Add(nu); gbmState.models.Add(tree); } #endregion } }