Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GeneticProgramming.BloodGlucosePrediction/Interpreter.cs @ 14559

Last change on this file since 14559 was 14311, checked in by gkronber, 8 years ago

simplification of grammar and problem and bug fixes related to precalculated smoothed features

File size: 8.1 KB
Line 
1using System;
2using System.Collections;
3using System.Collections.Generic;
4using System.Collections.ObjectModel;
5using System.Collections.Specialized;
6using System.Drawing.Design;
7using System.Linq;
8using HeuristicLab.Common;
9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
10using HeuristicLab.Problems.DataAnalysis;
11using HeuristicLab.Problems.DataAnalysis.Symbolic;
12
13namespace HeuristicLab.Problems.GeneticProgramming.GlucosePrediction {
14  public static class Interpreter {
15    private class Data {
16      public double[] realGluc;
17      public double[] realIns;
18      public double[] realCh;
19      public Dictionary<ISymbolicExpressionTreeNode, double[]> precalculatedValues;
20    }
21
22    public static IEnumerable<double> Apply(ISymbolicExpressionTreeNode model, IDataset dataset, IEnumerable<int> rows) {
23      double[] targetGluc = dataset.GetDoubleValues("Glucose_target", rows).ToArray(); // only for skipping rows for which we should not produce an output
24
25      var data = new Data {
26        realGluc = dataset.GetDoubleValues("Glucose_Interpol", rows).ToArray(),
27        realIns = dataset.GetDoubleValues("Insuline", rows).ToArray(),
28        realCh = dataset.GetDoubleValues("CH", rows).ToArray(),
29        precalculatedValues = CreatePrecalculatedValues(model, dataset)
30      };
31      var predictions = new double[targetGluc.Length];
32      var rowsEnumerator = rows.GetEnumerator();
33      rowsEnumerator.MoveNext();
34      for (int k = 0; k < predictions.Length; k++, rowsEnumerator.MoveNext()) {
35        if (double.IsNaN(targetGluc[k])) {
36          predictions[k] = double.NaN;
37        } else {
38          var rawPred = InterpretRec(model, data, rowsEnumerator.Current);
39          predictions[k] = rawPred;
40        }
41      }
42      return predictions;
43    }
44
45    private static Dictionary<ISymbolicExpressionTreeNode, double[]> CreatePrecalculatedValues(ISymbolicExpressionTreeNode root, IDataset dataset) {
46      var dict = new Dictionary<ISymbolicExpressionTreeNode, double[]>();
47      // here we integrate ins or ch inputs over the whole day to generate smoothed ins/ch values with the same number of rows
48      // the integrated values are reset to zero whenever a new evluation period starts
49      foreach (var node in root.IterateNodesPrefix()) {
50        var curvedInsNode = node as CurvedInsVariableTreeNode;
51        var curvedChNode = node as CurvedChVariableTreeNode;
52        if (curvedInsNode != null) {
53          dict.Add(curvedInsNode, Integrate(curvedInsNode, dataset));
54        } else if (curvedChNode != null) {
55          dict.Add(curvedChNode, Integrate(curvedChNode, dataset));
56        }
57      }
58      return dict;
59    }
60
61    private static double[] Integrate(CurvedInsVariableTreeNode node, IDataset dataset) {
62      // d Q1 / dt = ins(t) - alpha * Q1(t)
63      // d Q2 / dt = alpha * (Q1(t) - Q2(t))
64      // d Q3 / dt = alpha * Q2(t) - beta * Q3(t)
65      var alpha = node.Alpha;
66      var beta = node.Beta;
67
68      var ins = dataset.GetReadOnlyDoubleValues("Insuline");
69      var time = dataset.GetReadOnlyDoubleValues("HourMin").ToArray();
70
71      double q1, q2, q3, q1_prev, q2_prev, q3_prev;
72      // starting values: zeros
73      q1 = q2 = q3 = q1_prev = q2_prev = q3_prev = 0;
74      double[] s = new double[dataset.Rows];
75
76      for (int t = 1; t < dataset.Rows; t++) {
77        if (IsStartOfNewPeriod(time, t)) {
78          q1 = q2 = q3 = q1_prev = q2_prev = q3_prev = 0;
79        }
80        q1 = q1_prev + ins[t] - alpha * q1_prev;
81        q2 = q2_prev + alpha * (q1_prev - q2_prev);
82        q3 = q3_prev + alpha * q2_prev - beta * q3_prev;
83        s[t] = q3;
84        q1_prev = q1;
85        q2_prev = q2;
86        q3_prev = q3;
87
88      }
89      return s;
90    }
91
92    private static bool IsStartOfNewPeriod(double[] time, int t) {
93      return t == 0 ||
94             (time[t].IsAlmost(2005) && !time[t - 1].IsAlmost(2000));
95    }
96
97
98    private static double[] Integrate(CurvedChVariableTreeNode node, IDataset dataset) {
99      // d Q1 / dt = ins(t) - alpha * Q1(t)
100      // d Q2 / dt = alpha * (Q1(t) - Q2(t))
101      // d Q3 / dt = alpha * Q2(t) - beta * Q3(t)
102      var alpha = node.Alpha;
103      var beta = node.Beta;
104
105      var ins = dataset.GetReadOnlyDoubleValues("CH");
106      var time = dataset.GetReadOnlyDoubleValues("HourMin").ToArray();
107
108      double q1, q2, q3, q1_prev, q2_prev, q3_prev;
109      // starting values: zeros
110      q1 = q2 = q3 = q1_prev = q2_prev = q3_prev = 0;
111      double[] s = new double[dataset.Rows];
112
113      for (int t = 1; t < dataset.Rows; t++) {
114        if (IsStartOfNewPeriod(time, t)) {
115          q1 = q2 = q3 = q1_prev = q2_prev = q3_prev = 0;
116        }
117        q1 = q1_prev + ins[t] - alpha * q1_prev;
118        q2 = q2_prev + alpha * (q1_prev - q2_prev);
119        q3 = q3_prev + alpha * q2_prev - beta * q3_prev;
120        s[t] = q3;
121        q1_prev = q1;
122        q2_prev = q2;
123        q3_prev = q3;
124
125      }
126      return s;
127    }
128
129    private static double InterpretRec(ISymbolicExpressionTreeNode node, Data data, int k) {
130      if (node.Symbol is SimpleSymbol) {
131        switch (node.Symbol.Name) {
132          case "+":
133          case "+Ins":
134          case "+Ch": {
135              return InterpretRec(node.GetSubtree(0), data, k) + InterpretRec(node.GetSubtree(1), data, k);
136            }
137          case "-":
138          case "-Ins":
139          case "-Ch": {
140              return InterpretRec(node.GetSubtree(0), data, k) - InterpretRec(node.GetSubtree(1), data, k);
141            }
142          case "*":
143          case "*Ins":
144          case "*Ch": {
145              return InterpretRec(node.GetSubtree(0), data, k) * InterpretRec(node.GetSubtree(1), data, k);
146            }
147          case "/Ch":
148          case "/Ins":
149          case "/": {
150              return InterpretRec(node.GetSubtree(0), data, k) / InterpretRec(node.GetSubtree(1), data, k);
151            }
152          case "Exp":
153          case "ExpIns":
154          case "ExpCh": {
155              return Math.Exp(InterpretRec(node.GetSubtree(0), data, k));
156            }
157          case "Sin":
158          case "SinIns":
159          case "SinCh": {
160              return Math.Sin(InterpretRec(node.GetSubtree(0), data, k));
161            }
162          case "CosCh":
163          case "CosIns":
164          case "Cos": {
165              return Math.Cos(InterpretRec(node.GetSubtree(0), data, k));
166            }
167          case "LogCh":
168          case "LogIns":
169          case "Log": {
170              return Math.Log(InterpretRec(node.GetSubtree(0), data, k));
171            }
172          case "Func": {
173              // <exprgluc> + <exprch> - <exprins>
174              return InterpretRec(node.GetSubtree(0), data, k)
175                     + InterpretRec(node.GetSubtree(1), data, k)
176                     - InterpretRec(node.GetSubtree(2), data, k);
177            }
178          case "ExprGluc": {
179              return InterpretRec(node.GetSubtree(0), data, k);
180            }
181          case "ExprCh": {
182              return InterpretRec(node.GetSubtree(0), data, k);
183            }
184          case "ExprIns": {
185              return InterpretRec(node.GetSubtree(0), data, k);
186            }
187          default: {
188              throw new InvalidProgramException("Found an unknown symbol " + node.Symbol);
189            }
190        }
191      } else if (node.Symbol is PredictedGlucoseVariableSymbol) {
192        throw new NotSupportedException();
193      } else if (node.Symbol is RealGlucoseVariableSymbol) {
194        var n = (RealGlucoseVariableTreeNode)node;
195        if (k + n.RowOffset < 0 || k + n.RowOffset >= data.realGluc.Length) return double.NaN;
196        return data.realGluc[k + n.RowOffset];
197      } else if (node.Symbol is CurvedChVariableSymbol) {
198        return data.precalculatedValues[node][k];
199      } else if (node.Symbol is RealInsulineVariableSymbol) {
200        throw new NotSupportedException();
201      } else if (node.Symbol is CurvedInsVariableSymbol) {
202        return data.precalculatedValues[node][k];
203      } else if (node.Symbol is Constant) {
204        var n = (ConstantTreeNode)node;
205        return n.Value;
206      } else {
207        throw new InvalidProgramException("found unknown symbol " + node.Symbol);
208      }
209    }
210
211  }
212}
Note: See TracBrowser for help on using the repository browser.