Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Problem.cs @ 16153

Last change on this file since 16153 was 16153, checked in by gkronber, 6 years ago

#2925 added support for multiple training episodes, added simplification of models, fixed a bug in the normalization based on target variable variance

File size: 36.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics;
25using System.Linq;
26using HeuristicLab.Analysis;
27using HeuristicLab.Collections;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Optimization;
33using HeuristicLab.Parameters;
34using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Problems.DataAnalysis.Symbolic;
37using HeuristicLab.Problems.Instances;
38using Variable = HeuristicLab.Problems.DataAnalysis.Symbolic.Variable;
39
40namespace HeuristicLab.Problems.DynamicalSystemsModelling {
41  public class Vector {
42    public readonly static Vector Zero = new Vector(new double[0]);
43
44    public static Vector operator +(Vector a, Vector b) {
45      if(a == Zero) return b;
46      if(b == Zero) return a;
47      Debug.Assert(a.arr.Length == b.arr.Length);
48      var res = new double[a.arr.Length];
49      for(int i = 0; i < res.Length; i++)
50        res[i] = a.arr[i] + b.arr[i];
51      return new Vector(res);
52    }
53    public static Vector operator -(Vector a, Vector b) {
54      if(b == Zero) return a;
55      if(a == Zero) return -b;
56      Debug.Assert(a.arr.Length == b.arr.Length);
57      var res = new double[a.arr.Length];
58      for(int i = 0; i < res.Length; i++)
59        res[i] = a.arr[i] - b.arr[i];
60      return new Vector(res);
61    }
62    public static Vector operator -(Vector v) {
63      if(v == Zero) return Zero;
64      for(int i = 0; i < v.arr.Length; i++)
65        v.arr[i] = -v.arr[i];
66      return v;
67    }
68
69    public static Vector operator *(double s, Vector v) {
70      if(v == Zero) return Zero;
71      if(s == 0.0) return Zero;
72      var res = new double[v.arr.Length];
73      for(int i = 0; i < res.Length; i++)
74        res[i] = s * v.arr[i];
75      return new Vector(res);
76    }
77    public static Vector operator *(Vector v, double s) {
78      return s * v;
79    }
80    public static Vector operator /(double s, Vector v) {
81      if(s == 0.0) return Zero;
82      if(v == Zero) throw new ArgumentException("Division by zero vector");
83      var res = new double[v.arr.Length];
84      for(int i = 0; i < res.Length; i++)
85        res[i] = 1.0 / v.arr[i];
86      return new Vector(res);
87    }
88    public static Vector operator /(Vector v, double s) {
89      return v * 1.0 / s;
90    }
91
92
93    private readonly double[] arr; // backing array;
94
95    public Vector(double[] v) {
96      this.arr = v;
97    }
98
99    public void CopyTo(double[] target) {
100      Debug.Assert(arr.Length <= target.Length);
101      Array.Copy(arr, target, arr.Length);
102    }
103  }
104
105  [Item("Dynamical Systems Modelling Problem", "TODO")]
106  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 900)]
107  [StorableClass]
108  public sealed class Problem : SingleObjectiveBasicProblem<MultiEncoding>, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
109
110    #region parameter names
111    private const string ProblemDataParameterName = "Data";
112    private const string TargetVariablesParameterName = "Target variables";
113    private const string FunctionSetParameterName = "Function set";
114    private const string MaximumLengthParameterName = "Size limit";
115    private const string MaximumParameterOptimizationIterationsParameterName = "Max. parameter optimization iterations";
116    private const string NumberOfLatentVariablesParameterName = "Number of latent variables";
117    private const string NumericIntegrationStepsParameterName = "Steps for numeric integration";
118    private const string TrainingEpisodesParameterName = "Training episodes";
119    #endregion
120
121    #region Parameter Properties
122    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
123
124    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
125      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
126    }
127    public IValueParameter<ReadOnlyCheckedItemCollection<StringValue>> TargetVariablesParameter {
128      get { return (IValueParameter<ReadOnlyCheckedItemCollection<StringValue>>)Parameters[TargetVariablesParameterName]; }
129    }
130    public IValueParameter<ReadOnlyCheckedItemCollection<StringValue>> FunctionSetParameter {
131      get { return (IValueParameter<ReadOnlyCheckedItemCollection<StringValue>>)Parameters[FunctionSetParameterName]; }
132    }
133    public IFixedValueParameter<IntValue> MaximumLengthParameter {
134      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumLengthParameterName]; }
135    }
136    public IFixedValueParameter<IntValue> MaximumParameterOptimizationIterationsParameter {
137      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumParameterOptimizationIterationsParameterName]; }
138    }
139    public IFixedValueParameter<IntValue> NumberOfLatentVariablesParameter {
140      get { return (IFixedValueParameter<IntValue>)Parameters[NumberOfLatentVariablesParameterName]; }
141    }
142    public IFixedValueParameter<IntValue> NumericIntegrationStepsParameter {
143      get { return (IFixedValueParameter<IntValue>)Parameters[NumericIntegrationStepsParameterName]; }
144    }
145    public IValueParameter<ItemList<IntRange>> TrainingEpisodesParameter {
146      get { return (IValueParameter<ItemList<IntRange>>)Parameters[TrainingEpisodesParameterName]; }
147    }
148    #endregion
149
150    #region Properties
151    public IRegressionProblemData ProblemData {
152      get { return ProblemDataParameter.Value; }
153      set { ProblemDataParameter.Value = value; }
154    }
155    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
156
157    public ReadOnlyCheckedItemCollection<StringValue> TargetVariables {
158      get { return TargetVariablesParameter.Value; }
159    }
160
161    public ReadOnlyCheckedItemCollection<StringValue> FunctionSet {
162      get { return FunctionSetParameter.Value; }
163    }
164
165    public int MaximumLength {
166      get { return MaximumLengthParameter.Value.Value; }
167    }
168    public int MaximumParameterOptimizationIterations {
169      get { return MaximumParameterOptimizationIterationsParameter.Value.Value; }
170    }
171    public int NumberOfLatentVariables {
172      get { return NumberOfLatentVariablesParameter.Value.Value; }
173    }
174    public int NumericIntegrationSteps {
175      get { return NumericIntegrationStepsParameter.Value.Value; }
176    }
177    public IEnumerable<IntRange> TrainingEpisodes {
178      get { return TrainingEpisodesParameter.Value; }
179    }
180
181    #endregion
182
183    public event EventHandler ProblemDataChanged;
184
185    public override bool Maximization {
186      get { return false; } // we minimize NMSE
187    }
188
189    #region item cloning and persistence
190    // persistence
191    [StorableConstructor]
192    private Problem(bool deserializing) : base(deserializing) { }
193    [StorableHook(HookType.AfterDeserialization)]
194    private void AfterDeserialization() {
195      RegisterEventHandlers();
196    }
197
198    // cloning
199    private Problem(Problem original, Cloner cloner)
200      : base(original, cloner) {
201      RegisterEventHandlers();
202    }
203    public override IDeepCloneable Clone(Cloner cloner) { return new Problem(this, cloner); }
204    #endregion
205
206    public Problem()
207      : base() {
208      var targetVariables = new CheckedItemCollection<StringValue>().AsReadOnly(); // HACK: it would be better to provide a new class derived from IDataAnalysisProblem
209      var functions = CreateFunctionSet();
210      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data captured from the dynamical system. Use CSV import functionality to import data.", new RegressionProblemData()));
211      Parameters.Add(new ValueParameter<ReadOnlyCheckedItemCollection<StringValue>>(TargetVariablesParameterName, "Target variables (overrides setting in ProblemData)", targetVariables));
212      Parameters.Add(new ValueParameter<ReadOnlyCheckedItemCollection<StringValue>>(FunctionSetParameterName, "The list of allowed functions", functions));
213      Parameters.Add(new FixedValueParameter<IntValue>(MaximumLengthParameterName, "The maximally allowed length of each expression. Set to a small value (5 - 25). Default = 10", new IntValue(10)));
214      Parameters.Add(new FixedValueParameter<IntValue>(MaximumParameterOptimizationIterationsParameterName, "The maximum number of iterations for optimization of parameters (using L-BFGS). More iterations makes the algorithm slower, fewer iterations might prevent convergence in the optimization scheme. Default = 100", new IntValue(100)));
215      Parameters.Add(new FixedValueParameter<IntValue>(NumberOfLatentVariablesParameterName, "Latent variables (unobserved variables) allow us to produce expressions which are integrated up and can be used in other expressions. They are handled similarly to target variables in forward simulation / integration. The difference to target variables is that there are no data to which the calculated values of latent variables are compared. Set to a small value (0 .. 5) as necessary (default = 0)", new IntValue(0)));
216      Parameters.Add(new FixedValueParameter<IntValue>(NumericIntegrationStepsParameterName, "Number of steps in the numeric integration that are taken from one row to the next (set to 1 to 100). More steps makes the algorithm slower, less steps worsens the accuracy of the numeric integration scheme.", new IntValue(10)));
217      Parameters.Add(new ValueParameter<ItemList<IntRange>>(TrainingEpisodesParameterName, "A list of ranges that should be used for training, each range represents an independent episode. This overrides the TrainingSet parameter in ProblemData.", new ItemList<IntRange>()));
218
219      RegisterEventHandlers();
220      InitAllParameters();
221
222      // TODO: do not clear selection of target variables when the input variables are changed (keep selected target variables)
223      // TODO: UI hangs when selecting / deselecting input variables because the encoding is updated on each item
224
225    }
226
227
228    public override double Evaluate(Individual individual, IRandom random) {
229      var trees = individual.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
230
231      var problemData = ProblemData;
232      var rows = ProblemData.TrainingIndices.ToArray();
233      var targetVars = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
234      var latentVariables = Enumerable.Range(1, NumberOfLatentVariables).Select(i => "λ" + i).ToArray(); // TODO: must coincide with the variables which are actually defined in the grammar and also for which we actually have trees
235      var targetValues = new double[rows.Length, targetVars.Length];
236
237      // collect values of all target variables
238      var colIdx = 0;
239      foreach(var targetVar in targetVars) {
240        int rowIdx = 0;
241        foreach(var value in problemData.Dataset.GetDoubleValues(targetVar, rows)) {
242          targetValues[rowIdx, colIdx] = value;
243          rowIdx++;
244        }
245        colIdx++;
246      }
247
248      var nodeIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
249
250      foreach(var tree in trees) {
251        foreach(var node in tree.Root.IterateNodesPrefix().Where(n => IsConstantNode(n))) {
252          nodeIdx.Add(node, nodeIdx.Count);
253        }
254      }
255
256      var theta = nodeIdx.Select(_ => random.NextDouble() * 2.0 - 1.0).ToArray(); // init params randomly from Unif(-1,1)
257
258      double[] optTheta = new double[0];
259      if(theta.Length > 0) {
260        alglib.minlbfgsstate state;
261        alglib.minlbfgsreport report;
262        alglib.minlbfgscreate(Math.Min(theta.Length, 5), theta, out state);
263        alglib.minlbfgssetcond(state, 0.0, 0.0, 0.0, MaximumParameterOptimizationIterations);
264        alglib.minlbfgsoptimize(state, EvaluateObjectiveAndGradient, null,
265          new object[] { trees, targetVars, problemData, nodeIdx, targetValues, TrainingEpisodes.ToArray(), NumericIntegrationSteps, latentVariables }); //TODO: create a type
266        alglib.minlbfgsresults(state, out optTheta, out report);
267
268        /*
269         *
270         *         L-BFGS algorithm results
271
272          INPUT PARAMETERS:
273              State   -   algorithm state
274
275          OUTPUT PARAMETERS:
276              X       -   array[0..N-1], solution
277              Rep     -   optimization report:
278                          * Rep.TerminationType completetion code:
279                              * -7    gradient verification failed.
280                                      See MinLBFGSSetGradientCheck() for more information.
281                              * -2    rounding errors prevent further improvement.
282                                      X contains best point found.
283                              * -1    incorrect parameters were specified
284                              *  1    relative function improvement is no more than
285                                      EpsF.
286                              *  2    relative step is no more than EpsX.
287                              *  4    gradient norm is no more than EpsG
288                              *  5    MaxIts steps was taken
289                              *  7    stopping conditions are too stringent,
290                                      further improvement is impossible
291                          * Rep.IterationsCount contains iterations count
292                          * NFEV countains number of function calculations
293         */
294        if(report.terminationtype < 0) return double.MaxValue;
295      }
296
297      // perform evaluation for optimal theta to get quality value
298      double[] grad = new double[optTheta.Length];
299      double optQuality = double.NaN;
300      EvaluateObjectiveAndGradient(optTheta, ref optQuality, grad,
301        new object[] { trees, targetVars, problemData, nodeIdx, targetValues, TrainingEpisodes.ToArray(), NumericIntegrationSteps, latentVariables });
302      if(double.IsNaN(optQuality) || double.IsInfinity(optQuality)) return 10E6; // return a large value (TODO: be consistent by using NMSE)
303
304      individual["OptTheta"] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method
305      return optQuality;
306    }
307
308    private static void EvaluateObjectiveAndGradient(double[] x, ref double f, double[] grad, object obj) {
309      var trees = (ISymbolicExpressionTree[])((object[])obj)[0];
310      var targetVariables = (string[])((object[])obj)[1];
311      var problemData = (IRegressionProblemData)((object[])obj)[2];
312      var nodeIdx = (Dictionary<ISymbolicExpressionTreeNode, int>)((object[])obj)[3];
313      var targetValues = (double[,])((object[])obj)[4];
314      var episodes = (IntRange[])((object[])obj)[5];
315      var numericIntegrationSteps = (int)((object[])obj)[6];
316      var latentVariables = (string[])((object[])obj)[7];
317
318      var predicted = Integrate(
319        trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
320        problemData.Dataset,
321        problemData.AllowedInputVariables.ToArray(),
322        targetVariables,
323        latentVariables,
324        episodes,
325        nodeIdx,
326        x, numericIntegrationSteps).ToArray();
327
328
329      // for normalized MSE = 1/variance(t) * MSE(t, pred)
330      // TODO: Perf. (by standardization of target variables before evaluation of all trees)     
331      var invVar = Enumerable.Range(0, targetVariables.Length)
332        .Select(c => Enumerable.Range(0, targetValues.GetLength(0)).Select(row => targetValues[row, c])) // column vectors
333        .Select(vec => vec.Variance())
334        .Select(v => 1.0 / v)
335        .ToArray();
336
337      // objective function is NMSE
338      f = 0.0;
339      int n = predicted.Length;
340      double invN = 1.0 / n;
341      var g = Vector.Zero;
342      int r = 0;
343      foreach(var y_pred in predicted) {
344        for(int c = 0; c < y_pred.Length; c++) {
345
346          var y_pred_f = y_pred[c].Item1;
347          var y = targetValues[r, c];
348
349          var res = (y - y_pred_f);
350          var ressq = res * res;
351          f += ressq * invN * invVar[c];
352          g += -2.0 * res * y_pred[c].Item2 * invN * invVar[c];
353        }
354        r++;
355      }
356
357      g.CopyTo(grad);
358    }
359
360    public override void Analyze(Individual[] individuals, double[] qualities, ResultCollection results, IRandom random) {
361      base.Analyze(individuals, qualities, results, random);
362
363      if(!results.ContainsKey("Prediction (training)")) {
364        results.Add(new Result("Prediction (training)", typeof(ReadOnlyItemList<DataTable>)));
365      }
366      if(!results.ContainsKey("Prediction (test)")) {
367        results.Add(new Result("Prediction (test)", typeof(ReadOnlyItemList<DataTable>)));
368      }
369      if(!results.ContainsKey("Models")) {
370        results.Add(new Result("Models", typeof(VariableCollection)));
371      }
372
373      // TODO extract common functionality from Evaluate and Analyze
374      var bestIndividualAndQuality = this.GetBestIndividual(individuals, qualities);
375      var optTheta = ((DoubleArray)bestIndividualAndQuality.Item1["OptTheta"]).ToArray(); // see evaluate
376      var trees = bestIndividualAndQuality.Item1.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
377      var nodeIdx = new Dictionary<ISymbolicExpressionTreeNode, int>();
378
379
380      foreach(var tree in trees) {
381        foreach(var node in tree.Root.IterateNodesPrefix().Where(n => IsConstantNode(n))) {
382          nodeIdx.Add(node, nodeIdx.Count);
383        }
384      }
385      var problemData = ProblemData;
386      var targetVars = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
387      var latentVariables = Enumerable.Range(1, NumberOfLatentVariables).Select(i => "λ" + i).ToArray(); // TODO: must coincide with the variables which are actually defined in the grammar and also for which we actually have trees
388
389      var trainingList = new ItemList<DataTable>();
390      var trainingPrediction = Integrate(
391       trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
392       problemData.Dataset,
393       problemData.AllowedInputVariables.ToArray(),
394       targetVars,
395       latentVariables,
396       TrainingEpisodes,
397       nodeIdx,
398       optTheta,
399       NumericIntegrationSteps).ToArray();
400
401      // only for actual target values
402      var trainingRows = TrainingEpisodes.SelectMany(e => Enumerable.Range(e.Start, e.End - e.Start));
403      for(int colIdx = 0; colIdx < targetVars.Length; colIdx++) {
404        var targetVar = targetVars[colIdx];
405        var trainingDataTable = new DataTable(targetVar + " prediction (training)");
406        var actualValuesRow = new DataRow(targetVar, "The values of " + targetVar, problemData.Dataset.GetDoubleValues(targetVar, trainingRows));
407        var predictedValuesRow = new DataRow(targetVar + " pred.", "Predicted values for " + targetVar, trainingPrediction.Select(arr => arr[colIdx].Item1).ToArray());
408        trainingDataTable.Rows.Add(actualValuesRow);
409        trainingDataTable.Rows.Add(predictedValuesRow);
410        trainingList.Add(trainingDataTable);
411      }
412
413      // TODO: DRY for training and test
414      var testList = new ItemList<DataTable>();
415      var testRows = ProblemData.TestIndices.ToArray();
416      var testPrediction = Integrate(
417       trees,  // we assume trees contain expressions for the change of each target variable over time y'(t)
418       problemData.Dataset,
419       problemData.AllowedInputVariables.ToArray(),
420       targetVars,
421       latentVariables,
422       new IntRange[] { ProblemData.TestPartition },
423       nodeIdx,
424       optTheta,
425       NumericIntegrationSteps).ToArray();
426
427      for(int colIdx = 0; colIdx < targetVars.Length; colIdx++) {
428        var targetVar = targetVars[colIdx];
429        var testDataTable = new DataTable(targetVar + " prediction (test)");
430        var actualValuesRow = new DataRow(targetVar, "The values of " + targetVar, problemData.Dataset.GetDoubleValues(targetVar, testRows));
431        var predictedValuesRow = new DataRow(targetVar + " pred.", "Predicted values for " + targetVar, testPrediction.Select(arr => arr[colIdx].Item1).ToArray());
432        testDataTable.Rows.Add(actualValuesRow);
433        testDataTable.Rows.Add(predictedValuesRow);
434        testList.Add(testDataTable);
435      }
436
437      results["Prediction (training)"].Value = trainingList.AsReadOnly();
438      results["Prediction (test)"].Value = testList.AsReadOnly();
439
440      #region simplification of models
441      // TODO the dependency of HeuristicLab.Problems.DataAnalysis.Symbolic is not ideal
442      var models = new VariableCollection();    // to store target var names and original version of tree
443
444      foreach(var tup in targetVars.Zip(trees, Tuple.Create)) {
445        var targetVarName = tup.Item1;
446        var tree = tup.Item2;
447
448        // when we reference HeuristicLab.Problems.DataAnalysis.Symbolic we can translate symbols
449        int nextParIdx = 0;
450        var shownTree = new SymbolicExpressionTree(TranslateTreeNode(tree.Root, optTheta, ref nextParIdx));
451
452        // var shownTree = (SymbolicExpressionTree)tree.Clone();
453        // var constantsNodeOrig = tree.IterateNodesPrefix().Where(IsConstantNode);
454        // var constantsNodeShown = shownTree.IterateNodesPrefix().Where(IsConstantNode);
455        //
456        // foreach (var n in constantsNodeOrig.Zip(constantsNodeShown, (original, shown) => new { original, shown })) {
457        //   double constantsVal = optTheta[nodeIdx[n.original]];
458        //
459        //   ConstantTreeNode replacementNode = new ConstantTreeNode(new Constant()) { Value = constantsVal };
460        //
461        //   var parentNode = n.shown.Parent;
462        //   int replacementIndex = parentNode.IndexOfSubtree(n.shown);
463        //   parentNode.RemoveSubtree(replacementIndex);
464        //   parentNode.InsertSubtree(replacementIndex, replacementNode);
465        // }
466
467        var origTreeVar = new HeuristicLab.Core.Variable(targetVarName + "(original)");
468        origTreeVar.Value = (ISymbolicExpressionTree)tree.Clone();
469        models.Add(origTreeVar);
470        var simplifiedTreeVar = new HeuristicLab.Core.Variable(targetVarName + "(simplified)");
471        simplifiedTreeVar.Value = TreeSimplifier.Simplify(shownTree);
472        models.Add(simplifiedTreeVar);
473
474      }
475      results["Models"].Value = models;
476      #endregion
477    }
478
479    private ISymbolicExpressionTreeNode TranslateTreeNode(ISymbolicExpressionTreeNode n, double[] parameterValues, ref int nextParIdx) {
480      ISymbolicExpressionTreeNode translatedNode = null;
481      if(n.Symbol is StartSymbol) {
482        translatedNode = new StartSymbol().CreateTreeNode();
483      } else if(n.Symbol is ProgramRootSymbol) {
484        translatedNode = new ProgramRootSymbol().CreateTreeNode();
485      } else if(n.Symbol.Name == "+") {
486        translatedNode = new Addition().CreateTreeNode();
487      } else if(n.Symbol.Name == "-") {
488        translatedNode = new Subtraction().CreateTreeNode();
489      } else if(n.Symbol.Name == "*") {
490        translatedNode = new Multiplication().CreateTreeNode();
491      } else if(n.Symbol.Name == "%") {
492        translatedNode = new Division().CreateTreeNode();
493      } else if(IsConstantNode(n)) {
494        var constNode = (ConstantTreeNode)new Constant().CreateTreeNode();
495        constNode.Value = parameterValues[nextParIdx];
496        nextParIdx++;
497        translatedNode = constNode;
498      } else {
499        // assume a variable name
500        var varName = n.Symbol.Name;
501        var varNode = (VariableTreeNode)new Variable().CreateTreeNode();
502        varNode.Weight = 1.0;
503        varNode.VariableName = varName;
504        translatedNode = varNode;
505      }
506      foreach(var child in n.Subtrees) {
507        translatedNode.AddSubtree(TranslateTreeNode(child, parameterValues, ref nextParIdx));
508      }
509      return translatedNode;
510    }
511
512    #region interpretation
513    private static IEnumerable<Tuple<double, Vector>[]> Integrate(
514      ISymbolicExpressionTree[] trees, IDataset dataset, string[] inputVariables, string[] targetVariables, string[] latentVariables, IEnumerable<IntRange> episodes,
515      Dictionary<ISymbolicExpressionTreeNode, int> nodeIdx, double[] parameterValues, int numericIntegrationSteps = 100) {
516
517      int NUM_STEPS = numericIntegrationSteps;
518      double h = 1.0 / NUM_STEPS;
519
520      foreach(var episode in episodes) {
521        var rows = Enumerable.Range(episode.Start, episode.End - episode.Start);
522        // return first value as stored in the dataset
523        yield return targetVariables
524          .Select(targetVar => Tuple.Create(dataset.GetDoubleValue(targetVar, rows.First()), Vector.Zero))
525          .ToArray();
526
527        // integrate forward starting with known values for the target in t0
528
529        var variableValues = new Dictionary<string, Tuple<double, Vector>>();
530        var t0 = rows.First();
531        foreach(var varName in inputVariables) {
532          variableValues.Add(varName, Tuple.Create(dataset.GetDoubleValue(varName, t0), Vector.Zero));
533        }
534        foreach(var varName in targetVariables) {
535          variableValues.Add(varName, Tuple.Create(dataset.GetDoubleValue(varName, t0), Vector.Zero));
536        }
537        // add value entries for latent variables which are also integrated
538        foreach(var latentVar in latentVariables) {
539          variableValues.Add(latentVar, Tuple.Create(0.0, Vector.Zero)); // we don't have observations for latent variables -> assume zero as starting value
540        }
541        var calculatedVariables = targetVariables.Concat(latentVariables); // TODO: must conincide with the order of trees in the encoding
542
543        foreach(var t in rows.Skip(1)) {
544          for(int step = 0; step < NUM_STEPS; step++) {
545            var deltaValues = new Dictionary<string, Tuple<double, Vector>>();
546            foreach(var tup in trees.Zip(calculatedVariables, Tuple.Create)) {
547              var tree = tup.Item1;
548              var targetVarName = tup.Item2;
549              // skip programRoot and startSymbol
550              var res = InterpretRec(tree.Root.GetSubtree(0).GetSubtree(0), variableValues, nodeIdx, parameterValues);
551              deltaValues.Add(targetVarName, res);
552            }
553
554            // update variableValues for next step
555            foreach(var kvp in deltaValues) {
556              var oldVal = variableValues[kvp.Key];
557              variableValues[kvp.Key] = Tuple.Create(
558                oldVal.Item1 + h * kvp.Value.Item1,
559                oldVal.Item2 + h * kvp.Value.Item2
560              );
561            }
562          }
563
564          // only return the target variables for calculation of errors
565          yield return targetVariables
566            .Select(targetVar => variableValues[targetVar])
567            .ToArray();
568
569          // update for next time step
570          foreach(var varName in inputVariables) {
571            variableValues[varName] = Tuple.Create(dataset.GetDoubleValue(varName, t), Vector.Zero);
572          }
573        }
574      }
575    }
576
577    private static Tuple<double, Vector> InterpretRec(
578      ISymbolicExpressionTreeNode node,
579      Dictionary<string, Tuple<double, Vector>> variableValues,
580      Dictionary<ISymbolicExpressionTreeNode, int> nodeIdx,
581      double[] parameterValues
582        ) {
583
584      switch(node.Symbol.Name) {
585        case "+": {
586            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues); // TODO capture all parameters into a state type for interpretation
587            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
588
589            return Tuple.Create(l.Item1 + r.Item1, l.Item2 + r.Item2);
590          }
591        case "*": {
592            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
593            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
594
595            return Tuple.Create(l.Item1 * r.Item1, l.Item2 * r.Item1 + l.Item1 * r.Item2);
596          }
597
598        case "-": {
599            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
600            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
601
602            return Tuple.Create(l.Item1 - r.Item1, l.Item2 - r.Item2);
603          }
604        case "%": {
605            var l = InterpretRec(node.GetSubtree(0), variableValues, nodeIdx, parameterValues);
606            var r = InterpretRec(node.GetSubtree(1), variableValues, nodeIdx, parameterValues);
607
608            // protected division
609            if(r.Item1.IsAlmost(0.0)) {
610              return Tuple.Create(0.0, Vector.Zero);
611            } else {
612              return Tuple.Create(
613                l.Item1 / r.Item1,
614                l.Item1 * -1.0 / (r.Item1 * r.Item1) * r.Item2 + 1.0 / r.Item1 * l.Item2 // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f'
615                );
616            }
617          }
618        default: {
619            // distinguish other cases
620            if(IsConstantNode(node)) {
621              var vArr = new double[parameterValues.Length]; // backing array for vector
622              vArr[nodeIdx[node]] = 1.0;
623              var g = new Vector(vArr);
624              return Tuple.Create(parameterValues[nodeIdx[node]], g);
625            } else {
626              // assume a variable name
627              var varName = node.Symbol.Name;
628              return variableValues[varName];
629            }
630          }
631      }
632    }
633    #endregion
634
635    #region events
636    /*
637     * Dependencies between parameters:
638     *
639     * ProblemData
640     *    |
641     *    V
642     * TargetVariables   FunctionSet    MaximumLength    NumberOfLatentVariables
643     *               |   |                 |                   |
644     *               V   V                 |                   |
645     *             Grammar <---------------+-------------------
646     *                |
647     *                V
648     *            Encoding
649     */
650    private void RegisterEventHandlers() {
651      ProblemDataParameter.ValueChanged += ProblemDataParameter_ValueChanged;
652      if(ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += ProblemData_Changed;
653
654      TargetVariablesParameter.ValueChanged += TargetVariablesParameter_ValueChanged;
655      if(TargetVariablesParameter.Value != null) TargetVariablesParameter.Value.CheckedItemsChanged += CheckedTargetVariablesChanged;
656
657      FunctionSetParameter.ValueChanged += FunctionSetParameter_ValueChanged;
658      if(FunctionSetParameter.Value != null) FunctionSetParameter.Value.CheckedItemsChanged += CheckedFunctionsChanged;
659
660      MaximumLengthParameter.Value.ValueChanged += MaximumLengthChanged;
661
662      NumberOfLatentVariablesParameter.Value.ValueChanged += NumLatentVariablesChanged;
663    }
664
665    private void NumLatentVariablesChanged(object sender, EventArgs e) {
666      UpdateGrammarAndEncoding();
667    }
668
669    private void MaximumLengthChanged(object sender, EventArgs e) {
670      UpdateGrammarAndEncoding();
671    }
672
673    private void FunctionSetParameter_ValueChanged(object sender, EventArgs e) {
674      FunctionSetParameter.Value.CheckedItemsChanged += CheckedFunctionsChanged;
675    }
676
677    private void CheckedFunctionsChanged(object sender, CollectionItemsChangedEventArgs<StringValue> e) {
678      UpdateGrammarAndEncoding();
679    }
680
681    private void TargetVariablesParameter_ValueChanged(object sender, EventArgs e) {
682      TargetVariablesParameter.Value.CheckedItemsChanged += CheckedTargetVariablesChanged;
683    }
684
685    private void CheckedTargetVariablesChanged(object sender, CollectionItemsChangedEventArgs<StringValue> e) {
686      UpdateGrammarAndEncoding();
687    }
688
689    private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) {
690      ProblemDataParameter.Value.Changed += ProblemData_Changed;
691      OnProblemDataChanged();
692      OnReset();
693    }
694
695    private void ProblemData_Changed(object sender, EventArgs e) {
696      OnProblemDataChanged();
697      OnReset();
698    }
699
700    private void OnProblemDataChanged() {
701      UpdateTargetVariables();        // implicitly updates other dependent parameters
702      var handler = ProblemDataChanged;
703      if(handler != null) handler(this, EventArgs.Empty);
704    }
705
706    #endregion
707
708    #region  helper
709
710    private void InitAllParameters() {
711      UpdateTargetVariables(); // implicitly updates the grammar and the encoding     
712    }
713
714    private ReadOnlyCheckedItemCollection<StringValue> CreateFunctionSet() {
715      var l = new CheckedItemCollection<StringValue>();
716      l.Add(new StringValue("+").AsReadOnly());
717      l.Add(new StringValue("*").AsReadOnly());
718      l.Add(new StringValue("%").AsReadOnly());
719      l.Add(new StringValue("-").AsReadOnly());
720      return l.AsReadOnly();
721    }
722
723    private static bool IsConstantNode(ISymbolicExpressionTreeNode n) {
724      return n.Symbol.Name.StartsWith("θ");
725    }
726    private static bool IsLatentVariableNode(ISymbolicExpressionTreeNode n) {
727      return n.Symbol.Name.StartsWith("λ");
728    }
729
730
731    private void UpdateTargetVariables() {
732      var currentlySelectedVariables = TargetVariables.CheckedItems.Select(i => i.Value).ToArray();
733
734      var newVariablesList = new CheckedItemCollection<StringValue>(ProblemData.Dataset.VariableNames.Select(str => new StringValue(str).AsReadOnly()).ToArray()).AsReadOnly();
735      var matchingItems = newVariablesList.Where(item => currentlySelectedVariables.Contains(item.Value)).ToArray();
736      foreach(var matchingItem in matchingItems) {
737        newVariablesList.SetItemCheckedState(matchingItem, true);
738      }
739      TargetVariablesParameter.Value = newVariablesList;
740    }
741
742    private void UpdateGrammarAndEncoding() {
743      var encoding = new MultiEncoding();
744      var g = CreateGrammar();
745      foreach(var targetVar in TargetVariables.CheckedItems) {
746        encoding = encoding.Add(new SymbolicExpressionTreeEncoding(targetVar + "_tree", g, MaximumLength, MaximumLength)); // only limit by length
747      }
748      for(int i = 1; i <= NumberOfLatentVariables; i++) {
749        encoding = encoding.Add(new SymbolicExpressionTreeEncoding("λ" + i + "_tree", g, MaximumLength, MaximumLength));
750      }
751      Encoding = encoding;
752    }
753
754    private ISymbolicExpressionGrammar CreateGrammar() {
755      // whenever ProblemData is changed we create a new grammar with the necessary symbols
756      var g = new SimpleSymbolicExpressionGrammar();
757      g.AddSymbols(FunctionSet.CheckedItems.Select(i => i.Value).ToArray(), 2, 2);
758
759      // TODO
760      //g.AddSymbols(new[] {
761      //  "exp",
762      //  "log", // log( <expr> ) // TODO: init a theta to ensure the value is always positive
763      //  "exp_minus" // exp((-1) * <expr>
764      //}, 1, 1);
765
766      foreach(var variableName in ProblemData.AllowedInputVariables.Union(TargetVariables.CheckedItems.Select(i => i.Value)))
767        g.AddTerminalSymbol(variableName);
768
769      // generate symbols for numeric parameters for which the value is optimized using AutoDiff
770      // we generate multiple symbols to balance the probability for selecting a numeric parameter in the generation of random trees
771      var numericConstantsFactor = 2.0;
772      for(int i = 0; i < numericConstantsFactor * (ProblemData.AllowedInputVariables.Count() + TargetVariables.CheckedItems.Count()); i++) {
773        g.AddTerminalSymbol("θ" + i); // numeric parameter for which the value is optimized using AutoDiff
774      }
775
776      // generate symbols for latent variables
777      for(int i = 1; i <= NumberOfLatentVariables; i++) {
778        g.AddTerminalSymbol("λ" + i); // numeric parameter for which the value is optimized using AutoDiff
779      }
780
781      return g;
782    }
783
784    #endregion
785
786    #region Import & Export
787    public void Load(IRegressionProblemData data) {
788      Name = data.Name;
789      Description = data.Description;
790      ProblemData = data;
791    }
792
793    public IRegressionProblemData Export() {
794      return ProblemData;
795    }
796    #endregion
797
798  }
799}
Note: See TracBrowser for help on using the repository browser.