Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessBase.cs @ 17712

Last change on this file since 17712 was 17427, checked in by fholzing, 5 years ago

#2812: Changed type of MeanFunctionParameter and CovarianceFunctionParameter from IValueParameter to IConstrainedValueParameter (+ initialization)

File size: 12.5 KB
RevLine 
[9096]1
2#region License Information
3/* HeuristicLab
[17180]4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[9096]5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
[14434]23using System.Linq;
[17427]24using HEAL.Attic;
[9096]25using HeuristicLab.Algorithms.GradientDescent;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Operators;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
[17427]32using HeuristicLab.PluginInfrastructure;
[9096]33using HeuristicLab.Problems.DataAnalysis;
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  /// <summary>
37  /// Base class for Gaussian process data analysis algorithms (regression and classification).
38  /// </summary>
[16565]39  [StorableType("A5070F15-8E44-44DC-92E1-000826E933D3")]
[9096]40  public abstract class GaussianProcessBase : EngineAlgorithm {
41    protected const string MeanFunctionParameterName = "MeanFunction";
42    protected const string CovarianceFunctionParameterName = "CovarianceFunction";
43    protected const string MinimizationIterationsParameterName = "Iterations";
44    protected const string ApproximateGradientsParameterName = "ApproximateGradients";
45    protected const string SeedParameterName = "Seed";
46    protected const string SetSeedRandomlyParameterName = "SetSeedRandomly";
47    protected const string ModelCreatorParameterName = "GaussianProcessModelCreator";
48    protected const string NegativeLogLikelihoodParameterName = "NegativeLogLikelihood";
49    protected const string HyperparameterParameterName = "Hyperparameter";
50    protected const string HyperparameterGradientsParameterName = "HyperparameterGradients";
51    protected const string SolutionCreatorParameterName = "GaussianProcessSolutionCreator";
[13118]52    protected const string ScaleInputValuesParameterName = "ScaleInputValues";
[9096]53
54    public new IDataAnalysisProblem Problem {
55      get { return (IDataAnalysisProblem)base.Problem; }
56      set { base.Problem = value; }
57    }
58
59    #region parameter properties
[17427]60    public IConstrainedValueParameter<IMeanFunction> MeanFunctionParameter {
61      get { return (IConstrainedValueParameter<IMeanFunction>)Parameters[MeanFunctionParameterName]; }
[9096]62    }
[17427]63    public IConstrainedValueParameter<ICovarianceFunction> CovarianceFunctionParameter {
64      get { return (IConstrainedValueParameter<ICovarianceFunction>)Parameters[CovarianceFunctionParameterName]; }
[9096]65    }
66    public IValueParameter<IntValue> MinimizationIterationsParameter {
67      get { return (IValueParameter<IntValue>)Parameters[MinimizationIterationsParameterName]; }
68    }
69    public IValueParameter<IntValue> SeedParameter {
70      get { return (IValueParameter<IntValue>)Parameters[SeedParameterName]; }
71    }
72    public IValueParameter<BoolValue> SetSeedRandomlyParameter {
73      get { return (IValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
74    }
[13118]75    public IFixedValueParameter<BoolValue> ScaleInputValuesParameter {
76      get { return (IFixedValueParameter<BoolValue>)Parameters[ScaleInputValuesParameterName]; }
77    }
[9096]78    #endregion
79    #region properties
80    public IMeanFunction MeanFunction {
81      set { MeanFunctionParameter.Value = value; }
82      get { return MeanFunctionParameter.Value; }
83    }
84    public ICovarianceFunction CovarianceFunction {
85      set { CovarianceFunctionParameter.Value = value; }
86      get { return CovarianceFunctionParameter.Value; }
87    }
88    public int MinimizationIterations {
89      set { MinimizationIterationsParameter.Value.Value = value; }
90      get { return MinimizationIterationsParameter.Value.Value; }
91    }
92    public int Seed { get { return SeedParameter.Value.Value; } set { SeedParameter.Value.Value = value; } }
93    public bool SetSeedRandomly { get { return SetSeedRandomlyParameter.Value.Value; } set { SetSeedRandomlyParameter.Value.Value = value; } }
[13118]94
95    public bool ScaleInputValues {
96      get { return ScaleInputValuesParameter.Value.Value; }
97      set { ScaleInputValuesParameter.Value.Value = value; }
98    }
[9096]99    #endregion
100
101    [StorableConstructor]
[16565]102    protected GaussianProcessBase(StorableConstructorFlag _) : base(_) { }
[9096]103    protected GaussianProcessBase(GaussianProcessBase original, Cloner cloner)
104      : base(original, cloner) {
105    }
106    protected GaussianProcessBase(IDataAnalysisProblem problem)
107      : base() {
108      Problem = problem;
[17427]109      Parameters.Add(new ConstrainedValueParameter<IMeanFunction>(MeanFunctionParameterName, "The mean function to use."));
110      Parameters.Add(new ConstrainedValueParameter<ICovarianceFunction>(CovarianceFunctionParameterName, "The covariance function to use."));
[9096]111      Parameters.Add(new ValueParameter<IntValue>(MinimizationIterationsParameterName, "The number of iterations for likelihood optimization with LM-BFGS.", new IntValue(20)));
112      Parameters.Add(new ValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
113      Parameters.Add(new ValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
114
115      Parameters.Add(new ValueParameter<BoolValue>(ApproximateGradientsParameterName, "Indicates that gradients should not be approximated (necessary for LM-BFGS).", new BoolValue(false)));
116      Parameters[ApproximateGradientsParameterName].Hidden = true; // should not be changed
117
[13118]118      Parameters.Add(new FixedValueParameter<BoolValue>(ScaleInputValuesParameterName,
119        "Determines if the input variable values are scaled to the range [0..1] for training.", new BoolValue(true)));
120      Parameters[ScaleInputValuesParameterName].Hidden = true;
121
[12797]122      // necessary for BFGS
[14434]123      Parameters.Add(new FixedValueParameter<BoolValue>("Maximization (BFGS)", new BoolValue(false)));
124      Parameters["Maximization (BFGS)"].Hidden = true;
[12797]125
[9096]126      var randomCreator = new HeuristicLab.Random.RandomCreator();
127      var gpInitializer = new GaussianProcessHyperparameterInitializer();
128      var bfgsInitializer = new LbfgsInitializer();
129      var makeStep = new LbfgsMakeStep();
130      var branch = new ConditionalBranch();
131      var modelCreator = new Placeholder();
132      var updateResults = new LbfgsUpdateResults();
133      var analyzer = new LbfgsAnalyzer();
134      var finalModelCreator = new Placeholder();
135      var finalAnalyzer = new LbfgsAnalyzer();
136      var solutionCreator = new Placeholder();
137
138      OperatorGraph.InitialOperator = randomCreator;
139      randomCreator.SeedParameter.ActualName = SeedParameterName;
140      randomCreator.SeedParameter.Value = null;
141      randomCreator.SetSeedRandomlyParameter.ActualName = SetSeedRandomlyParameterName;
142      randomCreator.SetSeedRandomlyParameter.Value = null;
143      randomCreator.Successor = gpInitializer;
144
145      gpInitializer.CovarianceFunctionParameter.ActualName = CovarianceFunctionParameterName;
146      gpInitializer.MeanFunctionParameter.ActualName = MeanFunctionParameterName;
147      gpInitializer.ProblemDataParameter.ActualName = Problem.ProblemDataParameter.Name;
148      gpInitializer.HyperparameterParameter.ActualName = HyperparameterParameterName;
149      gpInitializer.RandomParameter.ActualName = randomCreator.RandomParameter.Name;
150      gpInitializer.Successor = bfgsInitializer;
151
152      bfgsInitializer.IterationsParameter.ActualName = MinimizationIterationsParameterName;
153      bfgsInitializer.PointParameter.ActualName = HyperparameterParameterName;
154      bfgsInitializer.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
155      bfgsInitializer.Successor = makeStep;
156
157      makeStep.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
158      makeStep.PointParameter.ActualName = HyperparameterParameterName;
159      makeStep.Successor = branch;
160
161      branch.ConditionParameter.ActualName = makeStep.TerminationCriterionParameter.Name;
162      branch.FalseBranch = modelCreator;
163      branch.TrueBranch = finalModelCreator;
164
165      modelCreator.OperatorParameter.ActualName = ModelCreatorParameterName;
166      modelCreator.Successor = updateResults;
167
[14434]168      updateResults.MaximizationParameter.ActualName = "Maximization (BFGS)";
[9096]169      updateResults.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
170      updateResults.QualityParameter.ActualName = NegativeLogLikelihoodParameterName;
171      updateResults.QualityGradientsParameter.ActualName = HyperparameterGradientsParameterName;
172      updateResults.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
173      updateResults.Successor = analyzer;
174
175      analyzer.QualityParameter.ActualName = NegativeLogLikelihoodParameterName;
176      analyzer.PointParameter.ActualName = HyperparameterParameterName;
177      analyzer.QualityGradientsParameter.ActualName = HyperparameterGradientsParameterName;
178      analyzer.StateParameter.ActualName = bfgsInitializer.StateParameter.Name;
179      analyzer.PointsTableParameter.ActualName = "Hyperparameter table";
180      analyzer.QualityGradientsTableParameter.ActualName = "Gradients table";
181      analyzer.QualitiesTableParameter.ActualName = "Negative log likelihood table";
182      analyzer.Successor = makeStep;
183
184      finalModelCreator.OperatorParameter.ActualName = ModelCreatorParameterName;
185      finalModelCreator.Successor = finalAnalyzer;
186
187      finalAnalyzer.QualityParameter.ActualName = NegativeLogLikelihoodParameterName;
188      finalAnalyzer.PointParameter.ActualName = HyperparameterParameterName;
189      finalAnalyzer.QualityGradientsParameter.ActualName = HyperparameterGradientsParameterName;
190      finalAnalyzer.PointsTableParameter.ActualName = analyzer.PointsTableParameter.ActualName;
191      finalAnalyzer.QualityGradientsTableParameter.ActualName = analyzer.QualityGradientsTableParameter.ActualName;
192      finalAnalyzer.QualitiesTableParameter.ActualName = analyzer.QualitiesTableParameter.ActualName;
193      finalAnalyzer.Successor = solutionCreator;
194
195      solutionCreator.OperatorParameter.ActualName = SolutionCreatorParameterName;
[17427]196
197      foreach (var meanfunction in ApplicationManager.Manager.GetInstances<IMeanFunction>().OrderBy(s => s.ItemName))
198        MeanFunctionParameter.ValidValues.Add(meanfunction);
199
200      var defaultMeanFunction = MeanFunctionParameter.ValidValues.OfType<MeanConst>().FirstOrDefault();
201      if (defaultMeanFunction != null) {
202        MeanFunctionParameter.Value = defaultMeanFunction;
203      }
204
205      foreach (var covarianceFunction in ApplicationManager.Manager.GetInstances<ICovarianceFunction>().OrderBy(s => s.ItemName))
206        CovarianceFunctionParameter.ValidValues.Add(covarianceFunction);
207
208      var defaultCovarianceFunctionParameter = CovarianceFunctionParameter.ValidValues.OfType<CovarianceSquaredExponentialIso>().FirstOrDefault();
209      if (defaultCovarianceFunctionParameter != null) {
210        CovarianceFunctionParameter.Value = defaultCovarianceFunctionParameter;
211      }
[9096]212    }
213
214    [StorableHook(HookType.AfterDeserialization)]
215    private void AfterDeserialization() {
[12797]216      // BackwardsCompatibility3.4
217      #region Backwards compatible code, remove with 3.5
[14434]218      if (Parameters.ContainsKey("Maximization")) {
219        Parameters.Remove("Maximization");
[12797]220      }
[13118]221
[14434]222      if (!Parameters.ContainsKey("Maximization (BFGS)")) {
223        Parameters.Add(new FixedValueParameter<BoolValue>("Maximization (BFGS)", new BoolValue(false)));
224        Parameters["Maximization (BFGS)"].Hidden = true;
225        OperatorGraph.Operators.OfType<LbfgsUpdateResults>().First().MaximizationParameter.ActualName = "Maximization BFGS";
226      }
227
[13118]228      if (!Parameters.ContainsKey(ScaleInputValuesParameterName)) {
229        Parameters.Add(new FixedValueParameter<BoolValue>(ScaleInputValuesParameterName,
230          "Determines if the input variable values are scaled to the range [0..1] for training.", new BoolValue(true)));
231        Parameters[ScaleInputValuesParameterName].Hidden = true;
232      }
[12797]233      #endregion
[9096]234    }
235  }
236}
Note: See TracBrowser for help on using the repository browser.