Changeset 9360


Ignore:
Timestamp:
04/15/13 15:07:16 (7 years ago)
Author:
gkronber
Message:

#1902: implemented neural network covariance function plus test case (comparison with GPML) for Gaussian processes

Location:
trunk/sources
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceFunctions/CovarianceNeuralNetwork.cs

    r9359 r9360  
    4040    }
    4141
    42     public IValueParameter<DoubleValue> InverseLengthParameter {
    43       get { return (IValueParameter<DoubleValue>)Parameters["InverseLength"]; }
     42    public IValueParameter<DoubleValue> LengthParameter {
     43      get { return (IValueParameter<DoubleValue>)Parameters["Length"]; }
    4444    }
    4545
     
    5959
    6060      Parameters.Add(new OptionalValueParameter<DoubleValue>("Scale", "The scale parameter."));
    61       Parameters.Add(new OptionalValueParameter<DoubleValue>("InverseLength", "The inverse length parameter."));
     61      Parameters.Add(new OptionalValueParameter<DoubleValue>("Length", "The length parameter."));
    6262    }
    6363
     
    6969      return
    7070        (ScaleParameter.Value != null ? 0 : 1) +
    71         (InverseLengthParameter.Value != null ? 0 : 1);
     71        (LengthParameter.Value != null ? 0 : 1);
    7272    }
    7373
    7474    public void SetParameter(double[] p) {
    75       double scale, inverseLength;
    76       GetParameterValues(p, out scale, out inverseLength);
     75      double scale, length;
     76      GetParameterValues(p, out scale, out length);
    7777      ScaleParameter.Value = new DoubleValue(scale);
    78       InverseLengthParameter.Value = new DoubleValue(inverseLength);
     78      LengthParameter.Value = new DoubleValue(length);
    7979    }
    8080
    8181
    82     private void GetParameterValues(double[] p, out double scale, out double inverseLength) {
     82    private void GetParameterValues(double[] p, out double scale, out double length) {
    8383      // gather parameter values
    8484      int c = 0;
    85       if (InverseLengthParameter.Value != null) {
    86         inverseLength = InverseLengthParameter.Value.Value;
     85      if (LengthParameter.Value != null) {
     86        length = LengthParameter.Value.Value;
    8787      } else {
    88         inverseLength = 1.0 / Math.Exp(p[c]);
     88        length = Math.Exp(2 * p[c]);
    8989        c++;
    9090      }
     
    105105    private static Func<Term, UnaryFunc> sqrt = UnaryFunc.Factory(
    106106      x => Math.Sqrt(x),
    107       x => 1 / 2 * Math.Sqrt(x));
     107      x => 1 / (2 * Math.Sqrt(x)));
    108108
    109109    public ParameterizedCovarianceFunction GetParameterizedCovarianceFunction(double[] p, IEnumerable<int> columnIndices) {
    110       double inverseLength, scale;
    111       GetParameterValues(p, out scale, out inverseLength);
     110      double length, scale;
     111      GetParameterValues(p, out scale, out length);
    112112      // create functions
    113113      AutoDiff.Variable p0 = new AutoDiff.Variable();
    114114      AutoDiff.Variable p1 = new AutoDiff.Variable();
    115       var invL = 1.0 / TermBuilder.Exp(p0);
    116       var s = TermBuilder.Exp(2 * p1);
     115      var l = TermBuilder.Exp(2.0 * p0);
     116      var s = TermBuilder.Exp(2.0 * p1);
    117117      AutoDiff.Variable[] x1 = new AutoDiff.Variable[columnIndices.Count()];
    118118      AutoDiff.Variable[] x2 = new AutoDiff.Variable[columnIndices.Count()];
    119       AutoDiff.Term sx = invL;
    120       AutoDiff.Term s1 = invL;
    121       AutoDiff.Term s2 = invL;
     119      AutoDiff.Term sx = 1;
     120      AutoDiff.Term s1 = 1;
     121      AutoDiff.Term s2 = 1;
    122122      foreach (var k in columnIndices) {
    123123        x1[k] = new AutoDiff.Variable();
    124124        x2[k] = new AutoDiff.Variable();
    125         sx += x1[k] * invL * x2[k];
    126         s1 += x1[k] * invL * x1[k];
    127         s2 += x2[k] * invL * x2[k];
     125        sx += x1[k] * x2[k];
     126        s1 += x1[k] * x1[k];
     127        s2 += x2[k] * x2[k];
    128128      }
    129129
    130130      var parameter = x1.Concat(x2).Concat(new AutoDiff.Variable[] { p0, p1 }).ToArray();
    131131      var values = new double[x1.Length + x2.Length + 2];
    132       var c = (s * asin(2 * sx / (sqrt((1 + 2 * s1) * (1 + 2 * s2))))).Compile(parameter);
     132      var c = (s * asin(sx / (sqrt((l + s1) * (l + s2))))).Compile(parameter);
    133133
    134134      var cov = new ParameterizedCovarianceFunction();
     
    143143          k++;
    144144        }
    145         values[k] = Math.Log(1.0 / inverseLength);
    146         values[k + 1] = Math.Log(scale) / 2.0;
     145        values[k] = Math.Log(Math.Sqrt(length));
     146        values[k + 1] = Math.Log(Math.Sqrt(scale));
    147147        return c.Evaluate(values);
    148148      };
     
    157157          k++;
    158158        }
    159         values[k] = Math.Log(1.0 / inverseLength);
    160         values[k + 1] = Math.Log(scale) / 2.0;
     159        values[k] = Math.Log(Math.Sqrt(length));
     160        values[k + 1] = Math.Log(Math.Sqrt(scale));
    161161        return c.Evaluate(values);
    162162      };
     
    171171          k++;
    172172        }
    173         values[k] = Math.Log(1.0 / inverseLength);
    174         values[k + 1] = Math.Log(scale) / 2.0;
     173        values[k] = Math.Log(Math.Sqrt(length));
     174        values[k + 1] = Math.Log(Math.Sqrt(scale));
    175175        return c.Differentiate(values).Item1.Skip(columnIndices.Count() * 2);
    176176      };
  • trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GaussianProcessFunctionsTest.cs

    r8982 r9360  
    264264
    265265    [TestMethod]
     266    public void CovNnTest() {
     267      TestCovarianceFunction(new CovarianceNeuralNetwork(), 0,
     268        new double[,]
     269          {
     270{    0.5930,    0.5896,    0.7951,    0.5808,    0.7787,    0.6411,    0.6672,    0.6814,    0.6297,    0.7338},
     271{    0.5214,    0.5987,    0.6763,    0.5053,    0.7289,    0.6077,    0.5909,    0.6230,    0.5621,    0.6472},
     272{    0.7299,    0.7305,    0.8081,    0.6837,    0.7039,    0.7994,    0.7756,    0.6668,    0.7145,    0.7665},
     273{    0.6399,    0.5347,    0.7261,    0.6044,    0.5836,    0.6549,    0.7250,    0.6815,    0.6720,    0.6819},
     274{    0.6627,    0.5300,    0.7045,    0.6665,    0.5340,    0.5659,    0.6509,    0.6692,    0.6600,    0.6747},
     275{    0.6151,    0.5719,    0.6465,    0.5881,    0.5593,    0.6189,    0.6585,    0.6397,    0.6364,    0.6382},
     276{    0.5978,    0.6929,    0.7292,    0.5719,    0.8209,    0.6695,    0.6469,    0.5966,    0.6160,    0.7203},
     277{    0.6944,    0.7128,    0.8241,    0.6566,    0.8002,    0.7548,    0.7503,    0.6494,    0.6961,    0.7875},
     278{    0.6443,    0.6893,    0.8074,    0.6258,    0.8018,    0.7049,    0.6885,    0.6633,    0.6530,    0.7602},
     279{    0.4829,    0.5970,    0.6259,    0.4461,    0.6737,    0.6484,    0.5912,    0.6067,    0.5329,    0.5928},
     280          },
     281        new double[][,]
     282          {
     283            new double[,] {
     284{   -0.5669,   -0.5220,   -0.3879,   -0.4304,   -0.4540,   -0.4460,   -0.4901,   -0.4465,   -0.5095,   -0.4340},
     285{   -0.5220,   -0.5969,   -0.3884,   -0.3843,   -0.4499,   -0.4737,   -0.4843,   -0.3961,   -0.5133,   -0.5540},
     286{   -0.3879,   -0.3884,   -0.5554,   -0.4160,   -0.4600,   -0.4671,   -0.4056,   -0.4603,   -0.4637,   -0.3810},
     287{   -0.4304,   -0.3843,   -0.4160,   -0.5895,   -0.4728,   -0.5384,   -0.3884,   -0.4288,   -0.3748,   -0.3672},
     288{   -0.4540,   -0.4499,   -0.4600,   -0.4728,   -0.5977,   -0.5296,   -0.3758,   -0.3954,   -0.4610,   -0.4165},
     289{   -0.4460,   -0.4737,   -0.4671,   -0.5384,   -0.5296,   -0.5987,   -0.4280,   -0.4285,   -0.4369,   -0.4802},
     290{   -0.4901,   -0.4843,   -0.4056,   -0.3884,   -0.3758,   -0.4280,   -0.5731,   -0.5003,   -0.4920,   -0.4043},
     291{   -0.4465,   -0.3961,   -0.4603,   -0.4288,   -0.3954,   -0.4285,   -0.5003,   -0.5362,   -0.4621,   -0.3360},
     292{   -0.5095,   -0.5133,   -0.4637,   -0.3748,   -0.4610,   -0.4369,   -0.4920,   -0.4621,   -0.5614,   -0.4463},
     293{   -0.4340,   -0.5540,   -0.3810,   -0.3672,   -0.4165,   -0.4802,   -0.4043,   -0.3360,   -0.4463,   -0.5987},
     294            },
     295            new double[,] {
     296{    1.6963,    1.4541,    1.3587,    1.3199,    1.3084,    1.1825,    1.5272,    1.5553,    1.6096,    1.2526},
     297{    1.4541,    1.4500,    1.2006,    1.0794,    1.1704,    1.1251,    1.3601,    1.2650,    1.4540,    1.3576},
     298{    1.3587,    1.2006,    1.7563,    1.3216,    1.3536,    1.2537,    1.3780,    1.6296,    1.5559,    1.1618},
     299{    1.3199,    1.0794,    1.3216,    1.5376,    1.2625,    1.2852,    1.2016,    1.4015,    1.2051,    1.0222},
     300{    1.3084,    1.1704,    1.3536,    1.2625,    1.4362,    1.2215,    1.1132,    1.2541,    1.3394,    1.0837},
     301{    1.1825,    1.1251,    1.2537,    1.2852,    1.2215,    1.2511,    1.1282,    1.2125,    1.1771,    1.1218},
     302{    1.5272,    1.3601,    1.3780,    1.2016,    1.1132,    1.1282,    1.6603,    1.6427,    1.5504,    1.1677},
     303{    1.5553,    1.2650,    1.6296,    1.4015,    1.2541,    1.2125,    1.6427,    1.8427,    1.6113,    1.0910},
     304{    1.6096,    1.4540,    1.5559,    1.2051,    1.3394,    1.1771,    1.5504,    1.6113,    1.7261,    1.2950},
     305{    1.2526,    1.3576,    1.1618,    1.0222,    1.0837,    1.1218,    1.1677,    1.0910,    1.2950,    1.4153},
     306          }
     307          }
     308      );
     309      TestCovarianceFunction(new CovarianceNeuralNetwork(), 1,
     310         new double[,]
     311           {
     312{    1.4436,    1.4866,    2.0692,    1.4105,    2.1077,    1.7712,    1.6764,    1.6030,    1.4898,    1.7857},
     313{    1.1652,    1.3662,    1.6384,    1.1271,    1.8076,    1.5312,    1.3659,    1.3446,    1.2210,    1.4545},
     314{    1.7710,    1.8348,    2.1497,    1.6684,    1.9875,    2.1872,    1.9499,    1.6132,    1.7025,    1.8978},
     315{    1.4480,    1.2766,    1.8006,    1.3710,    1.5424,    1.6920,    1.6866,    1.5035,    1.4784,    1.5742},
     316{    1.4350,    1.2175,    1.6874,    1.4354,    1.3683,    1.4290,    1.4793,    1.4231,    1.3994,    1.4995},
     317{    1.2557,    1.2181,    1.4634,    1.2013,    1.3324,    1.4450,    1.3951,    1.2781,    1.2662,    1.3344},
     318{    1.4328,    1.6864,    1.8991,    1.3709,    2.1658,    1.8123,    1.6081,    1.4065,    1.4398,    1.7323},
     319{    1.7618,    1.8647,    2.2657,    1.6713,    2.2933,    2.1661,    1.9689,    1.6354,    1.7276,    2.0126},
     320{    1.5724,    1.7252,    2.1209,    1.5259,    2.1852,    1.9465,    1.7436,    1.5858,    1.5568,    1.8614},
     321{    1.0716,    1.3447,    1.5116,    0.9906,    1.6687,    1.5994,    1.3485,    1.2963,    1.1483,    1.3287},
     322           },
     323         new double[][,]
     324          {
     325            new double[,] {
     326{   -3.1708,   -2.6488,   -2.6475,   -2.4774,   -2.4067,   -2.1099,   -2.8841,   -3.0223,   -3.0526,   -2.3041},
     327{   -2.6488,   -2.5111,   -2.2570,   -1.9637,   -2.0708,   -1.9172,   -2.4814,   -2.3998,   -2.6632,   -2.3537},
     328{   -2.6475,   -2.2570,   -3.3336,   -2.5066,   -2.5075,   -2.2490,   -2.6640,   -3.1778,   -2.9977,   -2.1743},
     329{   -2.4774,   -1.9637,   -2.5066,   -2.7415,   -2.2591,   -2.2029,   -2.2607,   -2.6778,   -2.2936,   -1.8530},
     330{   -2.4067,   -2.0708,   -2.5075,   -2.2591,   -2.4753,   -2.0592,   -2.0638,   -2.3743,   -2.4708,   -1.9151},
     331{   -2.1099,   -1.9172,   -2.2490,   -2.2029,   -2.0592,   -2.0176,   -2.0074,   -2.2134,   -2.1126,   -1.8976},
     332{   -2.8841,   -2.4814,   -2.6640,   -2.2607,   -2.0638,   -2.0074,   -3.0726,   -3.1450,   -2.9374,   -2.1472},
     333{   -3.0223,   -2.3998,   -3.1778,   -2.6778,   -2.3743,   -2.2134,   -3.1450,   -3.5652,   -3.1314,   -2.0795},
     334{   -3.0526,   -2.6632,   -2.9977,   -2.2936,   -2.4708,   -2.1126,   -2.9374,   -3.1314,   -3.2519,   -2.3870},
     335{   -2.3041,   -2.3537,   -2.1743,   -1.8530,   -1.9151,   -1.8976,   -2.1472,   -2.0795,   -2.3870,   -2.4217},
     336            },
     337            new double[,] {
     338{    4.3303,    3.4649,    3.7030,    3.3026,    3.1505,    2.6912,    3.9244,    4.3025,    4.2082,    3.0090},
     339{    3.4649,    3.1323,    3.0110,    2.5036,    2.5893,    2.3340,    3.2266,    3.2706,    3.5078,    2.9242},
     340{    3.7030,    3.0110,    4.6774,    3.3899,    3.3261,    2.9053,    3.6926,    4.5813,    4.1997,    2.8855},
     341{    3.3026,    2.5036,    3.3899,    3.5200,    2.8645,    2.7165,    2.9968,    3.6979,    3.0870,    2.3508},
     342{    3.1505,    2.5893,    3.3261,    2.8645,    3.0745,    2.4980,    2.6901,    3.2286,    3.2548,    2.3841},
     343{    2.6912,    2.3340,    2.9053,    2.7165,    2.4980,    2.3825,    2.5425,    2.9255,    2.7134,    2.2974},
     344{    3.9244,    3.2266,    3.6926,    2.9968,    2.6901,    2.5425,    4.1329,    4.4278,    4.0227,    2.7860},
     345{    4.3025,    3.2706,    4.5813,    3.6979,    3.2286,    2.9255,    4.4278,    5.2237,    4.4826,    2.8275},
     346{    4.2082,    3.5078,    4.1997,    3.0870,    3.2548,    2.7134,    4.0227,    4.4826,    4.4998,    3.1361},
     347{    3.0090,    2.9242,    2.8855,    2.3508,    2.3841,    2.2974,    2.7860,    2.8275,    3.1361,    2.9891},
     348            },
     349          }
     350       );
     351    }
     352
     353
     354    [TestMethod]
    266355    public void CovRQIsoTest() {
    267356      TestCovarianceFunction(new CovarianceRationalQuadraticIso(), 0,
Note: See TracChangeset for help on using the changeset viewer.