Changeset 8612 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/MeanLinear.cs
- Timestamp:
- 09/10/12 13:28:55 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/MeanLinear.cs
r8473 r8612 19 19 */ 20 20 #endregion 21 21 22 using System; 22 23 using System.Linq; 23 24 using HeuristicLab.Common; 24 25 using HeuristicLab.Core; 26 using HeuristicLab.Data; 25 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 26 28 … … 28 30 [StorableClass] 29 31 [Item(Name = "MeanLinear", Description = "Linear mean function for Gaussian processes.")] 30 public class MeanLinear :Item, IMeanFunction {32 public sealed class MeanLinear : ParameterizedNamedItem, IMeanFunction { 31 33 [Storable] 32 private double[] alpha; 33 public double[] Weights { 34 get { 35 if (alpha == null) return new double[0]; 36 var copy = new double[alpha.Length]; 37 Array.Copy(alpha, copy, copy.Length); 38 return copy; 34 private double[] weights; 35 [Storable] 36 private readonly HyperParameter<DoubleArray> weightsParameter; 37 public IValueParameter<DoubleArray> WeightsParameter { get { return weightsParameter; } } 38 39 [StorableConstructor] 40 private MeanLinear(bool deserializing) : base(deserializing) { } 41 private MeanLinear(MeanLinear original, Cloner cloner) 42 : base(original, cloner) { 43 if (original.weights != null) { 44 this.weights = new double[original.weights.Length]; 45 Array.Copy(original.weights, weights, original.weights.Length); 39 46 } 40 } 41 public int GetNumberOfParameters(int numberOfVariables) { 42 return numberOfVariables; 43 } 44 [StorableConstructor] 45 protected MeanLinear(bool deserializing) : base(deserializing) { } 46 protected MeanLinear(MeanLinear original, Cloner cloner) 47 : base(original, cloner) { 48 if (original.alpha != null) { 49 this.alpha = new double[original.alpha.Length]; 50 Array.Copy(original.alpha, alpha, original.alpha.Length); 51 } 47 weightsParameter = cloner.Clone(original.weightsParameter); 48 RegisterEvents(); 52 49 } 53 50 public MeanLinear() 54 51 : base() { 52 this.weightsParameter = new HyperParameter<DoubleArray>("Weights", "The weights parameter for the linear mean function."); 53 Parameters.Add(weightsParameter); 54 RegisterEvents(); 55 } 56 57 public override IDeepCloneable Clone(Cloner cloner) { 58 return new MeanLinear(this, cloner); 59 } 60 61 [StorableHook(HookType.AfterDeserialization)] 62 private void AfterDeserialization() { 63 RegisterEvents(); 64 } 65 66 private void RegisterEvents() { 67 Util.AttachArrayChangeHandler<DoubleArray, double>(weightsParameter, () => { 68 weights = weightsParameter.Value.ToArray(); 69 }); 70 } 71 72 public int GetNumberOfParameters(int numberOfVariables) { 73 return weightsParameter.Fixed ? 0 : numberOfVariables; 55 74 } 56 75 57 76 public void SetParameter(double[] hyp) { 58 this.alpha = new double[hyp.Length]; 59 Array.Copy(hyp, alpha, hyp.Length); 60 } 61 public void SetData(double[,] x) { 62 // nothing to do 77 if (!weightsParameter.Fixed) { 78 weightsParameter.SetValue(new DoubleArray(hyp)); 79 } else if (hyp.Length != 0) throw new ArgumentException("The length of the parameter vector does not match the number of free parameters for the linear mean function.", "hyp"); 63 80 } 64 81 65 82 public double[] GetMean(double[,] x) { 66 83 // sanity check 67 if ( alpha.Length != x.GetLength(1)) throw new ArgumentException("The number of hyperparameters must match the number of variables for the linear mean function.");84 if (weights.Length != x.GetLength(1)) throw new ArgumentException("The number of hyperparameters must match the number of variables for the linear mean function."); 68 85 int cols = x.GetLength(1); 69 86 int n = x.GetLength(0); 70 87 return (from i in Enumerable.Range(0, n) 71 let rowVector = from j in Enumerable.Range(0, cols) 72 select x[i, j] 73 select Util.ScalarProd(alpha, rowVector)) 88 let rowVector = Enumerable.Range(0, cols).Select(j => x[i, j]) 89 select Util.ScalarProd(weights, rowVector)) 74 90 .ToArray(); 75 91 } … … 79 95 int n = x.GetLength(0); 80 96 if (k > cols) throw new ArgumentException(); 81 return (from r in Enumerable.Range(0, n) 82 select x[r, k]).ToArray(); 83 } 84 85 public override IDeepCloneable Clone(Cloner cloner) { 86 return new MeanLinear(this, cloner); 97 return (Enumerable.Range(0, n).Select(r => x[r, k])).ToArray(); 87 98 } 88 99 }
Note: See TracChangeset
for help on using the changeset viewer.