Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Problem.cs @ 15964

Last change on this file since 15964 was 15964, checked in by gkronber, 6 years ago

#2925 first exploratory implementation of AutoDiff for modelling of dynamical systems

File size: 16.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Problems.Instances;
33
34namespace HeuristicLab.Problems.DynamicalSystemsModelling {
35  public class Vector {
36    public readonly static Vector Zero = new Vector(new double[0]);
37
38    public static Vector operator +(Vector a, Vector b) {
39      if (a == Zero) return b;
40      if (b == Zero) return a;
41      Debug.Assert(a.arr.Length == b.arr.Length);
42      var res = new double[a.arr.Length];
43      for (int i = 0; i < res.Length; i++)
44        res[i] = a.arr[i] + b.arr[i];
45      return new Vector(res);
46    }
47    public static Vector operator -(Vector a, Vector b) {
48      if (b == Zero) return a;
49      if (a == Zero) return -b;
50      Debug.Assert(a.arr.Length == b.arr.Length);
51      var res = new double[a.arr.Length];
52      for (int i = 0; i < res.Length; i++)
53        res[i] = a.arr[i] - b.arr[i];
54      return new Vector(res);
55    }
56    public static Vector operator -(Vector v) {
57      if (v == Zero) return Zero;
58      for (int i = 0; i < v.arr.Length; i++)
59        v.arr[i] = -v.arr[i];
60      return v;
61    }
62
63    public static Vector operator *(double s, Vector v) {
64      if (v == Zero) return Zero;
65      if (s == 0.0) return Zero;
66      var res = new double[v.arr.Length];
67      for (int i = 0; i < res.Length; i++)
68        res[i] = s * v.arr[i];
69      return new Vector(res);
70    }
71    public static Vector operator *(Vector v, double s) {
72      return s * v;
73    }
74    public static Vector operator /(double s, Vector v) {
75      if (s == 0.0) return Zero;
76      if (v == Zero) throw new ArgumentException("Division by zero vector");
77      var res = new double[v.arr.Length];
78      for (int i = 0; i < res.Length; i++)
79        res[i] = 1.0 / v.arr[i];
80      return new Vector(res);
81    }
82    public static Vector operator /(Vector v, double s) {
83      return v * 1.0 / s;
84    }
85
86
87    private readonly double[] arr; // backing array;
88
89    public Vector(double[] v) {
90      this.arr = v;
91    }
92
93    public void CopyTo(double[] target) {
94      Debug.Assert(arr.Length <= target.Length);
95      Array.Copy(arr, target, arr.Length);
96    }
97  }
98
99  [Item("Dynamical Systems Modelling Problem", "TODO")]
100  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 900)]
101  [StorableClass]
102  public sealed class Problem : SymbolicExpressionTreeProblem, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
103
104    #region parameter names
105    private const string ProblemDataParameterName = "ProblemData";
106    #endregion
107
108    #region Parameter Properties
109    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
110
111    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
112      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
113    }
114    #endregion
115
116    #region Properties
117    public IRegressionProblemData ProblemData {
118      get { return ProblemDataParameter.Value; }
119      set { ProblemDataParameter.Value = value; }
120    }
121    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
122    #endregion
123
124    public event EventHandler ProblemDataChanged;
125
126    public override bool Maximization {
127      get { return false; } // we minimize NMSE
128    }
129
130    #region item cloning and persistence
131    // persistence
132    [StorableConstructor]
133    private Problem(bool deserializing) : base(deserializing) { }
134    [StorableHook(HookType.AfterDeserialization)]
135    private void AfterDeserialization() {
136      RegisterEventHandlers();
137    }
138
139    // cloning
140    private Problem(Problem original, Cloner cloner)
141      : base(original, cloner) {
142      RegisterEventHandlers();
143    }
144    public override IDeepCloneable Clone(Cloner cloner) { return new Problem(this, cloner); }
145    #endregion
146
147    public Problem()
148      : base() {
149      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data captured from the dynamical system", new RegressionProblemData()));
150
151      // TODO: support multiple target variables
152
153      var g = new SimpleSymbolicExpressionGrammar(); // empty grammar is replaced in UpdateGrammar()
154      base.Encoding = new SymbolicExpressionTreeEncoding(g, 10, 5);         // small for testing
155
156      UpdateGrammar();
157      RegisterEventHandlers();
158    }
159
160
161    public override double Evaluate(ISymbolicExpressionTree tree, IRandom random) {
162      var problemData = ProblemData;
163      var rows = ProblemData.TrainingIndices.ToArray();
164      var target = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
165
166      var nodeIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
167     
168      foreach(var node in tree.Root.IterateNodesPrefix().Where(n => IsConstantNode(n))) {
169        nodeIdx.Add(node, nodeIdx.Count);
170      }
171
172      var theta = nodeIdx.Select(_ => random.NextDouble() * 2.0 - 1.0).ToArray(); // init params randomly from Unif(-1,1)
173
174      double[] optTheta = new double[0];
175      if (theta.Length > 0) {
176        alglib.minlbfgsstate state;
177        alglib.minlbfgsreport report;
178        alglib.minlbfgscreate(Math.Min(theta.Length, 5), theta, out state);
179        alglib.minlbfgssetcond(state, 0.0, 0.0, 0.0, 100);
180        alglib.minlbfgsoptimize(state, EvaluateObjectiveAndGradient, null, new object[] { tree, problemData, nodeIdx });
181        alglib.minlbfgsresults(state, out optTheta, out report);
182
183        /*
184         *
185         *         L-BFGS algorithm results
186
187          INPUT PARAMETERS:
188              State   -   algorithm state
189
190          OUTPUT PARAMETERS:
191              X       -   array[0..N-1], solution
192              Rep     -   optimization report:
193                          * Rep.TerminationType completetion code:
194                              * -7    gradient verification failed.
195                                      See MinLBFGSSetGradientCheck() for more information.
196                              * -2    rounding errors prevent further improvement.
197                                      X contains best point found.
198                              * -1    incorrect parameters were specified
199                              *  1    relative function improvement is no more than
200                                      EpsF.
201                              *  2    relative step is no more than EpsX.
202                              *  4    gradient norm is no more than EpsG
203                              *  5    MaxIts steps was taken
204                              *  7    stopping conditions are too stringent,
205                                      further improvement is impossible
206                          * Rep.IterationsCount contains iterations count
207                          * NFEV countains number of function calculations
208         */
209        if (report.terminationtype < 0) return double.MaxValue;
210      }
211
212      // perform evaluation for optimal theta to get quality value
213      double[] grad = new double[optTheta.Length];
214      double optQuality = double.NaN;
215      EvaluateObjectiveAndGradient(optTheta, ref optQuality, grad, new object[] { tree, problemData, nodeIdx});
216      if (double.IsNaN(optQuality) || double.IsInfinity(optQuality)) return 10E6; // return a large value (TODO: be consistent by using NMSE)
217      // TODO: write back values
218      return optQuality;
219    }
220
221    private static void EvaluateObjectiveAndGradient(double[] x, ref double f, double[] grad, object obj) {
222      var tree = (ISymbolicExpressionTree)((object[])obj)[0];
223      var problemData = (IRegressionProblemData)((object[])obj)[1];
224      var nodeIdx = (Dictionary<ISymbolicExpressionTreeNode, int>)((object[])obj)[2];
225     
226
227      var predicted = Integrate(
228        new[] { tree },  // we assume tree contains an expression for the change of the target variable over time y'(t)
229        problemData.Dataset,
230        problemData.AllowedInputVariables.ToArray(),
231        new[] { problemData.TargetVariable },
232        problemData.TrainingIndices,
233        nodeIdx,
234        x).ToArray();
235
236      // objective function is MSE
237      f = 0.0;
238      int n = predicted.Length;
239      double invN = 1.0 / n;
240      var g = Vector.Zero;
241      foreach(var pair in predicted.Zip(problemData.TargetVariableTrainingValues, Tuple.Create)) {
242        var y_pred = pair.Item1;
243        var y = pair.Item2;
244
245        var res = (y - y_pred.Item1);
246        var ressq = res * res;
247        f += ressq * invN;
248        g += -2.0 * res * y_pred.Item2 * invN;
249      }
250
251      g.CopyTo(grad);
252    }
253
254
255    private static IEnumerable<Tuple<double, Vector>> Integrate(
256      ISymbolicExpressionTree[] trees, IDataset dataset, string[] inputVariables, string[] targetVariables, IEnumerable<int> rows,
257      Dictionary<ISymbolicExpressionTreeNode, int> nodeIdx, double[] parameterValues) {
258
259      int NUM_STEPS = 1;
260      double h = 1.0 / NUM_STEPS;
261
262      // return first value as stored in the dataset
263      yield return Tuple.Create(dataset.GetDoubleValue(targetVariables.First(), rows.First()), Vector.Zero);
264
265      // integrate forward starting with known values for the target in t0
266
267      var variableValues = new Dictionary<string, Tuple<double, Vector>>();
268      var t0 = rows.First();
269      foreach (var varName in inputVariables) {
270        variableValues.Add(varName, Tuple.Create(dataset.GetDoubleValue(varName, t0), Vector.Zero));
271      }
272      foreach (var varName in targetVariables) {
273        variableValues.Add(varName, Tuple.Create(dataset.GetDoubleValue(varName, t0), Vector.Zero));
274      }
275
276      foreach (var t in rows.Skip(1)) {
277        for (int step = 0; step < NUM_STEPS; step++) {
278          var deltaValues = new Dictionary<string, Tuple<double, Vector>>();
279          foreach (var tup in trees.Zip(targetVariables, Tuple.Create)) {
280            var tree = tup.Item1;
281            var targetVarName = tup.Item2;
282            // skip programRoot and startSymbol
283            var res = InterpretRec(tree.Root.GetSubtree(0).GetSubtree(0), variableValues, nodeIdx, parameterValues);
284            deltaValues.Add(targetVarName, res);
285          }
286
287          // update variableValues for next step
288          foreach (var kvp in deltaValues) {
289            var oldVal = variableValues[kvp.Key];
290            variableValues[kvp.Key] = Tuple.Create(
291              oldVal.Item1 + h * kvp.Value.Item1,
292              oldVal.Item2 + h * kvp.Value.Item2
293            );
294          }
295        }
296
297        // yield target values
298        foreach (var varName in targetVariables) {
299          yield return variableValues[varName];
300        }
301
302        // update for next time step
303        foreach (var varName in inputVariables) {
304          variableValues[varName] = Tuple.Create(dataset.GetDoubleValue(varName, t), Vector.Zero);
305        }
306      }
307    }
308
309    private static Tuple<double, Vector> InterpretRec(
310      ISymbolicExpressionTreeNode node,
311      Dictionary<string, Tuple<double, Vector>> variableValues,
312      Dictionary<ISymbolicExpressionTreeNode, int> nodeIdx,
313      double[] parameterValues
314        ) {
315
316      switch (node.Symbol.Name) {
317        case "+": {
318            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
319            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
320
321            return Tuple.Create(l.Item1 + r.Item1, l.Item2 + r.Item2);
322          }
323        case "*": {
324            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx,parameterValues);
325            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
326
327            return Tuple.Create(l.Item1 * r.Item1, l.Item2 * r.Item1 + l.Item1 * r.Item2);
328          }
329
330        case "-": {
331            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx,parameterValues);
332            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
333
334            return Tuple.Create(l.Item1 - r.Item1, l.Item2 - r.Item2);
335          }
336        case "%": {
337            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx,parameterValues);
338            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
339
340            // protected division
341            if (r.Item1.IsAlmost(0.0)) {
342              return Tuple.Create(0.0, Vector.Zero);
343            } else {
344              return Tuple.Create(
345                l.Item1 / r.Item1,
346                l.Item1 * -1.0 / (r.Item1 * r.Item1) * r.Item2 + 1.0 / r.Item1 * l.Item2 // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f'
347                );
348            }
349          }
350        default: {
351            // distinguish other cases
352            if (IsConstantNode(node)) {
353              var vArr = new double[parameterValues.Length]; // backing array for vector
354              vArr[nodeIdx[node]] = 1.0;
355              var g = new Vector(vArr);
356              return Tuple.Create(parameterValues[nodeIdx[node]], g);
357            } else {
358              // assume a variable name
359              var varName = node.Symbol.Name;
360              return variableValues[varName];
361            }
362          }
363      }
364    }
365
366
367    #region events
368    private void RegisterEventHandlers() {
369      ProblemDataParameter.ValueChanged += new EventHandler(ProblemDataParameter_ValueChanged);
370      if (ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed);
371    }
372
373    private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) {
374      ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed);
375      OnProblemDataChanged();
376      OnReset();
377    }
378
379    private void ProblemData_Changed(object sender, EventArgs e) {
380      OnReset();
381    }
382
383    private void OnProblemDataChanged() {
384      UpdateGrammar();
385
386      var handler = ProblemDataChanged;
387      if (handler != null) handler(this, EventArgs.Empty);
388    }
389
390    private void UpdateGrammar() {
391      // whenever ProblemData is changed we create a new grammar with the necessary symbols
392      var g = new SimpleSymbolicExpressionGrammar();
393      g.AddSymbols(new[] {
394        "+",
395        "*",
396//        "%", // % is protected division 1/0 := 0 // removed for testing
397        "-",
398      }, 2, 2);
399
400      // TODO
401      //g.AddSymbols(new[] {
402      //  "exp",
403      //  "log", // log( <expr> ) // TODO: init a theta to ensure the value is always positive
404      //  "exp_minus" // exp((-1) * <expr>
405      //}, 1, 1);
406
407      foreach (var variableName in ProblemData.AllowedInputVariables)
408        g.AddTerminalSymbol(variableName);
409      foreach (var variableName in new string[] { ProblemData.TargetVariable }) // TODO: multiple target variables
410        g.AddTerminalSymbol(variableName);
411
412      // generate symbols for numeric parameters for which the value is optimized using AutoDiff
413      // we generate multiple symbols to balance the probability for selecting a numeric parameter in the generation of random trees
414      var numericConstantsFactor = 2.0;
415      for (int i = 0; i < numericConstantsFactor * (ProblemData.AllowedInputVariables.Count() + 1); i++) {
416        g.AddTerminalSymbol("θ" + i); // numeric parameter for which the value is optimized using AutoDiff
417      }
418      Encoding.Grammar = g;
419    }
420    #endregion
421
422    #region Import & Export
423    public void Load(IRegressionProblemData data) {
424      Name = data.Name;
425      Description = data.Description;
426      ProblemData = data;
427    }
428
429    public IRegressionProblemData Export() {
430      return ProblemData;
431    }
432    #endregion
433
434
435    #region  helper
436
437    private static bool IsConstantNode(ISymbolicExpressionTreeNode n) {
438      return n.Symbol.Name.StartsWith("θ");
439    }
440
441    #endregion
442
443  }
444}
Note: See TracBrowser for help on using the repository browser.