Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GeneticProgramming.BloodGlucosePrediction/Problem.cs @ 14360

Last change on this file since 14360 was 14311, checked in by gkronber, 8 years ago

simplification of grammar and problem and bug fixes related to precalculated smoothed features

File size: 9.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Linq;
26using HeuristicLab.Analysis;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using HeuristicLab.Problems.Instances;
36
37
38namespace HeuristicLab.Problems.GeneticProgramming.GlucosePrediction {
39  [Item("Blood Glucose Forecast", "See MedGEC Workshop at GECCO 2016")]
40  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 999)]
41  [StorableClass]
42  public sealed class Problem : SymbolicExpressionTreeProblem, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
43
44    #region parameter names
45    private const string ProblemDataParameterName = "ProblemData";
46    #endregion
47
48    #region Parameter Properties
49    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
50
51    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
52      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
53    }
54    #endregion
55
56    #region Properties
57    public IRegressionProblemData ProblemData {
58      get { return ProblemDataParameter.Value; }
59      set { ProblemDataParameter.Value = value; }
60    }
61    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
62    #endregion
63
64    public event EventHandler ProblemDataChanged;
65
66    public override bool Maximization {
67      get { return true; }
68    }
69
70    #region item cloning and persistence
71    // persistence
72    [StorableConstructor]
73    private Problem(bool deserializing) : base(deserializing) { }
74    [StorableHook(HookType.AfterDeserialization)]
75    private void AfterDeserialization() {
76      RegisterEventHandlers();
77    }
78
79    // cloning
80    private Problem(Problem original, Cloner cloner)
81      : base(original, cloner) {
82      RegisterEventHandlers();
83    }
84    public override IDeepCloneable Clone(Cloner cloner) { return new Problem(this, cloner); }
85    #endregion
86
87    public Problem()
88      : base() {
89      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data for the glucose prediction problem", new RegressionProblemData()));
90
91      var g = new SimpleSymbolicExpressionGrammar(); // empty grammar is replaced in UpdateGrammar()
92      base.Encoding = new SymbolicExpressionTreeEncoding(g, 100, 17);
93
94      UpdateGrammar();
95      RegisterEventHandlers();
96    }
97
98
99    public override double Evaluate(ISymbolicExpressionTree tree, IRandom random) {
100      var problemData = ProblemData;
101      var target = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
102      var allPredicted = Interpreter.Apply(tree.Root.GetSubtree(0).GetSubtree(0), problemData.Dataset, problemData.AllIndices).ToArray();
103      var predicted = problemData.TrainingIndices.Select(r => allPredicted[r]);
104
105      // var predicted1 = Interpreter.Apply(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(1), problemData.Dataset, rows);
106      // var predicted2 = Interpreter.Apply(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(2), problemData.Dataset, rows);
107
108      var pred0_rsq = Rsq(predicted, target);
109      // var pred1_rsq = Rsq(predicted1, target);
110      // var pred2_rsq = Rsq(predicted2, target);
111      return pred0_rsq; // + pred1_rsq + pred2_rsq;
112    }
113
114    private double Rsq(IEnumerable<double> predicted, IEnumerable<double> target) {
115      // only take predictions for which the target is not NaN
116      var selectedTuples = target.Zip(predicted, Tuple.Create).Where(t => !double.IsNaN(t.Item1)).ToArray();
117      target = selectedTuples.Select(t => t.Item1);
118      predicted = selectedTuples.Select(t => t.Item2);
119
120      OnlineCalculatorError errorState;
121      var r = OnlinePearsonsRCalculator.Calculate(target, predicted, out errorState);
122      if (errorState != OnlineCalculatorError.None) r = 0;
123      return r * r;
124    }
125
126    public override void Analyze(ISymbolicExpressionTree[] trees, double[] qualities, ResultCollection results,
127      IRandom random) {
128      base.Analyze(trees, qualities, results, random);
129
130      if (!results.ContainsKey("Solution")) {
131        results.Add(new Result("Solution", typeof(IRegressionSolution)));
132      }
133      if (!results.ContainsKey("ScaledTree")) {
134        results.Add(new Result("ScaledTree", typeof(ISymbolicExpressionTree)));
135      }
136      // if (!results.ContainsKey("Terms")) {
137      //   results.Add(new Result("Terms", typeof(DataTable)));
138      // }
139
140      var bestTree = trees.First();
141      var bestQuality = qualities.First();
142      for (int i = 1; i < trees.Length; i++) {
143        if (qualities[i] > bestQuality) {
144          bestQuality = qualities[i];
145          bestTree = trees[i];
146        }
147      }
148
149      bestTree = (ISymbolicExpressionTree)bestTree.Clone();
150      var expressionNode = bestTree.Root.GetSubtree(0).GetSubtree(0);
151      // scale
152
153      var problemData = ProblemData;
154      var rows = problemData.AllIndices.ToArray();
155      var target = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
156      var predicted =
157        Interpreter.Apply(expressionNode.GetSubtree(0), problemData.Dataset, rows)
158          .ToArray();
159
160      var filteredPredicted = rows.Where(r => !double.IsNaN(target[r])).Select(r => predicted[r]).ToArray();
161      var filteredTarget = target.Where(t => !double.IsNaN(t)).ToArray();
162      OnlineCalculatorError error;
163      double alpha;
164      double beta;
165      OnlineLinearScalingParameterCalculator.Calculate(filteredPredicted, filteredTarget, out alpha, out beta, out error);
166
167      var prod = new SimpleSymbol("*", "*", 2, 2).CreateTreeNode();
168      var sum = new SimpleSymbol("+", "+", 2, 2).CreateTreeNode();
169      var constAlpha = (ConstantTreeNode)(new Constant()).CreateTreeNode();
170      constAlpha.Value = alpha;
171      var constBeta = (ConstantTreeNode)(new Constant()).CreateTreeNode();
172      constBeta.Value = beta;
173
174      var originalTree = expressionNode.GetSubtree(0);
175      expressionNode.RemoveSubtree(0);
176      expressionNode.AddSubtree(sum);
177      sum.AddSubtree(prod);
178      sum.AddSubtree(constAlpha);
179      prod.AddSubtree(originalTree);
180      prod.AddSubtree(constBeta);
181
182      var model = new Model(bestTree, problemData.TargetVariable, problemData.AllowedInputVariables.ToArray());
183      model.Name = "Scaled Model";
184      model.Description = "Scaled Model";
185      results["Solution"].Value = model.CreateRegressionSolution(problemData);
186      results["ScaledTree"].Value = bestTree;
187
188    }
189
190    #region events
191    private void RegisterEventHandlers() {
192      ProblemDataParameter.ValueChanged += new EventHandler(ProblemDataParameter_ValueChanged);
193      if (ProblemDataParameter.Value != null) ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed);
194    }
195
196    private void ProblemDataParameter_ValueChanged(object sender, EventArgs e) {
197      ProblemDataParameter.Value.Changed += new EventHandler(ProblemData_Changed);
198      OnProblemDataChanged();
199      OnReset();
200    }
201
202    private void ProblemData_Changed(object sender, EventArgs e) {
203      OnReset();
204    }
205
206    private void OnProblemDataChanged() {
207      UpdateGrammar();
208
209      var handler = ProblemDataChanged;
210      if (handler != null) handler(this, EventArgs.Empty);
211    }
212
213    private void UpdateGrammar() {
214      // whenever ProblemData is changed we create a new grammar with the necessary symbols
215      var g = new Grammar();
216      Encoding.Grammar = g;
217    }
218    #endregion
219
220    #region Import & Export
221    public void Load(IRegressionProblemData data) {
222      Name = data.Name;
223      Description = data.Description;
224      ProblemData = data;
225    }
226
227    public IRegressionProblemData Export() {
228      return ProblemData;
229    }
230    #endregion
231  }
232}
Note: See TracBrowser for help on using the repository browser.