Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2947_ConfigurableIndexedDataTable/HeuristicLab.Algorithms.DataAnalysis/3.4/KernelRidgeRegression/KernelRidgeRegressionModel.cs @ 16830

Last change on this file since 16830 was 15583, checked in by swagner, 7 years ago

#2640: Updated year of copyrights in license headers

File size: 9.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  [StorableClass]
32  [Item("KernelRidgeRegressionModel", "A kernel ridge regression model")]
33  public sealed class KernelRidgeRegressionModel : RegressionModel {
34    public override IEnumerable<string> VariablesUsedForPrediction {
35      get { return allowedInputVariables; }
36    }
37
38    [Storable]
39    private readonly string[] allowedInputVariables;
40    public string[] AllowedInputVariables {
41      get { return allowedInputVariables.ToArray(); }
42    }
43
44
45    [Storable]
46    public double LooCvRMSE { get; private set; }
47
48    [Storable]
49    private readonly double[] alpha;
50
51    [Storable]
52    private readonly double[,] trainX; // it is better to store the original training dataset completely because this is more efficient in persistence
53
54    [Storable]
55    private readonly ITransformation<double>[] scaling;
56
57    [Storable]
58    private readonly ICovarianceFunction kernel;
59
60    [Storable]
61    private readonly double lambda;
62
63    [Storable]
64    private readonly double yOffset; // implementation works for zero-mean, unit-variance target variables
65
66    [Storable]
67    private readonly double yScale;
68
69    [StorableConstructor]
70    private KernelRidgeRegressionModel(bool deserializing) : base(deserializing) { }
71    private KernelRidgeRegressionModel(KernelRidgeRegressionModel original, Cloner cloner)
72      : base(original, cloner) {
73      // shallow copies of arrays because they cannot be modified
74      allowedInputVariables = original.allowedInputVariables;
75      alpha = original.alpha;
76      trainX = original.trainX;
77      scaling = original.scaling;
78      lambda = original.lambda;
79      LooCvRMSE = original.LooCvRMSE;
80
81      yOffset = original.yOffset;
82      yScale = original.yScale;
83      kernel = original.kernel;
84    }
85    public override IDeepCloneable Clone(Cloner cloner) {
86      return new KernelRidgeRegressionModel(this, cloner);
87    }
88
89    public static KernelRidgeRegressionModel Create(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows,
90      bool scaleInputs, ICovarianceFunction kernel, double lambda = 0.1) {
91      var trainingRows = rows.ToArray();
92      var model = new KernelRidgeRegressionModel(dataset, targetVariable, allowedInputVariables, trainingRows, scaleInputs, kernel, lambda);
93
94      try {
95        int info;
96        int n = model.trainX.GetLength(0);
97        alglib.densesolverreport denseSolveRep;
98        var gram = BuildGramMatrix(model.trainX, lambda, kernel);
99        var l = new double[n, n];
100        Array.Copy(gram, l, l.Length);
101
102        double[] alpha = new double[n];
103        double[,] invG;
104        var y = dataset.GetDoubleValues(targetVariable, trainingRows).ToArray();
105        for (int i = 0; i < y.Length; i++) {
106          y[i] -= model.yOffset;
107          y[i] *= model.yScale;
108        }
109        // cholesky decomposition
110        var res = alglib.trfac.spdmatrixcholesky(ref l, n, false);
111        if (res == false) { //try lua decomposition if cholesky faild
112          int[] pivots;
113          var lua = new double[n, n];
114          Array.Copy(gram, lua, lua.Length);
115          alglib.rmatrixlu(ref lua, n, n, out pivots);
116          alglib.rmatrixlusolve(lua, pivots, n, y, out info, out denseSolveRep, out alpha);
117          if (info != 1) throw new ArgumentException("Could not create model.");
118          alglib.matinvreport rep;
119          invG = lua;  // rename
120          alglib.rmatrixluinverse(ref invG, pivots, n, out info, out rep);
121        } else {
122          alglib.spdmatrixcholeskysolve(l, n, false, y, out info, out denseSolveRep, out alpha);
123          if (info != 1) throw new ArgumentException("Could not create model.");
124          // for LOO-CV we need to build the inverse of the gram matrix
125          alglib.matinvreport rep;
126          invG = l;   // rename
127          alglib.spdmatrixcholeskyinverse(ref invG, n, false, out info, out rep);
128        }
129        if (info != 1) throw new ArgumentException("Could not invert Gram matrix.");
130
131        var ssqLooError = 0.0;
132        for (int i = 0; i < n; i++) {
133          var pred_i = Util.ScalarProd(Util.GetRow(gram, i).ToArray(), alpha);
134          var looPred_i = pred_i - alpha[i] / invG[i, i];
135          var error = (y[i] - looPred_i) / model.yScale;
136          ssqLooError += error * error;
137        }
138
139        Array.Copy(alpha, model.alpha, n);
140        model.LooCvRMSE = Math.Sqrt(ssqLooError / n);
141      } catch (alglib.alglibexception ae) {
142        // wrap exception so that calling code doesn't have to know about alglib implementation
143        throw new ArgumentException("There was a problem in the calculation of the kernel ridge regression model", ae);
144      }
145      return model;
146    }
147
148    private KernelRidgeRegressionModel(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, int[] rows,
149      bool scaleInputs, ICovarianceFunction kernel, double lambda = 0.1) : base(targetVariable) {
150      this.allowedInputVariables = allowedInputVariables.ToArray();
151      if (kernel.GetNumberOfParameters(this.allowedInputVariables.Length) > 0) throw new ArgumentException("All parameters in the kernel function must be specified.");
152      name = ItemName;
153      description = ItemDescription;
154
155      this.kernel = (ICovarianceFunction)kernel.Clone();
156      this.lambda = lambda;
157      if (scaleInputs) scaling = CreateScaling(dataset, rows, this.allowedInputVariables);
158      trainX = ExtractData(dataset, rows, this.allowedInputVariables, scaling);
159      var y = dataset.GetDoubleValues(targetVariable, rows).ToArray();
160      yOffset = y.Average();
161      yScale = 1.0 / y.StandardDeviation();
162      alpha = new double[trainX.GetLength(0)];
163    }
164
165
166    #region IRegressionModel Members
167    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
168      var newX = ExtractData(dataset, rows, allowedInputVariables, scaling);
169      var dim = newX.GetLength(1);
170      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
171
172      var pred = new double[newX.GetLength(0)];
173      for (int i = 0; i < pred.Length; i++) {
174        double sum = 0.0;
175        for (int j = 0; j < alpha.Length; j++) {
176          sum += alpha[j] * cov.CrossCovariance(trainX, newX, j, i);
177        }
178        pred[i] = sum / yScale + yOffset;
179      }
180      return pred;
181    }
182    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
183      return new RegressionSolution(this, new RegressionProblemData(problemData));
184    }
185    #endregion
186
187    #region helpers
188    private static double[,] BuildGramMatrix(double[,] data, double lambda, ICovarianceFunction kernel) {
189      var n = data.GetLength(0);
190      var dim = data.GetLength(1);
191      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
192      var gram = new double[n, n];
193      // G = (K + λ I)
194      for (var i = 0; i < n; i++) {
195        for (var j = i; j < n; j++) {
196          gram[i, j] = gram[j, i] = cov.Covariance(data, i, j); // symmetric matrix
197        }
198        gram[i, i] += lambda;
199      }
200      return gram;
201    }
202
203    private static ITransformation<double>[] CreateScaling(IDataset dataset, int[] rows, IReadOnlyCollection<string> allowedInputVariables) {
204      var trans = new ITransformation<double>[allowedInputVariables.Count];
205      int i = 0;
206      foreach (var variable in allowedInputVariables) {
207        var lin = new LinearTransformation(allowedInputVariables);
208        var max = dataset.GetDoubleValues(variable, rows).Max();
209        var min = dataset.GetDoubleValues(variable, rows).Min();
210        lin.Multiplier = 1.0 / (max - min);
211        lin.Addend = -min / (max - min);
212        trans[i] = lin;
213        i++;
214      }
215      return trans;
216    }
217
218    private static double[,] ExtractData(IDataset dataset, IEnumerable<int> rows, IReadOnlyCollection<string> allowedInputVariables, ITransformation<double>[] scaling = null) {
219      double[][] variables;
220      if (scaling != null) {
221        variables =
222          allowedInputVariables.Select((var, i) => scaling[i].Apply(dataset.GetDoubleValues(var, rows)).ToArray())
223            .ToArray();
224      } else {
225        variables =
226        allowedInputVariables.Select(var => dataset.GetDoubleValues(var, rows).ToArray()).ToArray();
227      }
228      int n = variables.First().Length;
229      var res = new double[n, variables.Length];
230      for (int r = 0; r < n; r++)
231        for (int c = 0; c < variables.Length; c++) {
232          res[r, c] = variables[c][r];
233        }
234      return res;
235    }
236    #endregion
237  }
238}
Note: See TracBrowser for help on using the repository browser.