#region License Information
/* HeuristicLab
* Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Parameters;
using HeuristicLab.Persistence;
namespace HeuristicLab.Algorithms.DataAnalysis.KernelRidgeRegression {
[StorableType("c6e3751a-1eab-4068-af73-e39f52cded26")]
public abstract class KernelBase : ParameterizedNamedItem, IKernel {
#region Parameternames
private const string DistanceParameterName = "Distance";
#endregion
#region Parameterproperties
public ValueParameter DistanceParameter {
get { return Parameters[DistanceParameterName] as ValueParameter; }
}
[Storable]
public double? Beta { get; set; }
#endregion
#region Properties
public IDistance Distance {
get { return DistanceParameter.Value; }
set { DistanceParameter.Value = value; }
}
#endregion
[StorableConstructor]
protected KernelBase(StorableConstructorFlag deserializing) : base(deserializing) { }
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() { }
protected KernelBase(KernelBase original, Cloner cloner)
: base(original, cloner) {
Beta = original.Beta;
}
protected KernelBase() {
Parameters.Add(new ValueParameter(DistanceParameterName, "The distance function used for kernel calculation"));
DistanceParameter.Value = new EuclideanDistance();
}
public double Get(object a, object b) {
return Get(Distance.Get(a, b));
}
protected abstract double Get(double norm);
public int GetNumberOfParameters(int numberOfVariables) {
return Beta.HasValue ? 0 : 1;
}
public void SetParameter(double[] p) {
if (p != null && p.Length == 1) Beta = new double?(p[0]);
}
public ParameterizedCovarianceFunction GetParameterizedCovarianceFunction(double[] p, int[] columnIndices) {
if (p.Length != GetNumberOfParameters(columnIndices.Length)) throw new ArgumentException("Illegal parametrization");
var myClone = (KernelBase)Clone(new Cloner());
myClone.SetParameter(p);
var cov = new ParameterizedCovarianceFunction {
Covariance = (x, i, j) => myClone.Get(GetNorm(x, x, i, j, columnIndices)),
CrossCovariance = (x, xt, i, j) => myClone.Get(GetNorm(x, xt, i, j, columnIndices)),
CovarianceGradient = (x, i, j) => new List { myClone.GetGradient(GetNorm(x, x, i, j, columnIndices)) }
};
return cov;
}
protected abstract double GetGradient(double norm);
protected double GetNorm(double[,] x, double[,] xt, int i, int j, int[] columnIndices) {
var dist = Distance as IDistance>;
if (dist == null) throw new ArgumentException("The distance needs to apply to double vectors");
var r1 = columnIndices.Select(c => x[i, c]);
var r2 = columnIndices.Select(c => xt[j, c]);
return dist.Get(r1, r2);
}
}
}