Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Hive_Milestone2/sources/HeuristicLab.GP.StructureIdentification/3.3/HL2TreeEvaluator.cs @ 1835

Last change on this file since 1835 was 1817, checked in by gkronber, 15 years ago

Implemented evaluator that matches the semantic of the standard function library of HL2. #615 (Evaluation of HL3 function trees should be equivalent to evaluation in HL2)

File size: 10.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using System.Xml;
28using System.Diagnostics;
29using HeuristicLab.DataAnalysis;
30
31namespace HeuristicLab.GP.StructureIdentification {
32  /// <summary>
33  /// Evaluates FunctionTrees recursively by interpretation of the function symbols in each node.
34  /// Not thread-safe!
35  /// </summary>
36  public class HL2TreeEvaluator : ItemBase, ITreeEvaluator {
37    private const double EPSILON = 1.0e-10;
38    private double estimatedValueMax;
39    private double estimatedValueMin;
40
41    private class Instr {
42      public double d_arg0;
43      public short i_arg0;
44      public short i_arg1;
45      public byte arity;
46      public byte symbol;
47      public IFunction function;
48    }
49
50    private Instr[] codeArr;
51    private int PC;
52    private Dataset dataset;
53    private int sampleIndex;
54
55    public void ResetEvaluator(Dataset dataset, int targetVariable, int start, int end, double punishmentFactor) {
56      this.dataset = dataset;
57      double maximumPunishment = punishmentFactor * dataset.GetRange(targetVariable, start, end);
58
59      // get the mean of the values of the target variable to determine the max and min bounds of the estimated value
60      double targetMean = dataset.GetMean(targetVariable, start, end);
61      estimatedValueMin = targetMean - maximumPunishment;
62      estimatedValueMax = targetMean + maximumPunishment;
63    }
64
65    private Instr TranslateToInstr(LightWeightFunction f) {
66      Instr instr = new Instr();
67      instr.arity = f.arity;
68      instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType);
69      switch (instr.symbol) {
70        case EvaluatorSymbolTable.DIFFERENTIAL:
71        case EvaluatorSymbolTable.VARIABLE: {
72            instr.i_arg0 = (short)f.data[0]; // var
73            instr.d_arg0 = f.data[1]; // weight
74            instr.i_arg1 = (short)f.data[2]; // sample-offset
75            break;
76          }
77        case EvaluatorSymbolTable.CONSTANT: {
78            instr.d_arg0 = f.data[0]; // value
79            break;
80          }
81        case EvaluatorSymbolTable.UNKNOWN: {
82            instr.function = f.functionType;
83            break;
84          }
85      }
86      return instr;
87    }
88
89    public double Evaluate(IFunctionTree functionTree, int sampleIndex) {
90      BakedFunctionTree bakedTree = functionTree as BakedFunctionTree;
91      if (bakedTree == null) throw new ArgumentException("HL2Evaluator can only evaluate BakedFunctionTrees");
92
93      List<LightWeightFunction> linearRepresentation = bakedTree.LinearRepresentation;
94      codeArr = new Instr[linearRepresentation.Count];
95      int i = 0;
96      foreach (LightWeightFunction f in linearRepresentation) {
97        codeArr[i++] = TranslateToInstr(f);
98      }
99
100      PC = 0;
101      this.sampleIndex = sampleIndex;
102
103      double estimated = EvaluateBakedCode();
104      if (double.IsNaN(estimated) || double.IsInfinity(estimated)) {
105        estimated = estimatedValueMax;
106      } else if (estimated > estimatedValueMax) {
107        estimated = estimatedValueMax;
108      } else if (estimated < estimatedValueMin) {
109        estimated = estimatedValueMin;
110      }
111      return estimated;
112    }
113
114    // skips a whole branch
115    private void SkipBakedCode() {
116      int i = 1;
117      while (i > 0) {
118        i += codeArr[PC++].arity;
119        i--;
120      }
121    }
122
123    private double EvaluateBakedCode() {
124      Instr currInstr = codeArr[PC++];
125      switch (currInstr.symbol) {
126        case EvaluatorSymbolTable.VARIABLE: {
127            int row = sampleIndex + currInstr.i_arg1;
128            if (row < 0 || row >= dataset.Rows) return double.NaN;
129            else return currInstr.d_arg0 * dataset.GetValue(row, currInstr.i_arg0);
130          }
131        case EvaluatorSymbolTable.CONSTANT: {
132            return currInstr.d_arg0;
133          }
134        case EvaluatorSymbolTable.DIFFERENTIAL: {
135            int row = sampleIndex + currInstr.i_arg1;
136            if (row < 1 || row >= dataset.Rows) return double.NaN;
137            else {
138              double prevValue = dataset.GetValue(row - 1, currInstr.i_arg0);
139              return currInstr.d_arg0 * (dataset.GetValue(row, currInstr.i_arg0) - prevValue);
140            }
141          }
142        case EvaluatorSymbolTable.MULTIPLICATION: {
143            double result = EvaluateBakedCode();
144            for (int i = 1; i < currInstr.arity; i++) {
145              result *= EvaluateBakedCode();
146            }
147            return result;
148          }
149        case EvaluatorSymbolTable.ADDITION: {
150            double sum = EvaluateBakedCode();
151            for (int i = 1; i < currInstr.arity; i++) {
152              sum += EvaluateBakedCode();
153            }
154            return sum;
155          }
156        case EvaluatorSymbolTable.SUBTRACTION: {
157            return EvaluateBakedCode() - EvaluateBakedCode();
158          }
159        case EvaluatorSymbolTable.DIVISION: {
160            double arg0 = EvaluateBakedCode();
161            double arg1 = EvaluateBakedCode();
162            if (double.IsNaN(arg0) || double.IsNaN(arg1)) return double.NaN;
163            if (Math.Abs(arg1) < (10e-20)) return 0.0; else return (arg0 / arg1);
164          }
165        case EvaluatorSymbolTable.COSINUS: {
166            return Math.Cos(EvaluateBakedCode());
167          }
168        case EvaluatorSymbolTable.SINUS: {
169            return Math.Sin(EvaluateBakedCode());
170          }
171        case EvaluatorSymbolTable.EXP: {
172            return Math.Exp(EvaluateBakedCode());
173          }
174        case EvaluatorSymbolTable.LOG: {
175            return Math.Log(EvaluateBakedCode());
176          }
177        case EvaluatorSymbolTable.POWER: {
178            double x = EvaluateBakedCode();
179            double p = EvaluateBakedCode();
180            return Math.Pow(x, p);
181          }
182        case EvaluatorSymbolTable.SIGNUM: {
183            double value = EvaluateBakedCode();
184            if (double.IsNaN(value)) return double.NaN;
185            if (value < 0.0) return -1.0;
186            if (value > 0.0) return 1.0;
187            return 0.0;
188          }
189        case EvaluatorSymbolTable.SQRT: {
190            return Math.Sqrt(EvaluateBakedCode());
191          }
192        case EvaluatorSymbolTable.TANGENS: {
193            return Math.Tan(EvaluateBakedCode());
194          }
195        case EvaluatorSymbolTable.AND: { // only defined for inputs 1 and 0
196            double result = EvaluateBakedCode();
197            bool hasNaNBranch = false;
198            for (int i = 1; i < currInstr.arity; i++) {
199              if (result < 0.5 || double.IsNaN(result)) hasNaNBranch |= double.IsNaN(EvaluateBakedCode());
200              else {
201                result = EvaluateBakedCode();
202              }
203            }
204            if (hasNaNBranch || double.IsNaN(result)) return double.NaN;
205            if (result < 0.5) return 0.0;
206            return 1.0;
207          }
208        case EvaluatorSymbolTable.EQU: {
209            double x = EvaluateBakedCode();
210            double y = EvaluateBakedCode();
211            if (double.IsNaN(x) || double.IsNaN(y)) return double.NaN;
212            // direct comparison of double values is most likely incorrect but
213            // that's the way how it is implemented in the standard HL2 function library
214            if (x == y) return 1.0; else return 0.0;
215          }
216        case EvaluatorSymbolTable.GT: {
217            double x = EvaluateBakedCode();
218            double y = EvaluateBakedCode();
219            if (double.IsNaN(x) || double.IsNaN(y)) return double.NaN;
220            if (x > y) return 1.0;
221            return 0.0;
222          }
223        case EvaluatorSymbolTable.IFTE: { // only defined for condition 0 or 1
224            double condition = EvaluateBakedCode();
225            double result;
226            bool hasNaNBranch = false;
227            if (double.IsNaN(condition)) return double.NaN;
228            if (condition > 0.5) {
229              result = EvaluateBakedCode(); hasNaNBranch = double.IsNaN(EvaluateBakedCode());
230            } else {
231              hasNaNBranch = double.IsNaN(EvaluateBakedCode()); result = EvaluateBakedCode();
232            }
233            if (hasNaNBranch) return double.NaN;
234            return result;
235          }
236        case EvaluatorSymbolTable.LT: {
237            double x = EvaluateBakedCode();
238            double y = EvaluateBakedCode();
239            if (double.IsNaN(x) || double.IsNaN(y)) return double.NaN;
240            if (x < y) return 1.0;
241            return 0.0;
242          }
243        case EvaluatorSymbolTable.NOT: { // only defined for inputs 0 or 1
244            double result = EvaluateBakedCode();
245            if (double.IsNaN(result)) return double.NaN;
246            if (result < 0.5) return 1.0;
247            return 0.0;
248          }
249        case EvaluatorSymbolTable.OR: { // only defined for inputs 0 or 1
250            double result = EvaluateBakedCode();
251            bool hasNaNBranch = false;
252            for (int i = 1; i < currInstr.arity; i++) {
253              if (double.IsNaN(result) || result > 0.5) hasNaNBranch |= double.IsNaN(EvaluateBakedCode());
254              else
255                result = EvaluateBakedCode();
256            }
257            if (hasNaNBranch || double.IsNaN(result)) return double.NaN;
258            if (result > 0.5) return 1.0;
259            return 0.0;
260          }
261        default: {
262            throw new NotImplementedException();
263          }
264      }
265    }
266  }
267}
Note: See TracBrowser for help on using the repository browser.