Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GeneticProgramming.BloodGlucosePrediction/Interpreter.cs @ 14005

Last change on this file since 14005 was 13870, checked in by gkronber, 9 years ago

#2608: simplified grammar by removing unnecessary symbols: exprch, exprins, exprgluc and added constants

File size: 6.7 KB
Line 
1using System;
2using System.Collections;
3using System.Collections.Generic;
4using System.Collections.Specialized;
5using System.Drawing.Design;
6using System.Linq;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis;
9using HeuristicLab.Problems.DataAnalysis.Symbolic;
10
11namespace HeuristicLab.Problems.GeneticProgramming.GlucosePrediction {
12  public static class Interpreter {
13    private class Data {
14      public double[] realGluc;
15      public double[] realIns;
16      public double[] realCh;
17      public double[] predGluc;
18    }
19
20    public static IEnumerable<double> Apply(ISymbolicExpressionTreeNode model, IDataset dataset, IEnumerable<int> rows) {
21      double[] targetGluc = dataset.GetDoubleValues("Glucose_target", rows).ToArray(); // only for skipping rows for which we should not produce an output
22
23      var data = new Data {
24        realGluc = dataset.GetDoubleValues("Glucose_Interpol", rows).ToArray(),
25        realIns = dataset.GetDoubleValues("Insuline", rows).ToArray(),
26        realCh = dataset.GetDoubleValues("CH", rows).ToArray(),
27      };
28      data.predGluc = new double[data.realGluc.Length];
29      Array.Copy(data.realGluc, data.predGluc, data.predGluc.Length);
30      for (int k = 0; k < data.predGluc.Length; k++) {
31        if (double.IsNaN(targetGluc[k])) {
32          data.predGluc[k] = double.NaN;
33        } else {
34          var rawPred = InterpretRec(model, data, k);
35          data.predGluc[k] = Math.Max(0, Math.Min(400, rawPred)); // limit output values of the model to 0 ... 400
36        }
37      }
38      return data.predGluc;
39    }
40
41    private static double InterpretRec(ISymbolicExpressionTreeNode node, Data data, int k) {
42      if (node.Symbol is SimpleSymbol) {
43        switch (node.Symbol.Name) {
44          case "+":
45          case "+Ins":
46          case "+Ch": {
47              return InterpretRec(node.GetSubtree(0), data, k) + InterpretRec(node.GetSubtree(1), data, k);
48            }
49          case "-":
50          case "-Ins":
51          case "-Ch": {
52              return InterpretRec(node.GetSubtree(0), data, k) - InterpretRec(node.GetSubtree(1), data, k);
53            }
54          case "*":
55          case "*Ins":
56          case "*Ch": {
57              return InterpretRec(node.GetSubtree(0), data, k) * InterpretRec(node.GetSubtree(1), data, k);
58            }
59          case "/Ch":
60          case "/Ins":
61          case "/": {
62              return InterpretRec(node.GetSubtree(0), data, k) / InterpretRec(node.GetSubtree(1), data, k);
63            }
64          case "Exp":
65          case "ExpIns":
66          case "ExpCh": {
67              return Math.Exp(InterpretRec(node.GetSubtree(0), data, k));
68            }
69          case "Sin":
70          case "SinIns":
71          case "SinCh": {
72              return Math.Sin(InterpretRec(node.GetSubtree(0), data, k));
73            }
74          case "CosCh":
75          case "CosIns":
76          case "Cos": {
77              return Math.Cos(InterpretRec(node.GetSubtree(0), data, k));
78            }
79          case "LogCh":
80          case "LogIns":
81          case "Log": {
82              return Math.Log(InterpretRec(node.GetSubtree(0), data, k));
83            }
84          case "Func": {
85              // <exprgluc> + <exprch> - <exprins>
86              return InterpretRec(node.GetSubtree(0), data, k)
87                     + InterpretRec(node.GetSubtree(1), data, k)
88                     - InterpretRec(node.GetSubtree(2), data, k);
89            }
90          case "ExprGluc": {
91              return InterpretRec(node.GetSubtree(0), data, k);
92            }
93          case "ExprCh": {
94              return InterpretRec(node.GetSubtree(0), data, k);
95            }
96          case "ExprIns": {
97              return InterpretRec(node.GetSubtree(0), data, k);
98            }
99          default: {
100              throw new InvalidProgramException("Found an unknown symbol " + node.Symbol);
101            }
102        }
103      } else if (node.Symbol is PredictedGlucoseVariableSymbol) {
104        var n = (PredictedGlucoseVariableTreeNode)node;
105        return n.Weight * data.predGluc[k + n.RowOffset];
106      } else if (node.Symbol is RealGlucoseVariableSymbol) {
107        var n = (RealGlucoseVariableTreeNode)node;
108        return n.Weight * data.realGluc[k + n.RowOffset];
109      } else if (node.Symbol is CurvedChVariableSymbol) {
110        var n = (CurvedChVariableTreeNode)node;
111        double prevVal;
112        int prevValDistance;
113        GetPrevDataAndDistance(data.realCh, k, out prevVal, out prevValDistance, maxDistance: 48);
114        return n.Weight * prevVal * Beta(prevValDistance / 48.0, n.Alpha, n.Beta);
115      } else if (node.Symbol is RealInsulineVariableSymbol) {
116        var n = (RealInsulineVariableTreeNode)node;
117        return n.Weight * data.realIns[k + n.RowOffset];
118      } else if (node.Symbol is CurvedInsVariableSymbol) {
119        var n = (CurvedInsVariableTreeNode)node;
120        double maxVal;
121        int maxValDistance;
122        var sum = GetSumOfValues(48, k, data.realIns);
123
124        GetMaxValueAndDistance(data.realIns, k, out maxVal, out maxValDistance, maxDistance: 48);
125        return n.Weight * (sum - maxVal) * maxVal * Beta(maxValDistance / 48.0, n.Alpha, n.Beta);
126      } else if (node.Symbol is Constant) {
127        var n = (ConstantTreeNode)node;
128        return n.Value;
129      } else {
130        throw new InvalidProgramException("found unknown symbol " + node.Symbol);
131      }
132    }
133
134    private static double Beta(double x, double alpha, double beta) {
135      return 1.0 / alglib.beta(alpha, beta) * Math.Pow(x, alpha - 1) * Math.Pow(1 - x, beta - 1);
136    }
137
138    private static void GetPrevDataAndDistance(double[] vals, int k, out double val, out int dist, int maxDistance = 48, double threshold = 0.0) {
139      // look backward from the current idx k and find the first value above the threshold
140      for (int i = k; i >= 0 && i >= (k - maxDistance); i--) {
141        if (vals[i] > threshold) {
142          val = vals[i];
143          dist = k - i;
144          return;
145        }
146      }
147      val = 0;
148      dist = maxDistance;
149    }
150
151    private static double GetSumOfValues(int windowSize, int k, double[] vals) {
152      var sum = 0.0;
153      for (int i = k; i >= 0 && i >= k - windowSize; i--)
154        sum += vals[i];
155      return sum;
156    }
157
158    private static void GetMaxValueAndDistance(double[] vals, int k, out double maxVal, out int dist, int maxDistance = 48) {
159      // look backward from the current idx k and find the max value and it's distance
160      maxVal = vals[k];
161      dist = 0;
162      for (int i = k; i >= 0 && i >= (k - maxDistance); i--) {
163        if (vals[i] > maxVal) {
164          maxVal = vals[i];
165          dist = k - i;
166        }
167      }
168    }
169  }
170}
Note: See TracBrowser for help on using the repository browser.