Changeset 15438 for branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
- Timestamp:
- 10/28/17 19:56:51 (6 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r15437 r15438 26 26 using System.Linq; 27 27 using System.Text; 28 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;29 28 using HeuristicLab.Core; 30 29 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; … … 99 98 internal readonly Tree tree; 100 99 internal readonly Func<byte[], int, double> evalFun; 101 internal readonly IPolicy treePolicy;102 100 // MCTS might get stuck. Track statistics on the number of effective rollouts 103 101 internal int totalRollouts; … … 145 143 public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, 146 144 int constOptIterations, double lambda, 147 IPolicy treePolicy = null,148 145 bool collectParetoOptimalModels = false, 149 146 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, … … 187 184 this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit); 188 185 189 this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); 190 this.treePolicy = treePolicy ?? new EpsilonGreedy(); 186 this.automaton = new Automaton(x, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms, maxVariables); 191 187 this.tree = new Tree() { 192 188 state = automaton.CurrentState, 193 actionStatistics = treePolicy.CreateActionStatistics(),194 189 expr = "", 195 190 level = 0 … … 469 464 public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, 470 465 bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0, 471 IPolicy policy = null,472 466 bool collectParameterOptimalModels = false, 473 467 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, … … 479 473 ) { 480 474 return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda, 481 policy,collectParameterOptimalModels,475 collectParameterOptimalModels, 482 476 lowerEstimationLimit, upperEstimationLimit, 483 477 allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); … … 499 493 var eval = mctsState.evalFun; 500 494 var rand = mctsState.random; 501 var treePolicy = mctsState.treePolicy;502 495 double q = 0; 503 496 bool success = false; … … 505 498 506 499 automaton.Reset(); 507 success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy,mctsState, out q);500 success = TryTreeSearchRec2(rand, tree, automaton, eval, mctsState, out q); 508 501 mctsState.totalRollouts++; 509 502 } while (!success && !tree.Done); … … 517 510 518 511 // search forward 519 private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy, 512 private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, 513 Func<byte[], int, double> eval, 520 514 State state, 521 515 out double q) { … … 545 539 int selectedIdx = 0; 546 540 if (state.children[tree].Count > 1) { 547 selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);541 selectedIdx = SelectInternal(state.children[tree], rand); 548 542 } 549 543 … … 579 573 if (!state.nodes.TryGetValue(hc, out child)) { 580 574 child = new Tree() { 581 children = null,582 575 state = possibleFollowStates[i], 583 actionStatistics = treePolicy.CreateActionStatistics(),584 576 expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), 585 577 level = tree.level + 1 … … 591 583 // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link 592 584 // to all parents 593 BackpropagateStatistics( child.actionStatistics, tree, state);585 BackpropagateStatistics(tree, state, child.visits); 594 586 } else { 595 587 // prevent cycles … … 599 591 } else { 600 592 child = new Tree() { 601 children = null,602 593 state = possibleFollowStates[i], 603 actionStatistics = treePolicy.CreateActionStatistics(),604 594 expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), 605 595 level = tree.level + 1 … … 639 629 automaton.GetCode(out code, out nParams); 640 630 q = eval(code, nParams); 641 // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);642 631 success = true; 643 BackpropagateQuality(tree, q, treePolicy, state);632 BackpropagateQuality(tree, q, state); 644 633 } else { 645 634 // we got stuck in roll-out (not evaluation necessary!) 646 // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");647 635 q = 0.0; 648 636 success = false; … … 659 647 } 660 648 649 private static int SelectInternal(List<Tree> list, IRandom rand) { 650 // choose a random node. 651 Debug.Assert(list.Any(t => !t.Done)); 652 653 var idx = rand.Next(list.Count); 654 while(list[idx].Done) { idx = rand.Next(list.Count); } 655 return idx; 656 } 657 661 658 // backpropagate existing statistics to all parents 662 private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) { 663 tree.actionStatistics.Add(stats); 659 private static void BackpropagateStatistics(Tree tree, State state, int numVisits) { 660 tree.visits += numVisits; 661 664 662 if (state.parents.ContainsKey(tree)) { 665 663 foreach (var parent in state.parents[tree]) { 666 BackpropagateStatistics( stats, parent, state);664 BackpropagateStatistics(parent, state, numVisits); 667 665 } 668 666 } … … 676 674 } 677 675 678 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) { 679 policy.Update(tree.actionStatistics, q); 676 private static void BackpropagateQuality(Tree tree, double q, State state) { 677 tree.visits++; 678 // TODO: q is ignored for now 680 679 681 680 if (state.parents.ContainsKey(tree)) { 682 681 foreach (var parent in state.parents[tree]) { 683 BackpropagateQuality(parent, q, policy,state);682 BackpropagateQuality(parent, q, state); 684 683 } 685 684 } … … 718 717 } 719 718 return children[selectedChildIdx]; 720 } 721 722 // tree search might fail because of constraints for expressions 723 // in this case we get stuck we just restart 724 // see ConstraintHandler.cs for more info 725 private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy, 726 out double q) { 727 Tree selectedChild = null; 728 Contract.Assert(tree.state == automaton.CurrentState); 729 Contract.Assert(!tree.Done); 730 if (tree.children == null) { 731 if (automaton.IsFinalState(tree.state)) { 732 // final state 733 tree.Done = true; 734 735 // EVALUATE 736 byte[] code; int nParams; 737 automaton.GetCode(out code, out nParams); 738 q = eval(code, nParams); 739 740 treePolicy.Update(tree.actionStatistics, q); 741 return true; // we reached a final state 742 } else { 743 // EXPAND 744 int[] possibleFollowStates = new int[1000]; 745 int nFs; 746 automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs); 747 if (nFs == 0) { 748 // stuck in a dead end (no final state and no allowed follow states) 749 q = 0; 750 tree.Done = true; 751 tree.children = null; 752 return false; 753 } 754 tree.children = new Tree[nFs]; 755 for (int i = 0; i < tree.children.Length; i++) 756 tree.children[i] = new Tree() { 757 children = null, 758 state = possibleFollowStates[i], 759 actionStatistics = treePolicy.CreateActionStatistics() 760 }; 761 762 selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0]; 763 } 764 } else { 765 // tree.children != null 766 // UCT selection within tree 767 int selectedIdx = 0; 768 if (tree.children.Length > 1) { 769 selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand); 770 } 771 selectedChild = tree.children[selectedIdx]; 772 } 773 // make selected step and recurse 774 automaton.Goto(selectedChild.state); 775 var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q); 776 if (success) { 777 // only update if successful 778 treePolicy.Update(tree.actionStatistics, q); 779 } 780 781 tree.Done = tree.children.All(ch => ch.Done); 782 if (tree.Done) { 783 tree.children = null; // cut off the sub-branch if it has been fully explored 784 } 785 return success; 786 } 787 788 private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) { 789 // if one of the new children leads to a final state then go there 790 // otherwise choose a random child 791 int selectedChildIdx = -1; 792 // find first final state if there is one 793 for (int i = 0; i < tree.children.Length; i++) { 794 if (automaton.IsFinalState(tree.children[i].state)) { 795 selectedChildIdx = i; 796 break; 797 } 798 } 799 // no final state -> select a the first child 800 if (selectedChildIdx == -1) { 801 selectedChildIdx = 0; 802 } 803 return tree.children[selectedChildIdx]; 804 } 719 } 805 720 806 721 // scales data and extracts values from dataset into arrays … … 869 784 automaton.GetCode(out code, out nParams); 870 785 return Disassembler.CodeToString(code); 871 }872 873 874 private static string WriteStatistics(Tree tree, State state) {875 var sb = new System.IO.StringWriter();876 sb.Write("{0}\t{1:N5}\t", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);877 if (state.children.ContainsKey(tree)) {878 foreach (var ch in state.children[tree]) {879 sb.Write("{0}\t{1:N5}\t", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);880 }881 }882 sb.WriteLine();883 return sb.ToString();884 786 } 885 787 … … 899 801 900 802 private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) { 901 var avgNodeQ = tree.actionStatistics.AverageQuality; 902 var tries = tree.actionStatistics.Tries; 903 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 904 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 905 hue = 0.0; 906 907 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine(); 803 var tries = tree.visits; 804 805 sb.AppendFormat("{0} [label=\"{1}\"]; ", parentId, tries).AppendLine(); 908 806 909 807 var list = new List<Tuple<int, int, Tree>>(); … … 911 809 foreach (var ch in state.children[tree]) { 912 810 nextId++; 913 avgNodeQ = ch.actionStatistics.AverageQuality; 914 tries = ch.actionStatistics.Tries; 915 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 916 hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 917 hue = 0.0; 918 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 919 sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine(); 811 tries = ch.visits; 812 sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine(); 813 sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", parentId, nextId, ch.expr).AppendLine(); 920 814 list.Add(Tuple.Create(tries, nextId, ch)); 921 815 } … … 927 821 var chch = state.children[ch].First(); 928 822 nextId++; 929 avgNodeQ = chch.actionStatistics.AverageQuality; 930 tries = chch.actionStatistics.Tries; 931 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 932 hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 933 hue = 0.0; 934 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 935 sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine(); 823 tries = chch.visits; 824 sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine(); 825 sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", chId, nextId, chch.expr).AppendLine(); 936 826 } 937 827 } … … 957 847 if (!nodeIds.TryGetValue(parent, out parentId)) { 958 848 parentId = nodeIds.Count + 1; 959 var avgNodeQ = parent.actionStatistics.AverageQuality; 960 var tries = parent.actionStatistics.Tries; 961 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 962 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 963 hue = 0.0; 964 if (parent.actionStatistics.Tries > threshold) 965 sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue); 849 var tries = parent.visits; 850 if (tries > threshold) 851 sb.Write("{0} [label=\"{1}\"]; ", parentId, tries); 966 852 nodeIds.Add(parent, parentId); 967 853 } … … 972 858 nodeIds.Add(child, childId); 973 859 } 974 var avgNodeQ = child.actionStatistics.AverageQuality; 975 var tries = child.actionStatistics.Tries; 860 var tries = child.visits; 976 861 if (tries < 1) continue; 977 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;978 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue979 hue = 0.0;980 862 if (tries > threshold) { 981 sb.Write("{0} [label=\"{1 :E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);863 sb.Write("{0} [label=\"{1}\"]; ", childId, tries); 982 864 var edgeLabel = child.expr; 983 865 // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, ""); 984 sb.Write("{0} -> {1} [label=\"{ 3}\"]", parentId, childId, avgNodeQ, edgeLabel);866 sb.Write("{0} -> {1} [label=\"{2}\"]", parentId, childId, edgeLabel); 985 867 } 986 868 }
Note: See TracChangeset
for help on using the changeset viewer.