Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Functions/BakedTreeEvaluator.cs @ 229

Last change on this file since 229 was 229, checked in by gkronber, 16 years ago

merged changes (r201 r203 r206 r208 r220 r223 r224 r225 r226 r227) from branch ExperimentalFunctionsBaking into the trunk. (ticket #139)

File size: 9.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using HeuristicLab.DataAnalysis;
6
7namespace HeuristicLab.Functions {
8  static class BakedTreeEvaluator {
9    private const int ADDITION = 10010;
10    private const int AND = 10020;
11    private const int AVERAGE = 10030;
12    private const int CONSTANT = 10040;
13    private const int COSINUS = 10050;
14    private const int DIVISION = 10060;
15    private const int EQU = 10070;
16    private const int EXP = 10080;
17    private const int GT = 10090;
18    private const int IFTE = 10100;
19    private const int LT = 10110;
20    private const int LOG = 10120;
21    private const int MULTIPLICATION = 10130;
22    private const int NOT = 10140;
23    private const int OR = 10150;
24    private const int POWER = 10160;
25    private const int SIGNUM = 10170;
26    private const int SINUS = 10180;
27    private const int SQRT = 10190;
28    private const int SUBSTRACTION = 10200;
29    private const int TANGENS = 10210;
30    private const int VARIABLE = 10220;
31    private const int XOR = 10230;
32
33    private static int nextFunctionSymbol = 10240;
34    private static Dictionary<int, IFunction> symbolTable;
35    private static Dictionary<IFunction, int> reverseSymbolTable;
36    private static Dictionary<Type, int> staticTypes;
37
38    static BakedTreeEvaluator() {
39      symbolTable = new Dictionary<int, IFunction>();
40      reverseSymbolTable = new Dictionary<IFunction, int>();
41      staticTypes = new Dictionary<Type, int>();
42      staticTypes[typeof(Addition)] = ADDITION;
43      staticTypes[typeof(And)] = AND;
44      staticTypes[typeof(Average)] = AVERAGE;
45      staticTypes[typeof(Constant)] = CONSTANT;
46      staticTypes[typeof(Cosinus)] = COSINUS;
47      staticTypes[typeof(Division)] = DIVISION;
48      staticTypes[typeof(Equal)] = EQU;
49      staticTypes[typeof(Exponential)] = EXP;
50      staticTypes[typeof(GreaterThan)] = GT;
51      staticTypes[typeof(IfThenElse)] = IFTE;
52      staticTypes[typeof(LessThan)] = LT;
53      staticTypes[typeof(Logarithm)] = LOG;
54      staticTypes[typeof(Multiplication)] = MULTIPLICATION;
55      staticTypes[typeof(Not)] = NOT;
56      staticTypes[typeof(Or)] = OR;
57      staticTypes[typeof(Power)] = POWER;
58      staticTypes[typeof(Signum)] = SIGNUM;
59      staticTypes[typeof(Sinus)] = SINUS;
60      staticTypes[typeof(Sqrt)] = SQRT;
61      staticTypes[typeof(Substraction)] = SUBSTRACTION;
62      staticTypes[typeof(Tangens)] = TANGENS;
63      staticTypes[typeof(Variable)] = VARIABLE;
64      staticTypes[typeof(Xor)] = XOR;
65    }
66
67    internal static int MapFunction(IFunction function) {
68      if(!reverseSymbolTable.ContainsKey(function)) {
69        int curFunctionSymbol;
70        if(staticTypes.ContainsKey(function.GetType())) curFunctionSymbol = staticTypes[function.GetType()];
71        else {
72          curFunctionSymbol = nextFunctionSymbol;
73          nextFunctionSymbol++;
74        }
75        reverseSymbolTable[function] = curFunctionSymbol;
76        symbolTable[curFunctionSymbol] = function;
77      }
78      return reverseSymbolTable[function];
79    }
80
81    internal static IFunction MapSymbol(int symbol) {
82      return symbolTable[symbol];
83    }
84
85
86    private static int PC;
87    private static int DP;
88    private static int[] codeArr;
89    private static double[] dataArr;
90    private static Dataset dataset;
91    private static int sampleIndex;
92
93    internal static double Evaluate(Dataset _dataset, int _sampleIndex, List<int> code, List<double> data) {
94      PC = 0;
95      DP = 0;
96      if(codeArr == null || codeArr.Length!=code.Count) {
97        codeArr = new int[code.Count];
98        dataArr = new double[data.Count];
99        code.CopyTo(codeArr);
100        data.CopyTo(dataArr);
101      }
102      sampleIndex = _sampleIndex;
103      dataset = _dataset;
104      return EvaluateBakedCode();
105    }
106
107    private static double EvaluateBakedCode() {
108      int arity = codeArr[PC++];
109      int functionSymbol = codeArr[PC++];
110      int nLocalVariables = codeArr[PC++];
111      switch(functionSymbol) {
112        case VARIABLE: {
113            int var = (int)dataArr[DP++];
114            double weight = dataArr[DP++];
115            int offset = (int)dataArr[DP++];
116            int row = sampleIndex + offset;
117            if(row < 0 || row >= dataset.Rows) return double.NaN;
118            else return weight * dataset.GetValue(row, var);
119          }
120        case CONSTANT: {
121            double value = dataArr[DP++];
122            return value;
123          }
124        case MULTIPLICATION: {
125            double result = 1.0;
126            for(int i = 0; i < arity; i++) {
127              result *= EvaluateBakedCode();
128            }
129            return result;
130          }
131        case ADDITION: {
132            double sum = 0.0;
133            for(int i = 0; i < arity; i++) {
134              sum += EvaluateBakedCode();
135            }
136            return sum;
137          }
138        case SUBSTRACTION: {
139            if(arity == 1) {
140              return -EvaluateBakedCode();
141            } else {
142              double result = EvaluateBakedCode();
143              for(int i = 1; i < arity; i++) {
144                result -= EvaluateBakedCode();
145              }
146              return result;
147            }
148          }
149        case DIVISION: {
150            if(arity == 1) {
151              double divisor = EvaluateBakedCode();
152              if(divisor == 0) return 0;
153              else return 1.0 / divisor;
154            } else {
155              double result = EvaluateBakedCode();
156              for(int i = 1; i < arity; i++) {
157                double divisor = EvaluateBakedCode();
158                if(divisor == 0) result = 0;
159                else result /= divisor;
160              }
161              return result;
162            }
163          }
164        case AVERAGE: {
165            double sum = 0.0;
166            for(int i = 0; i < arity; i++) {
167              sum += EvaluateBakedCode();
168            }
169            return sum / arity;
170          }
171        case COSINUS: {
172            return Math.Cos(EvaluateBakedCode());
173          }
174        case SINUS: {
175            return Math.Sin(EvaluateBakedCode());
176          }
177        case EXP: {
178            return Math.Exp(EvaluateBakedCode());
179          }
180        case LOG: {
181            return Math.Log(EvaluateBakedCode());
182          }
183        case POWER: {
184            double x = EvaluateBakedCode();
185            double p = EvaluateBakedCode();
186            return Math.Pow(x, p);
187          }
188        case SIGNUM: {
189            // protected signum
190            double value = EvaluateBakedCode();
191            if(value < 0) return -1;
192            if(value > 0) return 1;
193            return 0;
194          }
195        case SQRT: {
196            return Math.Sqrt(EvaluateBakedCode());
197          }
198        case TANGENS: {
199            return Math.Tan(EvaluateBakedCode());
200          }
201        case AND: {
202            double result = 1.0;
203            // have to evaluate all sub-trees, skipping would probably not lead to a big gain because
204            // we have to iterate over the linear structure anyway
205            for(int i = 0; i < arity; i++) {
206              double x = Math.Round(EvaluateBakedCode());
207              if(x == 0) result *= 0;
208              else if(x == 1.0) result *= 1.0;
209              else result *= double.NaN;
210            }
211            return result;
212          }
213        case EQU: {
214            double x = EvaluateBakedCode();
215            double y = EvaluateBakedCode();
216            if(x == y) return 1.0; else return 0.0;
217          }
218        case GT: {
219            double x = EvaluateBakedCode();
220            double y = EvaluateBakedCode();
221            if(x > y) return 1.0;
222            else return 0.0;
223          }
224        case IFTE: {
225            double condition = Math.Round(EvaluateBakedCode());
226            double x = EvaluateBakedCode();
227            double y = EvaluateBakedCode();
228            if(condition < .5) return x;
229            else if(condition >= .5) return y;
230            else return double.NaN;
231          }
232        case LT: {
233            double x = EvaluateBakedCode();
234            double y = EvaluateBakedCode();
235            if(x < y) return 1.0;
236            else return 0.0;
237          }
238        case NOT: {
239            double result = Math.Round(EvaluateBakedCode());
240            if(result == 0.0) return 1.0;
241            else if(result == 1.0) return 0.0;
242            else return double.NaN;
243          }
244        case OR: {
245            double result = 0.0; // default is false
246            for(int i = 0; i < arity; i++) {
247              double x = Math.Round(EvaluateBakedCode());
248              if(x == 1.0 && result == 0.0) result = 1.0; // found first true (1.0) => set to true
249              else if(x != 0.0) result = double.NaN; // if it was not true it can only be false (0.0) all other cases are undefined => (NaN)
250            }
251            return result;
252          }
253        case XOR: {
254            double x = Math.Round(EvaluateBakedCode());
255            double y = Math.Round(EvaluateBakedCode());
256            if(x == 0.0 && y == 0.0) return 0.0;
257            if(x == 1.0 && y == 0.0) return 1.0;
258            if(x == 0.0 && y == 1.0) return 1.0;
259            if(x == 1.0 && y == 1.0) return 0.0;
260            return double.NaN;
261          }
262        default: {
263            IFunction function = symbolTable[functionSymbol];
264            double[] args = new double[nLocalVariables + arity];
265            for(int i = 0; i < nLocalVariables; i++) {
266              args[i] = dataArr[DP++];
267            }
268            for(int j = 0; j < arity; j++) {
269              args[nLocalVariables + j] = EvaluateBakedCode();
270            }
271            return function.Apply(dataset, sampleIndex, args);
272          }
273      }
274    }
275  }
276}
Note: See TracBrowser for help on using the repository browser.