source: branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.GradientDescent/3.3/LbfgsAnalyzer.cs @ 16057

Last change on this file since 16057 was 16057, checked in by jkarder, 15 months ago

#2839:

File size: 8.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Linq;
23using HeuristicLab.Analysis;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.RealVectorEncoding;
28using HeuristicLab.Operators;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Algorithms.GradientDescent {
34  [StorableClass]
35  [Item(Name = "LBFGS Analyzer", Description = "Analyzer to collect results for the LM-BFGS algorithm.")]
36  public sealed class LbfgsAnalyzer : SingleSuccessorOperator, IAnalyzer {
37    private const string PointParameterName = "Point";
38    private const string QualityGradientsParameterName = "QualityGradients";
39    private const string QualityParameterName = "Quality";
40    private const string ResultCollectionParameterName = "Results";
41    private const string QualitiesTableParameterName = "Qualities";
42    private const string PointsTableParameterName = "PointTable";
43    private const string QualityGradientsTableParameterName = "QualityGradientsTable";
44    private const string StateParameterName = "State";
45    private const string ApproximateGradientsParameterName = "ApproximateGradients";
46
47    #region Parameter Properties
48    public ILookupParameter<RealVector> QualityGradientsParameter {
49      get { return (ILookupParameter<RealVector>)Parameters[QualityGradientsParameterName]; }
50    }
51    public ILookupParameter<RealVector> PointParameter {
52      get { return (ILookupParameter<RealVector>)Parameters[PointParameterName]; }
53    }
54    public ILookupParameter<DoubleValue> QualityParameter {
55      get { return (ILookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
56    }
57    public ILookupParameter<ResultCollection> ResultCollectionParameter {
58      get { return (ILookupParameter<ResultCollection>)Parameters[ResultCollectionParameterName]; }
59    }
60    public ILookupParameter<DataTable> QualitiesTableParameter {
61      get { return (ILookupParameter<DataTable>)Parameters[QualitiesTableParameterName]; }
62    }
63    public ILookupParameter<DataTable> PointsTableParameter {
64      get { return (ILookupParameter<DataTable>)Parameters[PointsTableParameterName]; }
65    }
66    public ILookupParameter<DataTable> QualityGradientsTableParameter {
67      get { return (ILookupParameter<DataTable>)Parameters[QualityGradientsTableParameterName]; }
68    }
69    public ILookupParameter<LbfgsState> StateParameter {
70      get { return (ILookupParameter<LbfgsState>)Parameters[StateParameterName]; }
71    }
72    public ILookupParameter<BoolValue> ApproximateGradientsParameter {
73      get { return (ILookupParameter<BoolValue>)Parameters[ApproximateGradientsParameterName]; }
74    }
75    #endregion
76
77    #region Properties
78    private RealVector QualityGradients { get { return QualityGradientsParameter.ActualValue; } }
79    private RealVector Point { get { return PointParameter.ActualValue; } }
80    private DoubleValue Quality { get { return QualityParameter.ActualValue; } }
81    private ResultCollection ResultCollection { get { return ResultCollectionParameter.ActualValue; } }
82    private BoolValue ApproximateGradients { get { return ApproximateGradientsParameter.ActualValue; } }
83
84    public bool EnabledByDefault {
85      get { return true; }
86    }
87
88    #endregion
89
90    [StorableConstructor]
91    private LbfgsAnalyzer(bool deserializing) : base(deserializing) { }
92    private LbfgsAnalyzer(LbfgsAnalyzer original, Cloner cloner) : base(original, cloner) { }
93    public LbfgsAnalyzer()
94      : base() {
95      // in
96      Parameters.Add(new LookupParameter<RealVector>(PointParameterName, "The current point of the function to optimize."));
97      Parameters.Add(new LookupParameter<RealVector>(QualityGradientsParameterName, "The current gradients of the function to optimize."));
98      Parameters.Add(new LookupParameter<DoubleValue>(QualityParameterName, "The current value of the function to optimize."));
99      Parameters.Add(new LookupParameter<DataTable>(QualitiesTableParameterName, "The table of all visited quality values."));
100      Parameters.Add(new LookupParameter<DataTable>(PointsTableParameterName, "The table of all visited points."));
101      Parameters.Add(new LookupParameter<DataTable>(QualityGradientsTableParameterName, "The table of all visited gradient values."));
102      Parameters.Add(new LookupParameter<LbfgsState>(StateParameterName, "The state of the LM-BFGS optimization algorithm."));
103      Parameters.Add(new LookupParameter<BoolValue>(ApproximateGradientsParameterName,
104                                              "Flag that indicates if gradients should be approximated."));
105
106      // in & out
107      Parameters.Add(new LookupParameter<ResultCollection>(ResultCollectionParameterName, "The result collection of the algorithm."));
108    }
109
110    public override IDeepCloneable Clone(Cloner cloner) {
111      return new LbfgsAnalyzer(this, cloner);
112    }
113
114    public override IOperation Apply() {
115      if (StateParameter.ActualValue.State.xupdated) {
116        var f = Quality.Value;
117        double[] g;
118        if (ApproximateGradients.Value) {
119          g = StateParameter.ActualValue.State.g;
120        } else {
121          g = QualityGradients.ToArray();
122        }
123        var x = Point.ToArray();
124        var resultCollection = ResultCollection;
125
126        // create and add tables on the first time
127        if (QualitiesTableParameter.ActualValue == null) {
128          QualitiesTableParameter.ActualValue = new DataTable(QualityParameter.ActualName);
129          PointsTableParameter.ActualValue = new DataTable(PointParameter.ActualName);
130          QualityGradientsTableParameter.ActualValue = new DataTable(QualityGradientsParameter.ActualName);
131
132          QualitiesTableParameter.ActualValue.Rows.Add(new DataRow(QualityParameter.ActualName));
133
134          resultCollection.Add(new Result(QualitiesTableParameter.ActualName,
135                                          QualitiesTableParameter.ActualValue));
136          resultCollection.Add(new Result(PointsTableParameter.ActualName,
137                                          PointsTableParameter.ActualValue));
138          resultCollection.Add(new Result(QualityGradientsTableParameter.ActualName,
139                                          QualityGradientsTableParameter.ActualValue));
140          resultCollection.Add(new Result(QualityParameter.ActualName, QualityParameter.ActualValue));
141        }
142
143        // update
144        var functionValueRow = QualitiesTableParameter.ActualValue.Rows[QualityParameter.ActualName];
145        resultCollection[QualityParameter.ActualName].Value = Quality;
146        functionValueRow.Values.Add(f);
147
148        AddValues(g, QualityGradientsTableParameter.ActualValue);
149        AddValues(x, PointsTableParameter.ActualValue);
150      }
151      return base.Apply();
152    }
153
154    private void AddValues(double[] x, DataTable dataTable) {
155      if (!dataTable.Rows.Any()) {
156        for (int i = 0; i < x.Length; i++) {
157          var newRow = new DataRow("x" + i);
158          newRow.Values.Add(x[i]);
159          dataTable.Rows.Add(newRow);
160        }
161      } else {
162        for (int i = 0; i < x.Length; i++) {
163          dataTable.Rows.ElementAt(i).Values.Add(x[i]);
164        }
165      }
166    }
167  }
168}
Note: See TracBrowser for help on using the repository browser.