Free cookie consent management tool by TermsFeed Policy Generator

Changeset 14291 for trunk/sources


Ignore:
Timestamp:
09/19/16 17:46:47 (8 years ago)
Author:
gkronber
Message:

#2660: changed calculation of variable relevance values for variable interaction networks based on sampling Gaussian processes.
Instead of taking inverse length scale (ARD). We calculate the deviation from the original function (y) after individual variables are removed (y').

File:
1 edited

Legend:

Unmodified
Added
Removed
  • TabularUnified trunk/sources/HeuristicLab.Problems.Instances.DataAnalysis/3.3/Regression/VariableNetworks/VariableNetwork.cs

    r14271 r14291  
    2626using HeuristicLab.Common;
    2727using HeuristicLab.Core;
     28using HeuristicLab.Problems.DataAnalysis;
    2829using HeuristicLab.Random;
    2930
     
    195196      int nl = xs.Length;
    196197      int nRows = xs.First().Count;
    197       double[,] K = new double[nRows, nRows];
    198 
    199       // sample length-scales
     198
     199      // sample u iid ~ N(0, 1)
     200      var u = Enumerable.Range(0, nRows).Select(_ => NormalDistributedRandom.NextDouble(random, 0, 1)).ToArray();
     201
     202      // sample actual length-scales
    200203      var l = Enumerable.Range(0, nl)
    201204        .Select(_ => random.NextDouble() * 2 + 0.5)
    202205        .ToArray();
    203       // calculate covariance matrix
     206
     207      double[,] K = CalculateCovariance(xs, l);
     208
     209      // decompose
     210      alglib.trfac.spdmatrixcholesky(ref K, nRows, false);
     211
     212
     213      // calc y = Lu
     214      var y = new double[u.Length];
     215      alglib.ablas.rmatrixmv(nRows, nRows, K, 0, 0, 0, u, 0, ref y, 0);
     216
     217      // calculate relevance by removing dimensions
     218      relevance = CalculateRelevance(y, u, xs, l);
     219
     220
     221      // calculate variable relevance
     222      // as per Rasmussen and Williams "Gaussian Processes for Machine Learning" page 106:
     223      // ,,For the squared exponential covariance function [...] the l1, ..., lD hyperparameters
     224      // play the role of characteristic length scales [...]. Such a covariance function implements
     225      // automatic relevance determination (ARD) [Neal, 1996], since the inverse of the length-scale
     226      // determines how relevant an input is: if the length-scale has a very large value, the covariance
     227      // will become almost independent of that input, effectively removing it from inference.''
     228      // relevance = l.Select(li => 1.0 / li).ToArray();
     229
     230      return y;
     231    }
     232
     233    // calculate variable relevance based on removal of variables
     234    //  1) to remove a variable we set it's length scale to infinity (no relation of the variable value to the target)
     235    //  2) calculate MSE of the original target values (y) to the updated targes y' (after variable removal)
     236    //  3) relevance is larger if MSE(y,y') is large
     237    //  4) scale impacts so that the most important variable has impact = 1
     238    private double[] CalculateRelevance(double[] y, double[] u, List<double>[] xs, double[] l) {
     239      int nRows = xs.First().Count;
     240      var changedL = new double[l.Length];
     241      var relevance = new double[l.Length];
     242      for (int i = 0; i < l.Length; i++) {
     243        Array.Copy(l, changedL, changedL.Length);
     244        changedL[i] = double.MaxValue;
     245        var changedK = CalculateCovariance(xs, changedL);
     246
     247        var yChanged = new double[u.Length];
     248        alglib.ablas.rmatrixmv(nRows, nRows, changedK, 0, 0, 0, u, 0, ref yChanged, 0);
     249
     250        OnlineCalculatorError error;
     251        var mse = OnlineMeanSquaredErrorCalculator.Calculate(y, yChanged, out error);
     252        if (error != OnlineCalculatorError.None) mse = double.MaxValue;
     253        relevance[i] = mse;
     254      }
     255      // scale so that max relevance is 1.0
     256      var maxRel = relevance.Max();
     257      for (int i = 0; i < relevance.Length; i++) relevance[i] /= maxRel;
     258      return relevance;
     259    }
     260
     261    private double[,] CalculateCovariance(List<double>[] xs, double[] l) {
     262      int nRows = xs.First().Count;
     263      double[,] K = new double[nRows, nRows];
    204264      for (int r = 0; r < nRows; r++) {
    205265        double[] xi = xs.Select(x => x[r]).ToArray();
     
    213273        }
    214274      }
    215 
    216275      // add a small diagonal matrix for numeric stability
    217276      for (int i = 0; i < nRows; i++) {
     
    219278      }
    220279
    221       // decompose
    222       alglib.trfac.spdmatrixcholesky(ref K, nRows, false);
    223 
    224       // sample u iid ~ N(0, 1)
    225       var u = Enumerable.Range(0, nRows).Select(_ => NormalDistributedRandom.NextDouble(random, 0, 1)).ToArray();
    226 
    227       // calc y = Lu
    228       var y = new double[u.Length];
    229       alglib.ablas.rmatrixmv(nRows, nRows, K, 0, 0, 0, u, 0, ref y, 0);
    230 
    231       // calculate variable relevance
    232       // as per Rasmussen and Williams "Gaussian Processes for Machine Learning" page 106:
    233       // ,,For the squared exponential covariance function [...] the l1, ..., lD hyperparameters
    234       // play the role of characteristic length scales [...]. Such a covariance function implements
    235       // automatic relevance determination (ARD) [Neal, 1996], since the inverse of the length-scale
    236       // determines how relevant an input is: if the length-scale has a very large value, the covariance
    237       // will become almost independent of that input, effectively removing it from inference.''
    238       relevance = l.Select(li => 1.0 / li).ToArray();
    239 
    240       return y;
     280      return K;
    241281    }
    242282  }
Note: See TracChangeset for help on using the changeset viewer.