source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Problem.cs @ 16152

Last change on this file since 16152 was 16152, checked in by gkronber, 2 years ago

#2925: removed obsolete comment and a statement introduced in r16126

File size: 33.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Analysis;
27using HeuristicLab.Collections;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Optimization;
33using HeuristicLab.Parameters;
34using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Problems.DataAnalysis.Symbolic;
37using HeuristicLab.Problems.Instances;
38
39namespace HeuristicLab.Problems.DynamicalSystemsModelling {
40  public class Vector {
41    public readonly static Vector Zero = new Vector(new double[0]);
42
43    public static Vector operator +(Vector a, Vector b) {
44      if (a == Zero) return b;
45      if (b == Zero) return a;
46      Debug.Assert(a.arr.Length == b.arr.Length);
47      var res = new double[a.arr.Length];
48      for (int i = 0; i < res.Length; i++)
49        res[i] = a.arr[i] + b.arr[i];
50      return new Vector(res);
51    }
52    public static Vector operator -(Vector a, Vector b) {
53      if (b == Zero) return a;
54      if (a == Zero) return -b;
55      Debug.Assert(a.arr.Length == b.arr.Length);
56      var res = new double[a.arr.Length];
57      for (int i = 0; i < res.Length; i++)
58        res[i] = a.arr[i] - b.arr[i];
59      return new Vector(res);
60    }
61    public static Vector operator -(Vector v) {
62      if (v == Zero) return Zero;
63      for (int i = 0; i < v.arr.Length; i++)
64        v.arr[i] = -v.arr[i];
65      return v;
66    }
67
68    public static Vector operator *(double s, Vector v) {
69      if (v == Zero) return Zero;
70      if (s == 0.0) return Zero;
71      var res = new double[v.arr.Length];
72      for (int i = 0; i < res.Length; i++)
73        res[i] = s * v.arr[i];
74      return new Vector(res);
75    }
76    public static Vector operator *(Vector v, double s) {
77      return s * v;
78    }
79    public static Vector operator /(double s, Vector v) {
80      if (s == 0.0) return Zero;
81      if (v == Zero) throw new ArgumentException("Division by zero vector");
82      var res = new double[v.arr.Length];
83      for (int i = 0; i < res.Length; i++)
84        res[i] = 1.0 / v.arr[i];
85      return new Vector(res);
86    }
87    public static Vector operator /(Vector v, double s) {
88      return v * 1.0 / s;
89    }
90
91
92    private readonly double[] arr; // backing array;
93
94    public Vector(double[] v) {
95      this.arr = v;
96    }
97
98    public void CopyTo(double[] target) {
99      Debug.Assert(arr.Length <= target.Length);
100      Array.Copy(arr, target, arr.Length);
101    }
102  }
103
104  [Item("Dynamical Systems Modelling Problem", "TODO")]
105  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 900)]
106  [StorableClass]
107  public sealed class Problem : SingleObjectiveBasicProblem<MultiEncoding>, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
108
109    #region parameter names
110    private const string ProblemDataParameterName = "Data";
111    private const string TargetVariablesParameterName = "Target variables";
112    private const string FunctionSetParameterName = "Function set";
113    private const string MaximumLengthParameterName = "Size limit";
114    private const string MaximumParameterOptimizationIterationsParameterName = "Max. parameter optimization iterations";
115    private const string NumberOfLatentVariablesParameterName = "Number of latent variables";
116    private const string NumericIntegrationStepsParameterName = "Steps for numeric integration";
117    #endregion
118
119    #region Parameter Properties
120    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
121
122    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
123      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
124    }
125    public IValueParameter<ReadOnlyCheckedItemCollection<StringValue>> TargetVariablesParameter {
126      get { return (IValueParameter<ReadOnlyCheckedItemCollection<StringValue>>)Parameters[TargetVariablesParameterName]; }
127    }
128    public IValueParameter<ReadOnlyCheckedItemCollection<StringValue>> FunctionSetParameter {
129      get { return (IValueParameter<ReadOnlyCheckedItemCollection<StringValue>>)Parameters[FunctionSetParameterName]; }
130    }
131    public IFixedValueParameter<IntValue> MaximumLengthParameter {
132      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumLengthParameterName]; }
133    }
134    public IFixedValueParameter<IntValue> MaximumParameterOptimizationIterationsParameter {
135      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumParameterOptimizationIterationsParameterName]; }
136    }
137    public IFixedValueParameter<IntValue> NumberOfLatentVariablesParameter {
138      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfLatentVariablesParameterName]; }
139    }
140    public IFixedValueParameter<IntValue> NumericIntegrationStepsParameter {
141      get { return (IFixedValueParameter<IntValue>)Parameters[NumericIntegrationStepsParameterName]; }
142    }
143    #endregion
144
145    #region Properties
146    public IRegressionProblemData ProblemData {
147      get { return ProblemDataParameter.Value; }
148      set { ProblemDataParameter.Value = value; }
149    }
150    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
151
152    public ReadOnlyCheckedItemCollection<StringValue> TargetVariables {
153      get { return TargetVariablesParameter.Value; }
154    }
155
156    public ReadOnlyCheckedItemCollection<StringValue> FunctionSet {
157      get { return FunctionSetParameter.Value; }
158    }
159
160    public int MaximumLength {
161      get { return MaximumLengthParameter.Value.Value; }
162    }
163    public int MaximumParameterOptimizationIterations {
164      get { return MaximumParameterOptimizationIterationsParameter.Value.Value; }
165    }
166    public int NumberOfLatentVariables {
167      get { return NumberOfLatentVariablesParameter.Value.Value; }
168    }
169    public int NumericIntegrationSteps {
170      get { return NumericIntegrationStepsParameter.Value.Value; }
171    }
172
173    #endregion                                                                                     
174
175    public event EventHandler ProblemDataChanged;
176
177    public override bool Maximization {
178      get { return false; } // we minimize NMSE
179    }
180
181    #region item cloning and persistence
182    // persistence
183    [StorableConstructor]
184    private Problem(bool deserializing) : base(deserializing) { }
185    [StorableHook(HookType.AfterDeserialization)]
186    private void AfterDeserialization() {
187      RegisterEventHandlers();
188    }
189
190    // cloning
191    private Problem(Problem original, Cloner cloner)
192      : base(original, cloner) {
193      RegisterEventHandlers();
194    }
195    public override IDeepCloneable Clone(Cloner cloner) { return new Problem(this, cloner); }
196    #endregion
197
198    public Problem()
199      : base() {
200      var targetVariables = new CheckedItemCollection<StringValue>().AsReadOnly(); // HACK: it would be better to provide a new class derived from IDataAnalysisProblem
201      var functions = CreateFunctionSet();
202      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data captured from the dynamical system. Use CSV import functionality to import data.", new RegressionProblemData()));
203      Parameters.Add(new ValueParameter<ReadOnlyCheckedItemCollection<StringValue>>(TargetVariablesParameterName, "Target variables (overrides setting in ProblemData)", targetVariables));
204      Parameters.Add(new ValueParameter<ReadOnlyCheckedItemCollection<StringValue>>(FunctionSetParameterName, "The list of allowed functions", functions));
205      Parameters.Add(new FixedValueParameter<IntValue>(MaximumLengthParameterName, "The maximally allowed length of each expression. Set to a small value (5 - 25). Default = 10", new IntValue(10)));
206      Parameters.Add(new FixedValueParameter<IntValue>(MaximumParameterOptimizationIterationsParameterName, "The maximum number of iterations for optimization of parameters (using L-BFGS). More iterations makes the algorithm slower, fewer iterations might prevent convergence in the optimization scheme. Default = 100", new IntValue(100)));
207      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfLatentVariablesParameterName, "Latent variables (unobserved variables) allow us to produce expressions which are integrated up and can be used in other expressions. They are handled similarly to target variables in forward simulation / integration. The difference to target variables is that there are no data to which the calculated values of latent variables are compared. Set to a small value (0 .. 5) as necessary (default = 0)", new IntValue(0)));
208      Parameters.Add(new FixedValueParameter<IntValue>(NumericIntegrationStepsParameterName, "Number of steps in the numeric integration that are taken from one row to the next (set to 1 to 100). More steps makes the algorithm slower, less steps worsens the accuracy of the numeric integration scheme.", new IntValue(10)));
209
210      RegisterEventHandlers();
211      InitAllParameters();
212
213      // TODO: do not clear selection of target variables when the input variables are changed
214      // TODO: UI hangs when selecting / deselecting input variables because the encoding is updated on each item
215    }
216
217
218    public override double Evaluate(Individual individual, IRandom random) {
219      var trees = individual.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
220
221      var problemData = ProblemData;
222      var rows = ProblemData.TrainingIndices.ToArray();
223      var targetVars = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
224      var latentVariables = Enumerable.Range(1, NumberOfLatentVariables).Select(i => "λ" + i).ToArray(); // TODO: must coincide with the variables which are actually defined in the grammar and also for which we actually have trees
225      var targetValues = new double[rows.Length, targetVars.Length];
226
227      // collect values of all target variables
228      var colIdx = 0;
229      foreach (var targetVar in targetVars) {
230        int rowIdx = 0;
231        foreach (var value in problemData.Dataset.GetDoubleValues(targetVar, rows)) {
232          targetValues[rowIdx, colIdx] = value;
233          rowIdx++;
234        }
235        colIdx++;
236      }
237
238      var nodeIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
239
240      foreach (var tree in trees) {
241        foreach (var node in tree.Root.IterateNodesPrefix().Where(n => IsConstantNode(n))) {
242          nodeIdx.Add(node, nodeIdx.Count);
243        }
244      }
245
246      var theta = nodeIdx.Select(_ => random.NextDouble() * 2.0 - 1.0).ToArray(); // init params randomly from Unif(-1,1)
247
248      double[] optTheta = new double[0];
249      if (theta.Length > 0) {
250        alglib.minlbfgsstate state;
251        alglib.minlbfgsreport report;
252        alglib.minlbfgscreate(Math.Min(theta.Length, 5), theta, out state);
253        alglib.minlbfgssetcond(state, 0.0, 0.0, 0.0, MaximumParameterOptimizationIterations);
254        alglib.minlbfgsoptimize(state, EvaluateObjectiveAndGradient, null,
255          new object[] { trees, targetVars, problemData, nodeIdx, targetValues, rows, NumericIntegrationSteps, latentVariables }); //TODO: create a type
256        alglib.minlbfgsresults(state, out optTheta, out report);
257
258        /*
259         *
260         *         L-BFGS algorithm results
261
262          INPUT PARAMETERS:
263              State   -   algorithm state
264
265          OUTPUT PARAMETERS:
266              X       -   array[0..N-1], solution
267              Rep     -   optimization report:
268                          * Rep.TerminationType completetion code:
269                              * -7    gradient verification failed.
270                                      See MinLBFGSSetGradientCheck() for more information.
271                              * -2    rounding errors prevent further improvement.
272                                      X contains best point found.
273                              * -1    incorrect parameters were specified
274                              *  1    relative function improvement is no more than
275                                      EpsF.
276                              *  2    relative step is no more than EpsX.
277                              *  4    gradient norm is no more than EpsG
278                              *  5    MaxIts steps was taken
279                              *  7    stopping conditions are too stringent,
280                                      further improvement is impossible
281                          * Rep.IterationsCount contains iterations count
282                          * NFEV countains number of function calculations
283         */
284        if (report.terminationtype < 0) return double.MaxValue;
285      }
286
287      // perform evaluation for optimal theta to get quality value
288      double[] grad = new double[optTheta.Length];
289      double optQuality = double.NaN;
290      EvaluateObjectiveAndGradient(optTheta, ref optQuality, grad,
291        new object[] { trees, targetVars, problemData, nodeIdx, targetValues, rows, NumericIntegrationSteps, latentVariables });
292      if (double.IsNaN(optQuality) || double.IsInfinity(optQuality)) return 10E6; // return a large value (TODO: be consistent by using NMSE)
293
294      individual["OptTheta"] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method
295      return optQuality;
296    }
297
298    private static void EvaluateObjectiveAndGradient(double[] x, ref double f, double[] grad, object obj) {
299      var trees = (ISymbolicExpressionTree[])((object[])obj)[0];
300      var targetVariables = (string[])((object[])obj)[1];
301      var problemData = (IRegressionProblemData)((object[])obj)[2];
302      var nodeIdx = (Dictionary<ISymbolicExpressionTreeNode, int>)((object[])obj)[3];
303      var targetValues = (double[,])((object[])obj)[4];
304      var rows = (int[])((object[])obj)[5];
305      var numericIntegrationSteps = (int)((object[])obj)[6];
306      var latentVariables = (string[])((object[])obj)[7];
307
308      var predicted = Integrate(
309        trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
310        problemData.Dataset,
311        problemData.AllowedInputVariables.ToArray(),
312        targetVariables,
313        latentVariables,
314        rows,
315        nodeIdx,                // TODO: is it Ok to use rows here ?
316        x, numericIntegrationSteps).ToArray();
317
318
319      // for normalized MSE = 1/variance(t) * MSE(t, pred)
320      // TODO: Perf. (by standardization of target variables before evaluation of all trees)
321      var invVar = Enumerable.Range(0, targetVariables.Length)
322        .Select(c => rows.Select(row => targetValues[row, c])) // colums vectors
323        .Select(vec => vec.Variance())
324        .Select(v => 1.0 / v)
325        .ToArray();
326
327      // objective function is NMSE
328      f = 0.0;
329      int n = predicted.Length;
330      double invN = 1.0 / n;
331      var g = Vector.Zero;
332      int r = 0;
333      foreach (var y_pred in predicted) {
334        for (int c = 0; c < y_pred.Length; c++) {
335
336          var y_pred_f = y_pred[c].Item1;
337          var y = targetValues[r, c];
338
339          var res = (y - y_pred_f);
340          var ressq = res * res;
341          f += ressq * invN * invVar[c];
342          g += -2.0 * res * y_pred[c].Item2 * invN * invVar[c];
343        }
344        r++;
345      }
346
347      g.CopyTo(grad);
348    }
349
350    public override void Analyze(Individual[] individuals, double[] qualities, ResultCollection results, IRandom random) {
351      base.Analyze(individuals, qualities, results, random);
352
353      if (!results.ContainsKey("Prediction (training)")) {
354        results.Add(new Result("Prediction (training)", typeof(ReadOnlyItemList<DataTable>)));
355      }
356      if (!results.ContainsKey("Prediction (test)")) {
357        results.Add(new Result("Prediction (test)", typeof(ReadOnlyItemList<DataTable>)));
358      }
359      if (!results.ContainsKey("Models")) {
360        results.Add(new Result("Models", typeof(ReadOnlyItemList<ISymbolicExpressionTree>)));
361      }
362
363      // TODO extract common functionality from Evaluate and Analyze
364      var bestIndividualAndQuality = this.GetBestIndividual(individuals, qualities);
365      var optTheta = ((DoubleArray)bestIndividualAndQuality.Item1["OptTheta"]).ToArray(); // see evaluate
366      var trees = bestIndividualAndQuality.Item1.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
367      var nodeIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
368
369
370      foreach (var tree in trees) {
371        foreach (var node in tree.Root.IterateNodesPrefix().Where(n => IsConstantNode(n))) {
372          nodeIdx.Add(node, nodeIdx.Count);
373        }
374      }
375      var problemData = ProblemData;
376      var targetVars = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
377      var latentVariables = Enumerable.Range(1, NumberOfLatentVariables).Select(i => "λ" + i).ToArray(); // TODO: must coincide with the variables which are actually defined in the grammar and also for which we actually have trees
378
379      var trainingList = new ItemList<DataTable>();
380      var trainingRows = ProblemData.TrainingIndices.ToArray();
381      var trainingPrediction = Integrate(
382       trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
383       problemData.Dataset,
384       problemData.AllowedInputVariables.ToArray(),
385       targetVars,
386       latentVariables,
387       trainingRows,
388       nodeIdx,
389       optTheta,
390       NumericIntegrationSteps).ToArray();
391
392      for (int colIdx = 0; colIdx < targetVars.Length; colIdx++) {
393        var targetVar = targetVars[colIdx];
394        var trainingDataTable = new DataTable(targetVar + " prediction (training)");
395        var actualValuesRow = new DataRow(targetVar, "The values of " + targetVar, problemData.Dataset.GetDoubleValues(targetVar, trainingRows));
396        var predictedValuesRow = new DataRow(targetVar + " pred.", "Predicted values for " + targetVar, trainingPrediction.Select(arr => arr[colIdx].Item1).ToArray());
397        trainingDataTable.Rows.Add(actualValuesRow);
398        trainingDataTable.Rows.Add(predictedValuesRow);
399        trainingList.Add(trainingDataTable);
400      }
401
402      // TODO: DRY for training and test
403      var testList = new ItemList<DataTable>();
404      var testRows = ProblemData.TestIndices.ToArray();
405      var testPrediction = Integrate(
406       trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
407       problemData.Dataset,
408       problemData.AllowedInputVariables.ToArray(),
409       targetVars,
410       latentVariables,
411       testRows,
412       nodeIdx,
413       optTheta,
414       NumericIntegrationSteps).ToArray();
415
416      for (int colIdx = 0; colIdx < targetVars.Length; colIdx++) {
417        var targetVar = targetVars[colIdx];
418        var testDataTable = new DataTable(targetVar + " prediction (test)");
419        var actualValuesRow = new DataRow(targetVar, "The values of " + targetVar, problemData.Dataset.GetDoubleValues(targetVar, testRows));
420        var predictedValuesRow = new DataRow(targetVar + " pred.", "Predicted values for " + targetVar, testPrediction.Select(arr => arr[colIdx].Item1).ToArray());
421        testDataTable.Rows.Add(actualValuesRow);
422        testDataTable.Rows.Add(predictedValuesRow);
423        testList.Add(testDataTable);
424      }
425
426      results["Prediction (training)"].Value = trainingList.AsReadOnly();
427      results["Prediction (test)"].Value = testList.AsReadOnly();
428
429      #region simplification of models
430      // TODO the dependency of HeuristicLab.Problems.DataAnalysis.Symbolic is not ideal
431      var modelList = new ItemList<ISymbolicExpressionTree>();
432      foreach (var tree in trees) {
433        var shownTree = (ISymbolicExpressionTree)tree.Clone();
434        var constantsNodeOrig = tree.IterateNodesPrefix().Where(IsConstantNode);
435        var constantsNodeShown = shownTree.IterateNodesPrefix().Where(IsConstantNode);
436
437        foreach (var n in constantsNodeOrig.Zip(constantsNodeShown, (original, shown) => new { original, shown })) {
438          double constantsVal = optTheta[nodeIdx[n.original]];
439
440          ConstantTreeNode replacementNode = new ConstantTreeNode(new Constant()) { Value = constantsVal };
441
442          var parentNode = n.shown.Parent;
443          int replacementIndex = parentNode.IndexOfSubtree(n.shown);
444          parentNode.RemoveSubtree(replacementIndex);
445          parentNode.InsertSubtree(replacementIndex, replacementNode);
446        }
447
448        modelList.Add(shownTree);
449      }
450      results["Models"].Value = modelList.AsReadOnly();
451      #endregion
452    }
453
454
455    #region interpretation
456    private static IEnumerable<Tuple<double, Vector>[]> Integrate(
457      ISymbolicExpressionTree[] trees, IDataset dataset, string[] inputVariables, string[] targetVariables, string[] latentVariables, IEnumerable<int> rows,
458      Dictionary<ISymbolicExpressionTreeNode, int> nodeIdx, double[] parameterValues, int numericIntegrationSteps = 100) {
459
460      int NUM_STEPS = numericIntegrationSteps ;
461      double h = 1.0 / NUM_STEPS;
462
463      // return first value as stored in the dataset
464      yield return targetVariables
465        .Select(targetVar => Tuple.Create(dataset.GetDoubleValue(targetVar, rows.First()), Vector.Zero))
466        .ToArray();
467
468      // integrate forward starting with known values for the target in t0
469
470      var variableValues = new Dictionary<string, Tuple<double, Vector>>();
471      var t0 = rows.First();
472      foreach (var varName in inputVariables) {
473        variableValues.Add(varName, Tuple.Create(dataset.GetDoubleValue(varName, t0), Vector.Zero));
474      }
475      foreach (var varName in targetVariables) {
476        variableValues.Add(varName, Tuple.Create(dataset.GetDoubleValue(varName, t0), Vector.Zero));
477      }
478      // add value entries for latent variables which are also integrated
479      foreach(var latentVar in latentVariables) {
480        variableValues.Add(latentVar, Tuple.Create(0.0, Vector.Zero)); // we don't have observations for latent variables -> assume zero as starting value
481      }
482      var calculatedVariables = targetVariables.Concat(latentVariables); // TODO: must conincide with the order of trees in the encoding
483
484      foreach (var t in rows.Skip(1)) {
485        for (int step = 0; step < NUM_STEPS; step++) {
486          var deltaValues = new Dictionary<string, Tuple<double, Vector>>();
487          foreach (var tup in trees.Zip(calculatedVariables, Tuple.Create)) {
488            var tree = tup.Item1;
489            var targetVarName = tup.Item2;
490            // skip programRoot and startSymbol
491            var res = InterpretRec(tree.Root.GetSubtree(0).GetSubtree(0), variableValues, nodeIdx, parameterValues);
492            deltaValues.Add(targetVarName, res);
493          }
494
495          // update variableValues for next step
496          foreach (var kvp in deltaValues) {
497            var oldVal = variableValues[kvp.Key];
498            variableValues[kvp.Key] = Tuple.Create(
499              oldVal.Item1 + h * kvp.Value.Item1,
500              oldVal.Item2 + h * kvp.Value.Item2
501            );
502          }
503        }
504
505        // only return the target variables for calculation of errors
506        yield return targetVariables
507          .Select(targetVar => variableValues[targetVar])
508          .ToArray();
509
510        // update for next time step
511        foreach (var varName in inputVariables) {
512          variableValues[varName] = Tuple.Create(dataset.GetDoubleValue(varName, t), Vector.Zero);
513        }
514      }
515    }
516
517    private static Tuple<double, Vector> InterpretRec(
518      ISymbolicExpressionTreeNode node,
519      Dictionary<string, Tuple<double, Vector>> variableValues,
520      Dictionary<ISymbolicExpressionTreeNode, int> nodeIdx,
521      double[] parameterValues
522        ) {
523
524      switch (node.Symbol.Name) {
525        case "+": {
526            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues); // TODO capture all parameters into a state type for interpretation
527            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
528
529            return Tuple.Create(l.Item1 + r.Item1, l.Item2 + r.Item2);
530          }
531        case "*": {
532            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
533            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
534
535            return Tuple.Create(l.Item1 * r.Item1, l.Item2 * r.Item1 + l.Item1 * r.Item2);
536          }
537
538        case "-": {
539            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
540            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
541
542            return Tuple.Create(l.Item1 - r.Item1, l.Item2 - r.Item2);
543          }
544        case "%": {
545            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
546            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
547
548            // protected division
549            if (r.Item1.IsAlmost(0.0)) {
550              return Tuple.Create(0.0, Vector.Zero);
551            } else {
552              return Tuple.Create(
553                l.Item1 / r.Item1,
554                l.Item1 * -1.0 / (r.Item1 * r.Item1) * r.Item2 + 1.0 / r.Item1 * l.Item2 // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f'
555                );
556            }
557          }
558        default: {
559            // distinguish other cases
560            if (IsConstantNode(node)) {
561              var vArr = new double[parameterValues.Length]; // backing array for vector
562              vArr[nodeIdx[node]] = 1.0;
563              var g = new Vector(vArr);
564              return Tuple.Create(parameterValues[nodeIdx[node]], g);
565            } else {
566              // assume a variable name
567              var varName = node.Symbol.Name;
568              return variableValues[varName];
569            }
570          }
571      }
572    }
573    #endregion
574
575    #region events
576    /*
577     * Dependencies between parameters:
578     *
579     * ProblemData
580     *    |
581     *    V
582     * TargetVariables   FunctionSet    MaximumLength    NumberOfLatentVariables
583     *               |   |                 |                   |
584     *               V   V                 |                   |
585     *             Grammar <---------------+-------------------
586     *                |
587     *                V
588     *            Encoding
589     */
590    private void RegisterEventHandlers() {
591      ProblemDataParameter.ValueChanged += ProblemDataParameter_ValueChanged;
592      if (ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += ProblemData_Changed;
593
594      TargetVariablesParameter.ValueChanged += TargetVariablesParameter_ValueChanged;
595      if (TargetVariablesParameter.Value != null) TargetVariablesParameter.Value.CheckedItemsChanged += CheckedTargetVariablesChanged;
596
597      FunctionSetParameter.ValueChanged += FunctionSetParameter_ValueChanged;
598      if (FunctionSetParameter.Value != null) FunctionSetParameter.Value.CheckedItemsChanged += CheckedFunctionsChanged;
599
600      MaximumLengthParameter.Value.ValueChanged += MaximumLengthChanged;
601
602      NumberOfLatentVariablesParameter.Value.ValueChanged += NumLatentVariablesChanged;
603    }
604
605    private void NumLatentVariablesChanged(object sender, EventArgs e) {
606      UpdateGrammarAndEncoding();
607    }
608
609    private void MaximumLengthChanged(object sender, EventArgs e) {
610      UpdateGrammarAndEncoding();
611    }
612
613    private void FunctionSetParameter_ValueChanged(object sender, EventArgs e) {
614      FunctionSetParameter.Value.CheckedItemsChanged += CheckedFunctionsChanged;
615    }
616
617    private void CheckedFunctionsChanged(object sender, CollectionItemsChangedEventArgs<StringValue> e) {
618      UpdateGrammarAndEncoding();
619    }
620
621    private void TargetVariablesParameter_ValueChanged(object sender, EventArgs e) {
622      TargetVariablesParameter.Value.CheckedItemsChanged += CheckedTargetVariablesChanged;
623    }
624
625    private void CheckedTargetVariablesChanged(object sender, CollectionItemsChangedEventArgs<StringValue> e) {
626      UpdateGrammarAndEncoding();
627    }
628
629    private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) {
630      ProblemDataParameter.Value.Changed += ProblemData_Changed;
631      OnProblemDataChanged();
632      OnReset();
633    }
634
635    private void ProblemData_Changed(object sender, EventArgs e) {
636      OnProblemDataChanged();
637      OnReset();
638    }
639
640    private void OnProblemDataChanged() {
641      UpdateTargetVariables();        // implicitly updates other dependent parameters
642      var handler = ProblemDataChanged;
643      if (handler != null) handler(this, EventArgs.Empty);
644    }
645
646    #endregion
647
648    #region  helper
649
650    private void InitAllParameters() {
651      UpdateTargetVariables(); // implicitly updates the grammar and the encoding     
652    }
653
654    private ReadOnlyCheckedItemCollection<StringValue> CreateFunctionSet() {
655      var l = new CheckedItemCollection<StringValue>();
656      l.Add(new StringValue("+").AsReadOnly());
657      l.Add(new StringValue("*").AsReadOnly());
658      l.Add(new StringValue("%").AsReadOnly());
659      l.Add(new StringValue("-").AsReadOnly());
660      return l.AsReadOnly();
661    }
662
663    private static bool IsConstantNode(ISymbolicExpressionTreeNode n) {
664      return n.Symbol.Name.StartsWith("θ");
665    }
666    private static bool IsLatentVariableNode(ISymbolicExpressionTreeNode n) {
667      return n.Symbol.Name.StartsWith("λ");
668    }
669
670
671    private void UpdateTargetVariables() {
672      var currentlySelectedVariables = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
673
674      var newVariablesList = new CheckedItemCollection<StringValue>(ProblemData.Dataset.VariableNames.Select(str => new StringValue(str).AsReadOnly()).ToArray()).AsReadOnly();
675      var matchingItems = newVariablesList.Where(item => currentlySelectedVariables.Contains(item.Value)).ToArray();
676      foreach (var matchingItem in matchingItems) {
677        newVariablesList.SetItemCheckedState(matchingItem, true);
678      }
679      TargetVariablesParameter.Value = newVariablesList;
680    }
681
682    private void UpdateGrammarAndEncoding() {
683      var encoding = new MultiEncoding();
684      var g = CreateGrammar();
685      foreach (var targetVar in TargetVariables.CheckedItems) {
686        encoding = encoding.Add(new SymbolicExpressionTreeEncoding(targetVar + "_tree", g, MaximumLength, MaximumLength)); // only limit by length
687      }
688      for (int i = 1; i <= NumberOfLatentVariables; i++) {
689        encoding = encoding.Add(new SymbolicExpressionTreeEncoding("λ" + i + "_tree", g, MaximumLength, MaximumLength));
690      }
691      Encoding = encoding;
692    }
693
694    private ISymbolicExpressionGrammar CreateGrammar() {
695      // whenever ProblemData is changed we create a new grammar with the necessary symbols
696      var g = new SimpleSymbolicExpressionGrammar();
697      g.AddSymbols(FunctionSet.CheckedItems.Select(i => i.Value).ToArray(), 2, 2);
698
699      // TODO
700      //g.AddSymbols(new[] {
701      //  "exp",
702      //  "log", // log( <expr> ) // TODO: init a theta to ensure the value is always positive
703      //  "exp_minus" // exp((-1) * <expr>
704      //}, 1, 1);
705
706      foreach (var variableName in ProblemData.AllowedInputVariables.Union(TargetVariables.CheckedItems.Select(i => i.Value)))
707        g.AddTerminalSymbol(variableName);
708
709      // generate symbols for numeric parameters for which the value is optimized using AutoDiff
710      // we generate multiple symbols to balance the probability for selecting a numeric parameter in the generation of random trees
711      var numericConstantsFactor = 2.0;
712      for (int i = 0; i < numericConstantsFactor * (ProblemData.AllowedInputVariables.Count() + TargetVariables.CheckedItems.Count()); i++) {
713        g.AddTerminalSymbol("θ" + i); // numeric parameter for which the value is optimized using AutoDiff
714      }
715
716      // generate symbols for latent variables
717      for (int i = 1; i <= NumberOfLatentVariables; i++) {
718        g.AddTerminalSymbol("λ" + i); // numeric parameter for which the value is optimized using AutoDiff
719      }
720
721      return g;
722    }
723
724    #endregion
725
726    #region Import & Export
727    public void Load(IRegressionProblemData data) {
728      Name = data.Name;
729      Description = data.Description;
730      ProblemData = data;
731    }
732
733    public IRegressionProblemData Export() {
734      return ProblemData;
735    }
736    #endregion
737
738  }
739}
Note: See TracBrowser for help on using the repository browser.