Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
02/01/17 09:58:06 (8 years ago)
Author:
gkronber
Message:

#2288: introduced base class for variable network instance description and implemented GRR and Linear variable networks as specific classes

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.Instances.DataAnalysis/3.3/Regression/VariableNetworks/VariableNetwork.cs

    r14623 r14630  
    3030
    3131namespace HeuristicLab.Problems.Instances.DataAnalysis {
    32   public class VariableNetwork : ArtificialRegressionDataDescriptor {
     32  public abstract class VariableNetwork : ArtificialRegressionDataDescriptor {
    3333    private int nTrainingSamples;
    3434    private int nTestSamples;
     
    3838    private IRandom random;
    3939
    40     public override string Name { get { return string.Format("VariableNetwork-{0:0%} ({1} dim)", noiseRatio, numberOfFeatures); } }
    4140    private string networkDefinition;
    4241    public string NetworkDefinition { get { return networkDefinition; } }
     
    4746    }
    4847
    49     public VariableNetwork(int numberOfFeatures, double noiseRatio,
    50       IRandom rand)
    51       : this(250, 250, numberOfFeatures, noiseRatio, rand) { }
    52 
    53     public VariableNetwork(int nTrainingSamples, int nTestSamples,
     48    protected VariableNetwork(int nTrainingSamples, int nTestSamples,
    5449      int numberOfFeatures, double noiseRatio, IRandom rand) {
    5550      this.nTrainingSamples = nTrainingSamples;
     
    105100
    106101      var nrand = new NormalDistributedRandom(random, 0, 1);
    107       for (int c = 0; c < numLvl0; c++) {
     102      for(int c = 0; c < numLvl0; c++) {
    108103        inputVarNames.Add(new string[] { });
    109104        relevances.Add(new double[] { });
    110         description.Add(" ~ N(0, 1)");
    111         lvl0.Add(Enumerable.Range(0, TestPartitionEnd).Select(_ => nrand.NextDouble()).ToList());
     105        description.Add(" ~ N(0, 1 + noiseLvl)");
     106        // use same generation procedure for all variables
     107        var x = Enumerable.Range(0, TestPartitionEnd).Select(_ => nrand.NextDouble()).ToList();
     108        var sigma = x.StandardDeviationPop();
     109        var mean = x.Average();
     110        for(int i = 0; i < x.Count; i++) x[i] = (x[i] - mean) / sigma;
     111        var noisePrng = new NormalDistributedRandom(random, 0, Math.Sqrt(noiseRatio / (1.0 - noiseRatio)));
     112        lvl0.Add(x.Select(t => t + noisePrng.NextDouble()).ToList());
    112113      }
    113114
     
    125126
    126127      this.variableRelevances.Clear();
    127       for (int i = 0; i < variableNames.Length; i++) {
     128      for(int i = 0; i < variableNames.Length; i++) {
    128129        var targetVarName = variableNames[i];
    129130        var targetRelevantInputs =
     
    136137      // for graphviz
    137138      networkDefinition += Environment.NewLine + "digraph G {";
    138       for (int i = 0; i < variableNames.Length; i++) {
     139      for(int i = 0; i < variableNames.Length; i++) {
    139140        var name = variableNames[i];
    140141        var selectedVarNames = inputVarNames[i];
    141142        var selectedRelevances = relevances[i];
    142         for (int j = 0; j < selectedVarNames.Length; j++) {
     143        for(int j = 0; j < selectedVarNames.Length; j++) {
    143144          var selectedVarName = selectedVarNames[j];
    144145          var selectedRelevance = selectedRelevances[j];
     
    157158
    158159    private List<List<double>> CreateVariables(List<List<double>> allowedInputs, int numVars, List<string[]> inputVarNames, List<string> description, List<double[]> relevances) {
    159       var res = new List<List<double>>();
    160       for (int c = 0; c < numVars; c++) {
     160      var newVariables = new List<List<double>>();
     161      for(int c = 0; c < numVars; c++) {
    161162        string[] selectedVarNames;
    162163        double[] relevance;
    163         var x = GenerateRandomFunction(random, allowedInputs, out selectedVarNames, out relevance);
     164        var x = GenerateRandomFunction(random, allowedInputs, out selectedVarNames, out relevance).ToArray();
     165        // standardize x
    164166        var sigma = x.StandardDeviation();
    165         var noisePrng = new NormalDistributedRandom(random, 0, sigma * Math.Sqrt(noiseRatio / (1.0 - noiseRatio)));
    166         res.Add(x.Select(t => t + noisePrng.NextDouble()).ToList());
     167        var mean = x.Average();
     168        for(int i = 0; i < x.Length; i++) x[i] = (x[i] - mean) / sigma;
     169
     170        var noisePrng = new NormalDistributedRandom(random, 0, Math.Sqrt(noiseRatio / (1.0 - noiseRatio)));
     171        newVariables.Add(x.Select(t => t + noisePrng.NextDouble()).ToList());
    167172        Array.Sort(selectedVarNames, relevance);
    168173        inputVarNames.Add(selectedVarNames);
     
    176181        description.Add(string.Format(" ~ N({0}, {1:N3}) [Relevances: {2}]", desc, noisePrng.Sigma, relevanceStr));
    177182      }
    178       return res;
     183      return newVariables;
    179184    }
    180185
    181     // sample the input variables that are actually used and sample from a Gaussian process
    182     private IEnumerable<double> GenerateRandomFunction(IRandom rand, List<List<double>> xs, out string[] selectedVarNames, out double[] relevance) {
     186    public int SampleNumberOfVariables(IRandom rand, int maxNumberOfVariables) {
    183187      double r = -Math.Log(1.0 - rand.NextDouble()) * 2.0; // r is exponentially distributed with lambda = 2
    184188      int nl = (int)Math.Floor(1.5 + r); // number of selected vars is likely to be between three and four
    185       if (nl > xs.Count) nl = xs.Count; // limit max
    186 
    187       var selectedIdx = Enumerable.Range(0, xs.Count).Shuffle(random)
    188         .Take(nl).ToArray();
    189 
    190       var selectedVars = selectedIdx.Select(i => xs[i]).ToArray();
    191       selectedVarNames = selectedIdx.Select(i => VariableNames[i]).ToArray();
    192       return SampleGaussianProcess(random, selectedVars, out relevance);
     189      return Math.Min(maxNumberOfVariables, nl);
    193190    }
    194191
    195     private IEnumerable<double> SampleGaussianProcess(IRandom random, List<double>[] xs, out double[] relevance) {
    196       int nl = xs.Length;
    197       int nRows = xs.First().Count;
    198 
    199       // sample u iid ~ N(0, 1)
    200       var u = Enumerable.Range(0, nRows).Select(_ => NormalDistributedRandom.NextDouble(random, 0, 1)).ToArray();
    201 
    202       // sample actual length-scales
    203       var l = Enumerable.Range(0, nl)
    204         .Select(_ => random.NextDouble() * 2 + 0.5)
    205         .ToArray();
    206 
    207       double[,] K = CalculateCovariance(xs, l);
    208 
    209       // decompose
    210       alglib.trfac.spdmatrixcholesky(ref K, nRows, false);
    211 
    212 
    213       // calc y = Lu
    214       var y = new double[u.Length];
    215       alglib.ablas.rmatrixmv(nRows, nRows, K, 0, 0, 0, u, 0, ref y, 0);
    216 
    217       // calculate relevance by removing dimensions
    218       relevance = CalculateRelevance(y, u, xs, l);
    219 
    220       return y;
    221     }
    222 
    223     // calculate variable relevance based on removal of variables
    224     //  1) to remove a variable we set it's length scale to infinity (no relation of the variable value to the target)
    225     //  2) calculate MSE of the original target values (y) to the updated targes y' (after variable removal)
    226     //  3) relevance is larger if MSE(y,y') is large
    227     //  4) scale impacts so that the most important variable has impact = 1
    228     private double[] CalculateRelevance(double[] y, double[] u, List<double>[] xs, double[] l) {
    229       int nRows = xs.First().Count;
    230       var changedL = new double[l.Length];
    231       var relevance = new double[l.Length];
    232       for (int i = 0; i < l.Length; i++) {
    233         Array.Copy(l, changedL, changedL.Length);
    234         changedL[i] = double.MaxValue;
    235         var changedK = CalculateCovariance(xs, changedL);
    236 
    237         var yChanged = new double[u.Length];
    238         alglib.ablas.rmatrixmv(nRows, nRows, changedK, 0, 0, 0, u, 0, ref yChanged, 0);
    239 
    240         OnlineCalculatorError error;
    241         var mse = OnlineMeanSquaredErrorCalculator.Calculate(y, yChanged, out error);
    242         if (error != OnlineCalculatorError.None) mse = double.MaxValue;
    243         relevance[i] = mse;
    244       }
    245       // scale so that max relevance is 1.0
    246       var maxRel = relevance.Max();
    247       for (int i = 0; i < relevance.Length; i++) relevance[i] /= maxRel;
    248       return relevance;
    249     }
    250 
    251     private double[,] CalculateCovariance(List<double>[] xs, double[] l) {
    252       int nRows = xs.First().Count;
    253       double[,] K = new double[nRows, nRows];
    254       for (int r = 0; r < nRows; r++) {
    255         double[] xi = xs.Select(x => x[r]).ToArray();
    256         for (int c = 0; c <= r; c++) {
    257           double[] xj = xs.Select(x => x[c]).ToArray();
    258           double dSqr = xi.Zip(xj, (xik, xjk) => (xik - xjk))
    259             .Select(dk => dk * dk)
    260             .Zip(l, (dk, lk) => dk / lk)
    261             .Sum();
    262           K[r, c] = Math.Exp(-dSqr);
    263         }
    264       }
    265       // add a small diagonal matrix for numeric stability
    266       for (int i = 0; i < nRows; i++) {
    267         K[i, i] += 1.0E-7;
    268       }
    269 
    270       return K;
    271     }
     192    // sample a random function and calculate the variable relevances
     193    protected abstract IEnumerable<double> GenerateRandomFunction(IRandom rand, List<List<double>> xs, out string[] selectedVarNames, out double[] relevance);
    272194  }
    273195}
Note: See TracChangeset for help on using the changeset viewer.