source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Solution.cs @ 16660

Last change on this file since 16660 was 16660, checked in by gkronber, 2 years ago

#2925: re-introduced partial support for latent variables

File size: 4.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Data;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Problems.DataAnalysis.Symbolic;
11using HeuristicLab.Random;
12
13namespace HeuristicLab.Problems.DynamicalSystemsModelling {
14  [StorableClass]
15  public class Solution : Item {
16    [Storable]
17    private ISymbolicExpressionTree[] trees;
18    public ISymbolicExpressionTree[] Trees {
19      get { return trees; }
20    }
21    // [Storable]
22    // private double[] theta;
23
24    [Storable]
25    private IRegressionProblemData problemData;
26    public IRegressionProblemData ProblemData {
27      get { return problemData; }
28    }
29    [Storable]
30    private string[] targetVars;
31    public string[] TargetVariables {
32      get { return targetVars; }
33    }
34    [Storable]
35    private string[] latentVariables;
36    public string[] LatentVariables {
37      get { return latentVariables; }
38    }
39    [Storable]
40    private IEnumerable<IntRange> trainingEpisodes;
41    public IEnumerable<IntRange> TrainingEpisodes {
42      get { return trainingEpisodes; }
43    }
44    [Storable]
45    private string odeSolver;
46    [Storable]
47    private int numericIntegrationSteps;
48
49    [StorableConstructor]
50    private Solution(bool deserializing) : base(deserializing) { }
51    [StorableHook(HookType.AfterDeserialization)]
52    private void AfterDeserialization() {
53    }
54
55    // cloning
56    private Solution(Solution original, Cloner cloner)
57      : base(original, cloner) {
58      this.trees = new ISymbolicExpressionTree[original.trees.Length];
59      for (int i = 0; i < trees.Length; i++) this.trees[i] = cloner.Clone(original.trees[i]);
60      // this.theta = new double[original.theta.Length];
61      // Array.Copy(original.theta, this.theta, this.theta.Length);
62      this.problemData = cloner.Clone(original.problemData);
63      this.targetVars = original.TargetVariables.ToArray();
64      this.latentVariables = original.LatentVariables.ToArray();
65      this.trainingEpisodes = original.TrainingEpisodes.Select(te => cloner.Clone(te)).ToArray();
66      this.odeSolver = original.odeSolver;
67      this.numericIntegrationSteps = original.numericIntegrationSteps;
68    }
69
70    public Solution(ISymbolicExpressionTree[] trees,
71      IRegressionProblemData problemData,
72      string[] targetVars, string[] latentVariables, IEnumerable<IntRange> trainingEpisodes,
73      string odeSolver, int numericIntegrationSteps) : base() {
74      this.trees = trees;
75
76      this.problemData = problemData;
77      this.targetVars = targetVars;
78      this.latentVariables = latentVariables;
79      this.trainingEpisodes = trainingEpisodes;
80      this.odeSolver = odeSolver;
81      this.numericIntegrationSteps = numericIntegrationSteps;
82    }
83
84    public override IDeepCloneable Clone(Cloner cloner) {
85      return new Solution(this, cloner);
86    }
87
88    public IEnumerable<double[]> Predict(IntRange episode, int forecastHorizon) {
89      var forecastEpisode = new IntRange(episode.Start, episode.End + forecastHorizon);
90      //
91      // var random = new FastRandom(12345);
92      // snmse = Problem.OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { forecastEpisode }, 100, numericIntegrationSteps, odeSolver);
93
94      var inputVariables = trees.SelectMany(t => t.IterateNodesPrefix().OfType<VariableTreeNode>().Select(n => n.VariableName))
95        .Except(targetVars)
96        .Except(latentVariables)
97        .Distinct();
98
99      var optimizationData = new Problem.OptimizationData(trees, targetVars, inputVariables.ToArray(), problemData, null, new[] { forecastEpisode }, numericIntegrationSteps, latentVariables, odeSolver);
100      //
101      //
102      // var theta = Problem.ExtractParametersFromTrees(trees);
103
104
105      var fi = new double[forecastEpisode.Size * targetVars.Length];
106      var jac = new double[forecastEpisode.Size * targetVars.Length, optimizationData.nodeValueLookup.ParameterCount];
107      var latentValues = new double[forecastEpisode.Size, LatentVariables.Length];
108      Problem.Integrate(optimizationData, fi, jac, latentValues);
109      for (int i = 0; i < forecastEpisode.Size; i++) {
110        var res = new double[targetVars.Length + latentVariables.Length];
111        for (int j = 0; j < targetVars.Length; j++) {
112          res[j] = fi[i * targetVars.Length + j];
113        }
114        for (int j = 0; j < latentVariables.Length; j++) {
115          res[targetVars.Length + j] = latentValues[i, j];
116        }
117        yield return res;
118      }
119    }
120  }
121}
Note: See TracBrowser for help on using the repository browser.