Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/KernelRidgeRegression/KernelFunctions/KernelBase.cs @ 18190

Last change on this file since 18190 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 4.5 KB
RevLine 
[14386]1#region License Information
2/* HeuristicLab
[17181]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[14386]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
[14887]24using System.Linq;
[14386]25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Parameters;
[17097]28using HEAL.Attic;
[14386]29
[15249]30namespace HeuristicLab.Algorithms.DataAnalysis {
[17097]31  [StorableType("3449B830-E1E5-4176-B56D-AA32235F061B")]
[14887]32  public abstract class KernelBase : ParameterizedNamedItem, IKernel {
[14386]33
34    private const string DistanceParameterName = "Distance";
[15249]35
36    public IValueParameter<IDistance> DistanceParameter {
37      get { return (IValueParameter<IDistance>)Parameters[DistanceParameterName]; }
[14386]38    }
39
[14887]40    [Storable]
[15249]41    private double? beta;
42    public double? Beta {
43      get { return beta; }
44      set {
45        if (value != beta) {
46          beta = value;
47          RaiseBetaChanged();
48        }
49      }
50    }
51
[14872]52    public IDistance Distance {
[14386]53      get { return DistanceParameter.Value; }
[15249]54      set {
55        if (DistanceParameter.Value != value) {
56          DistanceParameter.Value = value;
57        }
58      }
[14386]59    }
60
61    [StorableConstructor]
[17097]62    protected KernelBase(StorableConstructorFlag _) : base(_) { }
[14386]63
[14872]64    protected KernelBase(KernelBase original, Cloner cloner)
[14887]65      : base(original, cloner) {
[15249]66      beta = original.beta;
67      RegisterEvents();
[14887]68    }
[14386]69
70    protected KernelBase() {
[14872]71      Parameters.Add(new ValueParameter<IDistance>(DistanceParameterName, "The distance function used for kernel calculation"));
72      DistanceParameter.Value = new EuclideanDistance();
[15249]73      RegisterEvents();
[14386]74    }
75
[15249]76    [StorableHook(HookType.AfterDeserialization)]
77    private void AfterDeserialization() {
78      RegisterEvents();
79    }
80
81    private void RegisterEvents() {
82      DistanceParameter.ValueChanged += (sender, args) => RaiseDistanceChanged();
83    }
84
[14872]85    public double Get(object a, object b) {
[14386]86      return Get(Distance.Get(a, b));
87    }
88
89    protected abstract double Get(double norm);
90
91    public int GetNumberOfParameters(int numberOfVariables) {
[14887]92      return Beta.HasValue ? 0 : 1;
[14386]93    }
94
95    public void SetParameter(double[] p) {
[14887]96      if (p != null && p.Length == 1) Beta = new double?(p[0]);
[14386]97    }
98
99    public ParameterizedCovarianceFunction GetParameterizedCovarianceFunction(double[] p, int[] columnIndices) {
[14887]100      if (p.Length != GetNumberOfParameters(columnIndices.Length)) throw new ArgumentException("Illegal parametrization");
[15249]101      var myClone = (KernelBase)Clone();
[14887]102      myClone.SetParameter(p);
[14386]103      var cov = new ParameterizedCovarianceFunction {
104        Covariance = (x, i, j) => myClone.Get(GetNorm(x, x, i, j, columnIndices)),
105        CrossCovariance = (x, xt, i, j) => myClone.Get(GetNorm(x, xt, i, j, columnIndices)),
106        CovarianceGradient = (x, i, j) => new List<double> { myClone.GetGradient(GetNorm(x, x, i, j, columnIndices)) }
107      };
108      return cov;
109    }
110
111    protected abstract double GetGradient(double norm);
112
113    protected double GetNorm(double[,] x, double[,] xt, int i, int j, int[] columnIndices) {
114      var dist = Distance as IDistance<IEnumerable<double>>;
[14872]115      if (dist == null) throw new ArgumentException("The distance needs to apply to double vectors");
[14887]116      var r1 = columnIndices.Select(c => x[i, c]);
117      var r2 = columnIndices.Select(c => xt[j, c]);
[14386]118      return dist.Get(r1, r2);
119    }
[15249]120
121    #region events
122    public event EventHandler BetaChanged;
123    public event EventHandler DistanceChanged;
124
125    protected void RaiseBetaChanged() {
126      var handler = BetaChanged;
127      if (handler != null) handler(this, EventArgs.Empty);
128    }
129
130    protected void RaiseDistanceChanged() {
131      var handler = DistanceChanged;
132      if (handler != null) handler(this, EventArgs.Empty);
133    }
134    #endregion
[14386]135  }
136}
Note: See TracBrowser for help on using the repository browser.