Changeset 15410 for branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
- Timestamp:
- 10/06/17 17:52:36 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r15404 r15410 22 22 using System; 23 23 using System.Collections.Generic; 24 using System.Diagnostics; 24 25 using System.Diagnostics.Contracts; 25 26 using System.Linq; … … 54 55 // 55 56 57 // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should 58 // weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a 59 // branch and less in the expected value. (--> Review "Extreme Bandit" literature again) 56 60 // TODO: Constraint handling is too restrictive! E.g. for Poly-10, if MCTS identifies the term x3*x4 first it is 57 61 // not possible to add the term x1*x2 later on. The same is true for individual terms after x2 it is not 58 62 // possible to multiply x1. It is easy to get stuck. Why do we actually need the current way of constraint handling? 59 63 // It would probably be easier to use some kind of hashing to identify equivalent expressions in the tree. 64 // TODO: State unification (using hashing) is partially done. The hashcode calculation should be improved to also detect that 65 // c*x1 + c*x1*x1 + c*x1 is the same as c*x1 + c*x1*x1 66 // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved? 60 67 // TODO: check if transformation of y is correct and works (Obj 2) 61 68 // TODO: The algorithm is not invariant to location and scale of variables. … … 172 179 this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit); 173 180 174 this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);181 this.automaton = new Automaton(x, new SimpleConstraintHandler(100), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); 175 182 this.treePolicy = treePolicy ?? new Ucb(); 176 this.tree = new Tree() { state = automaton.CurrentState, actionStatistics = treePolicy.CreateActionStatistics() }; 183 this.tree = new Tree() { 184 state = automaton.CurrentState, 185 actionStatistics = treePolicy.CreateActionStatistics(), 186 expr = "" 187 }; 177 188 178 189 // reset best solution … … 481 492 do { 482 493 automaton.Reset(); 483 success = TryTreeSearchRec (rand, tree, automaton, eval, treePolicy, out q);494 success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, out q); 484 495 mctsState.totalRollouts++; 485 496 } while (!success && !tree.Done); 486 497 mctsState.effectiveRollouts++; 498 499 if (mctsState.effectiveRollouts % 10 == 1) Console.WriteLine(WriteTree(tree)); 487 500 return q; 501 } 502 503 private static Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>(); 504 private static Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>(); 505 private static Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>(); 506 507 508 509 // search forward 510 private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy, 511 out double q) { 512 // ROLLOUT AND EXPANSION 513 // We are navigating a graph (states might be reached via different paths) instead of a tree. 514 // State equivalence is checked through ExprHash (based on the generated code through the path). 515 516 // We switch between rollout-mode and expansion mode 517 // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB) 518 // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression) 519 // In expansion mode we might re-enter the graph and switch back to rollout-mode 520 // We do this until we reach a complete expression (final state) 521 522 // Loops in the graph are possible! (Problem?) 523 // Sub-graphs which have been completely searched are marked as done. 524 // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective. 525 526 while (!automaton.IsFinalState(automaton.CurrentState)) { 527 if (children.ContainsKey(tree)) { 528 // ROLLOUT INSIDE TREE 529 // UCT selection within tree 530 int selectedIdx = 0; 531 if (children[tree].Count > 1) { 532 selectedIdx = treePolicy.Select(children[tree].Select(ch => ch.actionStatistics), rand); 533 } 534 tree = children[tree][selectedIdx]; 535 536 // move the automaton forward until reaching the state 537 // all steps where no alternatives are possible are immediately taken 538 // TODO: simplification of the automaton 539 int[] possibleFollowStates; 540 int nFs; 541 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 542 while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0])) { 543 automaton.Goto(possibleFollowStates[0]); 544 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 545 } 546 Debug.Assert(possibleFollowStates.Contains(tree.state)); 547 automaton.Goto(tree.state); 548 } else { 549 // EXPAND 550 int[] possibleFollowStates; 551 int nFs; 552 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 553 while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0])) { 554 // no alternatives -> just go to the next state 555 automaton.Goto(possibleFollowStates[0]); 556 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 557 } 558 if (nFs == 0) { 559 // stuck in a dead end (no final state and no allowed follow states) 560 tree.Done = true; 561 break; 562 } 563 var newChildren = new List<Tree>(nFs); 564 children.Add(tree, newChildren); 565 for (int i = 0; i < nFs; i++) { 566 Tree child = null; 567 // for selected states we introduce state unification (detection of equivalent states) 568 if (automaton.IsEvalState(possibleFollowStates[i])) { 569 var hc = Hashcode(automaton); 570 if (!nodes.TryGetValue(hc, out child)) { 571 child = new Tree() { 572 children = null, 573 state = possibleFollowStates[i], 574 actionStatistics = treePolicy.CreateActionStatistics(), 575 expr = ExprStr(automaton) 576 }; 577 nodes.Add(hc, child); 578 } else { 579 // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link 580 // to all parents 581 BackpropagateStatistics(child.actionStatistics, tree); 582 } 583 } else { 584 child = new Tree() { 585 children = null, 586 state = possibleFollowStates[i], 587 actionStatistics = treePolicy.CreateActionStatistics(), 588 expr = ExprStr(automaton) 589 }; 590 } 591 newChildren.Add(child); 592 } 593 594 foreach (var ch in newChildren) { 595 if (!parents.ContainsKey(ch)) { 596 parents.Add(ch, new List<Tree>()); 597 } 598 parents[ch].Add(tree); 599 } 600 601 // follow one of the children 602 tree = SelectFinalOrRandom2(automaton, tree, rand); 603 automaton.Goto(tree.state); 604 } 605 } 606 607 bool success; 608 609 // EVALUATE TREE 610 if (automaton.IsFinalState(automaton.CurrentState)) { 611 tree.Done = true; 612 byte[] code; int nParams; 613 automaton.GetCode(out code, out nParams); 614 q = eval(code, nParams); 615 q = TransformQuality(q); 616 success = true; 617 } else { 618 // we got stuck in roll-out (not evaluation necessary!) 619 q = 0.0; 620 success = false; 621 } 622 623 // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS 624 // Update statistics 625 // Set branch to done if all children are done. 626 BackpropagateQuality(tree, q, treePolicy); 627 628 return success; 629 } 630 631 632 private static double TransformQuality(double q) { 633 // no transformation 634 return q; 635 636 // EXPERIMENTAL! 637 // optimal result: q = 1 -> return huge value 638 if (q >= 1.0) return 1E16; 639 // return number of 9s in R² 640 return -Math.Log10(1 - q); 641 } 642 643 // backpropagate existing statistics to all parents 644 private static void BackpropagateStatistics(IActionStatistics stats, Tree tree) { 645 tree.actionStatistics.Add(stats); 646 if (parents.ContainsKey(tree)) { 647 foreach (var parent in parents[tree]) { 648 BackpropagateStatistics(stats, parent); 649 } 650 } 651 } 652 653 private static ulong Hashcode(Automaton automaton) { 654 byte[] code; 655 int nParams; 656 automaton.GetCode(out code, out nParams); 657 return ExprHash.GetHash(code, nParams); 658 } 659 660 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy) { 661 if (q > 0) policy.Update(tree.actionStatistics, q); 662 if (children.ContainsKey(tree) && children[tree].All(ch => ch.Done)) { 663 tree.Done = true; 664 // children[tree] = null; keep all nodes 665 } 666 667 if (parents.ContainsKey(tree)) { 668 foreach (var parent in parents[tree]) { 669 BackpropagateQuality(parent, q, policy); 670 } 671 } 672 } 673 674 private static Tree SelectFinalOrRandom2(Automaton automaton, Tree tree, IRandom rand) { 675 // if one of the new children leads to a final state then go there 676 // otherwise choose a random child 677 int selectedChildIdx = -1; 678 // find first final state if there is one 679 var children = MctsSymbolicRegressionStatic.children[tree]; 680 for (int i = 0; i < children.Count; i++) { 681 if (automaton.IsFinalState(children[i].state)) { 682 selectedChildIdx = i; 683 break; 684 } 685 } 686 // no final state -> select the first child 687 if (selectedChildIdx == -1) { 688 selectedChildIdx = 0; 689 } 690 return children[selectedChildIdx]; 488 691 } 489 692 … … 522 725 tree.children = new Tree[nFs]; 523 726 for (int i = 0; i < tree.children.Length; i++) 524 tree.children[i] = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics() }; 727 tree.children[i] = new Tree() { 728 children = null, 729 state = possibleFollowStates[i], 730 actionStatistics = treePolicy.CreateActionStatistics() 731 }; 525 732 526 733 selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0]; … … 624 831 } 625 832 } 833 834 // for debugging only 835 836 837 private static string ExprStr(Automaton automaton) { 838 byte[] code; 839 int nParams; 840 automaton.GetCode(out code, out nParams); 841 return Disassembler.CodeToString(code); 842 } 843 844 private static string WriteStatistics(Tree tree) { 845 var sb = new System.IO.StringWriter(); 846 sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality); 847 if (children.ContainsKey(tree)) { 848 foreach (var ch in children[tree]) { 849 sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality); 850 } 851 } 852 return sb.ToString(); 853 } 854 private static string WriteTree(Tree tree) { 855 var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture); 856 var nodeIds = new Dictionary<Tree, int>(); 857 sb.Write( 858 @"digraph { 859 ratio = fill; 860 node [style=filled]; 861 "); 862 foreach(var kvp in children) { 863 var parent = kvp.Key; 864 int parentId; 865 if(!nodeIds.TryGetValue(parent, out parentId)) { 866 parentId = nodeIds.Count + 1; 867 var avgNodeQ = parent.actionStatistics.AverageQuality; 868 var tries = parent.actionStatistics.Tries; 869 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 870 var hue = (1 - avgNodeQ) / 255.0 * 240.0; // 0 equals red, 240 equals blue 871 sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue); 872 nodeIds.Add(parent, parentId); 873 } 874 foreach(var child in kvp.Value) { 875 int childId; 876 if(!nodeIds.TryGetValue(child, out childId)) { 877 childId = nodeIds.Count + 1; 878 nodeIds.Add(child, childId); 879 } 880 var avgNodeQ = child.actionStatistics.AverageQuality; 881 var tries = child.actionStatistics.Tries; 882 if (tries < 1) continue; 883 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 884 var hue = (1 - avgNodeQ) / 255.0 * 240.0; // 0 equals red, 240 equals blue 885 sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue); 886 var edgeLabel = child.expr; 887 if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, ""); 888 sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel); 889 } 890 } 891 892 sb.Write("}"); 893 return sb.ToString(); 894 } 626 895 } 627 896 }
Note: See TracChangeset
for help on using the changeset viewer.