Changeset 8484 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceRQiso.cs
- Timestamp:
- 08/14/12 13:25:17 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceRQiso.cs
r8473 r8484 21 21 22 22 using System; 23 using System.Collections.Generic; 23 24 using System.Linq; 24 25 using HeuristicLab.Common; … … 32 33 public class CovarianceRQiso : Item, ICovarianceFunction { 33 34 [Storable] 34 private double[,] x;35 [Storable]36 private double[,] xt;37 [Storable]38 35 private double sf2; 39 36 public double Scale { get { return sf2; } } … … 44 41 private double alpha; 45 42 public double Shape { get { return alpha; } } 46 [Storable]47 private bool symmetric;48 private double[,] d2;49 43 50 44 [StorableConstructor] … … 55 49 protected CovarianceRQiso(CovarianceRQiso original, Cloner cloner) 56 50 : base(original, cloner) { 57 if (original.x != null) {58 this.x = new double[original.x.GetLength(0), original.x.GetLength(1)];59 Array.Copy(original.x, this.x, x.Length);60 61 this.xt = new double[original.xt.GetLength(0), original.xt.GetLength(1)];62 Array.Copy(original.xt, this.xt, xt.Length);63 64 this.d2 = new double[original.d2.GetLength(0), original.d2.GetLength(1)];65 Array.Copy(original.d2, this.d2, d2.Length);66 this.sf2 = original.sf2;67 }68 51 this.sf2 = original.sf2; 69 52 this.l = original.l; 70 53 this.alpha = original.alpha; 71 this.symmetric = original.symmetric;72 54 } 73 55 … … 85 67 86 68 public void SetParameter(double[] hyp) { 69 if (hyp.Length != 3) throw new ArgumentException("CovarianceRQiso has three hyperparameters", "k"); 87 70 this.l = Math.Exp(hyp[0]); 88 71 this.sf2 = Math.Exp(2 * hyp[1]); 89 72 this.alpha = Math.Exp(hyp[2]); 90 d2 = null;91 }92 public void SetData(double[,] x) {93 SetData(x, x);94 this.symmetric = true;95 73 } 96 74 97 75 98 public void SetData(double[,] x, double[,] xt) { 99 this.symmetric = false; 100 this.x = x; 101 this.xt = xt; 102 d2 = null; 76 public double GetCovariance(double[,] x, int i, int j) { 77 double lInv = 1.0 / l; 78 double d = i == j 79 ? 0.0 80 : Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(x, j).Select(e => e * lInv)); 81 return sf2 * Math.Pow(1 + 0.5 * d / alpha, -alpha); 103 82 } 104 83 105 public double GetCovariance(int i, int j) { 106 if (d2 == null) CalculateSquaredDistances(); 107 return sf2 * Math.Pow(1 + 0.5 * d2[i, j] / alpha, -alpha); 84 public IEnumerable<double> GetGradient(double[,] x, int i, int j) { 85 double lInv = 1.0 / l; 86 double d = i == j 87 ? 0.0 88 : Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(x, j).Select(e => e * lInv)); 89 90 double b = 1 + 0.5 * d / alpha; 91 yield return sf2 * Math.Pow(b, -alpha - 1) * d; 92 yield return 2 * sf2 * Math.Pow(b, -alpha); 93 yield return sf2 * Math.Pow(b, -alpha) * (0.5 * d / b - alpha * Math.Log(b)); 108 94 } 109 95 110 public double GetGradient(int i, int j, int k) { 111 switch (k) { 112 case 0: return sf2 * Math.Pow(1 + 0.5 * d2[i, j] / alpha, -alpha - 1) * d2[i, j]; 113 case 1: return 2 * sf2 * Math.Pow((1 + 0.5 * d2[i, j] / alpha), (-alpha)); 114 case 2: { 115 double g = (1 + 0.5 * d2[i, j] / alpha); 116 g = sf2 * Math.Pow(g, -alpha) * (0.5 * d2[i, j] / g - alpha * Math.Log(g)); 117 return g; 118 } 119 default: throw new ArgumentException("CovarianceRQiso has three hyperparameters", "k"); 120 } 121 } 122 123 private void CalculateSquaredDistances() { 124 if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException(); 125 int rows = x.GetLength(0); 126 int cols = xt.GetLength(0); 127 d2 = new double[rows, cols]; 96 public double GetCrossCovariance(double[,] x, double[,] xt, int i, int j) { 128 97 double lInv = 1.0 / l; 129 if (symmetric) { 130 for (int i = 0; i < rows; i++) { 131 for (int j = i; j < rows; j++) { 132 d2[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(xt, j).Select(e => e * lInv)); 133 d2[j, i] = d2[i, j]; 134 } 135 } 136 } else { 137 for (int i = 0; i < rows; i++) { 138 for (int j = 0; j < cols; j++) { 139 d2[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(xt, j).Select(e => e * lInv)); 140 } 141 } 142 } 98 double d = Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(xt, j).Select(e => e * lInv)); 99 return sf2 * Math.Pow(1 + 0.5 * d / alpha, -alpha); 143 100 } 144 101 }
Note: See TracChangeset
for help on using the changeset viewer.