source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Problem.cs @ 16126

Last change on this file since 16126 was 16126, checked in by lkammere, 2 years ago

#2925: added constants values to tree in result view and minor bugfix when initializing problem.

File size: 32.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Analysis;
27using HeuristicLab.Collections;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Optimization;
33using HeuristicLab.Parameters;
34using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Problems.DataAnalysis.Symbolic;
37using HeuristicLab.Problems.Instances;
38
39namespace HeuristicLab.Problems.DynamicalSystemsModelling {
40  public class Vector {
41    public readonly static Vector Zero = new Vector(new double[0]);
42
43    public static Vector operator +(Vector a, Vector b) {
44      if (a == Zero) return b;
45      if (b == Zero) return a;
46      Debug.Assert(a.arr.Length == b.arr.Length);
47      var res = new double[a.arr.Length];
48      for (int i = 0; i < res.Length; i++)
49        res[i] = a.arr[i] + b.arr[i];
50      return new Vector(res);
51    }
52    public static Vector operator -(Vector a, Vector b) {
53      if (b == Zero) return a;
54      if (a == Zero) return -b;
55      Debug.Assert(a.arr.Length == b.arr.Length);
56      var res = new double[a.arr.Length];
57      for (int i = 0; i < res.Length; i++)
58        res[i] = a.arr[i] - b.arr[i];
59      return new Vector(res);
60    }
61    public static Vector operator -(Vector v) {
62      if (v == Zero) return Zero;
63      for (int i = 0; i < v.arr.Length; i++)
64        v.arr[i] = -v.arr[i];
65      return v;
66    }
67
68    public static Vector operator *(double s, Vector v) {
69      if (v == Zero) return Zero;
70      if (s == 0.0) return Zero;
71      var res = new double[v.arr.Length];
72      for (int i = 0; i < res.Length; i++)
73        res[i] = s * v.arr[i];
74      return new Vector(res);
75    }
76    public static Vector operator *(Vector v, double s) {
77      return s * v;
78    }
79    public static Vector operator /(double s, Vector v) {
80      if (s == 0.0) return Zero;
81      if (v == Zero) throw new ArgumentException("Division by zero vector");
82      var res = new double[v.arr.Length];
83      for (int i = 0; i < res.Length; i++)
84        res[i] = 1.0 / v.arr[i];
85      return new Vector(res);
86    }
87    public static Vector operator /(Vector v, double s) {
88      return v * 1.0 / s;
89    }
90
91
92    private readonly double[] arr; // backing array;
93
94    public Vector(double[] v) {
95      this.arr = v;
96    }
97
98    public void CopyTo(double[] target) {
99      Debug.Assert(arr.Length <= target.Length);
100      Array.Copy(arr, target, arr.Length);
101    }
102  }
103
104  [Item("Dynamical Systems Modelling Problem", "TODO")]
105  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 900)]
106  [StorableClass]
107  public sealed class Problem : SingleObjectiveBasicProblem<MultiEncoding>, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
108
109    #region parameter names
110    private const string ProblemDataParameterName = "Data";
111    private const string TargetVariablesParameterName = "Target variables";
112    private const string FunctionSetParameterName = "Function set";
113    private const string MaximumLengthParameterName = "Size limit";
114    private const string MaximumParameterOptimizationIterationsParameterName = "Max. parameter optimization iterations";
115    private const string NumberOfLatentVariablesParameterName = "Number of latent variables";
116    private const string NumericIntegrationStepsParameterName = "Steps for numeric integration";
117    #endregion
118
119    #region Parameter Properties
120    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
121
122    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
123      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
124    }
125    public IValueParameter<ReadOnlyCheckedItemCollection<StringValue>> TargetVariablesParameter {
126      get { return (IValueParameter<ReadOnlyCheckedItemCollection<StringValue>>)Parameters[TargetVariablesParameterName]; }
127    }
128    public IValueParameter<ReadOnlyCheckedItemCollection<StringValue>> FunctionSetParameter {
129      get { return (IValueParameter<ReadOnlyCheckedItemCollection<StringValue>>)Parameters[FunctionSetParameterName]; }
130    }
131    public IFixedValueParameter<IntValue> MaximumLengthParameter {
132      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumLengthParameterName]; }
133    }
134    public IFixedValueParameter<IntValue> MaximumParameterOptimizationIterationsParameter {
135      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumParameterOptimizationIterationsParameterName]; }
136    }
137    public IFixedValueParameter<IntValue> NumberOfLatentVariablesParameter {
138      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfLatentVariablesParameterName]; }
139    }
140    public IFixedValueParameter<IntValue> NumericIntegrationStepsParameter {
141      get { return (IFixedValueParameter<IntValue>)Parameters[NumericIntegrationStepsParameterName]; }
142    }
143    #endregion
144
145    #region Properties
146    public IRegressionProblemData ProblemData {
147      get { return ProblemDataParameter.Value; }
148      set { ProblemDataParameter.Value = value; }
149    }
150    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
151
152    public ReadOnlyCheckedItemCollection<StringValue> TargetVariables {
153      get { return TargetVariablesParameter.Value; }
154    }
155
156    public ReadOnlyCheckedItemCollection<StringValue> FunctionSet {
157      get { return FunctionSetParameter.Value; }
158    }
159
160    public int MaximumLength {
161      get { return MaximumLengthParameter.Value.Value; }
162    }
163    public int MaximumParameterOptimizationIterations {
164      get { return MaximumParameterOptimizationIterationsParameter.Value.Value; }
165    }
166    public int NumberOfLatentVariables {
167      get { return NumberOfLatentVariablesParameter.Value.Value; }
168    }
169    public int NumericIntegrationSteps {
170      get { return NumericIntegrationStepsParameter.Value.Value; }
171    }
172
173    #endregion                                                                                     
174
175    public event EventHandler ProblemDataChanged;
176
177    public override bool Maximization {
178      get { return false; } // we minimize NMSE
179    }
180
181    #region item cloning and persistence
182    // persistence
183    [StorableConstructor]
184    private Problem(bool deserializing) : base(deserializing) { }
185    [StorableHook(HookType.AfterDeserialization)]
186    private void AfterDeserialization() {
187      RegisterEventHandlers();
188    }
189
190    // cloning
191    private Problem(Problem original, Cloner cloner)
192      : base(original, cloner) {
193      RegisterEventHandlers();
194    }
195    public override IDeepCloneable Clone(Cloner cloner) { return new Problem(this, cloner); }
196    #endregion
197
198    public Problem()
199      : base() {
200      var targetVariables = new CheckedItemCollection<StringValue>().AsReadOnly(); // HACK: it would be better to provide a new class derived from IDataAnalysisProblem
201      var functions = CreateFunctionSet();
202      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data captured from the dynamical system. Use CSV import functionality to import data.", new RegressionProblemData()));
203      Parameters.Add(new ValueParameter<ReadOnlyCheckedItemCollection<StringValue>>(TargetVariablesParameterName, "Target variables (overrides setting in ProblemData)", targetVariables));
204      Parameters.Add(new ValueParameter<ReadOnlyCheckedItemCollection<StringValue>>(FunctionSetParameterName, "The list of allowed functions", functions));
205      Parameters.Add(new FixedValueParameter<IntValue>(MaximumLengthParameterName, "The maximally allowed length of each expression. Set to a small value (5 - 25). Default = 10", new IntValue(10)));
206      Parameters.Add(new FixedValueParameter<IntValue>(MaximumParameterOptimizationIterationsParameterName, "The maximum number of iterations for optimization of parameters (using L-BFGS). More iterations makes the algorithm slower, fewer iterations might prevent convergence in the optimization scheme. Default = 100", new IntValue(100)));
207      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfLatentVariablesParameterName, "Latent variables (unobserved variables) allow us to produce expressions which are integrated up and can be used in other expressions. They are handled similarly to target variables in forward simulation / integration. The difference to target variables is that there are no data to which the calculated values of latent variables are compared. Set to a small value (0 .. 5) as necessary (default = 0)", new IntValue(0)));
208      Parameters.Add(new FixedValueParameter<IntValue>(NumericIntegrationStepsParameterName, "Number of steps in the numeric integration that are taken from one row to the next (set to 1 to 100). More steps makes the algorithm slower, less steps worsens the accuracy of the numeric integration scheme.", new IntValue(10)));
209
210      RegisterEventHandlers();
211      InitAllParameters();
212    }
213
214
215    public override double Evaluate(Individual individual, IRandom random) {
216      var trees = individual.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
217
218      var problemData = ProblemData;
219      var rows = ProblemData.TrainingIndices.ToArray();
220      var targetVars = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
221      var latentVariables = Enumerable.Range(1, NumberOfLatentVariables).Select(i => "λ" + i).ToArray(); // TODO: must coincide with the variables which are actually defined in the grammar and also for which we actually have trees
222      var targetValues = new double[rows.Length, targetVars.Length];
223
224      // collect values of all target variables
225      var colIdx = 0;
226      foreach (var targetVar in targetVars) {
227        int rowIdx = 0;
228        foreach (var value in problemData.Dataset.GetDoubleValues(targetVar, rows)) {
229          targetValues[rowIdx, colIdx] = value;
230          rowIdx++;
231        }
232        colIdx++;
233      }
234
235      var nodeIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
236
237      foreach (var tree in trees) {
238        foreach (var node in tree.Root.IterateNodesPrefix().Where(n => IsConstantNode(n))) {
239          nodeIdx.Add(node, nodeIdx.Count);
240        }
241      }
242
243      var theta = nodeIdx.Select(_ => random.NextDouble() * 2.0 - 1.0).ToArray(); // init params randomly from Unif(-1,1)
244
245      double[] optTheta = new double[0];
246      if (theta.Length > 0) {
247        alglib.minlbfgsstate state;
248        alglib.minlbfgsreport report;
249        alglib.minlbfgscreate(Math.Min(theta.Length, 5), theta, out state);
250        alglib.minlbfgssetcond(state, 0.0, 0.0, 0.0, MaximumParameterOptimizationIterations);
251        alglib.minlbfgsoptimize(state, EvaluateObjectiveAndGradient, null,
252          new object[] { trees, targetVars, problemData, nodeIdx, targetValues, rows, NumericIntegrationSteps, latentVariables }); //TODO: create a type
253        alglib.minlbfgsresults(state, out optTheta, out report);
254
255        /*
256         *
257         *         L-BFGS algorithm results
258
259          INPUT PARAMETERS:
260              State   -   algorithm state
261
262          OUTPUT PARAMETERS:
263              X       -   array[0..N-1], solution
264              Rep     -   optimization report:
265                          * Rep.TerminationType completetion code:
266                              * -7    gradient verification failed.
267                                      See MinLBFGSSetGradientCheck() for more information.
268                              * -2    rounding errors prevent further improvement.
269                                      X contains best point found.
270                              * -1    incorrect parameters were specified
271                              *  1    relative function improvement is no more than
272                                      EpsF.
273                              *  2    relative step is no more than EpsX.
274                              *  4    gradient norm is no more than EpsG
275                              *  5    MaxIts steps was taken
276                              *  7    stopping conditions are too stringent,
277                                      further improvement is impossible
278                          * Rep.IterationsCount contains iterations count
279                          * NFEV countains number of function calculations
280         */
281        if (report.terminationtype < 0) return double.MaxValue;
282      }
283
284      // perform evaluation for optimal theta to get quality value
285      double[] grad = new double[optTheta.Length];
286      double optQuality = double.NaN;
287      EvaluateObjectiveAndGradient(optTheta, ref optQuality, grad,
288        new object[] { trees, targetVars, problemData, nodeIdx, targetValues, rows, NumericIntegrationSteps, latentVariables });
289      if (double.IsNaN(optQuality) || double.IsInfinity(optQuality)) return 10E6; // return a large value (TODO: be consistent by using NMSE)
290
291      individual["OptTheta"] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method
292      return optQuality;
293    }
294
295    private static void EvaluateObjectiveAndGradient(double[] x, ref double f, double[] grad, object obj) {
296      var trees = (ISymbolicExpressionTree[])((object[])obj)[0];
297      var targetVariables = (string[])((object[])obj)[1];
298      var problemData = (IRegressionProblemData)((object[])obj)[2];
299      var nodeIdx = (Dictionary<ISymbolicExpressionTreeNode, int>)((object[])obj)[3];
300      var targetValues = (double[,])((object[])obj)[4];
301      var rows = (int[])((object[])obj)[5];
302      var numericIntegrationSteps = (int)((object[])obj)[6];
303      var latentVariables = (string[])((object[])obj)[7];
304
305      var predicted = Integrate(
306        trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
307        problemData.Dataset,
308        problemData.AllowedInputVariables.ToArray(),
309        targetVariables,
310        latentVariables,
311        rows,
312        nodeIdx,                // TODO: is it Ok to use rows here ?
313        x, numericIntegrationSteps).ToArray();
314
315
316      // for normalized MSE = 1/variance(t) * MSE(t, pred)
317      // TODO: Perf. (by standardization of target variables before evaluation of all trees)
318      var invVar = Enumerable.Range(0, targetVariables.Length)
319        .Select(c => rows.Select(row => targetValues[row, c])) // colums vectors
320        .Select(vec => vec.Variance())
321        .Select(v => 1.0 / v)
322        .ToArray();
323
324      // objective function is NMSE
325      f = 0.0;
326      int n = predicted.Length;
327      double invN = 1.0 / n;
328      var g = Vector.Zero;
329      int r = 0;
330      foreach (var y_pred in predicted) {
331        for (int c = 0; c < y_pred.Length; c++) {
332
333          var y_pred_f = y_pred[c].Item1;
334          var y = targetValues[r, c];
335
336          var res = (y - y_pred_f);
337          var ressq = res * res;
338          f += ressq * invN * invVar[c];
339          g += -2.0 * res * y_pred[c].Item2 * invN * invVar[c];
340        }
341        r++;
342      }
343
344      g.CopyTo(grad);
345    }
346
347    public override void Analyze(Individual[] individuals, double[] qualities, ResultCollection results, IRandom random) {
348      base.Analyze(individuals, qualities, results, random);
349
350      if (!results.ContainsKey("Prediction (training)")) {
351        results.Add(new Result("Prediction (training)", typeof(ReadOnlyItemList<DataTable>)));
352      }
353      if (!results.ContainsKey("Prediction (test)")) {
354        results.Add(new Result("Prediction (test)", typeof(ReadOnlyItemList<DataTable>)));
355      }
356      if (!results.ContainsKey("Models")) {
357        results.Add(new Result("Models", typeof(ReadOnlyItemList<ISymbolicExpressionTree>)));
358      }
359
360      // TODO extract common functionality from Evaluate and Analyze
361      var bestIndividualAndQuality = this.GetBestIndividual(individuals, qualities);
362      var optTheta = ((DoubleArray)bestIndividualAndQuality.Item1["OptTheta"]).ToArray(); // see evaluate
363      var trees = bestIndividualAndQuality.Item1.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
364      var nodeIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
365
366
367      foreach (var tree in trees) {
368        foreach (var node in tree.Root.IterateNodesPrefix().Where(n => IsConstantNode(n))) {
369          nodeIdx.Add(node, nodeIdx.Count);
370        }
371      }
372      var problemData = ProblemData;
373      var targetVars = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
374      var latentVariables = Enumerable.Range(1, NumberOfLatentVariables).Select(i => "λ" + i).ToArray(); // TODO: must coincide with the variables which are actually defined in the grammar and also for which we actually have trees
375
376      var trainingList = new ItemList<DataTable>();
377      var trainingRows = ProblemData.TrainingIndices.ToArray();
378      var trainingPrediction = Integrate(
379       trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
380       problemData.Dataset,
381       problemData.AllowedInputVariables.ToArray(),
382       targetVars,
383       latentVariables,
384       trainingRows,
385       nodeIdx,
386       optTheta,
387       NumericIntegrationSteps).ToArray();
388
389      for (int colIdx = 0; colIdx < targetVars.Length; colIdx++) {
390        var targetVar = targetVars[colIdx];
391        var trainingDataTable = new DataTable(targetVar + " prediction (training)");
392        var actualValuesRow = new DataRow(targetVar, "The values of " + targetVar, problemData.Dataset.GetDoubleValues(targetVar, trainingRows));
393        var predictedValuesRow = new DataRow(targetVar + " pred.", "Predicted values for " + targetVar, trainingPrediction.Select(arr => arr[colIdx].Item1).ToArray());
394        trainingDataTable.Rows.Add(actualValuesRow);
395        trainingDataTable.Rows.Add(predictedValuesRow);
396        trainingList.Add(trainingDataTable);
397      }
398
399      // TODO: DRY for training and test
400      var testList = new ItemList<DataTable>();
401      var testRows = ProblemData.TestIndices.ToArray();
402      var testPrediction = Integrate(
403       trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
404       problemData.Dataset,
405       problemData.AllowedInputVariables.ToArray(),
406       targetVars,
407       latentVariables,
408       testRows,
409       nodeIdx,
410       optTheta,
411       NumericIntegrationSteps).ToArray();
412
413      for (int colIdx = 0; colIdx < targetVars.Length; colIdx++) {
414        var targetVar = targetVars[colIdx];
415        var testDataTable = new DataTable(targetVar + " prediction (test)");
416        var actualValuesRow = new DataRow(targetVar, "The values of " + targetVar, problemData.Dataset.GetDoubleValues(targetVar, testRows));
417        var predictedValuesRow = new DataRow(targetVar + " pred.", "Predicted values for " + targetVar, testPrediction.Select(arr => arr[colIdx].Item1).ToArray());
418        testDataTable.Rows.Add(actualValuesRow);
419        testDataTable.Rows.Add(predictedValuesRow);
420        testList.Add(testDataTable);
421      }
422
423      results["Prediction (training)"].Value = trainingList.AsReadOnly();
424      results["Prediction (test)"].Value = testList.AsReadOnly();
425
426      var modelList = new ItemList<ISymbolicExpressionTree>();
427      foreach (var tree in trees) {
428        var shownTree = (ISymbolicExpressionTree)tree.Clone();
429        var constantsNodeOrig = tree.IterateNodesPrefix().Where(IsConstantNode);
430        var constantsNodeShown = shownTree.IterateNodesPrefix().Where(IsConstantNode);
431
432        foreach (var n in constantsNodeOrig.Zip(constantsNodeShown, (original, shown) => new { original, shown })) {
433          double constantsVal = optTheta[nodeIdx[n.original]];
434
435          ConstantTreeNode replacementNode = new ConstantTreeNode(new Constant()) { Value = constantsVal };
436
437          var parentNode = n.shown.Parent;
438          int replacementIndex = parentNode.IndexOfSubtree(n.shown);
439          parentNode.RemoveSubtree(replacementIndex);
440          parentNode.InsertSubtree(replacementIndex, replacementNode);
441        }
442        // TODO: simplify trees
443
444        modelList.Add(shownTree);
445      }
446      results["Models"].Value = modelList.AsReadOnly();
447    }
448
449
450    #region interpretation
451    private static IEnumerable<Tuple<double, Vector>[]> Integrate(
452      ISymbolicExpressionTree[] trees, IDataset dataset, string[] inputVariables, string[] targetVariables, string[] latentVariables, IEnumerable<int> rows,
453      Dictionary<ISymbolicExpressionTreeNode, int> nodeIdx, double[] parameterValues, int numericIntegrationSteps = 100) {
454
455      int NUM_STEPS = numericIntegrationSteps ;
456      double h = 1.0 / NUM_STEPS;
457
458      // return first value as stored in the dataset
459      yield return targetVariables
460        .Select(targetVar => Tuple.Create(dataset.GetDoubleValue(targetVar, rows.First()), Vector.Zero))
461        .ToArray();
462
463      // integrate forward starting with known values for the target in t0
464
465      var variableValues = new Dictionary<string, Tuple<double, Vector>>();
466      var t0 = rows.First();
467      foreach (var varName in inputVariables) {
468        variableValues.Add(varName, Tuple.Create(dataset.GetDoubleValue(varName, t0), Vector.Zero));
469      }
470      foreach (var varName in targetVariables) {
471        variableValues.Add(varName, Tuple.Create(dataset.GetDoubleValue(varName, t0), Vector.Zero));
472      }
473      // add value entries for latent variables which are also integrated
474      foreach(var latentVar in latentVariables) {
475        variableValues.Add(latentVar, Tuple.Create(0.0, Vector.Zero)); // we don't have observations for latent variables -> assume zero as starting value
476      }
477      var calculatedVariables = targetVariables.Concat(latentVariables); // TODO: must conincide with the order of trees in the encoding
478
479      foreach (var t in rows.Skip(1)) {
480        for (int step = 0; step < NUM_STEPS; step++) {
481          var deltaValues = new Dictionary<string, Tuple<double, Vector>>();
482          foreach (var tup in trees.Zip(calculatedVariables, Tuple.Create)) {
483            var tree = tup.Item1;
484            var targetVarName = tup.Item2;
485            // skip programRoot and startSymbol
486            var res = InterpretRec(tree.Root.GetSubtree(0).GetSubtree(0), variableValues, nodeIdx, parameterValues);
487            deltaValues.Add(targetVarName, res);
488          }
489
490          // update variableValues for next step
491          foreach (var kvp in deltaValues) {
492            var oldVal = variableValues[kvp.Key];
493            variableValues[kvp.Key] = Tuple.Create(
494              oldVal.Item1 + h * kvp.Value.Item1,
495              oldVal.Item2 + h * kvp.Value.Item2
496            );
497          }
498        }
499
500        // only return the target variables for calculation of errors
501        yield return targetVariables
502          .Select(targetVar => variableValues[targetVar])
503          .ToArray();
504
505        // update for next time step
506        foreach (var varName in inputVariables) {
507          variableValues[varName] = Tuple.Create(dataset.GetDoubleValue(varName, t), Vector.Zero);
508        }
509      }
510    }
511
512    private static Tuple<double, Vector> InterpretRec(
513      ISymbolicExpressionTreeNode node,
514      Dictionary<string, Tuple<double, Vector>> variableValues,
515      Dictionary<ISymbolicExpressionTreeNode, int> nodeIdx,
516      double[] parameterValues
517        ) {
518
519      switch (node.Symbol.Name) {
520        case "+": {
521            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues); // TODO capture all parameters into a state type for interpretation
522            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
523
524            return Tuple.Create(l.Item1 + r.Item1, l.Item2 + r.Item2);
525          }
526        case "*": {
527            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
528            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
529
530            return Tuple.Create(l.Item1 * r.Item1, l.Item2 * r.Item1 + l.Item1 * r.Item2);
531          }
532
533        case "-": {
534            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
535            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
536
537            return Tuple.Create(l.Item1 - r.Item1, l.Item2 - r.Item2);
538          }
539        case "%": {
540            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
541            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
542
543            // protected division
544            if (r.Item1.IsAlmost(0.0)) {
545              return Tuple.Create(0.0, Vector.Zero);
546            } else {
547              return Tuple.Create(
548                l.Item1 / r.Item1,
549                l.Item1 * -1.0 / (r.Item1 * r.Item1) * r.Item2 + 1.0 / r.Item1 * l.Item2 // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f'
550                );
551            }
552          }
553        default: {
554            // distinguish other cases
555            if (IsConstantNode(node)) {
556              var vArr = new double[parameterValues.Length]; // backing array for vector
557              vArr[nodeIdx[node]] = 1.0;
558              var g = new Vector(vArr);
559              return Tuple.Create(parameterValues[nodeIdx[node]], g);
560            } else {
561              // assume a variable name
562              var varName = node.Symbol.Name;
563              return variableValues[varName];
564            }
565          }
566      }
567    }
568    #endregion
569
570    #region events
571    /*
572     * Dependencies between parameters:
573     *
574     * ProblemData
575     *    |
576     *    V
577     * TargetVariables   FunctionSet    MaximumLength    NumberOfLatentVariables
578     *               |   |                 |                   |
579     *               V   V                 |                   |
580     *             Grammar <---------------+-------------------
581     *                |
582     *                V
583     *            Encoding
584     */
585    private void RegisterEventHandlers() {
586      ProblemDataParameter.ValueChanged += ProblemDataParameter_ValueChanged;
587      if (ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += ProblemData_Changed;
588
589      TargetVariablesParameter.ValueChanged += TargetVariablesParameter_ValueChanged;
590      if (TargetVariablesParameter.Value != null) TargetVariablesParameter.Value.CheckedItemsChanged += CheckedTargetVariablesChanged;
591
592      FunctionSetParameter.ValueChanged += FunctionSetParameter_ValueChanged;
593      if (FunctionSetParameter.Value != null) FunctionSetParameter.Value.CheckedItemsChanged += CheckedFunctionsChanged;
594
595      MaximumLengthParameter.Value.ValueChanged += MaximumLengthChanged;
596
597      NumberOfLatentVariablesParameter.Value.ValueChanged += NumLatentVariablesChanged;
598    }
599
600    private void NumLatentVariablesChanged(object sender, EventArgs e) {
601      UpdateGrammarAndEncoding();
602    }
603
604    private void MaximumLengthChanged(object sender, EventArgs e) {
605      UpdateGrammarAndEncoding();
606    }
607
608    private void FunctionSetParameter_ValueChanged(object sender, EventArgs e) {
609      FunctionSetParameter.Value.CheckedItemsChanged += CheckedFunctionsChanged;
610    }
611
612    private void CheckedFunctionsChanged(object sender, CollectionItemsChangedEventArgs<StringValue> e) {
613      UpdateGrammarAndEncoding();
614    }
615
616    private void TargetVariablesParameter_ValueChanged(object sender, EventArgs e) {
617      TargetVariablesParameter.Value.CheckedItemsChanged += CheckedTargetVariablesChanged;
618    }
619
620    private void CheckedTargetVariablesChanged(object sender, CollectionItemsChangedEventArgs<StringValue> e) {
621      UpdateGrammarAndEncoding();
622    }
623
624    private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) {
625      ProblemDataParameter.Value.Changed += ProblemData_Changed;
626      OnProblemDataChanged();
627      OnReset();
628    }
629
630    private void ProblemData_Changed(object sender, EventArgs e) {
631      OnProblemDataChanged();
632      OnReset();
633    }
634
635    private void OnProblemDataChanged() {
636      UpdateTargetVariables();        // implicitly updates other dependent parameters
637      UpdateGrammarAndEncoding();
638      var handler = ProblemDataChanged;
639      if (handler != null) handler(this, EventArgs.Empty);
640    }
641
642    #endregion
643
644    #region  helper
645
646    private void InitAllParameters() {
647      UpdateTargetVariables(); // implicitly updates the grammar and the encoding     
648    }
649
650    private ReadOnlyCheckedItemCollection<StringValue> CreateFunctionSet() {
651      var l = new CheckedItemCollection<StringValue>();
652      l.Add(new StringValue("+").AsReadOnly());
653      l.Add(new StringValue("*").AsReadOnly());
654      l.Add(new StringValue("%").AsReadOnly());
655      l.Add(new StringValue("-").AsReadOnly());
656      return l.AsReadOnly();
657    }
658
659    private static bool IsConstantNode(ISymbolicExpressionTreeNode n) {
660      return n.Symbol.Name.StartsWith("θ");
661    }
662    private static bool IsLatentVariableNode(ISymbolicExpressionTreeNode n) {
663      return n.Symbol.Name.StartsWith("λ");
664    }
665
666
667    private void UpdateTargetVariables() {
668      var currentlySelectedVariables = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
669
670      var newVariablesList = new CheckedItemCollection<StringValue>(ProblemData.Dataset.VariableNames.Select(str => new StringValue(str).AsReadOnly()).ToArray()).AsReadOnly();
671      var matchingItems = newVariablesList.Where(item => currentlySelectedVariables.Contains(item.Value)).ToArray();
672      foreach (var matchingItem in matchingItems) {
673        newVariablesList.SetItemCheckedState(matchingItem, true);
674      }
675      TargetVariablesParameter.Value = newVariablesList;
676    }
677
678    private void UpdateGrammarAndEncoding() {
679      var encoding = new MultiEncoding();
680      var g = CreateGrammar();
681      foreach (var targetVar in TargetVariables.CheckedItems) {
682        encoding = encoding.Add(new SymbolicExpressionTreeEncoding(targetVar + "_tree", g, MaximumLength, MaximumLength)); // only limit by length
683      }
684      for (int i = 1; i <= NumberOfLatentVariables; i++) {
685        encoding = encoding.Add(new SymbolicExpressionTreeEncoding("λ" + i + "_tree", g, MaximumLength, MaximumLength));
686      }
687      Encoding = encoding;
688    }
689
690    private ISymbolicExpressionGrammar CreateGrammar() {
691      // whenever ProblemData is changed we create a new grammar with the necessary symbols
692      var g = new SimpleSymbolicExpressionGrammar();
693      g.AddSymbols(FunctionSet.CheckedItems.Select(i => i.Value).ToArray(), 2, 2);
694
695      // TODO
696      //g.AddSymbols(new[] {
697      //  "exp",
698      //  "log", // log( <expr> ) // TODO: init a theta to ensure the value is always positive
699      //  "exp_minus" // exp((-1) * <expr>
700      //}, 1, 1);
701
702      foreach (var variableName in ProblemData.AllowedInputVariables.Union(TargetVariables.CheckedItems.Select(i => i.Value)))
703        g.AddTerminalSymbol(variableName);
704
705      // generate symbols for numeric parameters for which the value is optimized using AutoDiff
706      // we generate multiple symbols to balance the probability for selecting a numeric parameter in the generation of random trees
707      var numericConstantsFactor = 2.0;
708      for (int i = 0; i < numericConstantsFactor * (ProblemData.AllowedInputVariables.Count() + TargetVariables.CheckedItems.Count()); i++) {
709        g.AddTerminalSymbol("θ" + i); // numeric parameter for which the value is optimized using AutoDiff
710      }
711
712      // generate symbols for latent variables
713      for (int i = 1; i <= NumberOfLatentVariables; i++) {
714        g.AddTerminalSymbol("λ" + i); // numeric parameter for which the value is optimized using AutoDiff
715      }
716
717      return g;
718    }
719
720    #endregion
721
722    #region Import & Export
723    public void Load(IRegressionProblemData data) {
724      Name = data.Name;
725      Description = data.Description;
726      ProblemData = data;
727    }
728
729    public IRegressionProblemData Export() {
730      return ProblemData;
731    }
732    #endregion
733
734  }
735}
Note: See TracBrowser for help on using the repository browser.