source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis/3.4/KernelRidgeRegression/KernelFunctions/KernelBase.cs @ 17209

Last change on this file since 17209 was 17209, checked in by gkronber, 5 weeks ago

#2994: merged r17132:17198 from trunk to branch

File size: 4.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Parameters;
28using HEAL.Attic;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  [StorableType("3449B830-E1E5-4176-B56D-AA32235F061B")]
32  public abstract class KernelBase : ParameterizedNamedItem, IKernel {
33
34    private const string DistanceParameterName = "Distance";
35
36    public IValueParameter<IDistance> DistanceParameter {
37      get { return (IValueParameter<IDistance>)Parameters[DistanceParameterName]; }
38    }
39
40    [Storable]
41    private double? beta;
42    public double? Beta {
43      get { return beta; }
44      set {
45        if (value != beta) {
46          beta = value;
47          RaiseBetaChanged();
48        }
49      }
50    }
51
52    public IDistance Distance {
53      get { return DistanceParameter.Value; }
54      set {
55        if (DistanceParameter.Value != value) {
56          DistanceParameter.Value = value;
57        }
58      }
59    }
60
61    [StorableConstructor]
62    protected KernelBase(StorableConstructorFlag _) : base(_) { }
63
64    protected KernelBase(KernelBase original, Cloner cloner)
65      : base(original, cloner) {
66      beta = original.beta;
67      RegisterEvents();
68    }
69
70    protected KernelBase() {
71      Parameters.Add(new ValueParameter<IDistance>(DistanceParameterName, "The distance function used for kernel calculation"));
72      DistanceParameter.Value = new EuclideanDistance();
73      RegisterEvents();
74    }
75
76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() {
78      RegisterEvents();
79    }
80
81    private void RegisterEvents() {
82      DistanceParameter.ValueChanged += (sender, args) => RaiseDistanceChanged();
83    }
84
85    public double Get(object a, object b) {
86      return Get(Distance.Get(a, b));
87    }
88
89    protected abstract double Get(double norm);
90
91    public int GetNumberOfParameters(int numberOfVariables) {
92      return Beta.HasValue ? 0 : 1;
93    }
94
95    public void SetParameter(double[] p) {
96      if (p != null && p.Length == 1) Beta = new double?(p[0]);
97    }
98
99    public ParameterizedCovarianceFunction GetParameterizedCovarianceFunction(double[] p, int[] columnIndices) {
100      if (p.Length != GetNumberOfParameters(columnIndices.Length)) throw new ArgumentException("Illegal parametrization");
101      var myClone = (KernelBase)Clone();
102      myClone.SetParameter(p);
103      var cov = new ParameterizedCovarianceFunction {
104        Covariance = (x, i, j) => myClone.Get(GetNorm(x, x, i, j, columnIndices)),
105        CrossCovariance = (x, xt, i, j) => myClone.Get(GetNorm(x, xt, i, j, columnIndices)),
106        CovarianceGradient = (x, i, j) => new List<double> { myClone.GetGradient(GetNorm(x, x, i, j, columnIndices)) }
107      };
108      return cov;
109    }
110
111    protected abstract double GetGradient(double norm);
112
113    protected double GetNorm(double[,] x, double[,] xt, int i, int j, int[] columnIndices) {
114      var dist = Distance as IDistance<IEnumerable<double>>;
115      if (dist == null) throw new ArgumentException("The distance needs to apply to double vectors");
116      var r1 = columnIndices.Select(c => x[i, c]);
117      var r2 = columnIndices.Select(c => xt[j, c]);
118      return dist.Get(r1, r2);
119    }
120
121    #region events
122    public event EventHandler BetaChanged;
123    public event EventHandler DistanceChanged;
124
125    protected void RaiseBetaChanged() {
126      var handler = BetaChanged;
127      if (handler != null) handler(this, EventArgs.Empty);
128    }
129
130    protected void RaiseDistanceChanged() {
131      var handler = DistanceChanged;
132      if (handler != null) handler(this, EventArgs.Empty);
133    }
134    #endregion
135  }
136}
Note: See TracBrowser for help on using the repository browser.