Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.GP.StructureIdentification/BakedTreeEvaluator.cs @ 727

Last change on this file since 727 was 702, checked in by gkronber, 16 years ago

fixed #328 by restructuring evaluation operators to remove state in evaluation operators.

File size: 10.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using System.Xml;
28using System.Diagnostics;
29using HeuristicLab.DataAnalysis;
30
31namespace HeuristicLab.GP.StructureIdentification {
32  /// <summary>
33  /// Evaluates FunctionTrees recursively by interpretation of the function symbols in each node.
34  /// Not thread-safe!
35  /// </summary>
36  public class BakedTreeEvaluator {
37    private const double EPSILON = 1.0e-7;
38    private double estimatedValueMax;
39    private double estimatedValueMin;
40
41    private class Instr {
42      public double d_arg0;
43      public int i_arg0;
44      public int i_arg1;
45      public int arity;
46      public int symbol;
47      public IFunction function;
48    }
49
50    private List<Instr> code;
51    private Instr[] codeArr;
52    private int PC;
53    private Dataset dataset;
54    private int sampleIndex;
55
56
57    public BakedTreeEvaluator() {
58      code = new List<Instr>();
59    }
60
61    public void ResetEvaluator(BakedFunctionTree functionTree, Dataset dataset, int targetVariable, int start, int end, double punishmentFactor) {
62      this.dataset = dataset;
63      double maximumPunishment = punishmentFactor * dataset.GetRange(targetVariable);
64
65      // get the mean of the values of the target variable to determin the max and min bounds of the estimated value
66      double targetMean = dataset.GetMean(targetVariable, start, end - 1);
67      estimatedValueMin = targetMean - maximumPunishment;
68      estimatedValueMax = targetMean + maximumPunishment;
69
70      List<LightWeightFunction> linearRepresentation = functionTree.LinearRepresentation;
71      code.Clear();
72      foreach(LightWeightFunction f in linearRepresentation) {
73        Instr curInstr = new Instr();
74        TranslateToInstr(f, curInstr);
75        code.Add(curInstr);
76      }
77
78      codeArr = code.ToArray<Instr>();
79    }
80
81    private void TranslateToInstr(LightWeightFunction f, Instr instr) {
82      instr.arity = f.arity;
83      instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType);
84      switch(instr.symbol) {
85        case EvaluatorSymbolTable.DIFFERENTIAL:
86        case EvaluatorSymbolTable.VARIABLE: {
87            instr.i_arg0 = (int)f.data[0]; // var
88            instr.d_arg0 = f.data[1]; // weight
89            instr.i_arg1 = (int)f.data[2]; // sample-offset
90            break;
91          }
92        case EvaluatorSymbolTable.CONSTANT: {
93            instr.d_arg0 = f.data[0]; // value
94            break;
95          }
96        case EvaluatorSymbolTable.UNKNOWN: {
97            instr.function = f.functionType;
98            break;
99          }
100      }
101    }
102
103    public double Evaluate(int sampleIndex) {
104      PC = 0;
105      this.sampleIndex = sampleIndex;
106
107      double estimated = EvaluateBakedCode();
108      if(double.IsNaN(estimated) || double.IsInfinity(estimated)) {
109        estimated = estimatedValueMax;
110      } else if(estimated > estimatedValueMax) {
111        estimated = estimatedValueMax;
112      } else if(estimated < estimatedValueMin) {
113        estimated = estimatedValueMin;
114      }
115      return estimated;
116    }
117
118    // skips a whole branch
119    private void SkipBakedCode() {
120      int i = 1;
121      while(i > 0) {
122        i += code[PC++].arity;
123        i--;
124      }
125    }
126
127    private double EvaluateBakedCode() {
128      Instr currInstr = codeArr[PC++];
129      switch(currInstr.symbol) {
130        case EvaluatorSymbolTable.VARIABLE: {
131            int row = sampleIndex + currInstr.i_arg1;
132            if(row < 0 || row >= dataset.Rows) return double.NaN;
133            else return currInstr.d_arg0 * dataset.GetValue(row, currInstr.i_arg0);
134          }
135        case EvaluatorSymbolTable.CONSTANT: {
136            return currInstr.d_arg0;
137          }
138        case EvaluatorSymbolTable.DIFFERENTIAL: {
139            int row = sampleIndex + currInstr.i_arg1;
140            if(row < 1 || row >= dataset.Rows) return double.NaN;
141            else return currInstr.d_arg0 * (dataset.GetValue(row, currInstr.i_arg0) - dataset.GetValue(row - 1, currInstr.i_arg0));
142          }
143        case EvaluatorSymbolTable.MULTIPLICATION: {
144            double result = EvaluateBakedCode();
145            for(int i = 1; i < currInstr.arity; i++) {
146              result *= EvaluateBakedCode();
147            }
148            return result;
149          }
150        case EvaluatorSymbolTable.ADDITION: {
151            double sum = EvaluateBakedCode();
152            for(int i = 1; i < currInstr.arity; i++) {
153              sum += EvaluateBakedCode();
154            }
155            return sum;
156          }
157        case EvaluatorSymbolTable.SUBTRACTION: {
158            if(currInstr.arity == 1) {
159              return -EvaluateBakedCode();
160            } else {
161              double result = EvaluateBakedCode();
162              for(int i = 1; i < currInstr.arity; i++) {
163                result -= EvaluateBakedCode();
164              }
165              return result;
166            }
167          }
168        case EvaluatorSymbolTable.DIVISION: {
169            double result;
170            if(currInstr.arity == 1) {
171              result = 1.0 / EvaluateBakedCode();
172            } else {
173              result = EvaluateBakedCode();
174              for(int i = 1; i < currInstr.arity; i++) {
175                result /= EvaluateBakedCode();
176              }
177            }
178            if(double.IsInfinity(result)) return 0.0;
179            else return result;
180          }
181        case EvaluatorSymbolTable.AVERAGE: {
182            double sum = EvaluateBakedCode();
183            for(int i = 1; i < currInstr.arity; i++) {
184              sum += EvaluateBakedCode();
185            }
186            return sum / currInstr.arity;
187          }
188        case EvaluatorSymbolTable.COSINUS: {
189            return Math.Cos(EvaluateBakedCode());
190          }
191        case EvaluatorSymbolTable.SINUS: {
192            return Math.Sin(EvaluateBakedCode());
193          }
194        case EvaluatorSymbolTable.EXP: {
195            return Math.Exp(EvaluateBakedCode());
196          }
197        case EvaluatorSymbolTable.LOG: {
198            return Math.Log(EvaluateBakedCode());
199          }
200        case EvaluatorSymbolTable.POWER: {
201            double x = EvaluateBakedCode();
202            double p = EvaluateBakedCode();
203            return Math.Pow(x, p);
204          }
205        case EvaluatorSymbolTable.SIGNUM: {
206            double value = EvaluateBakedCode();
207            if(double.IsNaN(value)) return double.NaN;
208            else return Math.Sign(value);
209          }
210        case EvaluatorSymbolTable.SQRT: {
211            return Math.Sqrt(EvaluateBakedCode());
212          }
213        case EvaluatorSymbolTable.TANGENS: {
214            return Math.Tan(EvaluateBakedCode());
215          }
216        case EvaluatorSymbolTable.AND: { // only defined for inputs 1 and 0
217            double result = EvaluateBakedCode();
218            for(int i = 1; i < currInstr.arity; i++) {
219              if(result == 0.0) SkipBakedCode();
220              else {
221                result = EvaluateBakedCode();
222              }
223              Debug.Assert(result == 0.0 || result == 1.0);
224            }
225            return result;
226          }
227        case EvaluatorSymbolTable.EQU: {
228            double x = EvaluateBakedCode();
229            double y = EvaluateBakedCode();
230            if(Math.Abs(x - y) < EPSILON) return 1.0; else return 0.0;
231          }
232        case EvaluatorSymbolTable.GT: {
233            double x = EvaluateBakedCode();
234            double y = EvaluateBakedCode();
235            if(x > y) return 1.0;
236            else return 0.0;
237          }
238        case EvaluatorSymbolTable.IFTE: { // only defined for condition 0 or 1
239            double condition = EvaluateBakedCode();
240            Debug.Assert(condition == 0.0 || condition == 1.0);
241            double result;
242            if(condition == 0.0) {
243              result = EvaluateBakedCode(); SkipBakedCode();
244            } else {
245              SkipBakedCode(); result = EvaluateBakedCode();
246            }
247            return result;
248          }
249        case EvaluatorSymbolTable.LT: {
250            double x = EvaluateBakedCode();
251            double y = EvaluateBakedCode();
252            if(x < y) return 1.0;
253            else return 0.0;
254          }
255        case EvaluatorSymbolTable.NOT: { // only defined for inputs 0 or 1
256            double result = EvaluateBakedCode();
257            Debug.Assert(result == 0.0 || result == 1.0);
258            return Math.Abs(result - 1.0);
259          }
260        case EvaluatorSymbolTable.OR: { // only defined for inputs 0 or 1
261            double result = EvaluateBakedCode();
262            for(int i = 1; i < currInstr.arity; i++) {
263              if(result > 0.0) SkipBakedCode();
264              else {
265                result = EvaluateBakedCode();
266                Debug.Assert(result == 0.0 || result == 1.0);
267              }
268            }
269            return result;
270          }
271        case EvaluatorSymbolTable.XOR: { // only defined for inputs 0 or 1
272            double x = EvaluateBakedCode();
273            double y = EvaluateBakedCode();
274            return Math.Abs(x - y);
275          }
276        case EvaluatorSymbolTable.UNKNOWN: { // evaluate functions which are not statically defined directly
277            return currInstr.function.Apply();
278          }
279        default: {
280            throw new NotImplementedException();
281          }
282      }
283    }
284  }
285}
Note: See TracBrowser for help on using the repository browser.