Changeset 220 for branches/ExperimentalFunctionsBaking/BakedFunctionTree.cs
- Timestamp:
- 05/07/08 00:02:43 (16 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/ExperimentalFunctionsBaking/BakedFunctionTree.cs
r208 r220 30 30 namespace HeuristicLab.Functions { 31 31 class BakedFunctionTree : ItemBase, IFunctionTree { 32 private List<double> code; 33 private static double nextFunctionSymbol = 10000; 34 private static Dictionary<double, IFunction> symbolTable = new Dictionary<double, IFunction>(); 35 private static Dictionary<IFunction, double> reverseSymbolTable = new Dictionary<IFunction, double>(); 36 private static double additionSymbol = -1; 37 private static double substractionSymbol = -1; 38 private static double multiplicationSymbol = -1; 39 private static double divisionSymbol = -1; 40 private static double variableSymbol = -1; 41 private static double constantSymbol = -1; 32 private List<int> code; 33 private List<double> data; 34 private const int ADDITION = 10010; 35 private const int AND = 10020; 36 private const int AVERAGE = 10030; 37 private const int CONSTANT = 10040; 38 private const int COSINUS = 10050; 39 private const int DIVISION = 10060; 40 private const int EQU = 10070; 41 private const int EXP = 10080; 42 private const int GT = 10090; 43 private const int IFTE = 10100; 44 private const int LT = 10110; 45 private const int LOG = 10120; 46 private const int MULTIPLICATION = 10130; 47 private const int NOT = 10140; 48 private const int OR = 10150; 49 private const int POWER = 10160; 50 private const int SIGNUM = 10170; 51 private const int SINUS = 10180; 52 private const int SQRT = 10190; 53 private const int SUBSTRACTION = 10200; 54 private const int TANGENS = 10210; 55 private const int VARIABLE = 10220; 56 private const int XOR = 10230; 57 58 private static int nextFunctionSymbol = 10240; 59 private static Dictionary<int, IFunction> symbolTable; 60 private static Dictionary<IFunction, int> reverseSymbolTable; 61 private static Dictionary<Type, int> staticTypes; 62 63 static BakedFunctionTree() { 64 symbolTable = new Dictionary<int, IFunction>(); 65 reverseSymbolTable = new Dictionary<IFunction, int>(); 66 staticTypes = new Dictionary<Type, int>(); 67 staticTypes[typeof(Addition)] = ADDITION; 68 staticTypes[typeof(And)] = AND; 69 staticTypes[typeof(Average)] = AVERAGE; 70 staticTypes[typeof(Constant)] = CONSTANT; 71 staticTypes[typeof(Cosinus)] = COSINUS; 72 staticTypes[typeof(Division)] = DIVISION; 73 staticTypes[typeof(Equal)] = EQU; 74 staticTypes[typeof(Exponential)] = EXP; 75 staticTypes[typeof(GreaterThan)] = GT; 76 staticTypes[typeof(IfThenElse)] = IFTE; 77 staticTypes[typeof(LessThan)] = LT; 78 staticTypes[typeof(Logarithm)] = LOG; 79 staticTypes[typeof(Multiplication)] = MULTIPLICATION; 80 staticTypes[typeof(Not)] = NOT; 81 staticTypes[typeof(Or)] = OR; 82 staticTypes[typeof(Power)] = POWER; 83 staticTypes[typeof(Signum)] = SIGNUM; 84 staticTypes[typeof(Sinus)] = SINUS; 85 staticTypes[typeof(Sqrt)] = SQRT; 86 staticTypes[typeof(Substraction)] = SUBSTRACTION; 87 staticTypes[typeof(Tangens)] = TANGENS; 88 staticTypes[typeof(Variable)] = VARIABLE; 89 staticTypes[typeof(Xor)] = XOR; 90 } 42 91 43 92 internal BakedFunctionTree() { 44 code = new List<double>(); 93 code = new List<int>(); 94 data = new List<double>(); 45 95 } 46 96 … … 65 115 code.Add(0); 66 116 code.Add(MapFunction(tree.Function)); 67 code.Add( (byte)tree.LocalVariables.Count);117 code.Add(tree.LocalVariables.Count); 68 118 foreach(IVariable variable in tree.LocalVariables) { 69 119 IItem value = variable.Value; 70 code.Add(GetDoubleValue(value));120 data.Add(GetDoubleValue(value)); 71 121 } 72 122 foreach(IFunctionTree subTree in tree.SubTrees) { … … 87 137 } 88 138 89 private doubleMapFunction(IFunction function) {139 private int MapFunction(IFunction function) { 90 140 if(!reverseSymbolTable.ContainsKey(function)) { 91 reverseSymbolTable[function] = nextFunctionSymbol; 92 symbolTable[nextFunctionSymbol] = function; 93 if(function is Variable) { 94 variableSymbol = nextFunctionSymbol; 95 } else if(function is Constant) { 96 constantSymbol = nextFunctionSymbol; 97 } else if(function is Addition) { 98 additionSymbol = nextFunctionSymbol; 99 } else if(function is Substraction) { 100 substractionSymbol = nextFunctionSymbol; 101 } else if(function is Multiplication) { 102 multiplicationSymbol = nextFunctionSymbol; 103 } else if(function is Division) { 104 divisionSymbol = nextFunctionSymbol; 105 } else throw new NotSupportedException("Unsupported function " + function); 106 107 nextFunctionSymbol++; 141 int curFunctionSymbol; 142 if(staticTypes.ContainsKey(function.GetType())) curFunctionSymbol = staticTypes[function.GetType()]; 143 else { 144 curFunctionSymbol = nextFunctionSymbol; 145 nextFunctionSymbol++; 146 } 147 reverseSymbolTable[function] = curFunctionSymbol; 148 symbolTable[curFunctionSymbol] = function; 108 149 } 109 150 return reverseSymbolTable[function]; 110 151 } 111 152 112 private int BranchLength(int branchRoot) { 113 double arity = code[branchRoot]; 114 int nLocalVariables = (int)code[branchRoot + 2]; 115 int len = 3 + nLocalVariables; 116 int subBranchStart = branchRoot + len; 153 private void BranchLength(int branchRoot, out int codeLength, out int dataLength) { 154 int arity = code[branchRoot]; 155 int nLocalVariables = code[branchRoot + 2]; 156 codeLength = 3; 157 dataLength = nLocalVariables; 158 int subBranchStart = branchRoot + codeLength; 117 159 for(int i = 0; i < arity; i++) { 118 int branchLen = BranchLength(subBranchStart); 119 len += branchLen; 120 subBranchStart += branchLen; 121 } 122 return len; 160 int branchCodeLength; 161 int branchDataLength; 162 BranchLength(subBranchStart, out branchCodeLength, out branchDataLength); 163 subBranchStart += branchCodeLength; 164 codeLength += branchCodeLength; 165 dataLength += branchDataLength; 166 } 123 167 } 124 168 … … 130 174 subTree.FlattenTrees(); 131 175 code.AddRange(subTree.code); 176 data.AddRange(subTree.data); 132 177 } 133 178 treesExpanded = false; … … 139 184 if(variablesExpanded) { 140 185 code[2] = variables.Count; 141 int localVariableIndex = 3;142 186 foreach(IVariable variable in variables) { 143 code.Insert(localVariableIndex, GetDoubleValue(variable.Value)); 144 localVariableIndex++; 187 data.Add(GetDoubleValue(variable.Value)); 145 188 } 146 189 variablesExpanded = false; … … 155 198 if(!treesExpanded) { 156 199 subTrees = new List<IFunctionTree>(); 157 double arity = code[0]; 158 int nLocalVariables = (int)code[2]; 159 int branchIndex = 3 + nLocalVariables; 200 int arity = code[0]; 201 int nLocalVariables = code[2]; 202 int branchIndex = 3; 203 int dataIndex = nLocalVariables; // skip my local variables to reach the local variables of the first branch 160 204 for(int i = 0; i < arity; i++) { 161 205 BakedFunctionTree subTree = new BakedFunctionTree(); 162 int branchLen = BranchLength(branchIndex); 163 subTree.code = code.GetRange(branchIndex, branchLen); 164 branchIndex += branchLen; 206 int codeLength; 207 int dataLength; 208 BranchLength(branchIndex, out codeLength, out dataLength); 209 subTree.code = code.GetRange(branchIndex, codeLength); 210 subTree.data = data.GetRange(dataIndex, dataLength); 211 branchIndex += codeLength; 212 dataIndex += dataLength; 165 213 subTrees.Add(subTree); 166 214 } 167 215 treesExpanded = true; 168 code.RemoveRange(3 + nLocalVariables, code.Count - (3 + nLocalVariables));216 code.RemoveRange(3, code.Count - 3); 169 217 code[0] = 0; 218 data.RemoveRange(nLocalVariables, data.Count - nLocalVariables); 170 219 } 171 220 return subTrees; … … 180 229 variables = new List<IVariable>(); 181 230 IFunction function = symbolTable[code[1]]; 182 int localVariableIndex = 3;231 int localVariableIndex = 0; 183 232 foreach(IVariableInfo variableInfo in function.VariableInfos) { 184 233 if(variableInfo.Local) { … … 186 235 IItem value = clone.Value; 187 236 if(value is ConstrainedDoubleData) { 188 ((ConstrainedDoubleData)value).Data = code[localVariableIndex];237 ((ConstrainedDoubleData)value).Data = data[localVariableIndex]; 189 238 } else if(value is ConstrainedIntData) { 190 ((ConstrainedIntData)value).Data = (int) code[localVariableIndex];239 ((ConstrainedIntData)value).Data = (int)data[localVariableIndex]; 191 240 } else if(value is DoubleData) { 192 ((DoubleData)value).Data = code[localVariableIndex];241 ((DoubleData)value).Data = data[localVariableIndex]; 193 242 } else if(value is IntData) { 194 ((IntData)value).Data = (int) code[localVariableIndex];243 ((IntData)value).Data = (int)data[localVariableIndex]; 195 244 } else throw new NotSupportedException("Invalid local variable type for GP."); 196 245 variables.Add(clone); … … 200 249 variablesExpanded = true; 201 250 code[2] = 0; 202 code.RemoveRange(3, variables.Count);251 data.RemoveRange(0, variables.Count); 203 252 } 204 253 return variables; … … 211 260 212 261 public IVariable GetLocalVariable(string name) { 213 throw new NotImplementedException(); 262 foreach(IVariable var in LocalVariables) { 263 if(var.Name == name) return var; 264 } 265 return null; 214 266 } 215 267 216 268 public void AddVariable(IVariable variable) { 217 throw new Not ImplementedException();269 throw new NotSupportedException(); 218 270 } 219 271 220 272 public void RemoveVariable(string name) { 221 throw new Not ImplementedException();273 throw new NotSupportedException(); 222 274 } 223 275 224 276 public void AddSubTree(IFunctionTree tree) { 225 if(treesExpanded) { 226 subTrees.Add(tree); 227 } else { 228 //code.AddRange(((BakedFunctionTree)tree).code); 229 //code[0] = code[0] + 1; 230 throw new NotImplementedException(); 231 } 277 if(!treesExpanded) throw new InvalidOperationException(); 278 subTrees.Add(tree); 232 279 } 233 280 234 281 public void InsertSubTree(int index, IFunctionTree tree) { 235 if(treesExpanded) { 236 subTrees.Insert(index, tree); 237 } else { 238 //byte nLocalVariables = code[2]; 239 //// skip branches 240 //int currentBranchIndex = 3 + nLocalVariables; 241 //for(int i = 0; i < index; i++) { 242 // int branchLength = BranchLength(currentBranchIndex); 243 // currentBranchIndex += branchLength; 244 //} 245 246 //code.InsertRange(currentBranchIndex, ((BakedFunctionTree)tree).code); 247 //code[0] = code[0] + 1; 248 249 throw new NotImplementedException(); 250 } 282 if(!treesExpanded) throw new InvalidOperationException(); 283 subTrees.Insert(index, tree); 251 284 } 252 285 253 286 public void RemoveSubTree(int index) { 254 if(treesExpanded) { 255 subTrees.RemoveAt(index); 256 } else { 257 //int nLocalVariables = (int)code[2]; 258 //// skip branches 259 //int currentBranchIndex = 3 + nLocalVariables; 260 //for(int i = 0; i < index; i++) { 261 // int branchLength = BranchLength(currentBranchIndex); 262 // currentBranchIndex += branchLength; 263 //} 264 //int deletedBranchLength = BranchLength(currentBranchIndex); 265 //code.RemoveRange(currentBranchIndex, deletedBranchLength); 266 //code[0] = code[0] - 1; 267 throw new NotImplementedException(); 268 } 287 // sanity check 288 if(!treesExpanded) throw new InvalidOperationException(); 289 subTrees.RemoveAt(index); 269 290 } 270 291 271 292 private int PC; 272 private double[] codeArr; 293 private int DP; 294 private int[] codeArr; 295 private double[] dataArr; 296 private Dataset dataset; 297 private int sampleIndex; 273 298 public double Evaluate(Dataset dataset, int sampleIndex) { 274 299 PC = 0; 300 DP = 0; 275 301 FlattenVariables(); 276 302 FlattenTrees(); 277 303 if(codeArr == null) { 278 codeArr = new double[code.Count]; 304 codeArr = new int[code.Count]; 305 dataArr = new double[data.Count]; 279 306 code.CopyTo(codeArr); 280 } 281 return EvaluateBakedCode(dataset, sampleIndex); 282 } 283 284 private double EvaluateBakedCode(Dataset dataset, int sampleIndex) { 285 double arity = codeArr[PC++]; 286 double functionSymbol = codeArr[PC++]; 287 double nLocalVariables = codeArr[PC++]; 288 if(functionSymbol == variableSymbol) { 289 int var = (int)codeArr[PC++]; 290 double weight = codeArr[PC++]; 291 int offset = (int)codeArr[PC++]; 292 return weight * dataset.GetValue(sampleIndex+offset, var); 293 } else if(functionSymbol == constantSymbol) { 294 double value = codeArr[PC++]; 295 return value; 296 } else if(functionSymbol == additionSymbol) { 297 double sum = 0.0; 298 for(int i = 0; i < arity; i++) { 299 sum += EvaluateBakedCode(dataset, sampleIndex); 300 } 301 return sum; 302 } else if(functionSymbol == substractionSymbol) { 303 if(arity == 1) { 304 return -EvaluateBakedCode(dataset, sampleIndex); 305 } else { 306 double result = EvaluateBakedCode(dataset, sampleIndex); 307 for(int i = 1; i < arity; i++) { 308 result -= EvaluateBakedCode(dataset, sampleIndex); 309 } 310 return result; 311 } 312 } else if(functionSymbol == multiplicationSymbol) { 313 double result = 1.0; 314 for(int i = 0; i < arity; i++) { 315 result *= EvaluateBakedCode(dataset, sampleIndex); 316 } 317 return result; 318 } else if(functionSymbol == divisionSymbol) { 319 if(arity == 1) { 320 double divisor = EvaluateBakedCode(dataset, sampleIndex); 321 if(divisor == 0) return 0; 322 else return 1.0 / divisor; 323 } else { 324 double result = EvaluateBakedCode(dataset, sampleIndex); 325 for(int i = 1; i < arity; i++) { 326 double divisor = EvaluateBakedCode(dataset, sampleIndex); 327 if(divisor == 0) result = 0; 328 else result /= divisor; 329 } 330 return result; 331 } 332 } else { throw new NotSupportedException(); } 307 data.CopyTo(dataArr); 308 } 309 this.sampleIndex = sampleIndex; 310 this.dataset = dataset; 311 return EvaluateBakedCode(); 312 } 313 314 private double EvaluateBakedCode() { 315 int arity = codeArr[PC++]; 316 int functionSymbol = codeArr[PC++]; 317 int nLocalVariables = codeArr[PC++]; 318 switch(functionSymbol) { 319 case VARIABLE: { 320 int var = (int)dataArr[DP++]; 321 double weight = dataArr[DP++]; 322 int offset = (int)dataArr[DP++]; 323 return weight * dataset.GetValue(sampleIndex + offset, var); 324 } 325 case CONSTANT: { 326 double value = dataArr[DP++]; 327 return value; 328 } 329 case MULTIPLICATION: { 330 double result = 1.0; 331 for(int i = 0; i < arity; i++) { 332 result *= EvaluateBakedCode(); 333 } 334 return result; 335 } 336 case ADDITION: { 337 double sum = 0.0; 338 for(int i = 0; i < arity; i++) { 339 sum += EvaluateBakedCode(); 340 } 341 return sum; 342 } 343 case SUBSTRACTION: { 344 if(arity == 1) { 345 return -EvaluateBakedCode(); 346 } else { 347 double result = EvaluateBakedCode(); 348 for(int i = 1; i < arity; i++) { 349 result -= EvaluateBakedCode(); 350 } 351 return result; 352 } 353 } 354 case DIVISION: { 355 if(arity == 1) { 356 double divisor = EvaluateBakedCode(); 357 if(divisor == 0) return 0; 358 else return 1.0 / divisor; 359 } else { 360 double result = EvaluateBakedCode(); 361 for(int i = 1; i < arity; i++) { 362 double divisor = EvaluateBakedCode(); 363 if(divisor == 0) result = 0; 364 else result /= divisor; 365 } 366 return result; 367 } 368 } 369 case AVERAGE: { 370 double sum = 0.0; 371 for(int i = 0; i < arity; i++) { 372 sum += EvaluateBakedCode(); 373 } 374 return sum / arity; 375 } 376 case COSINUS: { 377 return Math.Cos(EvaluateBakedCode()); 378 } 379 case SINUS: { 380 return Math.Sin(EvaluateBakedCode()); 381 } 382 case EXP: { 383 return Math.Exp(EvaluateBakedCode()); 384 } 385 case LOG: { 386 return Math.Log(EvaluateBakedCode()); 387 } 388 case POWER: { 389 double x = EvaluateBakedCode(); 390 double p = EvaluateBakedCode(); 391 return Math.Pow(x, p); 392 } 393 case SIGNUM: { 394 // protected signum 395 double value = EvaluateBakedCode(); 396 if(value < 0) return -1; 397 if(value > 0) return 1; 398 return 0; 399 } 400 case SQRT: { 401 return Math.Sqrt(EvaluateBakedCode()); 402 } 403 case TANGENS: { 404 return Math.Tan(EvaluateBakedCode()); 405 } 406 case AND: { 407 double result = 1.0; 408 // have to evaluate all sub-trees, skipping would probably not lead to a big gain because 409 // we have to iterate over the linear structure anyway 410 for(int i = 0; i < arity; i++) { 411 double x = Math.Round(EvaluateBakedCode()); 412 if(x == 0) result *= 0; 413 else if(x == 1.0) result *= 1.0; 414 else result *= double.NaN; 415 } 416 return result; 417 } 418 case EQU: { 419 double x = EvaluateBakedCode(); 420 double y = EvaluateBakedCode(); 421 if(x == y) return 1.0; else return 0.0; 422 } 423 case GT: { 424 double x = EvaluateBakedCode(); 425 double y = EvaluateBakedCode(); 426 if(x > y) return 1.0; 427 else return 0.0; 428 } 429 case IFTE: { 430 double condition = Math.Round(EvaluateBakedCode()); 431 double x = EvaluateBakedCode(); 432 double y = EvaluateBakedCode(); 433 if(condition < .5) return x; 434 else if(condition >= .5) return y; 435 else return double.NaN; 436 } 437 case LT: { 438 double x = EvaluateBakedCode(); 439 double y = EvaluateBakedCode(); 440 if(x < y) return 1.0; 441 else return 0.0; 442 } 443 case NOT: { 444 double result = Math.Round(EvaluateBakedCode()); 445 if(result == 0.0) return 1.0; 446 else if(result == 1.0) return 0.0; 447 else return double.NaN; 448 } 449 case OR: { 450 double result = 0.0; // default is false 451 for(int i = 0; i < arity; i++) { 452 double x = Math.Round(EvaluateBakedCode()); 453 if(x == 1.0 && result == 0.0) result = 1.0; // found first true (1.0) => set to true 454 else if(x != 0.0) result = double.NaN; // if it was not true it can only be false (0.0) all other cases are undefined => (NaN) 455 } 456 return result; 457 } 458 case XOR: { 459 double x = Math.Round(EvaluateBakedCode()); 460 double y = Math.Round(EvaluateBakedCode()); 461 if(x == 0.0 && y == 0.0) return 0.0; 462 if(x == 1.0 && y == 0.0) return 1.0; 463 if(x == 0.0 && y == 1.0) return 1.0; 464 if(x == 1.0 && y == 1.0) return 0.0; 465 return double.NaN; 466 } 467 default: { 468 IFunction function = symbolTable[functionSymbol]; 469 double[] args = new double[nLocalVariables + arity]; 470 for(int i = 0; i < nLocalVariables; i++) { 471 args[i] = dataArr[DP++]; 472 } 473 for(int j = 0; j < arity; j++) { 474 args[nLocalVariables + j] = EvaluateBakedCode(); 475 } 476 return function.Apply(dataset, sampleIndex, args); 477 } 478 } 333 479 } 334 480 … … 343 489 public override object Clone(IDictionary<Guid, object> clonedObjects) { 344 490 BakedFunctionTree clone = new BakedFunctionTree(); 345 if(treesExpanded || variablesExpanded) throw new InvalidOperationException(); 346 //if(treesExpanded) FlattenTrees(); 347 //if(variablesExpanded) FlattenVariables(); 348 clone.code = new List<double>(code); 491 if(treesExpanded || variablesExpanded) throw new InvalidOperationException(); // sanity check 492 clone.code.AddRange(code); 493 clone.data.AddRange(data); 349 494 return clone; 495 } 496 497 public override IView CreateView() { 498 return new FunctionTreeView(this); 350 499 } 351 500 }
Note: See TracChangeset
for help on using the changeset viewer.