Free cookie consent management tool by TermsFeed Policy Generator

source: branches/RBFRegression/HeuristicLab.Algorithms.DataAnalysis/3.4/RadialBasisFunctions/RadialBasisRegression.cs @ 14872

Last change on this file since 14872 was 14872, checked in by gkronber, 7 years ago

#2699: made a number of changes mainly to RBF regression model

File size: 5.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using System.Threading;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32
33namespace HeuristicLab.Algorithms.DataAnalysis {
34  /// <summary>
35  /// Linear regression data analysis algorithm.
36  /// </summary>
37  [Item("Radial Basis Function Regression (RBF-R)", "Radial basis function regression data analysis algorithm.")]
38  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 100)]
39  [StorableClass]
40  public sealed class RadialBasisRegression : BasicAlgorithm {
41    private const string RBFRegressionSolutionResultName = "RBF regression solution";
42
43    public override bool SupportsPause {
44      get { return false; }
45    }
46    public override Type ProblemType {
47      get { return typeof(IRegressionProblem); }
48    }
49    public new IRegressionProblem Problem {
50      get { return (IRegressionProblem)base.Problem; }
51      set { base.Problem = value; }
52    }
53
54    #region parameter names
55    private const string KernelParameterName = "Kernel";
56    private const string ScaleInputVariablesParameterName = "ScaleInputVariables";
57    #endregion
58
59    #region parameter properties
60    public ValueParameter<ICovarianceFunction> KernelParameter {
61      get { return (ValueParameter<ICovarianceFunction>)Parameters[KernelParameterName]; }
62    }
63
64    public IFixedValueParameter<BoolValue> ScaleInputVariablesParameter {
65      get { return (IFixedValueParameter<BoolValue>)Parameters[ScaleInputVariablesParameterName]; }
66    }
67    #endregion
68
69    #region properties
70    public ICovarianceFunction Kernel {
71      get { return KernelParameter.Value; }
72    }
73
74    public bool ScaleInputVariables {
75      get { return ScaleInputVariablesParameter.Value.Value; }
76      set { ScaleInputVariablesParameter.Value.Value = value; }
77    }
78
79    #endregion
80
81    [StorableConstructor]
82    private RadialBasisRegression(bool deserializing) : base(deserializing) { }
83    private RadialBasisRegression(RadialBasisRegression original, Cloner cloner)
84      : base(original, cloner) {
85    }
86    public RadialBasisRegression() {
87      Problem = new RegressionProblem();
88      Parameters.Add(new ValueParameter<ICovarianceFunction>(KernelParameterName, "The radial basis function"));
89      Parameters.Add(new FixedValueParameter<BoolValue>(ScaleInputVariablesParameterName, "Set to true if the input variables should be scaled to the interval [0..1]", new BoolValue(true)));
90      var kernel = new GaussianKernel();
91      KernelParameter.Value = kernel;
92    }
93    [StorableHook(HookType.AfterDeserialization)]
94    private void AfterDeserialization() { }
95
96    public override IDeepCloneable Clone(Cloner cloner) {
97      return new RadialBasisRegression(this, cloner);
98    }
99
100    protected override void Run(CancellationToken cancellationToken) {
101      double loocvrmse, rmsError;
102      var solution = CreateRadialBasisRegressionSolution(Problem.ProblemData, Kernel, ScaleInputVariables, out loocvrmse, out rmsError);
103      Results.Add(new Result(RBFRegressionSolutionResultName, "The RBF regression solution.", solution));
104      Results.Add(new Result("LOOCVRMSE", "The root mean squared error of a leave-one-out-cross-validation on the training set", new DoubleValue(loocvrmse)));
105      Results.Add(new Result("RMSE (test)", "The root mean squared error of the solution on the test set.", new DoubleValue(rmsError)));
106    }
107
108    public static IRegressionSolution CreateRadialBasisRegressionSolution(IRegressionProblemData problemData, ICovarianceFunction kernel, bool scaleInputs, out double loocvRmsError, out double rmsError) {
109      var model = new RadialBasisFunctionModel(problemData.Dataset, problemData.TargetVariable, problemData.AllowedInputVariables, problemData.TrainingIndices, scaleInputs, kernel);
110      loocvRmsError = model.LeaveOneOutCrossValidationRootMeanSquaredError();
111      rmsError = Math.Sqrt(model.GetEstimatedValues(problemData.Dataset, problemData.TestIndices)
112        .Zip(problemData.TargetVariableTestValues, (a, b) => (a - b) * (a - b))
113        .Average());
114      var solution = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
115      solution.Model.Name = "RBF Regression Model";
116      solution.Name = "RBF Regression Solution";
117      return solution;
118    }
119  }
120}
Note: See TracBrowser for help on using the repository browser.