Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/10/12 13:28:55 (12 years ago)
Author:
gkronber
Message:

#1902 implemented all mean and covariance functions with parameters as ParameterizedNamedItems

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/MeanLinear.cs

    r8473 r8612  
    1919 */
    2020#endregion
     21
    2122using System;
    2223using System.Linq;
    2324using HeuristicLab.Common;
    2425using HeuristicLab.Core;
     26using HeuristicLab.Data;
    2527using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2628
     
    2830  [StorableClass]
    2931  [Item(Name = "MeanLinear", Description = "Linear mean function for Gaussian processes.")]
    30   public class MeanLinear : Item, IMeanFunction {
     32  public sealed class MeanLinear : ParameterizedNamedItem, IMeanFunction {
    3133    [Storable]
    32     private double[] alpha;
    33     public double[] Weights {
    34       get {
    35         if (alpha == null) return new double[0];
    36         var copy = new double[alpha.Length];
    37         Array.Copy(alpha, copy, copy.Length);
    38         return copy;
     34    private double[] weights;
     35    [Storable]
     36    private readonly HyperParameter<DoubleArray> weightsParameter;
     37    public IValueParameter<DoubleArray> WeightsParameter { get { return weightsParameter; } }
     38
     39    [StorableConstructor]
     40    private MeanLinear(bool deserializing) : base(deserializing) { }
     41    private MeanLinear(MeanLinear original, Cloner cloner)
     42      : base(original, cloner) {
     43      if (original.weights != null) {
     44        this.weights = new double[original.weights.Length];
     45        Array.Copy(original.weights, weights, original.weights.Length);
    3946      }
    40     }
    41     public int GetNumberOfParameters(int numberOfVariables) {
    42       return numberOfVariables;
    43     }
    44     [StorableConstructor]
    45     protected MeanLinear(bool deserializing) : base(deserializing) { }
    46     protected MeanLinear(MeanLinear original, Cloner cloner)
    47       : base(original, cloner) {
    48       if (original.alpha != null) {
    49         this.alpha = new double[original.alpha.Length];
    50         Array.Copy(original.alpha, alpha, original.alpha.Length);
    51       }
     47      weightsParameter = cloner.Clone(original.weightsParameter);
     48      RegisterEvents();
    5249    }
    5350    public MeanLinear()
    5451      : base() {
     52      this.weightsParameter = new HyperParameter<DoubleArray>("Weights", "The weights parameter for the linear mean function.");
     53      Parameters.Add(weightsParameter);
     54      RegisterEvents();
     55    }
     56
     57    public override IDeepCloneable Clone(Cloner cloner) {
     58      return new MeanLinear(this, cloner);
     59    }
     60
     61    [StorableHook(HookType.AfterDeserialization)]
     62    private void AfterDeserialization() {
     63      RegisterEvents();
     64    }
     65
     66    private void RegisterEvents() {
     67      Util.AttachArrayChangeHandler<DoubleArray, double>(weightsParameter, () => {
     68        weights = weightsParameter.Value.ToArray();
     69      });
     70    }
     71
     72    public int GetNumberOfParameters(int numberOfVariables) {
     73      return weightsParameter.Fixed ? 0 : numberOfVariables;
    5574    }
    5675
    5776    public void SetParameter(double[] hyp) {
    58       this.alpha = new double[hyp.Length];
    59       Array.Copy(hyp, alpha, hyp.Length);
    60     }
    61     public void SetData(double[,] x) {
    62       // nothing to do
     77      if (!weightsParameter.Fixed) {
     78        weightsParameter.SetValue(new DoubleArray(hyp));
     79      } else if (hyp.Length != 0) throw new ArgumentException("The length of the parameter vector does not match the number of free parameters for the linear mean function.", "hyp");
    6380    }
    6481
    6582    public double[] GetMean(double[,] x) {
    6683      // sanity check
    67       if (alpha.Length != x.GetLength(1)) throw new ArgumentException("The number of hyperparameters must match the number of variables for the linear mean function.");
     84      if (weights.Length != x.GetLength(1)) throw new ArgumentException("The number of hyperparameters must match the number of variables for the linear mean function.");
    6885      int cols = x.GetLength(1);
    6986      int n = x.GetLength(0);
    7087      return (from i in Enumerable.Range(0, n)
    71               let rowVector = from j in Enumerable.Range(0, cols)
    72                               select x[i, j]
    73               select Util.ScalarProd(alpha, rowVector))
     88              let rowVector = Enumerable.Range(0, cols).Select(j => x[i, j])
     89              select Util.ScalarProd(weights, rowVector))
    7490        .ToArray();
    7591    }
     
    7995      int n = x.GetLength(0);
    8096      if (k > cols) throw new ArgumentException();
    81       return (from r in Enumerable.Range(0, n)
    82               select x[r, k]).ToArray();
    83     }
    84 
    85     public override IDeepCloneable Clone(Cloner cloner) {
    86       return new MeanLinear(this, cloner);
     97      return (Enumerable.Range(0, n).Select(r => x[r, k])).ToArray();
    8798    }
    8899  }
Note: See TracChangeset for help on using the changeset viewer.