Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/06/12 15:02:34 (12 years ago)
Author:
gkronber
Message:

#1902 worked on sum and product covariance functions and fixed a few bugs.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceSum.cs

    r8366 r8416  
    1 using System;
     1#region License Information
     2/* HeuristicLab
     3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 *
     5 * This file is part of HeuristicLab.
     6 *
     7 * HeuristicLab is free software: you can redistribute it and/or modify
     8 * it under the terms of the GNU General Public License as published by
     9 * the Free Software Foundation, either version 3 of the License, or
     10 * (at your option) any later version.
     11 *
     12 * HeuristicLab is distributed in the hope that it will be useful,
     13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
     14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     15 * GNU General Public License for more details.
     16 *
     17 * You should have received a copy of the GNU General Public License
     18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
     19 */
     20#endregion
     21
    222using System.Linq;
    323using HeuristicLab.Common;
     
    525using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    626
    7 namespace HeuristicLab.Algorithms.DataAnalysis.GaussianProcess {
     27namespace HeuristicLab.Algorithms.DataAnalysis {
    828  [StorableClass]
    929  [Item(Name = "CovarianceSum",
     
    2646    protected CovarianceSum(CovarianceSum original, Cloner cloner)
    2747      : base(original, cloner) {
    28       this.terms = cloner.Clone(terms);
     48      this.terms = cloner.Clone(original.terms);
     49      this.numberOfVariables = original.numberOfVariables;
    2950    }
    3051
    3152    public CovarianceSum()
    3253      : base() {
     54      this.terms = new ItemList<ICovarianceFunction>();
    3355    }
    3456
     
    4264    }
    4365
    44     public void SetParameter(double[] hyp, double[,] x) {
     66    public void SetParameter(double[] hyp) {
    4567      int offset = 0;
    4668      foreach (var t in terms) {
    47         t.SetParameter(hyp.Skip(offset).Take(t.GetNumberOfParameters(numberOfVariables)), x);
    48         offset += numberOfVariables;
     69        var numberOfParameters = t.GetNumberOfParameters(numberOfVariables);
     70        t.SetParameter(hyp.Skip(offset).Take(numberOfParameters).ToArray());
     71        offset += numberOfParameters;
     72      }
     73    }
     74    public void SetData(double[,] x) {
     75      SetData(x, x);
     76    }
     77
     78    public void SetData(double[,] x, double[,] xt) {
     79      foreach (var t in terms) {
     80        t.SetData(x, xt);
    4981      }
    5082    }
    5183
    52 
    53     public void SetParameter(double[] hyp, double[,] x, double[,] xt) {
    54       this.l = Math.Exp(hyp[0]);
    55       this.sf2 = Math.Exp(2 * hyp[1]);
    56 
    57       this.symmetric = false;
    58       this.x = x;
    59       this.xt = xt;
    60       sd = null;
     84    public double GetCovariance(int i, int j) {
     85      return terms.Select(t => t.GetCovariance(i, j)).Sum();
    6186    }
    6287
    63     public double GetCovariance(int i, int j) {
    64       if (sd == null) CalculateSquaredDistances();
    65       return sf2 * Math.Exp(-sd[i, j] / 2.0);
    66     }
    67 
    68 
    69     public double[] GetDiagonalCovariances() {
    70       if (x != xt) throw new InvalidOperationException();
    71       int rows = x.GetLength(0);
    72       var sd = new double[rows];
    73       for (int i = 0; i < rows; i++) {
    74         sd[i] = Util.SqrDist(Util.GetRow(x, i).Select(e => e / l), Util.GetRow(xt, i).Select(e => e / l));
    75       }
    76       return sd.Select(d => sf2 * Math.Exp(-d / 2.0)).ToArray();
    77     }
    78 
    79 
    8088    public double[] GetGradient(int i, int j) {
    81       var res = new double[2];
    82       res[0] = sf2 * Math.Exp(-sd[i, j] / 2.0) * sd[i, j];
    83       res[1] = 2.0 * sf2 * Math.Exp(-sd[i, j] / 2.0);
    84       return res;
    85     }
    86 
    87     private void CalculateSquaredDistances() {
    88       if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
    89       int rows = x.GetLength(0);
    90       int cols = xt.GetLength(0);
    91       sd = new double[rows, cols];
    92       if (symmetric) {
    93         for (int i = 0; i < rows; i++) {
    94           for (int j = i; j < rows; j++) {
    95             sd[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e / l), Util.GetRow(xt, j).Select(e => e / l));
    96             sd[j, i] = sd[i, j];
    97           }
    98         }
    99       } else {
    100         for (int i = 0; i < rows; i++) {
    101           for (int j = 0; j < cols; j++) {
    102             sd[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e / l), Util.GetRow(xt, j).Select(e => e / l));
    103           }
    104         }
    105       }
     89      return terms.Select(t => t.GetGradient(i, j)).SelectMany(seq => seq).ToArray();
    10690    }
    10791  }
Note: See TracChangeset for help on using the changeset viewer.