Free cookie consent management tool by TermsFeed Policy Generator

source: branches/RBFRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/KernelRidgeRegression/KernelRidgeRegressionModel.cs @ 14887

Last change on this file since 14887 was 14887, checked in by gkronber, 7 years ago

#2699: worked on kernel ridge regression. moved beta parameter to algorithm. reintroduced IKernel interface to restrict choice of kernel in kernel ridge regression. speed-up by cholesky decomposition and optimization of the calculation of the covariance matrix.

File size: 7.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;   
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis.KernelRidgeRegression {
31  [StorableClass]
32  [Item("KernelRidgeRegressionModel", "A kernel ridge regression model")]
33  public sealed class KernelRidgeRegressionModel : RegressionModel {
34    public override IEnumerable<string> VariablesUsedForPrediction {
35      get { return allowedInputVariables; }
36    }
37
38    [Storable]
39    private readonly string[] allowedInputVariables;
40    public string[] AllowedInputVariables {
41      get { return allowedInputVariables; }
42    }
43
44    [Storable]
45    private readonly double[] alpha;
46
47    [Storable]
48    private readonly double[,] trainX; // it is better to store the original training dataset completely because this is more efficient in persistence
49
50    [Storable]
51    private readonly ITransformation<double>[] scaling;
52
53    [Storable]
54    private readonly ICovarianceFunction kernel;
55
56    [Storable]
57    private readonly double lambda;
58
59    [Storable]
60    private readonly double yOffset; // implementation works for zero-mean target variables
61
62    [Storable]
63    private readonly double yScale;
64
65    [StorableConstructor]
66    private KernelRidgeRegressionModel(bool deserializing) : base(deserializing) { }
67    private KernelRidgeRegressionModel(KernelRidgeRegressionModel original, Cloner cloner)
68      : base(original, cloner) {
69      // shallow copies of arrays because they cannot be modified
70      allowedInputVariables = original.allowedInputVariables;
71      alpha = original.alpha;
72      trainX = original.trainX;
73      scaling = original.scaling;
74      lambda = original.lambda;
75
76      yOffset = original.yOffset;
77      yScale = original.yScale;
78      if (original.kernel != null)
79        kernel = cloner.Clone(original.kernel);
80    }
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new KernelRidgeRegressionModel(this, cloner);
83    }
84
85    public KernelRidgeRegressionModel(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows,
86      bool scaleInputs, ICovarianceFunction kernel, double lambda = 0.1) : base(targetVariable) {
87      if (kernel.GetNumberOfParameters(allowedInputVariables.Count()) > 0) throw new ArgumentException("All parameters in the kernel function must be specified.");
88      name = ItemName;
89      description = ItemDescription;
90      this.allowedInputVariables = allowedInputVariables.ToArray();
91      var trainingRows = rows.ToArray();
92      this.kernel = (ICovarianceFunction)kernel.Clone();
93      this.lambda = lambda;
94      try {
95        if (scaleInputs)
96          scaling = CreateScaling(dataset, trainingRows);
97        trainX = ExtractData(dataset, trainingRows, scaling);
98        var y = dataset.GetDoubleValues(targetVariable, trainingRows).ToArray();
99        yOffset = y.Average();
100        yScale = 1.0 / y.StandardDeviation();
101        for (int i = 0; i < y.Length; i++) {
102          y[i] -= yOffset;
103          y[i] *= yScale;
104        }
105        int info;
106        alglib.densesolverreport denseSolveRep;
107        var gram = BuildGramMatrix(trainX, lambda);
108        int n = trainX.GetLength(0);
109
110        // cholesky decomposition
111        var res = alglib.trfac.spdmatrixcholesky(ref gram, n, false);
112        if(res == false) throw new ArgumentException("Could not decompose matrix. Is it quadratic symmetric positive definite?");
113
114        alglib.spdmatrixcholeskysolve(gram, n, false, y, out info, out denseSolveRep, out alpha);
115        if (info != 1) throw new ArgumentException("Could not create model.");
116      } catch (alglib.alglibexception ae) {
117        // wrap exception so that calling code doesn't have to know about alglib implementation
118        throw new ArgumentException("There was a problem in the calculation of the kernel ridge regression model", ae);
119      }
120    }
121
122
123    #region IRegressionModel Members
124    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
125      var newX = ExtractData(dataset, rows, scaling);
126      var dim = newX.GetLength(1);
127      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
128
129      var pred = new double[newX.GetLength(0)];
130      for (int i = 0; i < pred.Length; i++) {
131        double sum = 0.0;
132        for (int j = 0; j < alpha.Length; j++) {
133          sum += alpha[j] * cov.CrossCovariance(trainX, newX, j, i);
134        }
135        pred[i] = sum / yScale + yOffset;
136      }
137      return pred;
138    }
139    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
140      return new RegressionSolution(this, new RegressionProblemData(problemData));
141    }
142    #endregion
143
144    #region helpers
145    private double[,] BuildGramMatrix(double[,] data, double lambda) {
146      var n = data.GetLength(0);
147      var dim = data.GetLength(1);
148      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
149      var gram = new double[n, n];
150      // G = (K + λ I)
151      for (var i = 0; i < n; i++) {
152        for (var j = i; j < n; j++) {
153          gram[j, i] = cov.Covariance(data, i, j); // symmetric matrix --> fill only lower triangle
154        }
155        gram[i, i] += lambda;
156      }
157      return gram;
158    }
159
160    private ITransformation<double>[] CreateScaling(IDataset dataset, int[] rows) {
161      var trans = new ITransformation<double>[allowedInputVariables.Length];
162      int i = 0;
163      foreach (var variable in allowedInputVariables) {
164        var lin = new LinearTransformation(allowedInputVariables);
165        var max = dataset.GetDoubleValues(variable, rows).Max();
166        var min = dataset.GetDoubleValues(variable, rows).Min();
167        lin.Multiplier = 1.0 / (max - min);
168        lin.Addend = -min / (max - min);
169        trans[i] = lin;
170        i++;
171      }
172      return trans;
173    }
174
175    private double[,] ExtractData(IDataset dataset, IEnumerable<int> rows, ITransformation<double>[] scaling = null) {
176      double[][] variables;
177      if (scaling != null) {
178        variables =
179          allowedInputVariables.Select((var, i) => scaling[i].Apply(dataset.GetDoubleValues(var, rows)).ToArray())
180            .ToArray();
181      } else {
182        variables =
183        allowedInputVariables.Select(var => dataset.GetDoubleValues(var, rows).ToArray()).ToArray();
184      }
185      int n = variables.First().Length;
186      var res = new double[n, variables.Length];
187      for (int r = 0; r < n; r++)
188        for (int c = 0; c < variables.Length; c++) {
189          res[r, c] = variables[c][r];
190        }
191      return res;
192    }
193    #endregion
194  }
195}
Note: See TracBrowser for help on using the repository browser.