#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Parameters; using HEAL.Attic; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableType("3449B830-E1E5-4176-B56D-AA32235F061B")] public abstract class KernelBase : ParameterizedNamedItem, IKernel { private const string DistanceParameterName = "Distance"; public IValueParameter DistanceParameter { get { return (IValueParameter)Parameters[DistanceParameterName]; } } [Storable] private double? beta; public double? Beta { get { return beta; } set { if (value != beta) { beta = value; RaiseBetaChanged(); } } } public IDistance Distance { get { return DistanceParameter.Value; } set { if (DistanceParameter.Value != value) { DistanceParameter.Value = value; } } } [StorableConstructor] protected KernelBase(StorableConstructorFlag _) : base(_) { } protected KernelBase(KernelBase original, Cloner cloner) : base(original, cloner) { beta = original.beta; RegisterEvents(); } protected KernelBase() { Parameters.Add(new ValueParameter(DistanceParameterName, "The distance function used for kernel calculation")); DistanceParameter.Value = new EuclideanDistance(); RegisterEvents(); } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { RegisterEvents(); } private void RegisterEvents() { DistanceParameter.ValueChanged += (sender, args) => RaiseDistanceChanged(); } public double Get(object a, object b) { return Get(Distance.Get(a, b)); } protected abstract double Get(double norm); public int GetNumberOfParameters(int numberOfVariables) { return Beta.HasValue ? 0 : 1; } public void SetParameter(double[] p) { if (p != null && p.Length == 1) Beta = new double?(p[0]); } public ParameterizedCovarianceFunction GetParameterizedCovarianceFunction(double[] p, int[] columnIndices) { if (p.Length != GetNumberOfParameters(columnIndices.Length)) throw new ArgumentException("Illegal parametrization"); var myClone = (KernelBase)Clone(); myClone.SetParameter(p); var cov = new ParameterizedCovarianceFunction { Covariance = (x, i, j) => myClone.Get(GetNorm(x, x, i, j, columnIndices)), CrossCovariance = (x, xt, i, j) => myClone.Get(GetNorm(x, xt, i, j, columnIndices)), CovarianceGradient = (x, i, j) => new List { myClone.GetGradient(GetNorm(x, x, i, j, columnIndices)) } }; return cov; } protected abstract double GetGradient(double norm); protected double GetNorm(double[,] x, double[,] xt, int i, int j, int[] columnIndices) { var dist = Distance as IDistance>; if (dist == null) throw new ArgumentException("The distance needs to apply to double vectors"); var r1 = columnIndices.Select(c => x[i, c]); var r2 = columnIndices.Select(c => xt[j, c]); return dist.Get(r1, r2); } #region events public event EventHandler BetaChanged; public event EventHandler DistanceChanged; protected void RaiseBetaChanged() { var handler = BetaChanged; if (handler != null) handler(this, EventArgs.Empty); } protected void RaiseDistanceChanged() { var handler = DistanceChanged; if (handler != null) handler(this, EventArgs.Empty); } #endregion } }