Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.GradientDescent/3.3/Lbfgs.cs @ 16671

Last change on this file since 16671 was 16565, checked in by gkronber, 6 years ago

#2520: merged changes from PersistenceOverhaul branch (r16451:16564) into trunk

File size: 12.4 KB
Line 
1
2#region License Information
3/* HeuristicLab
4 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Linq;
25using HeuristicLab.Analysis;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.RealVectorEncoding;
30using HeuristicLab.Operators;
31using HeuristicLab.Optimization;
32using HeuristicLab.Parameters;
33using HEAL.Attic;
34using HeuristicLab.Random;
35
36namespace HeuristicLab.Algorithms.GradientDescent {
37  /// <summary>
38  /// Limited-Memory BFGS optimization algorithm.
39  /// </summary>
40  [Item("LM-BFGS", "The limited-memory BFGS (Broyden–Fletcher–Goldfarb–Shanno) optimization algorithm.")]
41  [Creatable(CreatableAttribute.Categories.SingleSolutionAlgorithms, Priority = 160)]
42  [StorableType("55E85596-0FC7-41B5-9B90-9A8BF33B7C55")]
43  public sealed class LbfgsAlgorithm : HeuristicOptimizationEngineAlgorithm, IStorableContent {
44    public override Type ProblemType {
45      get { return typeof(ISingleObjectiveHeuristicOptimizationProblem); }
46    }
47
48    public new ISingleObjectiveHeuristicOptimizationProblem Problem {
49      get { return (ISingleObjectiveHeuristicOptimizationProblem)base.Problem; }
50      set { base.Problem = value; }
51    }
52
53    public string Filename { get; set; }
54
55    private const string AnalyzerParameterName = "Analyzer";
56    private const string MaxIterationsParameterName = "MaxIterations";
57    private const string ApproximateGradientsParameterName = "ApproximateGradients";
58    private const string SeedParameterName = "Seed";
59    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
60    private const string GradientCheckStepSizeParameterName = "GradientCheckStepSize";
61
62    #region parameter properties
63    public IValueParameter<IMultiAnalyzer> AnalyzerParameter {
64      get { return (IValueParameter<IMultiAnalyzer>)Parameters[AnalyzerParameterName]; }
65    }
66    public IValueParameter<IntValue> MaxIterationsParameter {
67      get { return (IValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
68    }
69    public IValueParameter<IntValue> SeedParameter {
70      get { return (IValueParameter<IntValue>)Parameters[SeedParameterName]; }
71    }
72    public IValueParameter<BoolValue> SetSeedRandomlyParameter {
73      get { return (IValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
74    }
75    public IValueParameter<DoubleValue> GradientStepSizeParameter {
76      get { return (IValueParameter<DoubleValue>)Parameters[GradientCheckStepSizeParameterName]; }
77    }
78    #endregion
79    #region properties
80    public IMultiAnalyzer Analyzer {
81      get { return AnalyzerParameter.Value; }
82      set { AnalyzerParameter.Value = value; }
83    }
84    public int MaxIterations {
85      set { MaxIterationsParameter.Value.Value = value; }
86      get { return MaxIterationsParameter.Value.Value; }
87    }
88    public int Seed { get { return SeedParameter.Value.Value; } set { SeedParameter.Value.Value = value; } }
89    public bool SetSeedRandomly { get { return SetSeedRandomlyParameter.Value.Value; } set { SetSeedRandomlyParameter.Value.Value = value; } }
90    #endregion
91
92    [Storable]
93    private LbfgsInitializer initializer;
94    [Storable]
95    private LbfgsMakeStep makeStep;
96    [Storable]
97    private LbfgsUpdateResults updateResults;
98    [Storable]
99    private LbfgsAnalyzer analyzer;
100    [Storable]
101    private Placeholder solutionCreator;
102    [Storable]
103    private Placeholder evaluator;
104
105    [StorableConstructor]
106    private LbfgsAlgorithm(StorableConstructorFlag _) : base(_) { }
107    private LbfgsAlgorithm(LbfgsAlgorithm original, Cloner cloner)
108      : base(original, cloner) {
109      initializer = cloner.Clone(original.initializer);
110      makeStep = cloner.Clone(original.makeStep);
111      updateResults = cloner.Clone(original.updateResults);
112      analyzer = cloner.Clone(original.analyzer);
113      solutionCreator = cloner.Clone(original.solutionCreator);
114      evaluator = cloner.Clone(original.evaluator);
115      RegisterEvents();
116    }
117    public LbfgsAlgorithm()
118      : base() {
119      Parameters.Add(new ValueParameter<IMultiAnalyzer>(AnalyzerParameterName, "The analyzers that will be executed on the solution.", new MultiAnalyzer()));
120      Parameters.Add(new ValueParameter<IntValue>(MaxIterationsParameterName, "The maximal number of iterations for.", new IntValue(20)));
121      Parameters.Add(new ValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
122      Parameters.Add(new ValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
123      Parameters.Add(new ValueParameter<BoolValue>(ApproximateGradientsParameterName, "Indicates that gradients should be approximated.", new BoolValue(true)));
124      Parameters.Add(new OptionalValueParameter<DoubleValue>(GradientCheckStepSizeParameterName, "Step size for the gradient check (should be used for debugging the gradient calculation only)."));
125      // these parameter should not be changed usually
126      Parameters[ApproximateGradientsParameterName].Hidden = true;
127      Parameters[GradientCheckStepSizeParameterName].Hidden = true;
128
129      var randomCreator = new RandomCreator();
130      solutionCreator = new Placeholder();
131      initializer = new LbfgsInitializer();
132      makeStep = new LbfgsMakeStep();
133      var branch = new ConditionalBranch();
134      evaluator = new Placeholder();
135      updateResults = new LbfgsUpdateResults();
136      var analyzerPlaceholder = new Placeholder();
137      var finalAnalyzerPlaceholder = new Placeholder();
138
139      OperatorGraph.InitialOperator = randomCreator;
140
141      randomCreator.SeedParameter.ActualName = SeedParameterName;
142      randomCreator.SeedParameter.Value = null;
143      randomCreator.SetSeedRandomlyParameter.ActualName = SetSeedRandomlyParameterName;
144      randomCreator.SetSeedRandomlyParameter.Value = null;
145      randomCreator.Successor = solutionCreator;
146
147      solutionCreator.Name = "(Solution Creator)";
148      solutionCreator.Successor = initializer;
149
150      initializer.IterationsParameter.ActualName = MaxIterationsParameterName;
151      initializer.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
152      initializer.Successor = makeStep;
153
154      makeStep.StateParameter.ActualName = initializer.StateParameter.Name;
155      makeStep.Successor = branch;
156
157      branch.ConditionParameter.ActualName = makeStep.TerminationCriterionParameter.Name;
158      branch.FalseBranch = evaluator;
159      branch.TrueBranch = finalAnalyzerPlaceholder;
160
161      evaluator.Name = "(Evaluator)";
162      evaluator.Successor = updateResults;
163
164      updateResults.StateParameter.ActualName = initializer.StateParameter.Name;
165      updateResults.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
166      updateResults.Successor = analyzerPlaceholder;
167
168      analyzerPlaceholder.Name = "(Analyzer)";
169      analyzerPlaceholder.OperatorParameter.ActualName = AnalyzerParameterName;
170      analyzerPlaceholder.Successor = makeStep;
171
172      finalAnalyzerPlaceholder.Name = "(Analyzer)";
173      finalAnalyzerPlaceholder.OperatorParameter.ActualName = AnalyzerParameterName;
174      finalAnalyzerPlaceholder.Successor = null;
175
176      analyzer = new LbfgsAnalyzer();
177      analyzer.StateParameter.ActualName = initializer.StateParameter.Name;
178    }
179
180    [StorableHook(HookType.AfterDeserialization)]
181    private void AfterDeserialization() {
182      RegisterEvents();
183    }
184
185    public override IDeepCloneable Clone(Cloner cloner) {
186      return new LbfgsAlgorithm(this, cloner);
187    }
188
189    #region events
190    private void RegisterEvents() {
191      if (Problem != null) {
192        RegisterSolutionCreatorEvents();
193        RegisterEvaluatorEvents();
194      }
195    }
196
197    protected override void OnProblemChanged() {
198      base.OnProblemChanged();
199      if (Problem != null) {
200        RegisterEvents();
201        solutionCreator.OperatorParameter.ActualName = Problem.SolutionCreatorParameter.Name;
202        solutionCreator.OperatorParameter.Hidden = true;
203        evaluator.OperatorParameter.ActualName = Problem.EvaluatorParameter.Name;
204        evaluator.OperatorParameter.Hidden = true;
205        UpdateAnalyzers();
206        ParameterizeOperators();
207      }
208    }
209
210    protected override void Problem_SolutionCreatorChanged(object sender, EventArgs e) {
211      base.Problem_SolutionCreatorChanged(sender, e);
212      RegisterSolutionCreatorEvents();
213      ParameterizeOperators();
214    }
215
216    protected override void Problem_EvaluatorChanged(object sender, EventArgs e) {
217      base.Problem_EvaluatorChanged(sender, e);
218      RegisterEvaluatorEvents();
219      ParameterizeOperators();
220    }
221
222    protected override void Problem_OperatorsChanged(object sender, EventArgs e) {
223      base.Problem_OperatorsChanged(sender, e);
224      RegisterEvents();
225      solutionCreator.OperatorParameter.ActualName = Problem.SolutionCreatorParameter.Name;
226      solutionCreator.OperatorParameter.Hidden = true;
227      evaluator.OperatorParameter.ActualName = Problem.EvaluatorParameter.Name;
228      evaluator.OperatorParameter.Hidden = true;
229      UpdateAnalyzers();
230      ParameterizeOperators();
231    }
232
233    private void RegisterSolutionCreatorEvents() {
234      var realVectorCreator = Problem.SolutionCreator as IRealVectorCreator;
235      // ignore if we have a different kind of problem
236      if (realVectorCreator != null) {
237        realVectorCreator.RealVectorParameter.ActualNameChanged += (sender, args) => ParameterizeOperators();
238      }
239    }
240
241    private void RegisterEvaluatorEvents() {
242      Problem.Evaluator.QualityParameter.ActualNameChanged += (sender, args) => ParameterizeOperators();
243    }
244    #endregion
245
246    protected override void OnStarted() {
247      var realVectorCreator = Problem.SolutionCreator as IRealVectorCreator;
248      // must catch the case that user loaded an unsupported problem
249      if (realVectorCreator == null)
250        throw new InvalidOperationException("LM-BFGS only works with problems using a real-value encoding.");
251      base.OnStarted();
252    }
253
254    public override void Prepare() {
255      if (Problem != null) base.Prepare();
256    }
257
258    private void UpdateAnalyzers() {
259      Analyzer.Operators.Clear();
260      if (Problem != null) {
261        foreach (var a in Problem.Operators.OfType<IAnalyzer>()) {
262          foreach (var param in a.Parameters.OfType<IScopeTreeLookupParameter>())
263            param.Depth = 0;
264          Analyzer.Operators.Add(a, a.EnabledByDefault);
265        }
266      }
267      Analyzer.Operators.Add(analyzer, analyzer.EnabledByDefault);
268    }
269
270    private void ParameterizeOperators() {
271      var realVectorCreator = Problem.SolutionCreator as IRealVectorCreator;
272      // ignore if we have a different kind of problem
273      if (realVectorCreator != null) {
274        var realVectorParameterName = realVectorCreator.RealVectorParameter.ActualName;
275        initializer.PointParameter.ActualName = realVectorParameterName;
276        initializer.PointParameter.Hidden = true;
277        makeStep.PointParameter.ActualName = realVectorParameterName;
278        makeStep.PointParameter.Hidden = true;
279        analyzer.PointParameter.ActualName = realVectorParameterName;
280        analyzer.PointParameter.Hidden = true;
281      }
282
283      var qualityParameterName = Problem.Evaluator.QualityParameter.ActualName;
284      updateResults.QualityParameter.ActualName = qualityParameterName;
285      updateResults.QualityParameter.Hidden = true;
286      analyzer.QualityParameter.ActualName = qualityParameterName;
287      analyzer.QualityParameter.Hidden = true;
288    }
289  }
290}
Note: See TracBrowser for help on using the repository browser.