using System; using System.Collections; using System.Collections.Generic; using System.Collections.Specialized; using System.Drawing.Design; using System.Linq; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; namespace HeuristicLab.Problems.GeneticProgramming.GlucosePrediction { public static class Interpreter { private class Data { public double[] realGluc; public double[] realIns; public double[] realCh; public double[] predGluc; } public static IEnumerable Apply(ISymbolicExpressionTreeNode model, IDataset dataset, IEnumerable rows) { double[] targetGluc = dataset.GetDoubleValues("Glucose_target", rows).ToArray(); // only for skipping rows for which we should not produce an output var data = new Data { realGluc = dataset.GetDoubleValues("Glucose_Interpol", rows).ToArray(), realIns = dataset.GetDoubleValues("Insuline", rows).ToArray(), realCh = dataset.GetDoubleValues("CH", rows).ToArray(), }; data.predGluc = new double[data.realGluc.Length]; Array.Copy(data.realGluc, data.predGluc, data.predGluc.Length); for (int k = 0; k < data.predGluc.Length; k++) { if (double.IsNaN(targetGluc[k])) { data.predGluc[k] = double.NaN; } else { var rawPred = InterpretRec(model, data, k); data.predGluc[k] = Math.Max(0, Math.Min(400, rawPred)); // limit output values of the model to 0 ... 400 } } return data.predGluc; } private static double InterpretRec(ISymbolicExpressionTreeNode node, Data data, int k) { if (node.Symbol is SimpleSymbol) { switch (node.Symbol.Name) { case "+": case "+Ins": case "+Ch": { return InterpretRec(node.GetSubtree(0), data, k) + InterpretRec(node.GetSubtree(1), data, k); } case "-": case "-Ins": case "-Ch": { return InterpretRec(node.GetSubtree(0), data, k) - InterpretRec(node.GetSubtree(1), data, k); } case "*": case "*Ins": case "*Ch": { return InterpretRec(node.GetSubtree(0), data, k) * InterpretRec(node.GetSubtree(1), data, k); } case "/Ch": case "/Ins": case "/": { return InterpretRec(node.GetSubtree(0), data, k) / InterpretRec(node.GetSubtree(1), data, k); } case "Exp": case "ExpIns": case "ExpCh": { return Math.Exp(InterpretRec(node.GetSubtree(0), data, k)); } case "Sin": case "SinIns": case "SinCh": { return Math.Sin(InterpretRec(node.GetSubtree(0), data, k)); } case "CosCh": case "CosIns": case "Cos": { return Math.Cos(InterpretRec(node.GetSubtree(0), data, k)); } case "LogCh": case "LogIns": case "Log": { return Math.Log(InterpretRec(node.GetSubtree(0), data, k)); } case "Func": { // + - return InterpretRec(node.GetSubtree(0), data, k) + InterpretRec(node.GetSubtree(1), data, k) - InterpretRec(node.GetSubtree(2), data, k); } case "ExprGluc": { return InterpretRec(node.GetSubtree(0), data, k); } case "ExprCh": { return InterpretRec(node.GetSubtree(0), data, k); } case "ExprIns": { return InterpretRec(node.GetSubtree(0), data, k); } default: { throw new InvalidProgramException("Found an unknown symbol " + node.Symbol); } } } else if (node.Symbol is PredictedGlucoseVariableSymbol) { var n = (PredictedGlucoseVariableTreeNode)node; return n.Weight * data.predGluc[k + n.RowOffset]; } else if (node.Symbol is RealGlucoseVariableSymbol) { var n = (RealGlucoseVariableTreeNode)node; return n.Weight * data.realGluc[k + n.RowOffset]; } else if (node.Symbol is CurvedChVariableSymbol) { var n = (CurvedChVariableTreeNode)node; double prevVal; int prevValDistance; GetPrevDataAndDistance(data.realCh, k, out prevVal, out prevValDistance, maxDistance: 48); return n.Weight * prevVal * Beta(prevValDistance / 48.0, n.Alpha, n.Beta); } else if (node.Symbol is RealInsulineVariableSymbol) { var n = (RealInsulineVariableTreeNode)node; return n.Weight * data.realIns[k + n.RowOffset]; } else if (node.Symbol is CurvedInsVariableSymbol) { var n = (CurvedInsVariableTreeNode)node; double maxVal; int maxValDistance; var sum = GetSumOfValues(48, k, data.realIns); GetMaxValueAndDistance(data.realIns, k, out maxVal, out maxValDistance, maxDistance: 48); return n.Weight * (sum - maxVal) * maxVal * Beta(maxValDistance / 48.0, n.Alpha, n.Beta); } else if (node.Symbol is Constant) { var n = (ConstantTreeNode)node; return n.Value; } else { throw new InvalidProgramException("found unknown symbol " + node.Symbol); } } private static double Beta(double x, double alpha, double beta) { return 1.0 / alglib.beta(alpha, beta) * Math.Pow(x, alpha - 1) * Math.Pow(1 - x, beta - 1); } private static void GetPrevDataAndDistance(double[] vals, int k, out double val, out int dist, int maxDistance = 48, double threshold = 0.0) { // look backward from the current idx k and find the first value above the threshold for (int i = k; i >= 0 && i >= (k - maxDistance); i--) { if (vals[i] > threshold) { val = vals[i]; dist = k - i; return; } } val = 0; dist = maxDistance; } private static double GetSumOfValues(int windowSize, int k, double[] vals) { var sum = 0.0; for (int i = k; i >= 0 && i >= k - windowSize; i--) sum += vals[i]; return sum; } private static void GetMaxValueAndDistance(double[] vals, int k, out double maxVal, out int dist, int maxDistance = 48) { // look backward from the current idx k and find the max value and it's distance maxVal = vals[k]; dist = 0; for (int i = k; i >= 0 && i >= (k - maxDistance); i--) { if (vals[i] > maxVal) { maxVal = vals[i]; dist = k - i; } } } } }