Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Functions/BakedTreeEvaluator.cs @ 251

Last change on this file since 251 was 239, checked in by gkronber, 17 years ago

fixed #148

File size: 9.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using HeuristicLab.DataAnalysis;
6
7namespace HeuristicLab.Functions {
8  internal class BakedTreeEvaluator {
9    private const int ADDITION = 10010;
10    private const int AND = 10020;
11    private const int AVERAGE = 10030;
12    private const int CONSTANT = 10040;
13    private const int COSINUS = 10050;
14    private const int DIVISION = 10060;
15    private const int EQU = 10070;
16    private const int EXP = 10080;
17    private const int GT = 10090;
18    private const int IFTE = 10100;
19    private const int LT = 10110;
20    private const int LOG = 10120;
21    private const int MULTIPLICATION = 10130;
22    private const int NOT = 10140;
23    private const int OR = 10150;
24    private const int POWER = 10160;
25    private const int SIGNUM = 10170;
26    private const int SINUS = 10180;
27    private const int SQRT = 10190;
28    private const int SUBSTRACTION = 10200;
29    private const int TANGENS = 10210;
30    private const int VARIABLE = 10220;
31    private const int XOR = 10230;
32
33    private static int nextFunctionSymbol = 10240;
34    private static Dictionary<int, IFunction> symbolTable;
35    private static Dictionary<IFunction, int> reverseSymbolTable;
36    private static Dictionary<Type, int> staticTypes;
37    private static int MAX_CODE_LENGTH = 4096;
38    private static int MAX_DATA_LENGTH = 4096;
39    private static int[] codeArr = new int[MAX_CODE_LENGTH];
40    private static double[] dataArr = new double[MAX_DATA_LENGTH];
41
42    static BakedTreeEvaluator() {
43      symbolTable = new Dictionary<int, IFunction>();
44      reverseSymbolTable = new Dictionary<IFunction, int>();
45      staticTypes = new Dictionary<Type, int>();
46      staticTypes[typeof(Addition)] = ADDITION;
47      staticTypes[typeof(And)] = AND;
48      staticTypes[typeof(Average)] = AVERAGE;
49      staticTypes[typeof(Constant)] = CONSTANT;
50      staticTypes[typeof(Cosinus)] = COSINUS;
51      staticTypes[typeof(Division)] = DIVISION;
52      staticTypes[typeof(Equal)] = EQU;
53      staticTypes[typeof(Exponential)] = EXP;
54      staticTypes[typeof(GreaterThan)] = GT;
55      staticTypes[typeof(IfThenElse)] = IFTE;
56      staticTypes[typeof(LessThan)] = LT;
57      staticTypes[typeof(Logarithm)] = LOG;
58      staticTypes[typeof(Multiplication)] = MULTIPLICATION;
59      staticTypes[typeof(Not)] = NOT;
60      staticTypes[typeof(Or)] = OR;
61      staticTypes[typeof(Power)] = POWER;
62      staticTypes[typeof(Signum)] = SIGNUM;
63      staticTypes[typeof(Sinus)] = SINUS;
64      staticTypes[typeof(Sqrt)] = SQRT;
65      staticTypes[typeof(Substraction)] = SUBSTRACTION;
66      staticTypes[typeof(Tangens)] = TANGENS;
67      staticTypes[typeof(Variable)] = VARIABLE;
68      staticTypes[typeof(Xor)] = XOR;
69    }
70
71    internal BakedTreeEvaluator(List<int> code, List<double> data) {
72      code.CopyTo(codeArr);
73      data.CopyTo(dataArr);
74    }
75
76    internal static int MapFunction(IFunction function) {
77      if(!reverseSymbolTable.ContainsKey(function)) {
78        int curFunctionSymbol;
79        if(staticTypes.ContainsKey(function.GetType())) curFunctionSymbol = staticTypes[function.GetType()];
80        else {
81          curFunctionSymbol = nextFunctionSymbol;
82          nextFunctionSymbol++;
83        }
84        reverseSymbolTable[function] = curFunctionSymbol;
85        symbolTable[curFunctionSymbol] = function;
86      }
87      return reverseSymbolTable[function];
88    }
89
90    internal static IFunction MapSymbol(int symbol) {
91      return symbolTable[symbol];
92    }
93
94
95    private int PC;
96    private int DP;
97    private Dataset dataset;
98    private int sampleIndex;
99
100    internal double Evaluate(Dataset _dataset, int _sampleIndex) {
101      PC = 0;
102      DP = 0;
103      sampleIndex = _sampleIndex;
104      dataset = _dataset;
105      return EvaluateBakedCode();
106    }
107
108    private double EvaluateBakedCode() {
109      int arity = codeArr[PC];
110      int functionSymbol = codeArr[PC+1];
111      int nLocalVariables = codeArr[PC+2];
112      PC += 3;
113      switch(functionSymbol) {
114        case VARIABLE: {
115            int var = (int)dataArr[DP];
116            double weight = dataArr[DP+1];
117            int row = sampleIndex + (int)dataArr[DP+2];
118            DP += 3;
119            if(row < 0 || row >= dataset.Rows) return double.NaN;
120            else return weight * dataset.GetValue(row, var);
121          }
122        case CONSTANT: {
123            return dataArr[DP++];
124          }
125        case MULTIPLICATION: {
126            double result = EvaluateBakedCode();
127            for(int i = 1; i < arity; i++) {
128              result *= EvaluateBakedCode();
129            }
130            return result;
131          }
132        case ADDITION: {
133            double sum = EvaluateBakedCode();
134            for(int i = 1; i < arity; i++) {
135              sum += EvaluateBakedCode();
136            }
137            return sum;
138          }
139        case SUBSTRACTION: {
140            if(arity == 1) {
141              return -EvaluateBakedCode();
142            } else {
143              double result = EvaluateBakedCode();
144              for(int i = 1; i < arity; i++) {
145                result -= EvaluateBakedCode();
146              }
147              return result;
148            }
149          }
150        case DIVISION: {
151            double result;
152            if(arity == 1) {
153              result = 1.0 / EvaluateBakedCode();
154            } else {
155              result = EvaluateBakedCode();
156              for(int i = 1; i < arity; i++) {
157                result /= EvaluateBakedCode();
158              }
159            }
160            if(double.IsInfinity(result)) return 0.0;
161            else return result;
162          }
163        case AVERAGE: {
164            double sum = EvaluateBakedCode();
165            for(int i = 1; i < arity; i++) {
166              sum += EvaluateBakedCode();
167            }
168            return sum / arity;
169          }
170        case COSINUS: {
171            return Math.Cos(EvaluateBakedCode());
172          }
173        case SINUS: {
174            return Math.Sin(EvaluateBakedCode());
175          }
176        case EXP: {
177            return Math.Exp(EvaluateBakedCode());
178          }
179        case LOG: {
180            return Math.Log(EvaluateBakedCode());
181          }
182        case POWER: {
183            double x = EvaluateBakedCode();
184            double p = EvaluateBakedCode();
185            return Math.Pow(x, p);
186          }
187        case SIGNUM: {
188            double value = EvaluateBakedCode();
189            if(double.IsNaN(value)) return double.NaN;
190            else return Math.Sign(value);
191          }
192        case SQRT: {
193            return Math.Sqrt(EvaluateBakedCode());
194          }
195        case TANGENS: {
196            return Math.Tan(EvaluateBakedCode());
197          }
198        case AND: {
199            double result = 1.0;
200            // have to evaluate all sub-trees, skipping would probably not lead to a big gain because
201            // we have to iterate over the linear structure anyway
202            for(int i = 0; i < arity; i++) {
203              double x = Math.Round(EvaluateBakedCode());
204              if(x == 0 || x==1.0) result *= x;
205              else result = double.NaN;
206            }
207            return result;
208          }
209        case EQU: {
210            double x = EvaluateBakedCode();
211            double y = EvaluateBakedCode();
212            if(x == y) return 1.0; else return 0.0;
213          }
214        case GT: {
215            double x = EvaluateBakedCode();
216            double y = EvaluateBakedCode();
217            if(x > y) return 1.0;
218            else return 0.0;
219          }
220        case IFTE: {
221            double condition = Math.Round(EvaluateBakedCode());
222            double x = EvaluateBakedCode();
223            double y = EvaluateBakedCode();
224            if(condition < .5) return x;
225            else if(condition >= .5) return y;
226            else return double.NaN;
227          }
228        case LT: {
229            double x = EvaluateBakedCode();
230            double y = EvaluateBakedCode();
231            if(x < y) return 1.0;
232            else return 0.0;
233          }
234        case NOT: {
235            double result = Math.Round(EvaluateBakedCode());
236            if(result == 0.0) return 1.0;
237            else if(result == 1.0) return 0.0;
238            else return double.NaN;
239          }
240        case OR: {
241            double result = 0.0; // default is false
242            for(int i = 0; i < arity; i++) {
243              double x = Math.Round(EvaluateBakedCode());
244              if(x == 1.0 && result == 0.0) result = 1.0; // found first true (1.0) => set to true
245              else if(x != 0.0) result = double.NaN; // if it was not true it can only be false (0.0) all other cases are undefined => (NaN)
246            }
247            return result;
248          }
249        case XOR: {
250            double x = Math.Round(EvaluateBakedCode());
251            double y = Math.Round(EvaluateBakedCode());
252            if(x == 0.0 && y == 0.0) return 0.0;
253            if(x == 1.0 && y == 0.0) return 1.0;
254            if(x == 0.0 && y == 1.0) return 1.0;
255            if(x == 1.0 && y == 1.0) return 0.0;
256            return double.NaN;
257          }
258        default: {
259            IFunction function = symbolTable[functionSymbol];
260            double[] args = new double[nLocalVariables + arity];
261            for(int i = 0; i < nLocalVariables; i++) {
262              args[i] = dataArr[DP++];
263            }
264            for(int j = 0; j < arity; j++) {
265              args[nLocalVariables + j] = EvaluateBakedCode();
266            }
267            return function.Apply(dataset, sampleIndex, args);
268          }
269      }
270    }
271  }
272}
Note: See TracBrowser for help on using the repository browser.