Changeset 8982 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/MeanFunctions/MeanLinear.cs
- Timestamp:
- 12/01/12 19:02:47 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/MeanFunctions/MeanLinear.cs
r8929 r8982 21 21 22 22 using System; 23 using System.Collections.Generic; 23 24 using System.Linq; 24 25 using HeuristicLab.Common; 25 26 using HeuristicLab.Core; 26 27 using HeuristicLab.Data; 28 using HeuristicLab.Parameters; 27 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 30 … … 31 33 [Item(Name = "MeanLinear", Description = "Linear mean function for Gaussian processes.")] 32 34 public sealed class MeanLinear : ParameterizedNamedItem, IMeanFunction { 33 [Storable] 34 private double[] weights; 35 [Storable] 36 private readonly HyperParameter<DoubleArray> weightsParameter; 37 public IValueParameter<DoubleArray> WeightsParameter { get { return weightsParameter; } } 35 public IValueParameter<DoubleArray> WeightsParameter { 36 get { return (IValueParameter<DoubleArray>)Parameters["Weights"]; } 37 } 38 38 39 39 [StorableConstructor] … … 41 41 private MeanLinear(MeanLinear original, Cloner cloner) 42 42 : base(original, cloner) { 43 if (original.weights != null) {44 this.weights = new double[original.weights.Length];45 Array.Copy(original.weights, weights, original.weights.Length);46 }47 weightsParameter = cloner.Clone(original.weightsParameter);48 RegisterEvents();49 43 } 50 44 public MeanLinear() 51 45 : base() { 52 this.weightsParameter = new HyperParameter<DoubleArray>("Weights", "The weights parameter for the linear mean function."); 53 Parameters.Add(weightsParameter); 54 RegisterEvents(); 46 Parameters.Add(new OptionalValueParameter<DoubleArray>("Weights", "The weights parameter for the linear mean function.")); 55 47 } 56 48 … … 59 51 } 60 52 61 [StorableHook(HookType.AfterDeserialization)] 62 private void AfterDeserialization() { 63 RegisterEvents(); 53 public int GetNumberOfParameters(int numberOfVariables) { 54 return WeightsParameter.Value != null ? 0 : numberOfVariables; 64 55 } 65 56 66 p rivate void RegisterEvents() {67 Util.AttachArrayChangeHandler<DoubleArray, double>(weightsParameter, () => {68 weights = weightsParameter.Value.ToArray();69 });57 public void SetParameter(double[] p) { 58 double[] weights; 59 GetParameter(p, out weights); 60 WeightsParameter.Value = new DoubleArray(weights); 70 61 } 71 62 72 public int GetNumberOfParameters(int numberOfVariables) { 73 return weightsParameter.Fixed ? 0 : numberOfVariables; 63 public void GetParameter(double[] p, out double[] weights) { 64 if (WeightsParameter.Value == null) { 65 weights = p; 66 } else { 67 if (p.Length != 0) throw new ArgumentException("The length of the parameter vector does not match the number of free parameters for the linear mean function.", "p"); 68 weights = WeightsParameter.Value.ToArray(); 69 } 74 70 } 75 71 76 public void SetParameter(double[] hyp) { 77 if (!weightsParameter.Fixed) { 78 weightsParameter.SetValue(new DoubleArray(hyp)); 79 } else if (hyp.Length != 0) throw new ArgumentException("The length of the parameter vector does not match the number of free parameters for the linear mean function.", "hyp"); 80 } 81 82 public double[] GetMean(double[,] x) { 83 // sanity check 84 if (weights.Length != x.GetLength(1)) throw new ArgumentException("The number of hyperparameters must match the number of variables for the linear mean function."); 85 int cols = x.GetLength(1); 86 int n = x.GetLength(0); 87 return (from i in Enumerable.Range(0, n) 88 let rowVector = Enumerable.Range(0, cols).Select(j => x[i, j]) 89 select Util.ScalarProd(weights, rowVector)) 90 .ToArray(); 91 } 92 93 public double[] GetGradients(int k, double[,] x) { 94 int cols = x.GetLength(1); 95 int n = x.GetLength(0); 96 if (k > cols) throw new ArgumentException(); 97 return (Enumerable.Range(0, n).Select(r => x[r, k])).ToArray(); 72 public ParameterizedMeanFunction GetParameterizedMeanFunction(double[] p, IEnumerable<int> columnIndices) { 73 double[] weights; 74 int[] columns = columnIndices.ToArray(); 75 GetParameter(p, out weights); 76 var mf = new ParameterizedMeanFunction(); 77 mf.Mean = (x, i) => { 78 // sanity check 79 if (weights.Length != columns.Length) throw new ArgumentException("The number of rparameters must match the number of variables for the linear mean function."); 80 return Util.ScalarProd(weights, Util.GetRow(x, i, columns)); 81 }; 82 mf.Gradient = (x, i, k) => { 83 if (k > columns.Length) throw new ArgumentException(); 84 return x[i, columns[k]]; 85 }; 86 return mf; 98 87 } 99 88 }
Note: See TracChangeset
for help on using the changeset viewer.