Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GeneticProgramming.BloodGlucosePrediction/Interpreter.cs @ 13867

Last change on this file since 13867 was 13867, checked in by gkronber, 8 years ago

#2608 worked on glucose prediction problem

File size: 6.5 KB
Line 
1using System;
2using System.Collections;
3using System.Collections.Generic;
4using System.Collections.Specialized;
5using System.Drawing.Design;
6using System.Linq;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis;
9
10namespace HeuristicLab.Problems.GeneticProgramming.GlucosePrediction {
11  public static class Interpreter {
12    private class Data {
13      public double[] realGluc;
14      public double[] realIns;
15      public double[] realCh;
16      public double[] predGluc;
17    }
18
19    public static IEnumerable<double> Apply(ISymbolicExpressionTreeNode model, IDataset dataset, IEnumerable<int> rows) {
20      double[] targetGluc = dataset.GetDoubleValues("Glucose_target", rows).ToArray(); // only for skipping rows for which we should not produce an output
21
22      var data = new Data {
23        realGluc = dataset.GetDoubleValues("Glucose_Interpol", rows).ToArray(),
24        realIns = dataset.GetDoubleValues("Insuline", rows).ToArray(),
25        realCh = dataset.GetDoubleValues("CH", rows).ToArray(),
26      };
27      data.predGluc = new double[data.realGluc.Length];
28      for (int k = 0; k < data.predGluc.Length; k++) {
29        if (double.IsNaN(targetGluc[k])) {
30          data.predGluc[k] = double.NaN;
31        } else {
32          var rawPred = InterpretRec(model, data, k);
33          data.predGluc[k] = Math.Max(0, Math.Min(400, rawPred)); // limit output values of the model to 0 ... 400
34        }
35      }
36      return data.predGluc;
37    }
38
39    private static double InterpretRec(ISymbolicExpressionTreeNode node, Data data, int k) {
40      if (node.Symbol is SimpleSymbol) {
41        switch (node.Symbol.Name) {
42          case "+":
43          case "+Ins":
44          case "+Ch": {
45              return InterpretRec(node.GetSubtree(0), data, k) + InterpretRec(node.GetSubtree(1), data, k);
46            }
47          case "-":
48          case "-Ins":
49          case "-Ch": {
50              return InterpretRec(node.GetSubtree(0), data, k) - InterpretRec(node.GetSubtree(1), data, k);
51            }
52          case "*":
53          case "*Ins":
54          case "*Ch": {
55              return InterpretRec(node.GetSubtree(0), data, k) * InterpretRec(node.GetSubtree(1), data, k);
56            }
57          case "/Ch":
58          case "/Ins":
59          case "/": {
60              return InterpretRec(node.GetSubtree(0), data, k) / InterpretRec(node.GetSubtree(1), data, k);
61            }
62          case "Exp":
63          case "ExpIns":
64          case "ExpCh": {
65              return Math.Exp(InterpretRec(node.GetSubtree(0), data, k));
66            }
67          case "Sin":
68          case "SinIns":
69          case "SinCh": {
70              return Math.Sin(InterpretRec(node.GetSubtree(0), data, k));
71            }
72          case "CosCh":
73          case "CosIns":
74          case "Cos": {
75              return Math.Cos(InterpretRec(node.GetSubtree(0), data, k));
76            }
77          case "LogCh":
78          case "LogIns":
79          case "Log": {
80              return Math.Log(InterpretRec(node.GetSubtree(0), data, k));
81            }
82          case "Func": {
83              // <exprgluc> + <exprch> - <exprins>
84              return InterpretRec(node.GetSubtree(0), data, k)
85                     + InterpretRec(node.GetSubtree(1), data, k)
86                     - InterpretRec(node.GetSubtree(2), data, k);
87            }
88          case "ExprGluc": {
89              return InterpretRec(node.GetSubtree(0), data, k);
90            }
91          case "ExprCh": {
92              return InterpretRec(node.GetSubtree(0), data, k);
93            }
94          case "ExprIns": {
95              return InterpretRec(node.GetSubtree(0), data, k);
96            }
97          default: {
98              throw new InvalidProgramException("Found an unknown symbol " + node.Symbol);
99            }
100        }
101      } else if (node.Symbol is PredictedGlucoseVariableSymbol) {
102        var n = (PredictedGlucoseVariableTreeNode)node;
103        return n.Weight * data.predGluc[k + n.RowOffset];
104      } else if (node.Symbol is RealGlucoseVariableSymbol) {
105        var n = (RealGlucoseVariableTreeNode)node;
106        return n.Weight * data.realGluc[k + n.RowOffset];
107      } else if (node.Symbol is CurvedChVariableSymbol) {
108        var n = (CurvedChVariableTreeNode)node;
109        double prevVal;
110        int prevValDistance;
111        GetPrevDataAndDistance(data.realCh, k, out prevVal, out prevValDistance, maxDistance: 48);
112        return n.Weight * prevVal * Beta(prevValDistance / 48.0, n.Alpha, n.Beta);
113      } else if (node.Symbol is RealInsulineVariableSymbol) {
114        var n = (RealInsulineVariableTreeNode)node;
115        return n.Weight * data.realIns[k + n.RowOffset];
116      } else if (node.Symbol is CurvedInsVariableSymbol) {
117        var n = (CurvedInsVariableTreeNode)node;
118        double maxVal;
119        int maxValDistance;
120        var sum = GetSumOfValues(48, k, data.realIns);
121
122        GetMaxValueAndDistance(data.realIns, k, out maxVal, out maxValDistance, maxDistance: 48);
123        return n.Weight * (sum - maxVal) * maxVal * Beta(maxValDistance / 48.0, n.Alpha, n.Beta);
124      } else {
125        throw new InvalidProgramException("found unknown symbol " + node.Symbol);
126      }
127    }
128
129    private static double Beta(double x, double alpha, double beta) {
130      return 1.0 / alglib.beta(alpha, beta) * Math.Pow(x, alpha - 1) * Math.Pow(1 - x, beta - 1);
131    }
132
133    private static void GetPrevDataAndDistance(double[] vals, int k, out double val, out int dist, int maxDistance = 48, double threshold = 0.0) {
134      // look backward from the current idx k and find the first value above the threshold
135      for (int i = k; i >= 0 && i >= (k - maxDistance); i--) {
136        if (vals[i] > threshold) {
137          val = vals[i];
138          dist = k - i;
139          return;
140        }
141      }
142      val = 0;
143      dist = maxDistance;
144    }
145
146    private static double GetSumOfValues(int windowSize, int k, double[] vals) {
147      var sum = 0.0;
148      for (int i = k; i >= 0 && i >= k - windowSize; i--)
149        sum += vals[i];
150      return sum;
151    }
152
153    private static void GetMaxValueAndDistance(double[] vals, int k, out double maxVal, out int dist, int maxDistance = 48) {
154      // look backward from the current idx k and find the max value and it's distance
155      maxVal = vals[k];
156      dist = 0;
157      for (int i = k; i >= 0 && i >= (k - maxDistance); i--) {
158        if (vals[i] > maxVal) {
159          maxVal = vals[i];
160          dist = k - i;
161        }
162      }
163    }
164  }
165}
Note: See TracBrowser for help on using the repository browser.