Free cookie consent management tool by TermsFeed Policy Generator

source: branches/RBFRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/RadialBasisFunctions/RadialBasisFunctionModel.cs @ 14386

Last change on this file since 14386 was 14386, checked in by bwerth, 8 years ago

#2699 moved RadialBasisFunctions from Problems.SurrogateProblem to Algorithms.DataAnalysis

File size: 6.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  /// <summary>
33  /// Represents a Radial Basis Function regression model.
34  /// </summary>
35  [StorableClass]
36  [Item("RBFModel", "Represents a Gaussian process posterior.")]
37  public sealed class RadialBasisFunctionModel : RegressionModel, IConfidenceRegressionModel {
38    public override IEnumerable<string> VariablesUsedForPrediction
39    {
40      get { return allowedInputVariables; }
41    }
42
43    [Storable]
44    private string[] allowedInputVariables;
45    public string[] AllowedInputVariables
46    {
47      get { return allowedInputVariables; }
48    }
49
50    [Storable]
51    private double[] alpha;
52    [Storable]
53    private IDataset trainingDataset; // it is better to store the original training dataset completely because this is more efficient in persistence
54    [Storable]
55    private int[] trainingRows;
56    [Storable]
57    private IKernelFunction<double[]> kernel;
58    [Storable]
59    private DoubleMatrix gramInvert;
60
61
62    [StorableConstructor]
63    private RadialBasisFunctionModel(bool deserializing) : base(deserializing) { }
64    private RadialBasisFunctionModel(RadialBasisFunctionModel original, Cloner cloner)
65      : base(original, cloner) {
66      // shallow copies of arrays because they cannot be modified
67      trainingRows = original.trainingRows;
68      allowedInputVariables = original.allowedInputVariables;
69      alpha = original.alpha;
70      trainingDataset = original.trainingDataset;
71      kernel = original.kernel;
72    }
73    public RadialBasisFunctionModel(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows, IKernelFunction<double[]> kernel)
74      : base(targetVariable) {
75      name = ItemName;
76      description = ItemDescription;
77      this.allowedInputVariables = allowedInputVariables.ToArray();
78      trainingRows = rows.ToArray();
79      trainingDataset = dataset;
80      this.kernel = (IKernelFunction<double[]>)kernel.Clone();
81      try {
82        var data = ExtractData(dataset, trainingRows);
83        var qualities = dataset.GetDoubleValues(targetVariable, trainingRows).ToArray();
84        int info;
85        alglib.densesolverlsreport denseSolveRep;
86        var gr = BuildGramMatrix(data);
87        alglib.rmatrixsolvels(gr, data.Length + 1, data.Length + 1, qualities.Concat(new[] { 0.0 }).ToArray(), 0.0, out info, out denseSolveRep, out alpha);
88        if (info != 1) throw new ArgumentException("Could not create Model.");
89        gramInvert = new DoubleMatrix(gr).Invert();
90      }
91      catch (alglib.alglibexception ae) {
92        // wrap exception so that calling code doesn't have to know about alglib implementation
93        throw new ArgumentException("There was a problem in the calculation of the RBF process model", ae);
94      }
95    }
96    private double[][] ExtractData(IDataset dataset, IEnumerable<int> rows) {
97      return rows.Select(r => allowedInputVariables.Select(v => dataset.GetDoubleValue(v, r)).ToArray()).ToArray();
98    }
99
100    public override IDeepCloneable Clone(Cloner cloner) {
101      return new RadialBasisFunctionModel(this, cloner);
102    }
103
104    #region IRegressionModel Members
105    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
106      var solutions = ExtractData(dataset, rows);
107      var data = ExtractData(trainingDataset, trainingRows);
108      return solutions.Select(solution => alpha.Zip(data, (a, d) => a * kernel.Get(solution, d)).Sum() + 1 * alpha[alpha.Length - 1]).ToArray();
109    }
110    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
111      return new RadialBasisFunctionRegressionSolution(this, new RegressionProblemData(problemData));
112    }
113    #endregion
114
115    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
116      var data = ExtractData(trainingDataset, trainingRows);
117      return ExtractData(dataset, rows).Select(x => GetVariance(x, data));
118    }
119    public double LeaveOneOutCrossValidationRootMeanSquaredError() {
120      return Math.Sqrt(alpha.Select((t, i) => t / gramInvert[i, i]).Sum(d => d * d) / gramInvert.Rows);
121    }
122
123    #region helpers
124    private double[,] BuildGramMatrix(double[][] data) {
125      var size = data.Length + 1;
126      var gram = new double[size, size];
127      for (var i = 0; i < size; i++)
128        for (var j = i; j < size; j++) {
129          if (j == size - 1 && i == size - 1) gram[i, j] = 0;
130          else if (j == size - 1 || i == size - 1) gram[j, i] = gram[i, j] = 1;
131          else gram[j, i] = gram[i, j] = kernel.Get(data[i], data[j]); //symmteric Matrix --> half of the work
132        }
133      return gram;
134    }
135    private double GetVariance(double[] solution, IEnumerable<double[]> data) {
136      var phiT = data.Select(x => kernel.Get(x, solution)).Concat(new[] { 1.0 }).ToColumnVector();
137      var res = phiT.Transpose().Mul(gramInvert.Mul(phiT))[0, 0];
138      return res > 0 ? res : 0;
139    }
140    #endregion
141  }
142}
Note: See TracBrowser for help on using the repository browser.