#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HEAL.Attic; using HeuristicLab.Analysis; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.DataAnalysis { [Item("Generalized Additive Model (GAM)", "Generalized additive model using uni-variate penalized regression splines as base learner.")] [StorableType("98A887E7-73DD-4602-BD6C-2F6B9E6FBBC5")] [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 600)] public sealed class GeneralizedAdditiveModelAlgorithm : FixedDataAnalysisAlgorithm { #region ParameterNames private const string IterationsParameterName = "Iterations"; private const string LambdaParameterName = "Lambda"; private const string SeedParameterName = "Seed"; private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; private const string CreateSolutionParameterName = "CreateSolution"; #endregion #region ParameterProperties public IFixedValueParameter IterationsParameter { get { return (IFixedValueParameter)Parameters[IterationsParameterName]; } } public IFixedValueParameter LambdaParameter { get { return (IFixedValueParameter)Parameters[LambdaParameterName]; } } public IFixedValueParameter SeedParameter { get { return (IFixedValueParameter)Parameters[SeedParameterName]; } } public FixedValueParameter SetSeedRandomlyParameter { get { return (FixedValueParameter)Parameters[SetSeedRandomlyParameterName]; } } public IFixedValueParameter CreateSolutionParameter { get { return (IFixedValueParameter)Parameters[CreateSolutionParameterName]; } } #endregion #region Properties public int Iterations { get { return IterationsParameter.Value.Value; } set { IterationsParameter.Value.Value = value; } } public double Lambda { get { return LambdaParameter.Value.Value; } set { LambdaParameter.Value.Value = value; } } public int Seed { get { return SeedParameter.Value.Value; } set { SeedParameter.Value.Value = value; } } public bool SetSeedRandomly { get { return SetSeedRandomlyParameter.Value.Value; } set { SetSeedRandomlyParameter.Value.Value = value; } } public bool CreateSolution { get { return CreateSolutionParameter.Value.Value; } set { CreateSolutionParameter.Value.Value = value; } } #endregion [StorableConstructor] private GeneralizedAdditiveModelAlgorithm(StorableConstructorFlag deserializing) : base(deserializing) { } private GeneralizedAdditiveModelAlgorithm(GeneralizedAdditiveModelAlgorithm original, Cloner cloner) : base(original, cloner) { } public override IDeepCloneable Clone(Cloner cloner) { return new GeneralizedAdditiveModelAlgorithm(this, cloner); } public GeneralizedAdditiveModelAlgorithm() { Problem = new RegressionProblem(); // default problem Parameters.Add(new FixedValueParameter(IterationsParameterName, "Number of iterations. Try a large value and check convergence of the error over iterations. Usually, only a few iterations (e.g. 10) are needed for convergence.", new IntValue(10))); Parameters.Add(new FixedValueParameter(LambdaParameterName, "The penalty parameter for the penalized regression splines. Set to a value between -8 (weak smoothing) and 8 (strong smooting). Usually, a value between -4 and 4 should be fine", new DoubleValue(3))); Parameters.Add(new FixedValueParameter(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0))); Parameters.Add(new FixedValueParameter(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true))); Parameters.Add(new FixedValueParameter(CreateSolutionParameterName, "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); Parameters[CreateSolutionParameterName].Hidden = true; } protected override void Run(CancellationToken cancellationToken) { // Set up the algorithm if (SetSeedRandomly) Seed = new System.Random().Next(); var rand = new MersenneTwister((uint)Seed); // calculates a GAM model using univariate non-linear functions // using backfitting algorithm (see The Elements of Statistical Learning page 298) // init var problemData = Problem.ProblemData; var ds = problemData.Dataset; var trainRows = problemData.TrainingIndices.ToArray(); var testRows = problemData.TestIndices.ToArray(); var avgY = problemData.TargetVariableTrainingValues.Average(); var inputVars = problemData.AllowedInputVariables.ToArray(); int nTerms = inputVars.Length; #region init results // Set up the results display var iterations = new IntValue(0); Results.Add(new Result("Iterations", iterations)); var table = new DataTable("Qualities"); var rmseRow = new DataRow("RMSE (train)"); var rmseRowTest = new DataRow("RMSE (test)"); table.Rows.Add(rmseRow); table.Rows.Add(rmseRowTest); Results.Add(new Result("Qualities", table)); var curRMSE = new DoubleValue(); var curRMSETest = new DoubleValue(); Results.Add(new Result("RMSE (train)", curRMSE)); Results.Add(new Result("RMSE (test)", curRMSETest)); // calculate table with residual contributions of each term var rssTable = new DoubleMatrix(nTerms, 1, new string[] { "RSS" }, inputVars); Results.Add(new Result("RSS Values", rssTable)); #endregion // start with a set of constant models = 0 IRegressionModel[] f = new IRegressionModel[nTerms]; for (int i = 0; i < f.Length; i++) { f[i] = new ConstantModel(0.0, problemData.TargetVariable); } // init res which contains the current residual vector double[] res = problemData.TargetVariableTrainingValues.Select(yi => yi - avgY).ToArray(); double[] resTest = problemData.TargetVariableTestValues.Select(yi => yi - avgY).ToArray(); curRMSE.Value = RMSE(res); curRMSETest.Value = RMSE(resTest); rmseRow.Values.Add(curRMSE.Value); rmseRowTest.Values.Add(curRMSETest.Value); double lambda = Lambda; var idx = Enumerable.Range(0, nTerms).ToArray(); // Loop until iteration limit reached or canceled. for (int i = 0; i < Iterations && !cancellationToken.IsCancellationRequested; i++) { // shuffle order of terms in each iteration to remove bias on earlier terms idx.ShuffleInPlace(rand); foreach (var inputIdx in idx) { var inputVar = inputVars[inputIdx]; // first remove the effect of the previous model for the inputIdx (by adding the output of the current model to the residual) AddInPlace(res, f[inputIdx].GetEstimatedValues(ds, trainRows)); AddInPlace(resTest, f[inputIdx].GetEstimatedValues(ds, testRows)); rssTable[inputIdx, 0] = MSE(res); f[inputIdx] = RegressSpline(problemData, inputVar, res, lambda); SubtractInPlace(res, f[inputIdx].GetEstimatedValues(ds, trainRows)); SubtractInPlace(resTest, f[inputIdx].GetEstimatedValues(ds, testRows)); } curRMSE.Value = RMSE(res); curRMSETest.Value = RMSE(resTest); rmseRow.Values.Add(curRMSE.Value); rmseRowTest.Values.Add(curRMSETest.Value); iterations.Value = i; } // produce solution if (CreateSolution) { var model = new RegressionEnsembleModel(f.Concat(new[] { new ConstantModel(avgY, problemData.TargetVariable) })); model.AverageModelEstimates = false; var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()); Results.Add(new Result("Ensemble solution", solution)); } } public static double MSE(IEnumerable residuals) { var mse = residuals.Select(r => r * r).Average(); return mse; } public static double RMSE(IEnumerable residuals) { var mse = MSE(residuals); var rmse = Math.Sqrt(mse); return rmse; } private IRegressionModel RegressSpline(IRegressionProblemData problemData, string inputVar, double[] target, double lambda) { var x = problemData.Dataset.GetDoubleValues(inputVar, problemData.TrainingIndices).ToArray(); var y = (double[])target.Clone(); int info; alglib.spline1dinterpolant s; alglib.spline1dfitreport rep; int numKnots = (int)Math.Min(50, 3 * Math.Sqrt(x.Length)); // heuristic for number of knots (Elements of Statistical Learning) alglib.spline1dfitpenalized(x, y, numKnots, lambda, out info, out s, out rep); return new Spline1dModel(s.innerobj, problemData.TargetVariable, inputVar); } private static void AddInPlace(double[] a, IEnumerable enumerable) { int i = 0; foreach (var elem in enumerable) { a[i] += elem; i++; } } private static void SubtractInPlace(double[] a, IEnumerable enumerable) { int i = 0; foreach (var elem in enumerable) { a[i] -= elem; i++; } } } }