Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Functions/BakedTreeEvaluator.cs @ 236

Last change on this file since 236 was 236, checked in by gkronber, 16 years ago

minor changes

File size: 9.5 KB
RevLine 
[223]1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using HeuristicLab.DataAnalysis;
6
7namespace HeuristicLab.Functions {
8  static class BakedTreeEvaluator {
9    private const int ADDITION = 10010;
10    private const int AND = 10020;
11    private const int AVERAGE = 10030;
12    private const int CONSTANT = 10040;
13    private const int COSINUS = 10050;
14    private const int DIVISION = 10060;
15    private const int EQU = 10070;
16    private const int EXP = 10080;
17    private const int GT = 10090;
18    private const int IFTE = 10100;
19    private const int LT = 10110;
20    private const int LOG = 10120;
21    private const int MULTIPLICATION = 10130;
22    private const int NOT = 10140;
23    private const int OR = 10150;
24    private const int POWER = 10160;
25    private const int SIGNUM = 10170;
26    private const int SINUS = 10180;
27    private const int SQRT = 10190;
28    private const int SUBSTRACTION = 10200;
29    private const int TANGENS = 10210;
30    private const int VARIABLE = 10220;
31    private const int XOR = 10230;
32
33    private static int nextFunctionSymbol = 10240;
34    private static Dictionary<int, IFunction> symbolTable;
35    private static Dictionary<IFunction, int> reverseSymbolTable;
36    private static Dictionary<Type, int> staticTypes;
37
38    static BakedTreeEvaluator() {
39      symbolTable = new Dictionary<int, IFunction>();
40      reverseSymbolTable = new Dictionary<IFunction, int>();
41      staticTypes = new Dictionary<Type, int>();
42      staticTypes[typeof(Addition)] = ADDITION;
43      staticTypes[typeof(And)] = AND;
44      staticTypes[typeof(Average)] = AVERAGE;
45      staticTypes[typeof(Constant)] = CONSTANT;
46      staticTypes[typeof(Cosinus)] = COSINUS;
47      staticTypes[typeof(Division)] = DIVISION;
48      staticTypes[typeof(Equal)] = EQU;
49      staticTypes[typeof(Exponential)] = EXP;
50      staticTypes[typeof(GreaterThan)] = GT;
51      staticTypes[typeof(IfThenElse)] = IFTE;
52      staticTypes[typeof(LessThan)] = LT;
53      staticTypes[typeof(Logarithm)] = LOG;
54      staticTypes[typeof(Multiplication)] = MULTIPLICATION;
55      staticTypes[typeof(Not)] = NOT;
56      staticTypes[typeof(Or)] = OR;
57      staticTypes[typeof(Power)] = POWER;
58      staticTypes[typeof(Signum)] = SIGNUM;
59      staticTypes[typeof(Sinus)] = SINUS;
60      staticTypes[typeof(Sqrt)] = SQRT;
61      staticTypes[typeof(Substraction)] = SUBSTRACTION;
62      staticTypes[typeof(Tangens)] = TANGENS;
63      staticTypes[typeof(Variable)] = VARIABLE;
64      staticTypes[typeof(Xor)] = XOR;
65    }
66
67    internal static int MapFunction(IFunction function) {
68      if(!reverseSymbolTable.ContainsKey(function)) {
69        int curFunctionSymbol;
70        if(staticTypes.ContainsKey(function.GetType())) curFunctionSymbol = staticTypes[function.GetType()];
71        else {
72          curFunctionSymbol = nextFunctionSymbol;
73          nextFunctionSymbol++;
74        }
75        reverseSymbolTable[function] = curFunctionSymbol;
76        symbolTable[curFunctionSymbol] = function;
77      }
78      return reverseSymbolTable[function];
79    }
80
81    internal static IFunction MapSymbol(int symbol) {
82      return symbolTable[symbol];
83    }
84
85
86    private static int PC;
87    private static int DP;
[235]88    private static int MAX_CODE_LENGTH = 4096;
89    private static int MAX_DATA_LENGTH = 4096;
90    private static int[] codeArr = new int[MAX_CODE_LENGTH];
91    private static double[] dataArr = new double[MAX_DATA_LENGTH];
[223]92    private static Dataset dataset;
93    private static int sampleIndex;
94
95    internal static double Evaluate(Dataset _dataset, int _sampleIndex, List<int> code, List<double> data) {
96      PC = 0;
97      DP = 0;
[235]98      code.CopyTo(codeArr);
99      data.CopyTo(dataArr);
[223]100      sampleIndex = _sampleIndex;
101      dataset = _dataset;
102      return EvaluateBakedCode();
103    }
104
105    private static double EvaluateBakedCode() {
[236]106      int arity = codeArr[PC];
107      int functionSymbol = codeArr[PC+1];
108      int nLocalVariables = codeArr[PC+2];
109      PC += 3;
[223]110      switch(functionSymbol) {
111        case VARIABLE: {
[236]112            int var = (int)dataArr[DP];
113            double weight = dataArr[DP+1];
114            int row = sampleIndex + (int)dataArr[DP+2];
115            DP += 3;
[227]116            if(row < 0 || row >= dataset.Rows) return double.NaN;
117            else return weight * dataset.GetValue(row, var);
[223]118          }
119        case CONSTANT: {
[236]120            return dataArr[DP++];
[223]121          }
122        case MULTIPLICATION: {
[236]123            double result = EvaluateBakedCode();
124            for(int i = 1; i < arity; i++) {
[223]125              result *= EvaluateBakedCode();
126            }
127            return result;
128          }
129        case ADDITION: {
[236]130            double sum = EvaluateBakedCode();
131            for(int i = 1; i < arity; i++) {
[223]132              sum += EvaluateBakedCode();
133            }
134            return sum;
135          }
136        case SUBSTRACTION: {
137            if(arity == 1) {
138              return -EvaluateBakedCode();
139            } else {
140              double result = EvaluateBakedCode();
141              for(int i = 1; i < arity; i++) {
142                result -= EvaluateBakedCode();
143              }
144              return result;
145            }
146          }
147        case DIVISION: {
[236]148            double result;
[223]149            if(arity == 1) {
[236]150              result = 1.0 / EvaluateBakedCode();
[223]151            } else {
[236]152              result = EvaluateBakedCode();
[223]153              for(int i = 1; i < arity; i++) {
[236]154                result /= EvaluateBakedCode();
[223]155              }
156            }
[236]157            if(double.IsInfinity(result)) return 0.0;
158            else return result;
[223]159          }
160        case AVERAGE: {
[236]161            double sum = EvaluateBakedCode();
162            for(int i = 1; i < arity; i++) {
[223]163              sum += EvaluateBakedCode();
164            }
165            return sum / arity;
166          }
167        case COSINUS: {
168            return Math.Cos(EvaluateBakedCode());
169          }
170        case SINUS: {
171            return Math.Sin(EvaluateBakedCode());
172          }
173        case EXP: {
174            return Math.Exp(EvaluateBakedCode());
175          }
176        case LOG: {
177            return Math.Log(EvaluateBakedCode());
178          }
179        case POWER: {
180            double x = EvaluateBakedCode();
181            double p = EvaluateBakedCode();
182            return Math.Pow(x, p);
183          }
184        case SIGNUM: {
185            double value = EvaluateBakedCode();
[236]186            if(double.IsNaN(value)) return double.NaN;
187            else return Math.Sign(value);
[223]188          }
189        case SQRT: {
190            return Math.Sqrt(EvaluateBakedCode());
191          }
192        case TANGENS: {
193            return Math.Tan(EvaluateBakedCode());
194          }
195        case AND: {
196            double result = 1.0;
197            // have to evaluate all sub-trees, skipping would probably not lead to a big gain because
198            // we have to iterate over the linear structure anyway
199            for(int i = 0; i < arity; i++) {
200              double x = Math.Round(EvaluateBakedCode());
[236]201              if(x == 0 || x==1.0) result *= x;
202              else result = double.NaN;
[223]203            }
204            return result;
205          }
206        case EQU: {
207            double x = EvaluateBakedCode();
208            double y = EvaluateBakedCode();
209            if(x == y) return 1.0; else return 0.0;
210          }
211        case GT: {
212            double x = EvaluateBakedCode();
213            double y = EvaluateBakedCode();
214            if(x > y) return 1.0;
215            else return 0.0;
216          }
217        case IFTE: {
218            double condition = Math.Round(EvaluateBakedCode());
219            double x = EvaluateBakedCode();
220            double y = EvaluateBakedCode();
221            if(condition < .5) return x;
222            else if(condition >= .5) return y;
223            else return double.NaN;
224          }
225        case LT: {
226            double x = EvaluateBakedCode();
227            double y = EvaluateBakedCode();
228            if(x < y) return 1.0;
229            else return 0.0;
230          }
231        case NOT: {
232            double result = Math.Round(EvaluateBakedCode());
233            if(result == 0.0) return 1.0;
234            else if(result == 1.0) return 0.0;
235            else return double.NaN;
236          }
237        case OR: {
238            double result = 0.0; // default is false
239            for(int i = 0; i < arity; i++) {
240              double x = Math.Round(EvaluateBakedCode());
241              if(x == 1.0 && result == 0.0) result = 1.0; // found first true (1.0) => set to true
242              else if(x != 0.0) result = double.NaN; // if it was not true it can only be false (0.0) all other cases are undefined => (NaN)
243            }
244            return result;
245          }
246        case XOR: {
247            double x = Math.Round(EvaluateBakedCode());
248            double y = Math.Round(EvaluateBakedCode());
249            if(x == 0.0 && y == 0.0) return 0.0;
250            if(x == 1.0 && y == 0.0) return 1.0;
251            if(x == 0.0 && y == 1.0) return 1.0;
252            if(x == 1.0 && y == 1.0) return 0.0;
253            return double.NaN;
254          }
255        default: {
256            IFunction function = symbolTable[functionSymbol];
257            double[] args = new double[nLocalVariables + arity];
258            for(int i = 0; i < nLocalVariables; i++) {
259              args[i] = dataArr[DP++];
260            }
261            for(int j = 0; j < arity; j++) {
262              args[nLocalVariables + j] = EvaluateBakedCode();
263            }
264            return function.Apply(dataset, sampleIndex, args);
265          }
266      }
267    }
268  }
269}
Note: See TracBrowser for help on using the repository browser.