Changeset 327 for branches/BottomUpTreeEvaluation
- Timestamp:
- 06/20/08 00:35:02 (17 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/BottomUpTreeEvaluation/BakedTreeEvaluator.cs
r322 r327 31 31 internal static class BakedTreeEvaluator { 32 32 private const int MAX_TREE_SIZE = 4096; 33 33 private const int MAX_TREE_DEPTH = 20; 34 34 private class Instr { 35 35 public double d_arg0; … … 40 40 } 41 41 42 private static Instr[] codeArr;43 private static int PC;42 private static int[] nInstr; 43 private static Instr[,] evaluationTable; 44 44 private static Dataset dataset; 45 45 private static int sampleIndex; … … 47 47 48 48 static BakedTreeEvaluator() { 49 codeArr = new Instr[MAX_TREE_SIZE]; 50 for(int i = 0; i < MAX_TREE_SIZE; i++) { 51 codeArr[i] = new Instr(); 49 evaluationTable = new Instr[MAX_TREE_SIZE, MAX_TREE_DEPTH]; 50 nInstr = new int[MAX_TREE_DEPTH]; 51 for(int j = 0; j < MAX_TREE_DEPTH; j++) { 52 for(int i = 0; i < MAX_TREE_SIZE; i++) { 53 evaluationTable[i, j] = new Instr(); 54 } 52 55 } 53 56 } 54 57 55 58 public static void ResetEvaluator(List<LightWeightFunction> linearRepresentation) { 56 int i = 0; 57 foreach(LightWeightFunction f in linearRepresentation) { 58 TranslateToInstr(f, codeArr[i++]); 59 } 60 } 61 62 private static Instr TranslateToInstr(LightWeightFunction f, Instr instr) { 59 int length; 60 for(int i = 0; i < MAX_TREE_DEPTH; i++) nInstr[i] = 0; 61 TranslateToInstr(0, linearRepresentation, out length); 62 } 63 64 private static int TranslateToInstr(int pos, List<LightWeightFunction> linearRepresentation, out int branchLength) { 65 int height = 0; 66 int length = 1; 67 LightWeightFunction f = linearRepresentation[pos]; 68 for(int i = 0; i < f.arity; i++) { 69 int curBranchLength; 70 int curBranchHeight = TranslateToInstr(pos + length, linearRepresentation, out curBranchLength); 71 if(curBranchHeight > height) height = curBranchHeight; 72 length += curBranchLength; 73 } 74 Instr instr = evaluationTable[nInstr[height], height]; 63 75 instr.arity = f.arity; 64 76 instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType); … … 75 87 } 76 88 } 77 return instr; 78 } 89 nInstr[height]++; 90 branchLength = length; 91 return height; 92 } 93 94 //private static Instr TranslateToInstr(LightWeightFunction f, Instr instr) { 95 // instr.arity = f.arity; 96 // instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType); 97 // switch(instr.symbol) { 98 // case EvaluatorSymbolTable.VARIABLE: { 99 // instr.i_arg0 = (int)f.data[0]; // var 100 // instr.d_arg0 = f.data[1]; // weight 101 // instr.i_arg1 = (int)f.data[2]; // sample-offset 102 // break; 103 // } 104 // case EvaluatorSymbolTable.CONSTANT: { 105 // instr.d_arg0 = f.data[0]; // value 106 // break; 107 // } 108 // } 109 // return instr; 110 //} 79 111 80 112 internal static double Evaluate(Dataset dataset, int sampleIndex) { 81 PC = 0;82 113 BakedTreeEvaluator.sampleIndex = sampleIndex; 83 114 BakedTreeEvaluator.dataset = dataset; 84 return EvaluateBakedCode(); 85 } 86 87 private static double EvaluateBakedCode() { 88 Instr currInstr = codeArr[PC++]; 89 switch(currInstr.symbol) { 90 case EvaluatorSymbolTable.VARIABLE: { 91 int row = sampleIndex + currInstr.i_arg1; 92 if(row < 0 || row >= dataset.Rows) return double.NaN; 93 else return currInstr.d_arg0 * dataset.GetValue(row, currInstr.i_arg0); 115 return EvaluateTable(); 116 } 117 118 private static double EvaluateTable() { 119 int terminalP = 0; 120 for(; terminalP < nInstr[0]; terminalP += 2) { 121 Instr curInstr0 = evaluationTable[terminalP, 0]; 122 Instr curInstr1 = evaluationTable[terminalP + 1, 0]; 123 if(curInstr0.symbol == EvaluatorSymbolTable.VARIABLE) { 124 int row = sampleIndex + curInstr0.i_arg1; 125 if(row < 0 || row >= dataset.Rows) curInstr0.d_arg0 = double.NaN; 126 else curInstr0.d_arg0 = curInstr0.d_arg0 * dataset.GetValue(row, curInstr0.i_arg0); 127 } 128 if(curInstr1.symbol == EvaluatorSymbolTable.VARIABLE) { 129 int row = sampleIndex + curInstr1.i_arg1; 130 if(row < 0 || row >= dataset.Rows) curInstr1.d_arg0 = double.NaN; 131 else curInstr1.d_arg0 = curInstr1.d_arg0 * dataset.GetValue(row, curInstr1.i_arg0); 132 } 133 } 134 135 int curLevel = 1; 136 while(nInstr[curLevel] > 0) { 137 int lastLayerInstrP = 0; 138 for(int curLayerInstrP = 0; curLayerInstrP < nInstr[curLevel]; curLayerInstrP++) { 139 Instr curInstr = evaluationTable[curLayerInstrP, curLevel]; 140 switch(curInstr.symbol) { 141 case EvaluatorSymbolTable.MULTIPLICATION: { 142 curInstr.d_arg0 = evaluationTable[lastLayerInstrP, curLevel - 1].d_arg0; 143 for(int i = 1; i < curInstr.arity; i++) { 144 curInstr.d_arg0 *= evaluationTable[lastLayerInstrP + i, curLevel - 1].d_arg0; 145 } 146 lastLayerInstrP += curInstr.arity; 147 break; 148 } 149 case EvaluatorSymbolTable.ADDITION: { 150 curInstr.d_arg0 = evaluationTable[lastLayerInstrP, curLevel - 1].d_arg0; 151 for(int i = 1; i < curInstr.arity; i++) { 152 curInstr.d_arg0 += evaluationTable[lastLayerInstrP + i, curLevel - 1].d_arg0; 153 } 154 lastLayerInstrP += curInstr.arity; 155 break; 156 } 157 case EvaluatorSymbolTable.SUBTRACTION: { 158 if(curInstr.arity == 1) { 159 curInstr.d_arg0 = -evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0; 160 } else { 161 curInstr.d_arg0 = evaluationTable[lastLayerInstrP, curLevel - 1].d_arg0; 162 for(int i = 1; i < curInstr.arity; i++) { 163 curInstr.d_arg0 -= evaluationTable[lastLayerInstrP + i, curLevel - 1].d_arg0; 164 } 165 lastLayerInstrP += curInstr.arity; 166 } 167 break; 168 } 169 case EvaluatorSymbolTable.DIVISION: { 170 if(curInstr.arity == 1) { 171 curInstr.d_arg0 = 1.0 / evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0; 172 } else { 173 curInstr.d_arg0 = evaluationTable[lastLayerInstrP, curLevel - 1].d_arg0; 174 for(int i = 1; i < curInstr.arity; i++) { 175 curInstr.d_arg0 /= evaluationTable[lastLayerInstrP + i, curLevel - 1].d_arg0; 176 } 177 lastLayerInstrP += curInstr.arity; 178 } 179 if(double.IsInfinity(curInstr.d_arg0)) curInstr.d_arg0 = 0.0; 180 break; 181 } 182 case EvaluatorSymbolTable.AVERAGE: { 183 curInstr.d_arg0 = evaluationTable[lastLayerInstrP, curLevel - 1].d_arg0; 184 for(int i = 1; i < curInstr.arity; i++) { 185 curInstr.d_arg0 += evaluationTable[lastLayerInstrP + i, curLevel - 1].d_arg0; 186 } 187 lastLayerInstrP += curInstr.arity; 188 curInstr.d_arg0 /= curInstr.arity; 189 break; 190 } 191 case EvaluatorSymbolTable.COSINUS: { 192 curInstr.d_arg0 = Math.Cos(evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0); 193 break; 194 } 195 case EvaluatorSymbolTable.SINUS: { 196 curInstr.d_arg0 = Math.Sin(evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0); 197 break; 198 } 199 case EvaluatorSymbolTable.EXP: { 200 curInstr.d_arg0 = Math.Exp(evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0); 201 break; 202 } 203 case EvaluatorSymbolTable.LOG: { 204 curInstr.d_arg0 = Math.Log(evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0); 205 break; 206 } 207 case EvaluatorSymbolTable.POWER: { 208 double x = evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0; 209 double p = evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0; 210 curInstr.d_arg0 = Math.Pow(x, p); 211 break; 212 } 213 case EvaluatorSymbolTable.SIGNUM: { 214 double value = evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0; 215 if(double.IsNaN(value)) curInstr.d_arg0 = double.NaN; 216 else curInstr.d_arg0 = Math.Sign(value); 217 break; 218 } 219 case EvaluatorSymbolTable.SQRT: { 220 curInstr.d_arg0 = Math.Sqrt(evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0); 221 break; 222 } 223 case EvaluatorSymbolTable.TANGENS: { 224 curInstr.d_arg0 = Math.Tan(evaluationTable[lastLayerInstrP++, curLevel - 1].d_arg0); 225 break; 226 } 227 //case EvaluatorSymbolTable.AND: { 228 // double result = 1.0; 229 // // have to evaluate all sub-trees, skipping would probably not lead to a big gain because 230 // // we have to iterate over the linear structure anyway 231 // for(int i = 0; i < currInstr.arity; i++) { 232 // double x = Math.Round(EvaluateBakedCode()); 233 // if(x == 0 || x == 1.0) result *= x; 234 // else result = double.NaN; 235 // } 236 // return result; 237 // } 238 //case EvaluatorSymbolTable.EQU: { 239 // double x = EvaluateBakedCode(); 240 // double y = EvaluateBakedCode(); 241 // if(x == y) return 1.0; else return 0.0; 242 // } 243 //case EvaluatorSymbolTable.GT: { 244 // double x = EvaluateBakedCode(); 245 // double y = EvaluateBakedCode(); 246 // if(x > y) return 1.0; 247 // else return 0.0; 248 // } 249 //case EvaluatorSymbolTable.IFTE: { 250 // double condition = Math.Round(EvaluateBakedCode()); 251 // double x = EvaluateBakedCode(); 252 // double y = EvaluateBakedCode(); 253 // if(condition < .5) return x; 254 // else if(condition >= .5) return y; 255 // else return double.NaN; 256 // } 257 //case EvaluatorSymbolTable.LT: { 258 // double x = EvaluateBakedCode(); 259 // double y = EvaluateBakedCode(); 260 // if(x < y) return 1.0; 261 // else return 0.0; 262 // } 263 //case EvaluatorSymbolTable.NOT: { 264 // double result = Math.Round(EvaluateBakedCode()); 265 // if(result == 0.0) return 1.0; 266 // else if(result == 1.0) return 0.0; 267 // else return double.NaN; 268 // } 269 //case EvaluatorSymbolTable.OR: { 270 // double result = 0.0; // default is false 271 // for(int i = 0; i < currInstr.arity; i++) { 272 // double x = Math.Round(EvaluateBakedCode()); 273 // if(x == 1.0 && result == 0.0) result = 1.0; // found first true (1.0) => set to true 274 // else if(x != 0.0) result = double.NaN; // if it was not true it can only be false (0.0) all other cases are undefined => (NaN) 275 // } 276 // return result; 277 // } 278 //case EvaluatorSymbolTable.XOR: { 279 // double x = Math.Round(EvaluateBakedCode()); 280 // double y = Math.Round(EvaluateBakedCode()); 281 // if(x == 0.0 && y == 0.0) return 0.0; 282 // if(x == 1.0 && y == 0.0) return 1.0; 283 // if(x == 0.0 && y == 1.0) return 1.0; 284 // if(x == 1.0 && y == 1.0) return 0.0; 285 // return double.NaN; 286 // } 287 default: { 288 throw new NotImplementedException(); 289 } 94 290 } 95 case EvaluatorSymbolTable.CONSTANT: { 96 return currInstr.d_arg0; 97 } 98 case EvaluatorSymbolTable.MULTIPLICATION: { 99 double result = EvaluateBakedCode(); 100 for(int i = 1; i < currInstr.arity; i++) { 101 result *= EvaluateBakedCode(); 102 } 103 return result; 104 } 105 case EvaluatorSymbolTable.ADDITION: { 106 double sum = EvaluateBakedCode(); 107 for(int i = 1; i < currInstr.arity; i++) { 108 sum += EvaluateBakedCode(); 109 } 110 return sum; 111 } 112 case EvaluatorSymbolTable.SUBTRACTION: { 113 if(currInstr.arity == 1) { 114 return -EvaluateBakedCode(); 115 } else { 116 double result = EvaluateBakedCode(); 117 for(int i = 1; i < currInstr.arity; i++) { 118 result -= EvaluateBakedCode(); 119 } 120 return result; 121 } 122 } 123 case EvaluatorSymbolTable.DIVISION: { 124 double result; 125 if(currInstr.arity == 1) { 126 result = 1.0 / EvaluateBakedCode(); 127 } else { 128 result = EvaluateBakedCode(); 129 for(int i = 1; i < currInstr.arity; i++) { 130 result /= EvaluateBakedCode(); 131 } 132 } 133 if(double.IsInfinity(result)) return 0.0; 134 else return result; 135 } 136 case EvaluatorSymbolTable.AVERAGE: { 137 double sum = EvaluateBakedCode(); 138 for(int i = 1; i < currInstr.arity; i++) { 139 sum += EvaluateBakedCode(); 140 } 141 return sum / currInstr.arity; 142 } 143 case EvaluatorSymbolTable.COSINUS: { 144 return Math.Cos(EvaluateBakedCode()); 145 } 146 case EvaluatorSymbolTable.SINUS: { 147 return Math.Sin(EvaluateBakedCode()); 148 } 149 case EvaluatorSymbolTable.EXP: { 150 return Math.Exp(EvaluateBakedCode()); 151 } 152 case EvaluatorSymbolTable.LOG: { 153 return Math.Log(EvaluateBakedCode()); 154 } 155 case EvaluatorSymbolTable.POWER: { 156 double x = EvaluateBakedCode(); 157 double p = EvaluateBakedCode(); 158 return Math.Pow(x, p); 159 } 160 case EvaluatorSymbolTable.SIGNUM: { 161 double value = EvaluateBakedCode(); 162 if(double.IsNaN(value)) return double.NaN; 163 else return Math.Sign(value); 164 } 165 case EvaluatorSymbolTable.SQRT: { 166 return Math.Sqrt(EvaluateBakedCode()); 167 } 168 case EvaluatorSymbolTable.TANGENS: { 169 return Math.Tan(EvaluateBakedCode()); 170 } 171 case EvaluatorSymbolTable.AND: { 172 double result = 1.0; 173 // have to evaluate all sub-trees, skipping would probably not lead to a big gain because 174 // we have to iterate over the linear structure anyway 175 for(int i = 0; i < currInstr.arity; i++) { 176 double x = Math.Round(EvaluateBakedCode()); 177 if(x == 0 || x == 1.0) result *= x; 178 else result = double.NaN; 179 } 180 return result; 181 } 182 case EvaluatorSymbolTable.EQU: { 183 double x = EvaluateBakedCode(); 184 double y = EvaluateBakedCode(); 185 if(x == y) return 1.0; else return 0.0; 186 } 187 case EvaluatorSymbolTable.GT: { 188 double x = EvaluateBakedCode(); 189 double y = EvaluateBakedCode(); 190 if(x > y) return 1.0; 191 else return 0.0; 192 } 193 case EvaluatorSymbolTable.IFTE: { 194 double condition = Math.Round(EvaluateBakedCode()); 195 double x = EvaluateBakedCode(); 196 double y = EvaluateBakedCode(); 197 if(condition < .5) return x; 198 else if(condition >= .5) return y; 199 else return double.NaN; 200 } 201 case EvaluatorSymbolTable.LT: { 202 double x = EvaluateBakedCode(); 203 double y = EvaluateBakedCode(); 204 if(x < y) return 1.0; 205 else return 0.0; 206 } 207 case EvaluatorSymbolTable.NOT: { 208 double result = Math.Round(EvaluateBakedCode()); 209 if(result == 0.0) return 1.0; 210 else if(result == 1.0) return 0.0; 211 else return double.NaN; 212 } 213 case EvaluatorSymbolTable.OR: { 214 double result = 0.0; // default is false 215 for(int i = 0; i < currInstr.arity; i++) { 216 double x = Math.Round(EvaluateBakedCode()); 217 if(x == 1.0 && result == 0.0) result = 1.0; // found first true (1.0) => set to true 218 else if(x != 0.0) result = double.NaN; // if it was not true it can only be false (0.0) all other cases are undefined => (NaN) 219 } 220 return result; 221 } 222 case EvaluatorSymbolTable.XOR: { 223 double x = Math.Round(EvaluateBakedCode()); 224 double y = Math.Round(EvaluateBakedCode()); 225 if(x == 0.0 && y == 0.0) return 0.0; 226 if(x == 1.0 && y == 0.0) return 1.0; 227 if(x == 0.0 && y == 1.0) return 1.0; 228 if(x == 1.0 && y == 1.0) return 0.0; 229 return double.NaN; 230 } 231 default: { 232 throw new NotImplementedException(); 233 } 234 } 291 } 292 // copy remaining results from previous layer to current layer (identiy function) 293 int r = 0; 294 for(; lastLayerInstrP < nInstr[curLevel - 1]; lastLayerInstrP++) { 295 evaluationTable[nInstr[curLevel] + r, curLevel].d_arg0 = evaluationTable[lastLayerInstrP, curLevel - 1].d_arg0; 296 r++; 297 } 298 curLevel++; 299 } 300 return evaluationTable[0, curLevel - 1].d_arg0; 235 301 } 236 302 }
Note: See TracChangeset
for help on using the changeset viewer.