Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/14/17 17:53:30 (7 years ago)
Author:
gkronber
Message:

#2699: made a number of changes mainly to RBF regression model

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/RBFRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/RadialBasisFunctions/RadialBasisFunctionModel.cs

    r14386 r14872  
    1 #region License Information
     1#region License Information
    22/* HeuristicLab
    33 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     
    2222using System;
    2323using System.Collections.Generic;
     24using System.Diagnostics;
    2425using System.Linq;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
    27 using HeuristicLab.Data;
    2828using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2929using HeuristicLab.Problems.DataAnalysis;
     
    3131namespace HeuristicLab.Algorithms.DataAnalysis {
    3232  /// <summary>
    33   /// Represents a Radial Basis Function regression model.
     33  /// Represents an RBF regression model.
    3434  /// </summary>
    3535  [StorableClass]
    36   [Item("RBFModel", "Represents a Gaussian process posterior.")]
     36  [Item("RBFModel", "An RBF regression model")]
    3737  public sealed class RadialBasisFunctionModel : RegressionModel, IConfidenceRegressionModel {
    38     public override IEnumerable<string> VariablesUsedForPrediction
    39     {
     38    public override IEnumerable<string> VariablesUsedForPrediction {
    4039      get { return allowedInputVariables; }
    4140    }
    4241
    4342    [Storable]
    44     private string[] allowedInputVariables;
    45     public string[] AllowedInputVariables
    46     {
     43    private readonly string[] allowedInputVariables;
     44    public string[] AllowedInputVariables {
    4745      get { return allowedInputVariables; }
    4846    }
    4947
    5048    [Storable]
    51     private double[] alpha;
    52     [Storable]
    53     private IDataset trainingDataset; // it is better to store the original training dataset completely because this is more efficient in persistence
    54     [Storable]
    55     private int[] trainingRows;
    56     [Storable]
    57     private IKernelFunction<double[]> kernel;
    58     [Storable]
    59     private DoubleMatrix gramInvert;
    60 
     49    private readonly double[] alpha;
     50
     51    [Storable]
     52    private readonly double[,] trainX; // it is better to store the original training dataset completely because this is more efficient in persistence
     53
     54    [Storable]
     55    private readonly ITransformation<double>[] scaling;
     56
     57    [Storable]
     58    private readonly ICovarianceFunction kernel;
     59
     60    private double[,] gramInvert; // not storable as it can be large (recreate after deserialization as required)
     61
     62    [Storable]
     63    private readonly double meanOffset; // implementation works for zero-mean target variables
    6164
    6265    [StorableConstructor]
     
    6568      : base(original, cloner) {
    6669      // shallow copies of arrays because they cannot be modified
    67       trainingRows = original.trainingRows;
    6870      allowedInputVariables = original.allowedInputVariables;
    6971      alpha = original.alpha;
    70       trainingDataset = original.trainingDataset;
    71       kernel = original.kernel;
    72     }
    73     public RadialBasisFunctionModel(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows, IKernelFunction<double[]> kernel)
     72      trainX = original.trainX;
     73      gramInvert = original.gramInvert;
     74      scaling = original.scaling;
     75
     76      meanOffset = original.meanOffset;
     77      if (original.kernel != null)
     78        kernel = cloner.Clone(original.kernel);
     79    }
     80    public RadialBasisFunctionModel(IDataset dataset, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows,
     81      bool scaleInputs, ICovarianceFunction kernel)
    7482      : base(targetVariable) {
     83      if (kernel.GetNumberOfParameters(allowedInputVariables.Count()) > 0) throw new ArgumentException("All parameters in the kernel function must be specified.");
    7584      name = ItemName;
    7685      description = ItemDescription;
    7786      this.allowedInputVariables = allowedInputVariables.ToArray();
    78       trainingRows = rows.ToArray();
    79       trainingDataset = dataset;
    80       this.kernel = (IKernelFunction<double[]>)kernel.Clone();
     87      var trainingRows = rows.ToArray();
     88      this.kernel = (ICovarianceFunction)kernel.Clone();
    8189      try {
    82         var data = ExtractData(dataset, trainingRows);
    83         var qualities = dataset.GetDoubleValues(targetVariable, trainingRows).ToArray();
     90        if (scaleInputs)
     91          scaling = CreateScaling(dataset, trainingRows);
     92        trainX = ExtractData(dataset, trainingRows, scaling);
     93        var y = dataset.GetDoubleValues(targetVariable, trainingRows).ToArray();
     94        meanOffset = y.Average();
     95        for (int i = 0; i < y.Length; i++) y[i] -= meanOffset;
    8496        int info;
     97        // TODO: improve efficiency by decomposing matrix once instead of solving the system and then inverting the matrix
    8598        alglib.densesolverlsreport denseSolveRep;
    86         var gr = BuildGramMatrix(data);
    87         alglib.rmatrixsolvels(gr, data.Length + 1, data.Length + 1, qualities.Concat(new[] { 0.0 }).ToArray(), 0.0, out info, out denseSolveRep, out alpha);
    88         if (info != 1) throw new ArgumentException("Could not create Model.");
    89         gramInvert = new DoubleMatrix(gr).Invert();
    90       }
    91       catch (alglib.alglibexception ae) {
     99        gramInvert = BuildGramMatrix(trainX);
     100        int n = trainX.GetLength(0);
     101        alglib.rmatrixsolvels(gramInvert, n, n, y, 0.0, out info, out denseSolveRep, out alpha);
     102        if (info != 1) throw new ArgumentException("Could not create model.");
     103
     104        alglib.matinvreport report;
     105        alglib.rmatrixinverse(ref gramInvert, out info, out report);
     106        if (info != 1) throw new ArgumentException("Could not invert matrix. Is it quadratic symmetric positive definite?");
     107
     108      } catch (alglib.alglibexception ae) {
    92109        // wrap exception so that calling code doesn't have to know about alglib implementation
    93         throw new ArgumentException("There was a problem in the calculation of the RBF process model", ae);
    94       }
    95     }
    96     private double[][] ExtractData(IDataset dataset, IEnumerable<int> rows) {
    97       return rows.Select(r => allowedInputVariables.Select(v => dataset.GetDoubleValue(v, r)).ToArray()).ToArray();
     110        throw new ArgumentException("There was a problem in the calculation of the RBF model", ae);
     111      }
     112    }
     113
     114    private ITransformation<double>[] CreateScaling(IDataset dataset, int[] rows) {
     115      var trans = new ITransformation<double>[allowedInputVariables.Length];
     116      int i = 0;
     117      foreach (var variable in allowedInputVariables) {
     118        var lin = new LinearTransformation(allowedInputVariables);
     119        var max = dataset.GetDoubleValues(variable, rows).Max();
     120        var min = dataset.GetDoubleValues(variable, rows).Min();
     121        lin.Multiplier = 1.0 / (max - min);
     122        lin.Addend = -min / (max - min);
     123        trans[i] = lin;
     124        i++;
     125      }
     126      return trans;
     127    }
     128
     129    private double[,] ExtractData(IDataset dataset, IEnumerable<int> rows, ITransformation<double>[] scaling = null) {
     130      double[][] variables;
     131      if (scaling != null) {
     132        variables =
     133          allowedInputVariables.Select((var, i) => scaling[i].Apply(dataset.GetDoubleValues(var, rows)).ToArray())
     134            .ToArray();
     135      } else {
     136        variables =
     137        allowedInputVariables.Select(var => dataset.GetDoubleValues(var, rows).ToArray()).ToArray();
     138      }
     139      int n = variables.First().Length;
     140      var res = new double[n, variables.Length];
     141      for (int r = 0; r < n; r++)
     142        for (int c = 0; c < variables.Length; c++) {
     143          res[r, c] = variables[c][r];
     144        }
     145      return res;
    98146    }
    99147
     
    104152    #region IRegressionModel Members
    105153    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    106       var solutions = ExtractData(dataset, rows);
    107       var data = ExtractData(trainingDataset, trainingRows);
    108       return solutions.Select(solution => alpha.Zip(data, (a, d) => a * kernel.Get(solution, d)).Sum() + 1 * alpha[alpha.Length - 1]).ToArray();
     154      var newX = ExtractData(dataset, rows, scaling);
     155      var dim = newX.GetLength(1);
     156      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
     157
     158      var pred = new double[newX.GetLength(0)];
     159      for (int i = 0; i < pred.Length; i++) {
     160        double sum = meanOffset;
     161        for (int j = 0; j < alpha.Length; j++) {
     162          sum += alpha[j] * cov.CrossCovariance(trainX, newX, j, i);
     163        }
     164        pred[i] = sum;
     165      }
     166      return pred;
    109167    }
    110168    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
    111       return new RadialBasisFunctionRegressionSolution(this, new RegressionProblemData(problemData));
     169      return new ConfidenceRegressionSolution(this, new RegressionProblemData(problemData));
    112170    }
    113171    #endregion
    114172
    115173    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
    116       var data = ExtractData(trainingDataset, trainingRows);
    117       return ExtractData(dataset, rows).Select(x => GetVariance(x, data));
     174      if (gramInvert == null) CreateAndInvertGramMatrix();
     175      int n = gramInvert.GetLength(0);
     176      var newData = ExtractData(dataset, rows, scaling);
     177      var dim = newData.GetLength(1);
     178      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
     179
     180      // TODO perf (matrix matrix multiplication)
     181      for (int i = 0; i < newData.GetLength(0); i++) {
     182        double[] p = new double[n];
     183
     184        for (int j = 0; j < trainX.GetLength(0); j++) {
     185          p[j] = cov.CrossCovariance(trainX, newData, j, i);
     186        }
     187
     188        var Ap = new double[n];
     189        alglib.ablas.rmatrixmv(n, n, gramInvert, 0, 0, 0, p, 0, ref Ap, 0);
     190        var res = 0.0;
     191        // dot product
     192        for (int j = 0; j < p.Length; j++) res += p[j] * Ap[j];
     193        yield return res > 0 ? res : 0;
     194      }
    118195    }
    119196    public double LeaveOneOutCrossValidationRootMeanSquaredError() {
    120       return Math.Sqrt(alpha.Select((t, i) => t / gramInvert[i, i]).Sum(d => d * d) / gramInvert.Rows);
    121     }
    122 
     197      if (gramInvert == null) CreateAndInvertGramMatrix();
     198      var n = gramInvert.GetLength(0);
     199      var s = 1.0 / n;
     200
     201      var sum = 0.0;
     202      for (int i = 0; i < alpha.Length; i++) {
     203        var x = alpha[i] / gramInvert[i, i];
     204        sum += x * x;
     205      }
     206      sum *= s;
     207      return Math.Sqrt(sum);
     208    }
     209
     210    private void CreateAndInvertGramMatrix() {
     211      try {
     212        gramInvert = BuildGramMatrix(trainX);
     213        int info = 0;
     214        alglib.matinvreport report;
     215        alglib.rmatrixinverse(ref gramInvert, out info, out report);
     216        if (info != 1)
     217          throw new ArgumentException("Could not invert matrix. Is it quadratic symmetric positive definite?");
     218      } catch (alglib.alglibexception) {
     219        // wrap exception so that calling code doesn't have to know about alglib implementation
     220        throw new ArgumentException("Could not invert matrix. Is it quadratic symmetric positive definite?");
     221      }
     222    }
    123223    #region helpers
    124     private double[,] BuildGramMatrix(double[][] data) {
    125       var size = data.Length + 1;
    126       var gram = new double[size, size];
    127       for (var i = 0; i < size; i++)
    128         for (var j = i; j < size; j++) {
    129           if (j == size - 1 && i == size - 1) gram[i, j] = 0;
    130           else if (j == size - 1 || i == size - 1) gram[j, i] = gram[i, j] = 1;
    131           else gram[j, i] = gram[i, j] = kernel.Get(data[i], data[j]); //symmteric Matrix --> half of the work
     224    private double[,] BuildGramMatrix(double[,] data) {
     225      var n = data.GetLength(0);
     226      var dim = data.GetLength(1);
     227      var cov = kernel.GetParameterizedCovarianceFunction(new double[0], Enumerable.Range(0, dim).ToArray());
     228      var gram = new double[n, n];
     229      for (var i = 0; i < n; i++)
     230        for (var j = i; j < n; j++) {
     231          gram[j, i] = gram[i, j] = cov.Covariance(data, i, j); // symmetric matrix --> half of the work
    132232        }
    133233      return gram;
    134234    }
    135     private double GetVariance(double[] solution, IEnumerable<double[]> data) {
    136       var phiT = data.Select(x => kernel.Get(x, solution)).Concat(new[] { 1.0 }).ToColumnVector();
    137       var res = phiT.Transpose().Mul(gramInvert.Mul(phiT))[0, 0];
    138       return res > 0 ? res : 0;
    139     }
     235
    140236    #endregion
    141237  }
Note: See TracChangeset for help on using the changeset viewer.