Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/KernelRidgeRegression/KernelRidgeRegressionModel.cs @ 18079

Last change on this file since 18079 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 9.7 KB
RevLine 
[14872]1#region License Information
[14386]2/* HeuristicLab
[17181]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[14386]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[14888]23using System.Collections.Generic;
[14386]24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
[17097]27using HEAL.Attic;
[14386]28using HeuristicLab.Problems.DataAnalysis;
29
[15249]30namespace HeuristicLab.Algorithms.DataAnalysis {
[17097]31  [StorableType("4148D88C-6081-4D84-B718-C949CA5AA766")]
[14887]32  [Item("KernelRidgeRegressionModel", "A kernel ridge regression model")]
33  public sealed class KernelRidgeRegressionModel : RegressionModel {
[14892]34    public override IEnumerable<string> VariablesUsedForPrediction {
[14386]35      get { return allowedInputVariables; }
36    }
37
38    [Storable]
[14872]39    private readonly string[] allowedInputVariables;
[14892]40    public string[] AllowedInputVariables {
[15249]41      get { return allowedInputVariables.ToArray(); }
[14386]42    }
43
[14888]44
[14386]45    [Storable]
[14888]46    public double LooCvRMSE { get; private set; }
47
48    [Storable]
[14872]49    private readonly double[] alpha;
50
[14386]51    [Storable]
[14872]52    private readonly double[,] trainX; // it is better to store the original training dataset completely because this is more efficient in persistence
53
[14386]54    [Storable]
[14872]55    private readonly ITransformation<double>[] scaling;
56
[14386]57    [Storable]
[14872]58    private readonly ICovarianceFunction kernel;
59
[14887]60    [Storable]
61    private readonly double lambda;
[14872]62
[14386]63    [Storable]
[14888]64    private readonly double yOffset; // implementation works for zero-mean, unit-variance target variables
[14386]65
[14887]66    [Storable]
67    private readonly double yScale;
68
[14386]69    [StorableConstructor]
[17097]70    private KernelRidgeRegressionModel(StorableConstructorFlag _) : base(_) { }
[14887]71    private KernelRidgeRegressionModel(KernelRidgeRegressionModel original, Cloner cloner)
[14386]72      : base(original, cloner) {
73      // shallow copies of arrays because they cannot be modified
74      allowedInputVariables = original.allowedInputVariables;
75      alpha = original.alpha;
[14872]76      trainX = original.trainX;
77      scaling = original.scaling;
[14887]78      lambda = original.lambda;
[14888]79      LooCvRMSE = original.LooCvRMSE;
[14872]80
[14887]81      yOffset = original.yOffset;
82      yScale = original.yScale;
[15249]83      kernel = original.kernel;
[14386]84    }
[14887]85    public override IDeepCloneable Clone(Cloner cloner) {
86      return new KernelRidgeRegressionModel(this, cloner);
87    }
88
[15249]89    public static KernelRidgeRegressionModel Create(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows,
90      bool scaleInputs, ICovarianceFunction kernel, double lambda = 0.1) {
[14872]91      var trainingRows = rows.ToArray();
[15249]92      var model = new KernelRidgeRegressionModel(dataset, targetVariable, allowedInputVariables, trainingRows, scaleInputs, kernel, lambda);
93
[14386]94      try {
95        int info;
[15249]96        int n = model.trainX.GetLength(0);
[14887]97        alglib.densesolverreport denseSolveRep;
[15249]98        var gram = BuildGramMatrix(model.trainX, lambda, kernel);
99        var l = new double[n, n];
100        Array.Copy(gram, l, l.Length);
[14872]101
[15249]102        double[] alpha = new double[n];
[14891]103        double[,] invG;
[15249]104        var y = dataset.GetDoubleValues(targetVariable, trainingRows).ToArray();
105        for (int i = 0; i < y.Length; i++) {
106          y[i] -= model.yOffset;
107          y[i] *= model.yScale;
108        }
[14887]109        // cholesky decomposition
[14888]110        var res = alglib.trfac.spdmatrixcholesky(ref l, n, false);
[15249]111        if (res == false) { //try lua decomposition if cholesky faild
[14891]112          int[] pivots;
113          var lua = new double[n, n];
114          Array.Copy(gram, lua, lua.Length);
115          alglib.rmatrixlu(ref lua, n, n, out pivots);
116          alglib.rmatrixlusolve(lua, pivots, n, y, out info, out denseSolveRep, out alpha);
117          if (info != 1) throw new ArgumentException("Could not create model.");
118          alglib.matinvreport rep;
119          invG = lua;  // rename
120          alglib.rmatrixluinverse(ref invG, pivots, n, out info, out rep);
121        } else {
122          alglib.spdmatrixcholeskysolve(l, n, false, y, out info, out denseSolveRep, out alpha);
123          if (info != 1) throw new ArgumentException("Could not create model.");
[14888]124          // for LOO-CV we need to build the inverse of the gram matrix
125          alglib.matinvreport rep;
[14891]126          invG = l;   // rename
127          alglib.spdmatrixcholeskyinverse(ref invG, n, false, out info, out rep);
[14888]128        }
[15249]129        if (info != 1) throw new ArgumentException("Could not invert Gram matrix.");
[14891]130
131        var ssqLooError = 0.0;
132        for (int i = 0; i < n; i++) {
133          var pred_i = Util.ScalarProd(Util.GetRow(gram, i).ToArray(), alpha);
134          var looPred_i = pred_i - alpha[i] / invG[i, i];
[15249]135          var error = (y[i] - looPred_i) / model.yScale;
[14891]136          ssqLooError += error * error;
137        }
[15249]138
139        Array.Copy(alpha, model.alpha, n);
140        model.LooCvRMSE = Math.Sqrt(ssqLooError / n);
[14892]141      } catch (alglib.alglibexception ae) {
[14386]142        // wrap exception so that calling code doesn't have to know about alglib implementation
[14887]143        throw new ArgumentException("There was a problem in the calculation of the kernel ridge regression model", ae);
[14386]144      }
[15249]145      return model;
[14386]146    }
[14872]147
[15249]148    private KernelRidgeRegressionModel(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, int[] rows,
149      bool scaleInputs, ICovarianceFunction kernel, double lambda = 0.1) : base(targetVariable) {
150      this.allowedInputVariables = allowedInputVariables.ToArray();
151      if (kernel.GetNumberOfParameters(this.allowedInputVariables.Length) > 0) throw new ArgumentException("All parameters in the kernel function must be specified.");
152      name = ItemName;
153      description = ItemDescription;
[14887]154
[15249]155      this.kernel = (ICovarianceFunction)kernel.Clone();
156      this.lambda = lambda;
157      if (scaleInputs) scaling = CreateScaling(dataset, rows, this.allowedInputVariables);
158      trainX = ExtractData(dataset, rows, this.allowedInputVariables, scaling);
159      var y = dataset.GetDoubleValues(targetVariable, rows).ToArray();
160      yOffset = y.Average();
161      yScale = 1.0 / y.StandardDeviation();
162      alpha = new double[trainX.GetLength(0)];
163    }
164
165
[14887]166    #region IRegressionModel Members
167    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
[15249]168      var newX = ExtractData(dataset, rows, allowedInputVariables, scaling);
[14887]169      var dim = newX.GetLength(1);
170      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
171
172      var pred = new double[newX.GetLength(0)];
173      for (int i = 0; i < pred.Length; i++) {
174        double sum = 0.0;
175        for (int j = 0; j < alpha.Length; j++) {
176          sum += alpha[j] * cov.CrossCovariance(trainX, newX, j, i);
177        }
178        pred[i] = sum / yScale + yOffset;
179      }
180      return pred;
181    }
182    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
183      return new RegressionSolution(this, new RegressionProblemData(problemData));
184    }
185    #endregion
186
187    #region helpers
[15249]188    private static double[,] BuildGramMatrix(double[,] data, double lambda, ICovarianceFunction kernel) {
[14887]189      var n = data.GetLength(0);
190      var dim = data.GetLength(1);
191      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
192      var gram = new double[n, n];
193      // G = (K + λ I)
194      for (var i = 0; i < n; i++) {
195        for (var j = i; j < n; j++) {
[14888]196          gram[i, j] = gram[j, i] = cov.Covariance(data, i, j); // symmetric matrix
[14887]197        }
198        gram[i, i] += lambda;
199      }
200      return gram;
201    }
202
[15249]203    private static ITransformation<double>[] CreateScaling(IDataset dataset, int[] rows, IReadOnlyCollection<string> allowedInputVariables) {
204      var trans = new ITransformation<double>[allowedInputVariables.Count];
[14872]205      int i = 0;
206      foreach (var variable in allowedInputVariables) {
207        var lin = new LinearTransformation(allowedInputVariables);
208        var max = dataset.GetDoubleValues(variable, rows).Max();
209        var min = dataset.GetDoubleValues(variable, rows).Min();
210        lin.Multiplier = 1.0 / (max - min);
211        lin.Addend = -min / (max - min);
212        trans[i] = lin;
213        i++;
214      }
215      return trans;
[14386]216    }
217
[15249]218    private static double[,] ExtractData(IDataset dataset, IEnumerable<int> rows, IReadOnlyCollection<string> allowedInputVariables, ITransformation<double>[] scaling = null) {
[14872]219      double[][] variables;
220      if (scaling != null) {
221        variables =
222          allowedInputVariables.Select((var, i) => scaling[i].Apply(dataset.GetDoubleValues(var, rows)).ToArray())
223            .ToArray();
224      } else {
225        variables =
226        allowedInputVariables.Select(var => dataset.GetDoubleValues(var, rows).ToArray()).ToArray();
227      }
228      int n = variables.First().Length;
229      var res = new double[n, variables.Length];
230      for (int r = 0; r < n; r++)
231        for (int c = 0; c < variables.Length; c++) {
232          res[r, c] = variables[c][r];
233        }
234      return res;
235    }
[14386]236    #endregion
237  }
238}
Note: See TracBrowser for help on using the repository browser.