source: trunk/sources/HeuristicLab.Algorithms.GradientDescent/3.3/Lbfgs.cs @ 9127

Last change on this file since 9127 was 9127, checked in by gkronber, 9 years ago

#1423 changed LBFGS to take any single-objective optimization algorithm and throw an exception if the solution creator does not have the correct type in OnStarted().
Also added wiring for parameter names.

File size: 10.2 KB
Line 
1
2#region License Information
3/* HeuristicLab
4 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.RealVectorEncoding;
28using HeuristicLab.Operators;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.TestFunctions;
33using HeuristicLab.Random;
34
35namespace HeuristicLab.Algorithms.GradientDescent {
36  /// <summary>
37  /// Limited-Memory BFGS optimization algorithm.
38  /// </summary>
39  [Item("LM-BFGS", "The limited-memory BFGS (Broyden–Fletcher–Goldfarb–Shanno) optimization algorithm.")]
40  [Creatable("Algorithms")]
41  [StorableClass]
42  public sealed class LbfgsAlgorithm : HeuristicOptimizationEngineAlgorithm, IStorableContent {
43    public override Type ProblemType {
44      get { return typeof(ISingleObjectiveHeuristicOptimizationProblem); }
45    }
46
47    public new ISingleObjectiveHeuristicOptimizationProblem Problem {
48      get { return (ISingleObjectiveHeuristicOptimizationProblem)base.Problem; }
49      set { base.Problem = value; }
50    }
51
52    public string Filename { get; set; }
53
54    private const string MaxIterationsParameterName = "MaxIterations";
55    private const string ApproximateGradientsParameterName = "ApproximateGradients";
56    private const string SeedParameterName = "Seed";
57    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
58
59    #region parameter properties
60    public IValueParameter<IntValue> MaxIterationsParameter {
61      get { return (IValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
62    }
63    public IValueParameter<IntValue> SeedParameter {
64      get { return (IValueParameter<IntValue>)Parameters[SeedParameterName]; }
65    }
66    public IValueParameter<BoolValue> SetSeedRandomlyParameter {
67      get { return (IValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
68    }
69    #endregion
70    #region properties
71    public int MaxIterations {
72      set { MaxIterationsParameter.Value.Value = value; }
73      get { return MaxIterationsParameter.Value.Value; }
74    }
75    public int Seed { get { return SeedParameter.Value.Value; } set { SeedParameter.Value.Value = value; } }
76    public bool SetSeedRandomly { get { return SetSeedRandomlyParameter.Value.Value; } set { SetSeedRandomlyParameter.Value.Value = value; } }
77    #endregion
78
79    [Storable]
80    private LbfgsInitializer initializer;
81    [Storable]
82    private LbfgsMakeStep makeStep;
83    [Storable]
84    private LbfgsUpdateResults updateResults;
85    [Storable]
86    private LbfgsAnalyzer analyzer;
87    [Storable]
88    private LbfgsAnalyzer finalAnalyzer;
89    [Storable]
90    private Placeholder solutionCreator;
91    [Storable]
92    private Placeholder evaluator;
93
94    [StorableConstructor]
95    private LbfgsAlgorithm(bool deserializing) : base(deserializing) { }
96    private LbfgsAlgorithm(LbfgsAlgorithm original, Cloner cloner)
97      : base(original, cloner) {
98      initializer = cloner.Clone(original.initializer);
99      makeStep = cloner.Clone(original.makeStep);
100      updateResults = cloner.Clone(original.updateResults);
101      analyzer = cloner.Clone(original.analyzer);
102      finalAnalyzer = cloner.Clone(original.finalAnalyzer);
103      solutionCreator = cloner.Clone(original.solutionCreator);
104      evaluator = cloner.Clone(original.evaluator);
105      RegisterEvents();
106    }
107    public LbfgsAlgorithm()
108      : base() {
109      this.name = ItemName;
110      this.description = ItemDescription;
111
112      Parameters.Add(new ValueParameter<IntValue>(MaxIterationsParameterName, "The maximal number of iterations for.", new IntValue(20)));
113      Parameters.Add(new ValueParameter<IntValue>(SeedParameterName, "The random seed used to initialize the new pseudo random number generator.", new IntValue(0)));
114      Parameters.Add(new ValueParameter<BoolValue>(SetSeedRandomlyParameterName, "True if the random seed should be set to a random value, otherwise false.", new BoolValue(true)));
115      Parameters.Add(new ValueParameter<BoolValue>(ApproximateGradientsParameterName, "Indicates that gradients should be approximated.", new BoolValue(true)));
116      Parameters[ApproximateGradientsParameterName].Hidden = true; // should not be changed
117
118      var randomCreator = new RandomCreator();
119      solutionCreator = new Placeholder();
120      initializer = new LbfgsInitializer();
121      makeStep = new LbfgsMakeStep();
122      var branch = new ConditionalBranch();
123      evaluator = new Placeholder();
124      updateResults = new LbfgsUpdateResults();
125      analyzer = new LbfgsAnalyzer();
126      finalAnalyzer = new LbfgsAnalyzer();
127
128      OperatorGraph.InitialOperator = randomCreator;
129
130      randomCreator.SeedParameter.ActualName = SeedParameterName;
131      randomCreator.SeedParameter.Value = null;
132      randomCreator.SetSeedRandomlyParameter.ActualName = SetSeedRandomlyParameterName;
133      randomCreator.SetSeedRandomlyParameter.Value = null;
134      randomCreator.Successor = solutionCreator;
135
136      solutionCreator.Name = "Solution Creator (placeholder)";
137      solutionCreator.Successor = initializer;
138
139      initializer.IterationsParameter.ActualName = MaxIterationsParameterName;
140      initializer.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
141      initializer.Successor = makeStep;
142
143      makeStep.StateParameter.ActualName = initializer.StateParameter.Name;
144      makeStep.Successor = branch;
145
146      branch.ConditionParameter.ActualName = makeStep.TerminationCriterionParameter.Name;
147      branch.FalseBranch = evaluator;
148      branch.TrueBranch = finalAnalyzer;
149
150      evaluator.Name = "Evaluator (placeholder)";
151      evaluator.Successor = updateResults;
152
153      updateResults.StateParameter.ActualName = initializer.StateParameter.Name;
154      updateResults.ApproximateGradientsParameter.ActualName = ApproximateGradientsParameterName;
155      updateResults.Successor = analyzer;
156
157      analyzer.StateParameter.ActualName = initializer.StateParameter.Name;
158      analyzer.Successor = makeStep;
159
160      finalAnalyzer.PointsTableParameter.ActualName = analyzer.PointsTableParameter.ActualName;
161      finalAnalyzer.QualityGradientsTableParameter.ActualName = analyzer.QualityGradientsTableParameter.ActualName;
162      finalAnalyzer.QualitiesTableParameter.ActualName = analyzer.QualitiesTableParameter.ActualName;
163    }
164
165    [StorableHook(HookType.AfterDeserialization)]
166    private void AfterDeserialization() {
167      RegisterEvents();
168    }
169
170    public override IDeepCloneable Clone(Cloner cloner) {
171      return new LbfgsAlgorithm(this, cloner);
172    }
173
174    #region events
175    private void RegisterEvents() {
176      if (Problem != null) {
177        RegisterSolutionCreatorEvents();
178        RegisterEvaluatorEvents();
179      }
180    }
181
182    protected override void OnProblemChanged() {
183      base.OnProblemChanged();
184      if (Problem != null) {
185        RegisterEvents();
186        solutionCreator.OperatorParameter.ActualName = Problem.SolutionCreatorParameter.Name;
187        evaluator.OperatorParameter.ActualName = Problem.EvaluatorParameter.Name;
188      }
189    }
190
191    protected override void Problem_SolutionCreatorChanged(object sender, EventArgs e) {
192      base.Problem_SolutionCreatorChanged(sender, e);
193      RegisterSolutionCreatorEvents();
194      ParameterizeOperators();
195    }
196
197    protected override void Problem_EvaluatorChanged(object sender, EventArgs e) {
198      base.Problem_EvaluatorChanged(sender, e);
199      RegisterEvaluatorEvents();
200      ParameterizeOperators();
201    }
202
203    private void RegisterSolutionCreatorEvents() {
204      var realVectorCreator = Problem.SolutionCreator as RealVectorCreator;
205      // ignore if we have a different kind of problem
206      if (realVectorCreator != null) {
207        realVectorCreator.RealVectorParameter.ActualNameChanged += (sender, args) => ParameterizeOperators();
208      }
209    }
210
211    private void RegisterEvaluatorEvents() {
212      Problem.Evaluator.QualityParameter.ActualNameChanged += (sender, args) => ParameterizeOperators();
213    }
214    #endregion
215
216    protected override void OnStarted() {
217      var realVectorCreator = Problem.SolutionCreator as RealVectorCreator;
218      // must catch the case that user loaded an unsupported problem
219      if (realVectorCreator == null)
220        throw new InvalidOperationException("LM-BFGS only works with problems using a real-value encoding.");
221      base.OnStarted();
222    }
223
224    public override void Prepare() {
225      if (Problem != null) base.Prepare();
226    }
227
228    private void ParameterizeOperators() {
229      var realVectorCreator = Problem.SolutionCreator as RealVectorCreator;
230      // ignore if we have a different kind of problem
231      if (realVectorCreator != null) {
232        var realVectorParameterName = realVectorCreator.RealVectorParameter.ActualName;
233        initializer.PointParameter.ActualName = realVectorParameterName;
234        makeStep.PointParameter.ActualName = realVectorParameterName;
235        analyzer.PointParameter.ActualName = realVectorParameterName;
236        finalAnalyzer.PointParameter.ActualName = realVectorParameterName;
237      }
238
239      var qualityParameterName = Problem.Evaluator.QualityParameter.ActualName;
240      updateResults.QualityParameter.ActualName = qualityParameterName;
241      analyzer.QualityParameter.ActualName = qualityParameterName;
242      finalAnalyzer.QualityParameter.ActualName = qualityParameterName;
243    }
244  }
245}
Note: See TracBrowser for help on using the repository browser.