Changeset 16692 for branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessModel.cs
- Timestamp:
- 03/18/19 17:24:30 (5 years ago)
- Location:
- branches/2521_ProblemRefactoring
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2521_ProblemRefactoring
- Property svn:ignore
-
old new 24 24 protoc.exe 25 25 obj 26 .vs
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
/stable/HeuristicLab.Algorithms.DataAnalysis/3.4 merged eligible /branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4 10321-10322 /branches/Async/HeuristicLab.Algorithms.DataAnalysis/3.4 13329-15286 /branches/Benchmarking/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 6917-7005 /branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.4 9070-13099 /branches/CloningRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4 4656-4721 /branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4 5471-5808 /branches/DataAnalysis SolutionEnsembles/HeuristicLab.Algorithms.DataAnalysis/3.4 5815-6180 /branches/DataAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.4 4458-4459,4462,4464 /branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis/3.4 10085-11101 /branches/GP.Grammar.Editor/HeuristicLab.Algorithms.DataAnalysis/3.4 6284-6795 /branches/GP.Symbols (TimeLag, Diff, Integral)/HeuristicLab.Algorithms.DataAnalysis/3.4 5060 /branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 11570-12508 /branches/HeuristicLab.Problems.Orienteering/HeuristicLab.Algorithms.DataAnalysis/3.4 11130-12721 /branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4 13819-14091 /branches/HeuristicLab.TimeSeries/HeuristicLab.Algorithms.DataAnalysis/3.4 8116-8789 /branches/LogResidualEvaluator/HeuristicLab.Algorithms.DataAnalysis/3.4 10202-10483 /branches/NET40/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 5138-5162 /branches/ParallelEngine/HeuristicLab.Algorithms.DataAnalysis/3.4 5175-5192 /branches/ProblemInstancesRegressionAndClassification/HeuristicLab.Algorithms.DataAnalysis/3.4 7773-7810 /branches/QAPAlgorithms/HeuristicLab.Algorithms.DataAnalysis/3.4 6350-6627 /branches/Restructure trunk solution/HeuristicLab.Algorithms.DataAnalysis/3.4 6828 /branches/SpectralKernelForGaussianProcesses/HeuristicLab.Algorithms.DataAnalysis/3.4 10204-10479 /branches/SuccessProgressAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.4 5370-5682 /branches/Trunk/HeuristicLab.Algorithms.DataAnalysis/3.4 6829-6865 /branches/VNS/HeuristicLab.Algorithms.DataAnalysis/3.4 5594-5752 /branches/Weighted TSNE/3.4 15451-15531 /branches/histogram/HeuristicLab.Algorithms.DataAnalysis/3.4 5959-6341 /branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4 14232-14825 /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 13331-15681
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
-
branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessModel.cs
r13160 r16692 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 34 34 [StorableClass] 35 35 [Item("GaussianProcessModel", "Represents a Gaussian process posterior.")] 36 public sealed class GaussianProcessModel : NamedItem, IGaussianProcessModel { 36 public sealed class GaussianProcessModel : RegressionModel, IGaussianProcessModel { 37 public override IEnumerable<string> VariablesUsedForPrediction { 38 get { return allowedInputVariables; } 39 } 40 37 41 [Storable] 38 42 private double negativeLogLikelihood; 39 43 public double NegativeLogLikelihood { 40 44 get { return negativeLogLikelihood; } 45 } 46 47 [Storable] 48 private double loocvNegLogPseudoLikelihood; 49 public double LooCvNegativeLogPseudoLikelihood { 50 get { return loocvNegLogPseudoLikelihood; } 41 51 } 42 52 … … 61 71 get { return meanFunction; } 62 72 } 63 [Storable] 64 private string targetVariable; 65 public string TargetVariable { 66 get { return targetVariable; } 67 } 73 68 74 [Storable] 69 75 private string[] allowedInputVariables; … … 128 134 this.trainingDataset = cloner.Clone(original.trainingDataset); 129 135 this.negativeLogLikelihood = original.negativeLogLikelihood; 130 this. targetVariable = original.targetVariable;136 this.loocvNegLogPseudoLikelihood = original.loocvNegLogPseudoLikelihood; 131 137 this.sqrSigmaNoise = original.sqrSigmaNoise; 132 138 if (original.meanParameter != null) { … … 147 153 IEnumerable<double> hyp, IMeanFunction meanFunction, ICovarianceFunction covarianceFunction, 148 154 bool scaleInputs = true) 149 : base( ) {155 : base(targetVariable) { 150 156 this.name = ItemName; 151 157 this.description = ItemDescription; 152 158 this.meanFunction = (IMeanFunction)meanFunction.Clone(); 153 159 this.covarianceFunction = (ICovarianceFunction)covarianceFunction.Clone(); 154 this.targetVariable = targetVariable;155 160 this.allowedInputVariables = allowedInputVariables.ToArray(); 156 161 … … 181 186 182 187 IEnumerable<double> y; 183 y = ds.GetDoubleValues( targetVariable, rows);188 y = ds.GetDoubleValues(TargetVariable, rows); 184 189 185 190 int n = x.GetLength(0); 186 191 192 var columns = Enumerable.Range(0, x.GetLength(1)).ToArray(); 187 193 // calculate cholesky decomposed (lower triangular) covariance matrix 188 var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, Enumerable.Range(0, x.GetLength(1)));194 var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, columns); 189 195 this.l = CalculateL(x, cov, sqrSigmaNoise); 190 196 191 197 // calculate mean 192 var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, Enumerable.Range(0, x.GetLength(1)));198 var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, columns); 193 199 double[] m = Enumerable.Range(0, x.GetLength(0)) 194 200 .Select(r => mean.Mean(x, r)) … … 218 224 alglib.spdmatrixcholeskyinverse(ref lCopy, n, false, out info, out matInvRep); 219 225 if (info != 1) throw new ArgumentException("Can't invert matrix to calculate gradients."); 226 227 // LOOCV log pseudo-likelihood (or log predictive probability) (GPML page 116 and 117) 228 var sumLoo = 0.0; 229 var ki = new double[n]; 230 for (int i = 0; i < n; i++) { 231 for (int j = 0; j < n; j++) ki[j] = cov.Covariance(x, i, j); 232 ki[i] += sqrSigmaNoise; 233 234 var yi = Util.ScalarProd(ki, alpha); 235 var yi_loo = yi - alpha[i] / (lCopy[i, i] / sqrSigmaNoise); 236 var s2_loo = 1.0 / (lCopy[i, i] / sqrSigmaNoise); 237 var err = ym[i] - yi_loo; 238 var nll_loo = 0.5 * Math.Log(2 * Math.PI * s2_loo) + 0.5 * err * err / s2_loo; 239 sumLoo += nll_loo; 240 } 241 loocvNegLogPseudoLikelihood = sumLoo; 242 220 243 for (int i = 0; i < n; i++) { 221 244 for (int j = 0; j <= i; j++) … … 227 250 double[] meanGradients = new double[meanFunction.GetNumberOfParameters(nAllowedVariables)]; 228 251 for (int k = 0; k < meanGradients.Length; k++) { 229 var meanGrad = Enumerable.Range(0, alpha.Length) 230 .Select(r => mean.Gradient(x, r, k)); 252 var meanGrad = new double[alpha.Length]; 253 for (int g = 0; g < meanGrad.Length; g++) 254 meanGrad[g] = mean.Gradient(x, g, k); 231 255 meanGradients[k] = -Util.ScalarProd(meanGrad, alpha); 232 256 } … … 236 260 for (int i = 0; i < n; i++) { 237 261 for (int j = 0; j < i; j++) { 238 var g = cov.CovarianceGradient(x, i, j) .ToArray();262 var g = cov.CovarianceGradient(x, i, j); 239 263 for (int k = 0; k < covGradients.Length; k++) { 240 264 covGradients[k] += lCopy[i, j] * g[k]; … … 242 266 } 243 267 244 var gDiag = cov.CovarianceGradient(x, i, i) .ToArray();268 var gDiag = cov.CovarianceGradient(x, i, i); 245 269 for (int k = 0; k < covGradients.Length; k++) { 246 270 // diag … … 259 283 private static double[,] GetData(IDataset ds, IEnumerable<string> allowedInputs, IEnumerable<int> rows, Scaling scaling) { 260 284 if (scaling != null) { 261 return AlglibUtil.PrepareAndScaleInputMatrix(ds, allowedInputs, rows, scaling); 285 // BackwardsCompatibility3.3 286 #region Backwards compatible code, remove with 3.4 287 // TODO: completely remove Scaling class 288 List<string> variablesList = allowedInputs.ToList(); 289 List<int> rowsList = rows.ToList(); 290 291 double[,] matrix = new double[rowsList.Count, variablesList.Count]; 292 293 int col = 0; 294 foreach (string column in variablesList) { 295 var values = scaling.GetScaledValues(ds, column, rowsList); 296 int row = 0; 297 foreach (var value in values) { 298 matrix[row, col] = value; 299 row++; 300 } 301 col++; 302 } 303 return matrix; 304 #endregion 262 305 } else { 263 return AlglibUtil.PrepareInputMatrix(ds,allowedInputs, rows);306 return ds.ToArray(allowedInputs, rows); 264 307 } 265 308 } … … 298 341 299 342 #region IRegressionModel Members 300 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {343 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 301 344 return GetEstimatedValuesHelper(dataset, rows); 302 345 } 303 public GaussianProcessRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {346 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 304 347 return new GaussianProcessRegressionSolution(this, new RegressionProblemData(problemData)); 305 }306 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {307 return CreateRegressionSolution(problemData);308 348 } 309 349 #endregion … … 320 360 int newN = newX.GetLength(0); 321 361 322 var Ks = new double[newN, n]; 323 var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, Enumerable.Range(0, newX.GetLength(1))); 362 var Ks = new double[newN][]; 363 var columns = Enumerable.Range(0, newX.GetLength(1)).ToArray(); 364 var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, columns); 324 365 var ms = Enumerable.Range(0, newX.GetLength(0)) 325 366 .Select(r => mean.Mean(newX, r)) 326 367 .ToArray(); 327 var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, Enumerable.Range(0, newX.GetLength(1)));368 var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, columns); 328 369 for (int i = 0; i < newN; i++) { 370 Ks[i] = new double[n]; 329 371 for (int j = 0; j < n; j++) { 330 Ks[i ,j] = cov.CrossCovariance(x, newX, j, i);372 Ks[i][j] = cov.CrossCovariance(x, newX, j, i); 331 373 } 332 374 } 333 375 334 376 return Enumerable.Range(0, newN) 335 .Select(i => ms[i] + Util.ScalarProd( Util.GetRow(Ks, i), alpha));377 .Select(i => ms[i] + Util.ScalarProd(Ks[i], alpha)); 336 378 } catch (alglib.alglibexception ae) { 337 379 // wrap exception so that calling code doesn't have to know about alglib implementation … … 340 382 } 341 383 342 public IEnumerable<double> GetEstimatedVariance (IDataset dataset, IEnumerable<int> rows) {384 public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) { 343 385 try { 344 386 if (x == null) { … … 352 394 var kss = new double[newN]; 353 395 double[,] sWKs = new double[n, newN]; 354 var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, Enumerable.Range(0, x.GetLength(1))); 396 var columns = Enumerable.Range(0, newX.GetLength(1)).ToArray(); 397 var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, columns); 355 398 356 399 if (l == null) { … … 372 415 373 416 for (int i = 0; i < newN; i++) { 374 var sumV = Util.ScalarProd(Util.GetCol(sWKs, i), Util.GetCol(sWKs, i)); 417 var col = Util.GetCol(sWKs, i).ToArray(); 418 var sumV = Util.ScalarProd(col, col); 375 419 kss[i] += sqrSigmaNoise; // kss is V(f), add noise variance of predictive distibution to get V(y) 376 420 kss[i] -= sumV; … … 383 427 } 384 428 } 429 385 430 } 386 431 }
Note: See TracChangeset
for help on using the changeset viewer.