Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.GP.StructureIdentification/BakedTreeEvaluator.cs @ 1265

Last change on this file since 1265 was 1069, checked in by gkronber, 16 years ago

fixed #447 (Differential symbol should evaluate to 0 if the previous value is NaN, Inf. or out of range).

File size: 11.5 KB
RevLine 
[645]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using System.Xml;
28using System.Diagnostics;
29using HeuristicLab.DataAnalysis;
30
31namespace HeuristicLab.GP.StructureIdentification {
[702]32  /// <summary>
33  /// Evaluates FunctionTrees recursively by interpretation of the function symbols in each node.
34  /// Not thread-safe!
35  /// </summary>
[645]36  public class BakedTreeEvaluator {
37    private const double EPSILON = 1.0e-7;
[702]38    private double estimatedValueMax;
39    private double estimatedValueMin;
[645]40
41    private class Instr {
42      public double d_arg0;
[767]43      public short i_arg0;
44      public short i_arg1;
45      public byte arity;
46      public byte symbol;
47      public ushort exprLength;
[645]48      public IFunction function;
49    }
50
[702]51    private Instr[] codeArr;
[645]52    private int PC;
53    private Dataset dataset;
54    private int sampleIndex;
55
[702]56    public void ResetEvaluator(BakedFunctionTree functionTree, Dataset dataset, int targetVariable, int start, int end, double punishmentFactor) {
[645]57      this.dataset = dataset;
[702]58      double maximumPunishment = punishmentFactor * dataset.GetRange(targetVariable);
59
[767]60      // get the mean of the values of the target variable to determine the max and min bounds of the estimated value
[702]61      double targetMean = dataset.GetMean(targetVariable, start, end - 1);
62      estimatedValueMin = targetMean - maximumPunishment;
63      estimatedValueMax = targetMean + maximumPunishment;
64
[645]65      List<LightWeightFunction> linearRepresentation = functionTree.LinearRepresentation;
[767]66      codeArr = new Instr[linearRepresentation.Count];
67      int i = 0;
68      foreach (LightWeightFunction f in linearRepresentation) {
69        codeArr[i++] = TranslateToInstr(f);
[645]70      }
[767]71      exprIndex = 0;
72      ushort exprLength;
73      bool constExpr;
74      PatchExpressionLengthsAndConstants(0, out constExpr, out exprLength);
75    }
[702]76
[767]77    ushort exprIndex;
78    private void PatchExpressionLengthsAndConstants(ushort index, out bool constExpr, out ushort exprLength) {
79      exprLength = 1;
80      if (codeArr[index].arity == 0) {
81        // when no children then it's a constant expression only if the terminal is a constant
82        constExpr = codeArr[index].symbol == EvaluatorSymbolTable.CONSTANT;
83      } else {
84        constExpr = true; // when there are children it's a constant expression if all children are constant;
85      }
86      for (int i = 0; i < codeArr[index].arity; i++) {
87        exprIndex++;
88        ushort branchLength;
89        bool branchConstExpr;
90        PatchExpressionLengthsAndConstants(exprIndex, out branchConstExpr, out branchLength);
91        exprLength += branchLength;
92        constExpr &= branchConstExpr;
93      }
94
95      if (constExpr) {
96        PC = index;
97        codeArr[index].d_arg0 = EvaluateBakedCode();
[773]98        codeArr[index].symbol = EvaluatorSymbolTable.CONSTANT;
[767]99      }
[773]100      codeArr[index].exprLength = exprLength;
[645]101    }
102
[767]103    private Instr TranslateToInstr(LightWeightFunction f) {
104      Instr instr = new Instr();
[645]105      instr.arity = f.arity;
106      instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType);
[767]107      switch (instr.symbol) {
[645]108        case EvaluatorSymbolTable.DIFFERENTIAL:
109        case EvaluatorSymbolTable.VARIABLE: {
[773]110            instr.i_arg0 = (short)f.data[0]; // var
[645]111            instr.d_arg0 = f.data[1]; // weight
[773]112            instr.i_arg1 = (short)f.data[2]; // sample-offset
[767]113            instr.exprLength = 1;
[645]114            break;
115          }
116        case EvaluatorSymbolTable.CONSTANT: {
117            instr.d_arg0 = f.data[0]; // value
[767]118            instr.exprLength = 1;
[645]119            break;
120          }
121        case EvaluatorSymbolTable.UNKNOWN: {
122            instr.function = f.functionType;
[767]123            instr.exprLength = 1;
[645]124            break;
125          }
126      }
[767]127      return instr;
[645]128    }
129
130    public double Evaluate(int sampleIndex) {
131      PC = 0;
132      this.sampleIndex = sampleIndex;
[702]133
134      double estimated = EvaluateBakedCode();
[767]135      if (double.IsNaN(estimated) || double.IsInfinity(estimated)) {
[702]136        estimated = estimatedValueMax;
[767]137      } else if (estimated > estimatedValueMax) {
[702]138        estimated = estimatedValueMax;
[767]139      } else if (estimated < estimatedValueMin) {
[702]140        estimated = estimatedValueMin;
141      }
142      return estimated;
[645]143    }
144
145    // skips a whole branch
146    private void SkipBakedCode() {
[767]147      PC += codeArr[PC].exprLength;
[645]148    }
149
150    private double EvaluateBakedCode() {
[702]151      Instr currInstr = codeArr[PC++];
[767]152      switch (currInstr.symbol) {
[645]153        case EvaluatorSymbolTable.VARIABLE: {
154            int row = sampleIndex + currInstr.i_arg1;
[767]155            if (row < 0 || row >= dataset.Rows) return double.NaN;
[645]156            else return currInstr.d_arg0 * dataset.GetValue(row, currInstr.i_arg0);
157          }
158        case EvaluatorSymbolTable.CONSTANT: {
[767]159            PC += currInstr.exprLength - 1;
[645]160            return currInstr.d_arg0;
161          }
162        case EvaluatorSymbolTable.DIFFERENTIAL: {
163            int row = sampleIndex + currInstr.i_arg1;
[1069]164            if (row < 0 || row >= dataset.Rows) return double.NaN;
165            else if (row < 1) return 0.0;
166            else {
167              double prevValue = dataset.GetValue(row - 1, currInstr.i_arg0);
168              if (double.IsNaN(prevValue) || double.IsInfinity(prevValue)) return 0.0;
169              else return currInstr.d_arg0 * (dataset.GetValue(row, currInstr.i_arg0) - prevValue);
170            }
[645]171          }
172        case EvaluatorSymbolTable.MULTIPLICATION: {
173            double result = EvaluateBakedCode();
[767]174            for (int i = 1; i < currInstr.arity; i++) {
[645]175              result *= EvaluateBakedCode();
176            }
177            return result;
178          }
179        case EvaluatorSymbolTable.ADDITION: {
180            double sum = EvaluateBakedCode();
[767]181            for (int i = 1; i < currInstr.arity; i++) {
[645]182              sum += EvaluateBakedCode();
183            }
184            return sum;
185          }
186        case EvaluatorSymbolTable.SUBTRACTION: {
[767]187            if (currInstr.arity == 1) {
[645]188              return -EvaluateBakedCode();
189            } else {
190              double result = EvaluateBakedCode();
[767]191              for (int i = 1; i < currInstr.arity; i++) {
[645]192                result -= EvaluateBakedCode();
193              }
194              return result;
195            }
196          }
197        case EvaluatorSymbolTable.DIVISION: {
198            double result;
[767]199            if (currInstr.arity == 1) {
[645]200              result = 1.0 / EvaluateBakedCode();
201            } else {
202              result = EvaluateBakedCode();
[767]203              for (int i = 1; i < currInstr.arity; i++) {
[645]204                result /= EvaluateBakedCode();
205              }
206            }
[767]207            if (double.IsInfinity(result)) return 0.0;
[645]208            else return result;
209          }
210        case EvaluatorSymbolTable.AVERAGE: {
211            double sum = EvaluateBakedCode();
[767]212            for (int i = 1; i < currInstr.arity; i++) {
[645]213              sum += EvaluateBakedCode();
214            }
215            return sum / currInstr.arity;
216          }
217        case EvaluatorSymbolTable.COSINUS: {
218            return Math.Cos(EvaluateBakedCode());
219          }
220        case EvaluatorSymbolTable.SINUS: {
221            return Math.Sin(EvaluateBakedCode());
222          }
223        case EvaluatorSymbolTable.EXP: {
224            return Math.Exp(EvaluateBakedCode());
225          }
226        case EvaluatorSymbolTable.LOG: {
227            return Math.Log(EvaluateBakedCode());
228          }
229        case EvaluatorSymbolTable.POWER: {
230            double x = EvaluateBakedCode();
231            double p = EvaluateBakedCode();
232            return Math.Pow(x, p);
233          }
234        case EvaluatorSymbolTable.SIGNUM: {
235            double value = EvaluateBakedCode();
[767]236            if (double.IsNaN(value)) return double.NaN;
[645]237            else return Math.Sign(value);
238          }
239        case EvaluatorSymbolTable.SQRT: {
240            return Math.Sqrt(EvaluateBakedCode());
241          }
242        case EvaluatorSymbolTable.TANGENS: {
243            return Math.Tan(EvaluateBakedCode());
244          }
245        case EvaluatorSymbolTable.AND: { // only defined for inputs 1 and 0
246            double result = EvaluateBakedCode();
[767]247            for (int i = 1; i < currInstr.arity; i++) {
248              if (result == 0.0) SkipBakedCode();
[645]249              else {
250                result = EvaluateBakedCode();
251              }
252              Debug.Assert(result == 0.0 || result == 1.0);
253            }
254            return result;
255          }
256        case EvaluatorSymbolTable.EQU: {
257            double x = EvaluateBakedCode();
258            double y = EvaluateBakedCode();
[767]259            if (Math.Abs(x - y) < EPSILON) return 1.0; else return 0.0;
[645]260          }
261        case EvaluatorSymbolTable.GT: {
262            double x = EvaluateBakedCode();
263            double y = EvaluateBakedCode();
[767]264            if (x > y) return 1.0;
[645]265            else return 0.0;
266          }
267        case EvaluatorSymbolTable.IFTE: { // only defined for condition 0 or 1
268            double condition = EvaluateBakedCode();
269            Debug.Assert(condition == 0.0 || condition == 1.0);
270            double result;
[767]271            if (condition == 0.0) {
[645]272              result = EvaluateBakedCode(); SkipBakedCode();
273            } else {
274              SkipBakedCode(); result = EvaluateBakedCode();
275            }
276            return result;
277          }
278        case EvaluatorSymbolTable.LT: {
279            double x = EvaluateBakedCode();
280            double y = EvaluateBakedCode();
[767]281            if (x < y) return 1.0;
[645]282            else return 0.0;
283          }
284        case EvaluatorSymbolTable.NOT: { // only defined for inputs 0 or 1
285            double result = EvaluateBakedCode();
286            Debug.Assert(result == 0.0 || result == 1.0);
287            return Math.Abs(result - 1.0);
288          }
289        case EvaluatorSymbolTable.OR: { // only defined for inputs 0 or 1
290            double result = EvaluateBakedCode();
[767]291            for (int i = 1; i < currInstr.arity; i++) {
292              if (result > 0.0) SkipBakedCode();
[645]293              else {
294                result = EvaluateBakedCode();
295                Debug.Assert(result == 0.0 || result == 1.0);
296              }
297            }
298            return result;
299          }
300        case EvaluatorSymbolTable.XOR: { // only defined for inputs 0 or 1
301            double x = EvaluateBakedCode();
302            double y = EvaluateBakedCode();
303            return Math.Abs(x - y);
304          }
305        case EvaluatorSymbolTable.UNKNOWN: { // evaluate functions which are not statically defined directly
306            return currInstr.function.Apply();
307          }
308        default: {
309            throw new NotImplementedException();
310          }
311      }
312    }
313  }
314}
Note: See TracBrowser for help on using the repository browser.