Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessBase.cs @ 12962

Last change on this file since 12962 was 12797, checked in by gkronber, 9 years ago

#2439: fixed problem in GaussianProcess algorithm after changes in BFGS

File size: 10.2 KB
RevLine 
[9096]1
2#region License Information
3/* HeuristicLab
[12012]4 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[9096]5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Linq;
25using HeuristicLab.Algorithms.GradientDescent;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Operators;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.PluginInfrastructure;
34using HeuristicLab.Problems.DataAnalysis;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  /// <summary>
38  /// Base class for Gaussian process data analysis algorithms (regression and classification).
39  /// </summary>
40  [StorableClass]
41  public abstract class GaussianProcessBase : EngineAlgorithm {
42    protected const string MeanFunctionParameterName = "MeanFunction";
43    protected const string CovarianceFunctionParameterName = "CovarianceFunction";
44    protected const string MinimizationIterationsParameterName = "Iterations";
45    protected const string ApproximateGradientsParameterName = "ApproximateGradients";
46    protected const string SeedParameterName = "Seed";
47    protected const string SetSeedRandomlyParameterName = "SetSeedRandomly";
48    protected const string ModelCreatorParameterName = "GaussianProcessModelCreator";
49    protected const string NegativeLogLikelihoodParameterName = "NegativeLogLikelihood";
50    protected const string HyperparameterParameterName = "Hyperparameter";
51    protected const string HyperparameterGradientsParameterName = "HyperparameterGradients";
52    protected const string SolutionCreatorParameterName = "GaussianProcessSolutionCreator";
53
54    public new IDataAnalysisProblem Problem {
55      get { return (IDataAnalysisProblem)base.Problem; }
56      set { base.Problem = value; }
57    }
58
59    #region parameter properties
60    public IValueParameter<IMeanFunction> MeanFunctionParameter {
61      get { return (IValueParameter<IMeanFunction>)Parameters[MeanFunctionParameterName]; }
62    }
63    public IValueParameter<ICovarianceFunction> CovarianceFunctionParameter {
64      get { return (IValueParameter<ICovarianceFunction>)Parameters[CovarianceFunctionParameterName]; }
65    }
66    public IValueParameter<IntValue> MinimizationIterationsParameter {
67      get { return (IValueParameter<IntValue>)Parameters[MinimizationIterationsParameterName]; }
68    }
69    public IValueParameter<IntValue> SeedParameter {
70      get { return (IValueParameter<IntValue>)Parameters[SeedParameterName]; }
71    }
72    public IValueParameter<BoolValue> SetSeedRandomlyParameter {
73      get { return (IValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
74    }
75    #endregion
76    #region properties
77    public IMeanFunction MeanFunction {
78      set { MeanFunctionParameter.Value = value; }
79      get { return MeanFunctionParameter.Value; }
80    }
81    public ICovarianceFunction CovarianceFunction {
82      set { CovarianceFunctionParameter.Value = value; }
83      get { return CovarianceFunctionParameter.Value; }
84    }
85    public int MinimizationIterations {
86      set { MinimizationIterationsParameter.Value.Value = value; }
87      get { return MinimizationIterationsParameter.Value.Value; }
88    }
89    public int Seed { get { return SeedParameter.Value.Value; } set { SeedParameter.Value.Value = value; } }
90    public bool SetSeedRandomly { get { return SetSeedRandomlyParameter.Value.Value; } set { SetSeedRandomlyParameter.Value.Value = value; } }
91    #endregion
92
93    [StorableConstructor]
94    protected GaussianProcessBase(bool deserializing) : base(deserializing) { }
95    protected GaussianProcessBase(GaussianProcessBase original, Cloner cloner)
96      : base(original, cloner) {
97    }
98    protected GaussianProcessBase(IDataAnalysisProblem problem)
99      : base() {
100      Problem = problem;
101      Parameters.Add(new ValueParameter<IMeanFunction>(MeanFunctionParameterName, "The mean function to use.", new MeanConst()));
102      Parameters.Add(new ValueParameter<ICovarianceFunction>(CovarianceFunctionParameterName, "The covariance function to use.", new CovarianceSquaredExponentialIso()));
103      Parameters.Add(new ValueParameter<IntValue>(MinimizationIterationsParameterName, "The number of iterations for likelihood optimization with LM-BFGS.", new IntValue(20)));
104      Parameters.Add(new ValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
105      Parameters.Add(new ValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
106
107      Parameters.Add(new ValueParameter<BoolValue>(ApproximateGradientsParameterName, "Indicates that gradients should not be approximated (necessary for LM-BFGS).", new BoolValue(false)));
108      Parameters[ApproximateGradientsParameterName].Hidden = true; // should not be changed
109
[12797]110      // necessary for BFGS
111      Parameters.Add(new ValueParameter<BoolValue>("Maximization", new BoolValue(false)));
112      Parameters["Maximization"].Hidden = true;
113
[9096]114      var randomCreator = new HeuristicLab.Random.RandomCreator();
115      var gpInitializer = new GaussianProcessHyperparameterInitializer();
116      var bfgsInitializer = new LbfgsInitializer();
117      var makeStep = new LbfgsMakeStep();
118      var branch = new ConditionalBranch();
119      var modelCreator = new Placeholder();
120      var updateResults = new LbfgsUpdateResults();
121      var analyzer = new LbfgsAnalyzer();
122      var finalModelCreator = new Placeholder();
123      var finalAnalyzer = new LbfgsAnalyzer();
124      var solutionCreator = new Placeholder();
125
126      OperatorGraph.InitialOperator = randomCreator;
127      randomCreator.SeedParameter.ActualName = SeedParameterName;
128      randomCreator.SeedParameter.Value = null;
129      randomCreator.SetSeedRandomlyParameter.ActualName = SetSeedRandomlyParameterName;
130      randomCreator.SetSeedRandomlyParameter.Value = null;
131      randomCreator.Successor = gpInitializer;
132
133      gpInitializer.CovarianceFunctionParameter.ActualName = CovarianceFunctionParameterName;
134      gpInitializer.MeanFunctionParameter.ActualName = MeanFunctionParameterName;
135      gpInitializer.ProblemDataParameter.ActualName = Problem.ProblemDataParameter.Name;
136      gpInitializer.HyperparameterParameter.ActualName = HyperparameterParameterName;
137      gpInitializer.RandomParameter.ActualName = randomCreator.RandomParameter.Name;
138      gpInitializer.Successor = bfgsInitializer;
139
140      bfgsInitializer.IterationsParameter.ActualName = MinimizationIterationsParameterName;
141      bfgsInitializer.PointParameter.ActualName = HyperparameterParameterName;
142      bfgsInitializer.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
143      bfgsInitializer.Successor = makeStep;
144
145      makeStep.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
146      makeStep.PointParameter.ActualName = HyperparameterParameterName;
147      makeStep.Successor = branch;
148
149      branch.ConditionParameter.ActualName = makeStep.TerminationCriterionParameter.Name;
150      branch.FalseBranch = modelCreator;
151      branch.TrueBranch = finalModelCreator;
152
153      modelCreator.OperatorParameter.ActualName = ModelCreatorParameterName;
154      modelCreator.Successor = updateResults;
155
156      updateResults.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
157      updateResults.QualityParameter.ActualName = NegativeLogLikelihoodParameterName;
158      updateResults.QualityGradientsParameter.ActualName = HyperparameterGradientsParameterName;
159      updateResults.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
160      updateResults.Successor = analyzer;
161
162      analyzer.QualityParameter.ActualName = NegativeLogLikelihoodParameterName;
163      analyzer.PointParameter.ActualName = HyperparameterParameterName;
164      analyzer.QualityGradientsParameter.ActualName = HyperparameterGradientsParameterName;
165      analyzer.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
166      analyzer.PointsTableParameter.ActualName = "Hyperparameter table";
167      analyzer.QualityGradientsTableParameter.ActualName = "Gradients table";
168      analyzer.QualitiesTableParameter.ActualName = "Negative log likelihood table";
169      analyzer.Successor = makeStep;
170
171      finalModelCreator.OperatorParameter.ActualName = ModelCreatorParameterName;
172      finalModelCreator.Successor = finalAnalyzer;
173
174      finalAnalyzer.QualityParameter.ActualName = NegativeLogLikelihoodParameterName;
175      finalAnalyzer.PointParameter.ActualName = HyperparameterParameterName;
176      finalAnalyzer.QualityGradientsParameter.ActualName = HyperparameterGradientsParameterName;
177      finalAnalyzer.PointsTableParameter.ActualName = analyzer.PointsTableParameter.ActualName;
178      finalAnalyzer.QualityGradientsTableParameter.ActualName = analyzer.QualityGradientsTableParameter.ActualName;
179      finalAnalyzer.QualitiesTableParameter.ActualName = analyzer.QualitiesTableParameter.ActualName;
180      finalAnalyzer.Successor = solutionCreator;
181
182      solutionCreator.OperatorParameter.ActualName = SolutionCreatorParameterName;
183    }
184
185    [StorableHook(HookType.AfterDeserialization)]
186    private void AfterDeserialization() {
[12797]187      // BackwardsCompatibility3.4
188      #region Backwards compatible code, remove with 3.5
189      if (!Parameters.ContainsKey("Maximization")) {
190        Parameters.Add(new ValueParameter<BoolValue>("Maximization", new BoolValue(false)));
191        Parameters["Maximization"].Hidden = true;
192      }
193      #endregion
[9096]194    }
195  }
196}
Note: See TracBrowser for help on using the repository browser.