1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System;


23  using System.Linq;


24  using HeuristicLab.Common;


25  using HeuristicLab.Core;


26  using HeuristicLab.Data;


27  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


28 


29  namespace HeuristicLab.Algorithms.DataAnalysis {


30  [StorableClass]


31  [Item(Name = "MeanLinear", Description = "Linear mean function for Gaussian processes.")]


32  public sealed class MeanLinear : ParameterizedNamedItem, IMeanFunction {


33  [Storable]


34  private double[] weights;


35  [Storable]


36  private readonly HyperParameter<DoubleArray> weightsParameter;


37  public IValueParameter<DoubleArray> WeightsParameter { get { return weightsParameter; } }


38 


39  [StorableConstructor]


40  private MeanLinear(bool deserializing) : base(deserializing) { }


41  private MeanLinear(MeanLinear original, Cloner cloner)


42  : base(original, cloner) {


43  if (original.weights != null) {


44  this.weights = new double[original.weights.Length];


45  Array.Copy(original.weights, weights, original.weights.Length);


46  }


47  weightsParameter = cloner.Clone(original.weightsParameter);


48  RegisterEvents();


49  }


50  public MeanLinear()


51  : base() {


52  this.weightsParameter = new HyperParameter<DoubleArray>("Weights", "The weights parameter for the linear mean function.");


53  Parameters.Add(weightsParameter);


54  RegisterEvents();


55  }


56 


57  public override IDeepCloneable Clone(Cloner cloner) {


58  return new MeanLinear(this, cloner);


59  }


60 


61  [StorableHook(HookType.AfterDeserialization)]


62  private void AfterDeserialization() {


63  RegisterEvents();


64  }


65 


66  private void RegisterEvents() {


67  Util.AttachArrayChangeHandler<DoubleArray, double>(weightsParameter, () => {


68  weights = weightsParameter.Value.ToArray();


69  });


70  }


71 


72  public int GetNumberOfParameters(int numberOfVariables) {


73  return weightsParameter.Fixed ? 0 : numberOfVariables;


74  }


75 


76  public void SetParameter(double[] hyp) {


77  if (!weightsParameter.Fixed) {


78  weightsParameter.SetValue(new DoubleArray(hyp));


79  } else if (hyp.Length != 0) throw new ArgumentException("The length of the parameter vector does not match the number of free parameters for the linear mean function.", "hyp");


80  }


81 


82  public double[] GetMean(double[,] x) {


83  // sanity check


84  if (weights.Length != x.GetLength(1)) throw new ArgumentException("The number of hyperparameters must match the number of variables for the linear mean function.");


85  int cols = x.GetLength(1);


86  int n = x.GetLength(0);


87  return (from i in Enumerable.Range(0, n)


88  let rowVector = Enumerable.Range(0, cols).Select(j => x[i, j])


89  select Util.ScalarProd(weights, rowVector))


90  .ToArray();


91  }


92 


93  public double[] GetGradients(int k, double[,] x) {


94  int cols = x.GetLength(1);


95  int n = x.GetLength(0);


96  if (k > cols) throw new ArgumentException();


97  return (Enumerable.Range(0, n).Select(r => x[r, k])).ToArray();


98  }


99  }


100  }

