Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Algorithms.GradientDescent/3.3/Lbfgs.cs @ 18066

Last change on this file since 18066 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 12.3 KB
RevLine 
[8396]1
2#region License Information
3/* HeuristicLab
[17181]4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[8396]5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
[9439]24using System.Linq;
25using HeuristicLab.Analysis;
[8396]26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
[9127]29using HeuristicLab.Encodings.RealVectorEncoding;
[8396]30using HeuristicLab.Operators;
31using HeuristicLab.Optimization;
32using HeuristicLab.Parameters;
[17097]33using HEAL.Attic;
[8396]34using HeuristicLab.Random;
35
[8401]36namespace HeuristicLab.Algorithms.GradientDescent {
[8396]37  /// <summary>
38  /// Limited-Memory BFGS optimization algorithm.
39  /// </summary>
40  [Item("LM-BFGS", "The limited-memory BFGS (Broyden–Fletcher–Goldfarb–Shanno) optimization algorithm.")]
[12708]41  [Creatable(CreatableAttribute.Categories.SingleSolutionAlgorithms, Priority = 160)]
[17097]42  [StorableType("55E85596-0FC7-41B5-9B90-9A8BF33B7C55")]
[8396]43  public sealed class LbfgsAlgorithm : HeuristicOptimizationEngineAlgorithm, IStorableContent {
44    public override Type ProblemType {
[9127]45      get { return typeof(ISingleObjectiveHeuristicOptimizationProblem); }
[8396]46    }
47
[9127]48    public new ISingleObjectiveHeuristicOptimizationProblem Problem {
49      get { return (ISingleObjectiveHeuristicOptimizationProblem)base.Problem; }
[8396]50      set { base.Problem = value; }
51    }
52
53    public string Filename { get; set; }
54
[9439]55    private const string AnalyzerParameterName = "Analyzer";
[8396]56    private const string MaxIterationsParameterName = "MaxIterations";
57    private const string ApproximateGradientsParameterName = "ApproximateGradients";
[8397]58    private const string SeedParameterName = "Seed";
59    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
[9409]60    private const string GradientCheckStepSizeParameterName = "GradientCheckStepSize";
[8396]61
62    #region parameter properties
[9439]63    public IValueParameter<IMultiAnalyzer> AnalyzerParameter {
64      get { return (IValueParameter<IMultiAnalyzer>)Parameters[AnalyzerParameterName]; }
65    }
[8396]66    public IValueParameter<IntValue> MaxIterationsParameter {
67      get { return (IValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
68    }
[8397]69    public IValueParameter<IntValue> SeedParameter {
70      get { return (IValueParameter<IntValue>)Parameters[SeedParameterName]; }
71    }
72    public IValueParameter<BoolValue> SetSeedRandomlyParameter {
73      get { return (IValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
74    }
[9409]75    public IValueParameter<DoubleValue> GradientStepSizeParameter {
76      get { return (IValueParameter<DoubleValue>)Parameters[GradientCheckStepSizeParameterName]; }
77    }
[8396]78    #endregion
79    #region properties
[9439]80    public IMultiAnalyzer Analyzer {
81      get { return AnalyzerParameter.Value; }
82      set { AnalyzerParameter.Value = value; }
83    }
[8396]84    public int MaxIterations {
85      set { MaxIterationsParameter.Value.Value = value; }
86      get { return MaxIterationsParameter.Value.Value; }
87    }
[8397]88    public int Seed { get { return SeedParameter.Value.Value; } set { SeedParameter.Value.Value = value; } }
89    public bool SetSeedRandomly { get { return SetSeedRandomlyParameter.Value.Value; } set { SetSeedRandomlyParameter.Value.Value = value; } }
[8396]90    #endregion
[8397]91
[9127]92    [Storable]
93    private LbfgsInitializer initializer;
94    [Storable]
95    private LbfgsMakeStep makeStep;
96    [Storable]
97    private LbfgsUpdateResults updateResults;
98    [Storable]
99    private LbfgsAnalyzer analyzer;
100    [Storable]
101    private Placeholder solutionCreator;
102    [Storable]
103    private Placeholder evaluator;
104
[8396]105    [StorableConstructor]
[17097]106    private LbfgsAlgorithm(StorableConstructorFlag _) : base(_) { }
[8396]107    private LbfgsAlgorithm(LbfgsAlgorithm original, Cloner cloner)
108      : base(original, cloner) {
[9127]109      initializer = cloner.Clone(original.initializer);
110      makeStep = cloner.Clone(original.makeStep);
111      updateResults = cloner.Clone(original.updateResults);
112      analyzer = cloner.Clone(original.analyzer);
113      solutionCreator = cloner.Clone(original.solutionCreator);
114      evaluator = cloner.Clone(original.evaluator);
115      RegisterEvents();
[8396]116    }
117    public LbfgsAlgorithm()
118      : base() {
[9439]119      Parameters.Add(new ValueParameter<IMultiAnalyzer>(AnalyzerParameterName, "The analyzers that will be executed on the solution.", new MultiAnalyzer()));
[8396]120      Parameters.Add(new ValueParameter<IntValue>(MaxIterationsParameterName, "The maximal number of iterations for.", new IntValue(20)));
[8397]121      Parameters.Add(new ValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
122      Parameters.Add(new ValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
[8396]123      Parameters.Add(new ValueParameter<BoolValue>(ApproximateGradientsParameterName, "Indicates that gradients should be approximated.", new BoolValue(true)));
[9409]124      Parameters.Add(new OptionalValueParameter<DoubleValue>(GradientCheckStepSizeParameterName, "Step size for the gradient check (should be used for debugging the gradient calculation only)."));
125      // these parameter should not be changed usually
126      Parameters[ApproximateGradientsParameterName].Hidden = true;
127      Parameters[GradientCheckStepSizeParameterName].Hidden = true;
[8396]128
129      var randomCreator = new RandomCreator();
[9127]130      solutionCreator = new Placeholder();
131      initializer = new LbfgsInitializer();
132      makeStep = new LbfgsMakeStep();
[8396]133      var branch = new ConditionalBranch();
[9127]134      evaluator = new Placeholder();
135      updateResults = new LbfgsUpdateResults();
[9439]136      var analyzerPlaceholder = new Placeholder();
137      var finalAnalyzerPlaceholder = new Placeholder();
[8396]138
139      OperatorGraph.InitialOperator = randomCreator;
140
[8397]141      randomCreator.SeedParameter.ActualName = SeedParameterName;
142      randomCreator.SeedParameter.Value = null;
143      randomCreator.SetSeedRandomlyParameter.ActualName = SetSeedRandomlyParameterName;
144      randomCreator.SetSeedRandomlyParameter.Value = null;
[8396]145      randomCreator.Successor = solutionCreator;
146
[9409]147      solutionCreator.Name = "(Solution Creator)";
[9127]148      solutionCreator.Successor = initializer;
[8396]149
[9127]150      initializer.IterationsParameter.ActualName = MaxIterationsParameterName;
151      initializer.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
152      initializer.Successor = makeStep;
[8396]153
[9127]154      makeStep.StateParameter.ActualName = initializer.StateParameter.Name;
[8396]155      makeStep.Successor = branch;
156
157      branch.ConditionParameter.ActualName = makeStep.TerminationCriterionParameter.Name;
158      branch.FalseBranch = evaluator;
[9439]159      branch.TrueBranch = finalAnalyzerPlaceholder;
[8396]160
[9409]161      evaluator.Name = "(Evaluator)";
[8396]162      evaluator.Successor = updateResults;
163
[9127]164      updateResults.StateParameter.ActualName = initializer.StateParameter.Name;
[8396]165      updateResults.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
[9439]166      updateResults.Successor = analyzerPlaceholder;
[8396]167
[9439]168      analyzerPlaceholder.Name = "(Analyzer)";
169      analyzerPlaceholder.OperatorParameter.ActualName = AnalyzerParameterName;
170      analyzerPlaceholder.Successor = makeStep;
171
172      finalAnalyzerPlaceholder.Name = "(Analyzer)";
173      finalAnalyzerPlaceholder.OperatorParameter.ActualName = AnalyzerParameterName;
174      finalAnalyzerPlaceholder.Successor = null;
175
176      analyzer = new LbfgsAnalyzer();
[9127]177      analyzer.StateParameter.ActualName = initializer.StateParameter.Name;
[8396]178    }
179
180    [StorableHook(HookType.AfterDeserialization)]
[9127]181    private void AfterDeserialization() {
182      RegisterEvents();
183    }
[8396]184
185    public override IDeepCloneable Clone(Cloner cloner) {
186      return new LbfgsAlgorithm(this, cloner);
187    }
[9127]188
189    #region events
190    private void RegisterEvents() {
191      if (Problem != null) {
192        RegisterSolutionCreatorEvents();
193        RegisterEvaluatorEvents();
194      }
195    }
196
197    protected override void OnProblemChanged() {
198      base.OnProblemChanged();
199      if (Problem != null) {
200        RegisterEvents();
201        solutionCreator.OperatorParameter.ActualName = Problem.SolutionCreatorParameter.Name;
[13053]202        solutionCreator.OperatorParameter.Hidden = true;
[9127]203        evaluator.OperatorParameter.ActualName = Problem.EvaluatorParameter.Name;
[13053]204        evaluator.OperatorParameter.Hidden = true;
[9439]205        UpdateAnalyzers();
206        ParameterizeOperators();
[9127]207      }
208    }
209
210    protected override void Problem_SolutionCreatorChanged(object sender, EventArgs e) {
211      base.Problem_SolutionCreatorChanged(sender, e);
212      RegisterSolutionCreatorEvents();
213      ParameterizeOperators();
214    }
215
216    protected override void Problem_EvaluatorChanged(object sender, EventArgs e) {
217      base.Problem_EvaluatorChanged(sender, e);
218      RegisterEvaluatorEvents();
219      ParameterizeOperators();
220    }
221
[9439]222    protected override void Problem_OperatorsChanged(object sender, EventArgs e) {
223      base.Problem_OperatorsChanged(sender, e);
[13053]224      RegisterEvents();
225      solutionCreator.OperatorParameter.ActualName = Problem.SolutionCreatorParameter.Name;
226      solutionCreator.OperatorParameter.Hidden = true;
227      evaluator.OperatorParameter.ActualName = Problem.EvaluatorParameter.Name;
228      evaluator.OperatorParameter.Hidden = true;
[9439]229      UpdateAnalyzers();
[13053]230      ParameterizeOperators();
[9439]231    }
232
[9127]233    private void RegisterSolutionCreatorEvents() {
[9408]234      var realVectorCreator = Problem.SolutionCreator as IRealVectorCreator;
[9127]235      // ignore if we have a different kind of problem
236      if (realVectorCreator != null) {
237        realVectorCreator.RealVectorParameter.ActualNameChanged += (sender, args) => ParameterizeOperators();
238      }
239    }
240
241    private void RegisterEvaluatorEvents() {
242      Problem.Evaluator.QualityParameter.ActualNameChanged += (sender, args) => ParameterizeOperators();
243    }
244    #endregion
245
246    protected override void OnStarted() {
[9408]247      var realVectorCreator = Problem.SolutionCreator as IRealVectorCreator;
[9127]248      // must catch the case that user loaded an unsupported problem
249      if (realVectorCreator == null)
250        throw new InvalidOperationException("LM-BFGS only works with problems using a real-value encoding.");
251      base.OnStarted();
252    }
253
254    public override void Prepare() {
255      if (Problem != null) base.Prepare();
256    }
257
[9439]258    private void UpdateAnalyzers() {
259      Analyzer.Operators.Clear();
260      if (Problem != null) {
261        foreach (var a in Problem.Operators.OfType<IAnalyzer>()) {
262          foreach (var param in a.Parameters.OfType<IScopeTreeLookupParameter>())
263            param.Depth = 0;
264          Analyzer.Operators.Add(a, a.EnabledByDefault);
265        }
266      }
267      Analyzer.Operators.Add(analyzer, analyzer.EnabledByDefault);
268    }
269
[9127]270    private void ParameterizeOperators() {
[9408]271      var realVectorCreator = Problem.SolutionCreator as IRealVectorCreator;
[9127]272      // ignore if we have a different kind of problem
273      if (realVectorCreator != null) {
274        var realVectorParameterName = realVectorCreator.RealVectorParameter.ActualName;
275        initializer.PointParameter.ActualName = realVectorParameterName;
[13053]276        initializer.PointParameter.Hidden = true;
[9127]277        makeStep.PointParameter.ActualName = realVectorParameterName;
[13053]278        makeStep.PointParameter.Hidden = true;
[9127]279        analyzer.PointParameter.ActualName = realVectorParameterName;
[13053]280        analyzer.PointParameter.Hidden = true;
[9127]281      }
282
283      var qualityParameterName = Problem.Evaluator.QualityParameter.ActualName;
284      updateResults.QualityParameter.ActualName = qualityParameterName;
[13053]285      updateResults.QualityParameter.Hidden = true;
[9127]286      analyzer.QualityParameter.ActualName = qualityParameterName;
[13053]287      analyzer.QualityParameter.Hidden = true;
[9127]288    }
[8396]289  }
290}
Note: See TracBrowser for help on using the repository browser.