Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Functions/BakedTreeEvaluator.cs @ 627

Last change on this file since 627 was 523, checked in by gkronber, 16 years ago

minor speed tuning of BakedTreeEvaluator

File size: 8.9 KB
RevLine 
[266]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[223]23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.DataAnalysis;
[259]27using HeuristicLab.Core;
28using System.Xml;
[424]29using System.Diagnostics;
[223]30
31namespace HeuristicLab.Functions {
[396]32  internal class BakedTreeEvaluator : IEvaluator {
[322]33    private const int MAX_TREE_SIZE = 4096;
[424]34    private const double EPSILON = 1.0e-7;
[322]35
[321]36    private class Instr {
[318]37      public double d_arg0;
38      public int i_arg0;
39      public int i_arg1;
40      public int arity;
41      public int symbol;
42    }
43
[396]44    private Instr[] codeArr;
45    private int PC;
46    private Dataset dataset;
47    private int sampleIndex;
[223]48
[322]49
[483]50    public BakedTreeEvaluator() {
[322]51      codeArr = new Instr[MAX_TREE_SIZE];
52      for(int i = 0; i < MAX_TREE_SIZE; i++) {
53        codeArr[i] = new Instr();
54      }
55    }
56
[483]57    public void ResetEvaluator(IFunctionTree functionTree, Dataset dataset) {
58      this.dataset = dataset;
[396]59      List<LightWeightFunction> linearRepresentation = ((BakedFunctionTree)functionTree).LinearRepresentation;
[318]60      int i = 0;
61      foreach(LightWeightFunction f in linearRepresentation) {
[322]62        TranslateToInstr(f, codeArr[i++]);
[317]63      }
[223]64    }
65
[396]66    private Instr TranslateToInstr(LightWeightFunction f, Instr instr) {
[318]67      instr.arity = f.arity;
[319]68      instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType);
[318]69      switch(instr.symbol) {
[523]70        case EvaluatorSymbolTable.DIFFERENTIAL:
[318]71        case EvaluatorSymbolTable.VARIABLE: {
72            instr.i_arg0 = (int)f.data[0]; // var
73            instr.d_arg0 = f.data[1]; // weight
74            instr.i_arg1 = (int)f.data[2]; // sample-offset
75            break;
76          }
77        case EvaluatorSymbolTable.CONSTANT: {
78            instr.d_arg0 = f.data[0]; // value
79            break;
80          }
81      }
82      return instr;
83    }
84
[396]85    public double Evaluate(int sampleIndex) {
[223]86      PC = 0;
[396]87      this.sampleIndex = sampleIndex;
[223]88      return EvaluateBakedCode();
89    }
90
[523]91    // skips a whole branch
92    private void SkipBakedCode() {
93      int i = 1;
94      while(i > 0) {
95        i+=codeArr[PC++].arity;
96        i--;
97      }
98    }
99
[396]100    private double EvaluateBakedCode() {
[318]101      Instr currInstr = codeArr[PC++];
102      switch(currInstr.symbol) {
[260]103        case EvaluatorSymbolTable.VARIABLE: {
[318]104            int row = sampleIndex + currInstr.i_arg1;
[227]105            if(row < 0 || row >= dataset.Rows) return double.NaN;
[318]106            else return currInstr.d_arg0 * dataset.GetValue(row, currInstr.i_arg0);
[223]107          }
[260]108        case EvaluatorSymbolTable.CONSTANT: {
[318]109            return currInstr.d_arg0;
[223]110          }
[365]111        case EvaluatorSymbolTable.DIFFERENTIAL: {
112            int row = sampleIndex + currInstr.i_arg1;
113            if(row < 1 || row >= dataset.Rows) return double.NaN;
114            else return currInstr.d_arg0 * (dataset.GetValue(row, currInstr.i_arg0) - dataset.GetValue(row - 1, currInstr.i_arg0));
115          }
[260]116        case EvaluatorSymbolTable.MULTIPLICATION: {
[236]117            double result = EvaluateBakedCode();
[318]118            for(int i = 1; i < currInstr.arity; i++) {
[223]119              result *= EvaluateBakedCode();
120            }
121            return result;
122          }
[260]123        case EvaluatorSymbolTable.ADDITION: {
[236]124            double sum = EvaluateBakedCode();
[318]125            for(int i = 1; i < currInstr.arity; i++) {
[223]126              sum += EvaluateBakedCode();
127            }
128            return sum;
129          }
[308]130        case EvaluatorSymbolTable.SUBTRACTION: {
[318]131            if(currInstr.arity == 1) {
[223]132              return -EvaluateBakedCode();
133            } else {
134              double result = EvaluateBakedCode();
[318]135              for(int i = 1; i < currInstr.arity; i++) {
[223]136                result -= EvaluateBakedCode();
137              }
138              return result;
139            }
140          }
[260]141        case EvaluatorSymbolTable.DIVISION: {
[236]142            double result;
[318]143            if(currInstr.arity == 1) {
[236]144              result = 1.0 / EvaluateBakedCode();
[223]145            } else {
[236]146              result = EvaluateBakedCode();
[318]147              for(int i = 1; i < currInstr.arity; i++) {
[236]148                result /= EvaluateBakedCode();
[223]149              }
150            }
[236]151            if(double.IsInfinity(result)) return 0.0;
152            else return result;
[223]153          }
[260]154        case EvaluatorSymbolTable.AVERAGE: {
[236]155            double sum = EvaluateBakedCode();
[318]156            for(int i = 1; i < currInstr.arity; i++) {
[223]157              sum += EvaluateBakedCode();
158            }
[318]159            return sum / currInstr.arity;
[223]160          }
[260]161        case EvaluatorSymbolTable.COSINUS: {
[223]162            return Math.Cos(EvaluateBakedCode());
163          }
[260]164        case EvaluatorSymbolTable.SINUS: {
[223]165            return Math.Sin(EvaluateBakedCode());
166          }
[260]167        case EvaluatorSymbolTable.EXP: {
[223]168            return Math.Exp(EvaluateBakedCode());
169          }
[260]170        case EvaluatorSymbolTable.LOG: {
[223]171            return Math.Log(EvaluateBakedCode());
172          }
[260]173        case EvaluatorSymbolTable.POWER: {
[223]174            double x = EvaluateBakedCode();
175            double p = EvaluateBakedCode();
176            return Math.Pow(x, p);
177          }
[260]178        case EvaluatorSymbolTable.SIGNUM: {
[223]179            double value = EvaluateBakedCode();
[236]180            if(double.IsNaN(value)) return double.NaN;
181            else return Math.Sign(value);
[223]182          }
[260]183        case EvaluatorSymbolTable.SQRT: {
[223]184            return Math.Sqrt(EvaluateBakedCode());
185          }
[260]186        case EvaluatorSymbolTable.TANGENS: {
[223]187            return Math.Tan(EvaluateBakedCode());
188          }
[424]189        case EvaluatorSymbolTable.AND: { // only defined for inputs 1 and 0
[523]190            double result = EvaluateBakedCode();
191            for(int i = 1; i < currInstr.arity; i++) {
192              if(result == 0.0) SkipBakedCode();
193              else {
194                result = EvaluateBakedCode();
195              }
196              Debug.Assert(result == 0.0 || result == 1.0);
[223]197            }
198            return result;
199          }
[260]200        case EvaluatorSymbolTable.EQU: {
[223]201            double x = EvaluateBakedCode();
202            double y = EvaluateBakedCode();
[424]203            if(Math.Abs(x - y) < EPSILON) return 1.0; else return 0.0;
[223]204          }
[260]205        case EvaluatorSymbolTable.GT: {
[223]206            double x = EvaluateBakedCode();
207            double y = EvaluateBakedCode();
208            if(x > y) return 1.0;
209            else return 0.0;
210          }
[424]211        case EvaluatorSymbolTable.IFTE: { // only defined for condition 0 or 1
212            double condition = EvaluateBakedCode();
213            Debug.Assert(condition == 0.0 || condition == 1.0);
[523]214            double result;
215            if(condition == 0.0) {
216              result = EvaluateBakedCode(); SkipBakedCode();
217            } else {
218              SkipBakedCode(); result = EvaluateBakedCode();
219            }
220            return result;
[223]221          }
[260]222        case EvaluatorSymbolTable.LT: {
[223]223            double x = EvaluateBakedCode();
224            double y = EvaluateBakedCode();
225            if(x < y) return 1.0;
226            else return 0.0;
227          }
[424]228        case EvaluatorSymbolTable.NOT: { // only defined for inputs 0 or 1
229            double result = EvaluateBakedCode();
230            Debug.Assert(result == 0.0 || result == 1.0);
[460]231            return Math.Abs(result - 1.0);
[223]232          }
[424]233        case EvaluatorSymbolTable.OR: { // only defined for inputs 0 or 1
[523]234            double result = EvaluateBakedCode();
235            for(int i = 1; i < currInstr.arity; i++) {
236              if(result > 0.0) SkipBakedCode();
237              else {
238                result = EvaluateBakedCode();
239                Debug.Assert(result == 0.0 || result == 1.0);
240              }
[223]241            }
[523]242            return result;
[223]243          }
[424]244        case EvaluatorSymbolTable.XOR: { // only defined for inputs 0 or 1
245            double x = EvaluateBakedCode();
246            double y = EvaluateBakedCode();
247            return Math.Abs(x - y);
[223]248          }
[523]249        case EvaluatorSymbolTable.UNKNOWN:
[223]250        default: {
[318]251            throw new NotImplementedException();
[223]252          }
253      }
254    }
255  }
256}
Note: See TracBrowser for help on using the repository browser.