#region License Information /* HeuristicLab * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * and the BEACON Center for the Study of Evolution in Action. * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Diagnostics.Contracts; using System.Linq; using GradientBoostedTrees; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.DataAnalysis { public static class GradientBoostedTreesAlgorithmStatic { #region static API public interface IGbmState { IRegressionModel GetModel(); double GetTrainLoss(); double GetTestLoss(); IEnumerable> GetVariableRelevance(); } // created through factory method private class GbmState : IGbmState { internal IRegressionProblemData problemData { get; set; } internal MersenneTwister random { get; set; } internal ILossFunction lossFunction { get; set; } internal int maxDepth { get; set; } internal double nu { get; set; } internal double r { get; set; } internal double m { get; set; } internal RegressionTreeBuilder treeBuilder; // array members (allocate only once) internal double[] pred; internal double[] predTest; internal double[] w; internal double[] y; internal int[] activeIdx; internal double[] rim; internal IList models; internal IList weights; public GbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxDepth, double r, double m, double nu) { // default settings for MaxDepth, Nu and R this.maxDepth = maxDepth; this.nu = nu; this.r = r; this.m = m; random = new MersenneTwister(randSeed); this.problemData = problemData; this.lossFunction = lossFunction; int nRows = problemData.TrainingIndices.Count(); y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray(); // weights are all 1 for now (HL doesn't support weights yet) w = Enumerable.Repeat(1.0, nRows).ToArray(); treeBuilder = new RegressionTreeBuilder(problemData, random); activeIdx = Enumerable.Range(0, nRows).ToArray(); // prepare arrays (allocate only once) double f0 = y.Average(); // default prediction (constant) pred = Enumerable.Repeat(f0, nRows).ToArray(); predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray(); rim = new double[nRows]; models = new List(); weights = new List(); // add constant model models.Add(new ConstantRegressionModel(f0)); weights.Add(1.0); } public IRegressionModel GetModel() { return new GradientBoostedTreesModel(models, weights); } public IEnumerable> GetVariableRelevance() { return treeBuilder.GetVariableRelevance(); } public double GetTrainLoss() { int nRows = y.Length; return lossFunction.GetLoss(y, pred, w) / nRows; } public double GetTestLoss() { var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices); var wTest = yTest.Select(_ => 1.0); // ones var nRows = yTest.Count(); return lossFunction.GetLoss(yTest, predTest, wTest) / nRows; } } // simple interface public static IRegressionSolution TrainGbm(IRegressionProblemData problemData, int maxDepth, double nu, double r, int maxIterations) { return TrainGbm(problemData, new SquaredErrorLoss(), maxDepth, nu, r, maxIterations); } // simple interface public static IRegressionSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxDepth, double nu, double r, int maxIterations, uint randSeed = 31415) { Contract.Assert(r > 0); Contract.Assert(r <= 1.0); Contract.Assert(nu > 0); Contract.Assert(nu <= 1.0); var state = (GbmState)CreateGbmState(problemData, lossFunction, randSeed); state.maxDepth = maxDepth; state.r = r; state.nu = nu; for (int iter = 0; iter < maxIterations; iter++) { MakeStep(state); } var model = new GradientBoostedTreesModel(state.models, state.weights); return new RegressionSolution(model, (IRegressionProblemData)problemData.Clone()); } // for custom stepping & termination public static IGbmState CreateGbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxDepth = 3, double r = 0.66, double m = 0.5, double nu = 0.01) { return new GbmState(problemData, lossFunction, randSeed, maxDepth, r, m, nu); } // use default settings for maxDepth, nu, r from state public static void MakeStep(IGbmState state) { var gbmState = state as GbmState; if (gbmState == null) throw new ArgumentException("state"); MakeStep(gbmState, gbmState.maxDepth, gbmState.nu, gbmState.r, gbmState.m); } // allow dynamic adaptation of maxDepth, nu and r public static void MakeStep(IGbmState state, int maxDepth, double nu, double r, double m) { var gbmState = state as GbmState; if (gbmState == null) throw new ArgumentException("state"); var problemData = gbmState.problemData; var lossFunction = gbmState.lossFunction; var yPred = gbmState.pred; var yPredTest = gbmState.predTest; var w = gbmState.w; var treeBuilder = gbmState.treeBuilder; var y = gbmState.y; var activeIdx = gbmState.activeIdx; var rim = gbmState.rim; // copy output of gradient function to pre-allocated rim array (pseudo-residuals) int rimIdx = 0; foreach (var g in lossFunction.GetLossGradient(y, yPred, w)) { rim[rimIdx++] = g; } var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(rim, maxDepth, activeIdx, lossFunction.GetLineSearchFunc(y, yPred, w), r, m); int i = 0; foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices)) { yPred[i] = yPred[i] + nu * pred; i++; } // update predictions for validation set i = 0; foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TestIndices)) { yPredTest[i] = yPredTest[i] + nu * pred; i++; } gbmState.weights.Add(nu); gbmState.models.Add(tree); } #endregion } }