Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/KernelRidgeRegression/KernelRidgeRegression.cs @ 17877

Last change on this file since 17877 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 6.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using System.Threading;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HEAL.Attic;
31using HeuristicLab.PluginInfrastructure;
32using HeuristicLab.Problems.DataAnalysis;
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  [Item("Kernel Ridge Regression", "Kernelized ridge regression e.g. for radial basis function (RBF) regression.")]
36  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 100)]
37  [StorableType("8AD45266-68CA-4710-A99C-59952132AA9D")]
38  public sealed class KernelRidgeRegression : BasicAlgorithm, IDataAnalysisAlgorithm<IRegressionProblem> {
39    private const string SolutionResultName = "Kernel ridge regression solution";
40
41    public override bool SupportsPause {
42      get { return false; }
43    }
44    public override Type ProblemType {
45      get { return typeof(IRegressionProblem); }
46    }
47    public new IRegressionProblem Problem {
48      get { return (IRegressionProblem)base.Problem; }
49      set { base.Problem = value; }
50    }
51
52    #region parameter names
53    private const string KernelParameterName = "Kernel";
54    private const string ScaleInputVariablesParameterName = "ScaleInputVariables";
55    private const string LambdaParameterName = "LogLambda";
56    private const string BetaParameterName = "Beta";
57    #endregion
58
59    #region parameter properties
60    public IConstrainedValueParameter<IKernel> KernelParameter {
61      get { return (IConstrainedValueParameter<IKernel>)Parameters[KernelParameterName]; }
62    }
63
64    public IFixedValueParameter<BoolValue> ScaleInputVariablesParameter {
65      get { return (IFixedValueParameter<BoolValue>)Parameters[ScaleInputVariablesParameterName]; }
66    }
67
68    public IFixedValueParameter<DoubleValue> LogLambdaParameter {
69      get { return (IFixedValueParameter<DoubleValue>)Parameters[LambdaParameterName]; }
70    }
71
72    public IFixedValueParameter<DoubleValue> BetaParameter {
73      get { return (IFixedValueParameter<DoubleValue>)Parameters[BetaParameterName]; }
74    }
75    #endregion
76
77    #region properties
78    public IKernel Kernel {
79      get { return KernelParameter.Value; }
80    }
81
82    public bool ScaleInputVariables {
83      get { return ScaleInputVariablesParameter.Value.Value; }
84      set { ScaleInputVariablesParameter.Value.Value = value; }
85    }
86
87    public double LogLambda {
88      get { return LogLambdaParameter.Value.Value; }
89      set { LogLambdaParameter.Value.Value = value; }
90    }
91
92    public double Beta {
93      get { return BetaParameter.Value.Value; }
94      set { BetaParameter.Value.Value = value; }
95    }
96    #endregion
97
98    [StorableConstructor]
99    private KernelRidgeRegression(StorableConstructorFlag _) : base(_) { }
100    private KernelRidgeRegression(KernelRidgeRegression original, Cloner cloner)
101      : base(original, cloner) {
102    }
103    public KernelRidgeRegression() {
104      Problem = new RegressionProblem();
105      var values = new ItemSet<IKernel>(ApplicationManager.Manager.GetInstances<IKernel>());
106      Parameters.Add(new ConstrainedValueParameter<IKernel>(KernelParameterName, "The kernel", values, values.OfType<GaussianKernel>().FirstOrDefault()));
107      Parameters.Add(new FixedValueParameter<BoolValue>(ScaleInputVariablesParameterName, "Set to true if the input variables should be scaled to the interval [0..1]", new BoolValue(true)));
108      Parameters.Add(new FixedValueParameter<DoubleValue>(LambdaParameterName, "The log10-transformed weight for the regularization term lambda [-inf..+inf]. Small values produce more complex models, large values produce models with larger errors. Set to very small value (e.g. -1.0e15) for almost exact approximation", new DoubleValue(-2)));
109      Parameters.Add(new FixedValueParameter<DoubleValue>(BetaParameterName, "The inverse width of the kernel ]0..+inf]. The distance between points is divided by this value before being plugged into the kernel.", new DoubleValue(2)));
110    }
111
112    public override IDeepCloneable Clone(Cloner cloner) {
113      return new KernelRidgeRegression(this, cloner);
114    }
115
116    protected override void Run(CancellationToken cancellationToken) {
117      double rmsError, looCvRMSE;
118      var kernel = Kernel;
119      kernel.Beta = Beta;
120      var solution = CreateRadialBasisRegressionSolution(Problem.ProblemData, kernel, Math.Pow(10, LogLambda), ScaleInputVariables, out rmsError, out looCvRMSE);
121      Results.Add(new Result(SolutionResultName, "The kernel ridge regression solution.", solution));
122      Results.Add(new Result("RMSE (test)", "The root mean squared error of the solution on the test set.", new DoubleValue(rmsError)));
123      Results.Add(new Result("RMSE (LOO-CV)", "The leave-one-out-cross-validation root mean squared error", new DoubleValue(looCvRMSE)));
124    }
125
126    public static IRegressionSolution CreateRadialBasisRegressionSolution(IRegressionProblemData problemData, ICovarianceFunction kernel, double lambda, bool scaleInputs, out double rmsError, out double looCvRMSE) {
127      var model = KernelRidgeRegressionModel.Create(problemData.Dataset, problemData.TargetVariable, problemData.AllowedInputVariables, problemData.TrainingIndices, scaleInputs, kernel, lambda);
128      rmsError = double.NaN;
129      if (problemData.TestIndices.Any()) {
130        rmsError = Math.Sqrt(model.GetEstimatedValues(problemData.Dataset, problemData.TestIndices)
131          .Zip(problemData.TargetVariableTestValues, (a, b) => (a - b) * (a - b))
132          .Average());
133      }
134      var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
135      solution.Model.Name = "Kernel ridge regression model";
136      solution.Name = SolutionResultName;
137      looCvRMSE = model.LooCvRMSE;
138      return solution;
139    }
140  }
141}
Note: See TracBrowser for help on using the repository browser.