#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Parameters; using HeuristicLab.Persistence; namespace HeuristicLab.Algorithms.DataAnalysis.KernelRidgeRegression { [StorableType("c6e3751a-1eab-4068-af73-e39f52cded26")] public abstract class KernelBase : ParameterizedNamedItem, IKernel { #region Parameternames private const string DistanceParameterName = "Distance"; #endregion #region Parameterproperties public ValueParameter DistanceParameter { get { return Parameters[DistanceParameterName] as ValueParameter; } } [Storable] public double? Beta { get; set; } #endregion #region Properties public IDistance Distance { get { return DistanceParameter.Value; } set { DistanceParameter.Value = value; } } #endregion [StorableConstructor] protected KernelBase(StorableConstructorFlag deserializing) : base(deserializing) { } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { } protected KernelBase(KernelBase original, Cloner cloner) : base(original, cloner) { Beta = original.Beta; } protected KernelBase() { Parameters.Add(new ValueParameter(DistanceParameterName, "The distance function used for kernel calculation")); DistanceParameter.Value = new EuclideanDistance(); } public double Get(object a, object b) { return Get(Distance.Get(a, b)); } protected abstract double Get(double norm); public int GetNumberOfParameters(int numberOfVariables) { return Beta.HasValue ? 0 : 1; } public void SetParameter(double[] p) { if (p != null && p.Length == 1) Beta = new double?(p[0]); } public ParameterizedCovarianceFunction GetParameterizedCovarianceFunction(double[] p, int[] columnIndices) { if (p.Length != GetNumberOfParameters(columnIndices.Length)) throw new ArgumentException("Illegal parametrization"); var myClone = (KernelBase)Clone(new Cloner()); myClone.SetParameter(p); var cov = new ParameterizedCovarianceFunction { Covariance = (x, i, j) => myClone.Get(GetNorm(x, x, i, j, columnIndices)), CrossCovariance = (x, xt, i, j) => myClone.Get(GetNorm(x, xt, i, j, columnIndices)), CovarianceGradient = (x, i, j) => new List { myClone.GetGradient(GetNorm(x, x, i, j, columnIndices)) } }; return cov; } protected abstract double GetGradient(double norm); protected double GetNorm(double[,] x, double[,] xt, int i, int j, int[] columnIndices) { var dist = Distance as IDistance>; if (dist == null) throw new ArgumentException("The distance needs to apply to double vectors"); var r1 = columnIndices.Select(c => x[i, c]); var r2 = columnIndices.Select(c => xt[j, c]); return dist.Get(r1, r2); } } }