1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21  using System.Linq;


22  using HeuristicLab.Common;


23  using HeuristicLab.Core;


24  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


25 


26  namespace HeuristicLab.Algorithms.DataAnalysis {


27  [StorableClass]


28  [Item(Name = "MeanProduct", Description = "Product of mean functions for Gaussian processes.")]


29  public sealed class MeanProduct : Item, IMeanFunction {


30  [Storable]


31  private ItemList<IMeanFunction> factors;


32 


33  [Storable]


34  private int numberOfVariables;


35 


36  public ItemList<IMeanFunction> Factors {


37  get { return factors; }


38  }


39 


40  [StorableConstructor]


41  private MeanProduct(bool deserializing)


42  : base(deserializing) {


43  }


44 


45  private MeanProduct(MeanProduct original, Cloner cloner)


46  : base(original, cloner) {


47  this.factors = cloner.Clone(original.factors);


48  this.numberOfVariables = original.numberOfVariables;


49  }


50 


51  public MeanProduct() {


52  this.factors = new ItemList<IMeanFunction>();


53  }


54  public override IDeepCloneable Clone(Cloner cloner) {


55  return new MeanProduct(this, cloner);


56  }


57 


58  public int GetNumberOfParameters(int numberOfVariables) {


59  this.numberOfVariables = numberOfVariables;


60  return factors.Select(t => t.GetNumberOfParameters(numberOfVariables)).Sum();


61  }


62 


63  public void SetParameter(double[] hyp) {


64  int offset = 0;


65  foreach (var t in factors) {


66  var numberOfParameters = t.GetNumberOfParameters(numberOfVariables);


67  t.SetParameter(hyp.Skip(offset).Take(numberOfParameters).ToArray());


68  offset += numberOfParameters;


69  }


70  }


71 


72  public double[] GetMean(double[,] x) {


73  var res = factors.First().GetMean(x);


74  foreach (var t in factors.Skip(1)) {


75  var a = t.GetMean(x);


76  for (int i = 0; i < res.Length; i++) res[i] *= a[i];


77  }


78  return res;


79  }


80 


81  public double[] GetGradients(int k, double[,] x) {


82  double[] res = Enumerable.Repeat(1.0, x.GetLength(0)).ToArray();


83  // find index of factor for the given k


84  int j = 0;


85  while (k >= factors[j].GetNumberOfParameters(numberOfVariables)) {


86  k = factors[j].GetNumberOfParameters(numberOfVariables);


87  j++;


88  }


89  for (int i = 0; i < factors.Count; i++) {


90  var f = factors[i];


91  if (i == j) {


92  // multiply gradient


93  var g = f.GetGradients(k, x);


94  for (int ii = 0; ii < res.Length; ii++) res[ii] *= g[ii];


95  } else {


96  // multiply mean


97  var m = f.GetMean(x);


98  for (int ii = 0; ii < res.Length; ii++) res[ii] *= m[ii];


99  }


100  }


101  return res;


102  }


103  }


104  }

