Changeset 15416 for branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
- Timestamp:
- 10/11/17 08:15:19 (5 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r15414 r15416 61 61 // TODO: Solve Poly-10 62 62 // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved? 63 // TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)? 64 // TODO: check if we can use a quality measure with range [-1..1] in policies 63 65 // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants 64 66 // TODO: check if transformation of y is correct and works (Obj 2) … … 117 119 private readonly ExpressionEvaluator evaluator, testEvaluator; 118 120 121 internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>(); 122 internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>(); 123 internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>(); 124 119 125 // values for best solution 120 private double bestR Sq;126 private double bestR; 121 127 private byte[] bestCode; 122 128 private int bestNParams; … … 187 193 188 194 // reset best solution 189 this.bestR Sq= 0;195 this.bestR = 0; 190 196 // code for default solution (constant model) 191 197 this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit }; … … 208 214 get { 209 215 evaluator.Exec(bestCode, x, bestConsts, predBuf); 210 return R Sq(y, predBuf);216 return Rho(y, predBuf); 211 217 } 212 218 } … … 215 221 get { 216 222 testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf); 217 return R Sq(testY, testPredBuf);223 return Rho(testY, testPredBuf); 218 224 } 219 225 } … … 248 254 249 255 // single objective best 250 if (q > bestR Sq) {251 bestR Sq= q;256 if (q > bestR) { 257 bestR = q; 252 258 bestNParams = nParams; 253 259 this.bestCode = new byte[code.Length]; … … 266 272 } 267 273 268 private void Eval(byte[] code, int nParams, out double r sq, out double[] optConsts) {274 private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) { 269 275 // we make a first pass to determine a valid starting configuration for all constants 270 276 // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator) … … 280 286 // if we don't need to optimize parameters then we are done 281 287 // changing scale and offset does not influence r² 282 r sq = RSq(y, predBuf);288 rho = Rho(y, predBuf); 283 289 optConsts = constsBuf; 284 290 } else { … … 289 295 funcEvaluations++; 290 296 291 r sq = RSq(y, predBuf);297 rho = Rho(y, predBuf); 292 298 optConsts = constsBuf; 293 299 } … … 297 303 298 304 #region helpers 299 private static double R Sq(IEnumerable<double> x, IEnumerable<double> y) {305 private static double Rho(IEnumerable<double> x, IEnumerable<double> y) { 300 306 OnlineCalculatorError error; 301 307 double r = OnlinePearsonsRCalculator.Calculate(x, y, out error); 302 return error == OnlineCalculatorError.None ? r * r: 0.0;308 return error == OnlineCalculatorError.None ? r : 0.0; 303 309 } 304 310 … … 491 497 do { 492 498 automaton.Reset(); 493 success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, out q);499 success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q); 494 500 mctsState.totalRollouts++; 495 501 } while (!success && !tree.Done); 496 502 mctsState.effectiveRollouts++; 497 503 498 if (mctsState.effectiveRollouts % 10 == 1) {499 // Console.WriteLine(WriteTree(tree));500 // Console.WriteLine(TraceTree(tree));501 }504 //if (mctsState.effectiveRollouts % 100 == 1) { 505 // Console.WriteLine(WriteTree(tree, mctsState)); 506 // Console.WriteLine(TraceTree(tree, mctsState)); 507 //} 502 508 return q; 503 509 } 504 505 private static Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();506 private static Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();507 private static Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();508 509 510 510 511 511 512 // search forward 512 513 private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy, 514 State state, 513 515 out double q) { 514 516 // ROLLOUT AND EXPANSION … … 527 529 528 530 while (!automaton.IsFinalState(automaton.CurrentState)) { 529 if ( children.ContainsKey(tree)) {530 if ( children[tree].All(ch => ch.Done)) {531 if (state.children.ContainsKey(tree)) { 532 if (state.children[tree].All(ch => ch.Done)) { 531 533 tree.Done = true; 532 534 break; … … 535 537 // UCT selection within tree 536 538 int selectedIdx = 0; 537 if ( children[tree].Count > 1) {538 selectedIdx = treePolicy.Select( children[tree].Select(ch => ch.actionStatistics), rand);539 } 540 tree = children[tree][selectedIdx];539 if (state.children[tree].Count > 1) { 540 selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand); 541 } 542 tree = state.children[tree][selectedIdx]; 541 543 542 544 // move the automaton forward until reaching the state … … 556 558 int[] possibleFollowStates; 557 559 int nFs; 560 string actionString = ""; 558 561 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 559 562 while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 563 actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]); 560 564 // no alternatives -> just go to the next state 561 565 automaton.Goto(possibleFollowStates[0]); … … 568 572 } 569 573 var newChildren = new List<Tree>(nFs); 570 children.Add(tree, newChildren);574 state.children.Add(tree, newChildren); 571 575 for (int i = 0; i < nFs; i++) { 572 576 Tree child = null; … … 574 578 if (automaton.IsEvalState(possibleFollowStates[i])) { 575 579 var hc = Hashcode(automaton); 576 if (! nodes.TryGetValue(hc, out child)) {580 if (!state.nodes.TryGetValue(hc, out child)) { 577 581 child = new Tree() { 578 582 children = null, 579 583 state = possibleFollowStates[i], 580 584 actionStatistics = treePolicy.CreateActionStatistics(), 581 expr = string.Empty, // ExprStr(automaton),585 expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), 582 586 level = tree.level + 1 583 587 }; 584 nodes.Add(hc, child);588 state.nodes.Add(hc, child); 585 589 } 586 590 // only allow forward edges (don't add the child if we would go back in the graph) 587 else if (child.level > tree.level) {591 else if (child.level > tree.level) { 588 592 // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link 589 593 // to all parents 590 BackpropagateStatistics(child.actionStatistics, tree );594 BackpropagateStatistics(child.actionStatistics, tree, state); 591 595 } else { 592 596 // prevent cycles 593 597 Debug.Assert(child.level <= tree.level); 594 598 child = null; 595 } 599 } 596 600 } else { 597 601 child = new Tree() { … … 599 603 state = possibleFollowStates[i], 600 604 actionStatistics = treePolicy.CreateActionStatistics(), 601 expr = string.Empty, // ExprStr(automaton),605 expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), 602 606 level = tree.level + 1 603 607 }; … … 614 618 615 619 foreach (var ch in newChildren) { 616 if (! parents.ContainsKey(ch)) {617 parents.Add(ch, new List<Tree>());620 if (!state.parents.ContainsKey(ch)) { 621 state.parents.Add(ch, new List<Tree>()); 618 622 } 619 parents[ch].Add(tree);620 } 621 622 623 // follow one of the children 624 tree = Select FinalOrRandom2(automaton, tree, rand);623 state.parents[ch].Add(tree); 624 } 625 626 627 // follow one of the children 628 tree = SelectStateLeadingToFinal(automaton, tree, rand, state); 625 629 automaton.Goto(tree.state); 626 630 } … … 636 640 automaton.GetCode(out code, out nParams); 637 641 q = eval(code, nParams); 642 // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr); 638 643 q = TransformQuality(q); 639 644 success = true; 640 645 } else { 641 646 // we got stuck in roll-out (not evaluation necessary!) 647 // Console.WriteLine("\t" + ExprStr(automaton) + " STOP"); 642 648 q = 0.0; 643 649 success = false; … … 647 653 // Update statistics 648 654 // Set branch to done if all children are done. 649 BackpropagateQuality(tree, q, treePolicy );655 BackpropagateQuality(tree, q, treePolicy, state); 650 656 651 657 return success; … … 655 661 private static double TransformQuality(double q) { 656 662 // no transformation 657 return q;663 // return q; 658 664 659 665 // EXPERIMENTAL! 666 667 // Fisher transformation 668 // (assumes q is Correl(pred, target) 669 670 q = Math.Min(q, 0.99999999); 671 q = Math.Max(q, -0.99999999); 672 return 0.5 * Math.Log((1 + q) / (1 - q)); 673 660 674 // optimal result: q = 1 -> return huge value 661 675 // if (q >= 1.0) return 1E16; … … 665 679 666 680 // backpropagate existing statistics to all parents 667 private static void BackpropagateStatistics(IActionStatistics stats, Tree tree ) {681 private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) { 668 682 tree.actionStatistics.Add(stats); 669 if ( parents.ContainsKey(tree)) {670 foreach (var parent in parents[tree]) {671 BackpropagateStatistics(stats, parent );683 if (state.parents.ContainsKey(tree)) { 684 foreach (var parent in state.parents[tree]) { 685 BackpropagateStatistics(stats, parent, state); 672 686 } 673 687 } … … 681 695 } 682 696 683 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy ) {697 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) { 684 698 if (q > 0) policy.Update(tree.actionStatistics, q); 685 if ( children.ContainsKey(tree) &&children[tree].All(ch => ch.Done)) {699 if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) { 686 700 tree.Done = true; 687 701 // children[tree] = null; keep all nodes 688 702 } 689 703 690 if (parents.ContainsKey(tree)) { 691 foreach (var parent in parents[tree]) { 692 BackpropagateQuality(parent, q, policy); 693 } 694 } 695 } 696 697 private static Tree SelectFinalOrRandom2(Automaton automaton, Tree tree, IRandom rand) { 698 // if one of the new children leads to a final state then go there 699 // otherwise choose a random child 700 int selectedChildIdx = -1; 701 // find first final state if there is one 702 var children = MctsSymbolicRegressionStatic.children[tree]; 703 for (int i = 0; i < children.Count; i++) { 704 if (automaton.IsFinalState(children[i].state)) { 704 if (state.parents.ContainsKey(tree)) { 705 foreach (var parent in state.parents[tree]) { 706 BackpropagateQuality(parent, q, policy, state); 707 } 708 } 709 } 710 711 private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) { 712 // find the child with the smallest state value (smaller values are closer to the final state) 713 int selectedChildIdx = 0; 714 var children = state.children[tree]; 715 Tree minChild = children.First(); 716 for (int i = 1; i < children.Count; i++) { 717 if(children[i].state < minChild.state) 705 718 selectedChildIdx = i; 706 break;707 }708 }709 // no final state -> select the first child710 if (selectedChildIdx == -1) {711 selectedChildIdx = 0;712 719 } 713 720 return children[selectedChildIdx]; … … 865 872 } 866 873 867 private static string WriteStatistics(Tree tree ) {874 private static string WriteStatistics(Tree tree, State state) { 868 875 var sb = new System.IO.StringWriter(); 869 876 sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality); 870 if ( children.ContainsKey(tree)) {871 foreach (var ch in children[tree]) {877 if (state.children.ContainsKey(tree)) { 878 foreach (var ch in state.children[tree]) { 872 879 sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality); 873 880 } … … 876 883 } 877 884 878 private static string TraceTree(Tree tree ) {885 private static string TraceTree(Tree tree, State state) { 879 886 var sb = new StringBuilder(); 880 887 sb.Append( … … 885 892 int nodeId = 0; 886 893 887 TraceTreeRec(tree, 0, sb, ref nodeId );894 TraceTreeRec(tree, 0, sb, ref nodeId, state); 888 895 sb.Append("}"); 889 896 return sb.ToString(); 890 897 } 891 898 892 private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId ) {899 private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) { 893 900 var avgNodeQ = tree.actionStatistics.AverageQuality; 894 901 var tries = tree.actionStatistics.Tries; 895 902 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 896 903 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 897 898 sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine(); 904 hue = 0.0; 905 906 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine(); 899 907 900 908 var list = new List<Tuple<int, int, Tree>>(); 901 if ( children.ContainsKey(tree)) {902 foreach (var ch in children[tree]) {909 if (state.children.ContainsKey(tree)) { 910 foreach (var ch in state.children[tree]) { 903 911 nextId++; 904 912 avgNodeQ = ch.actionStatistics.AverageQuality; … … 906 914 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 907 915 hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 908 sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 909 sb.AppendFormat("{0} -> {1}", parentId, nextId, avgNodeQ).AppendLine(); 916 hue = 0.0; 917 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 918 sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine(); 910 919 list.Add(Tuple.Create(tries, nextId, ch)); 911 920 } 921 922 foreach(var tup in list) { 923 var ch = tup.Item3; 924 var chId = tup.Item2; 925 if(state.children.ContainsKey(ch) && state.children[ch].Count == 1) { 926 var chch = state.children[ch].First(); 927 nextId++; 928 avgNodeQ = chch.actionStatistics.AverageQuality; 929 tries = chch.actionStatistics.Tries; 930 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 931 hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 932 hue = 0.0; 933 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 934 sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine(); 935 } 936 } 937 912 938 foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) { 913 TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId );914 } 915 } 916 } 917 918 private static string WriteTree(Tree tree ) {939 TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state); 940 } 941 } 942 } 943 944 private static string WriteTree(Tree tree, State state) { 919 945 var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture); 920 946 var nodeIds = new Dictionary<Tree, int>(); … … 924 950 node [style=filled]; 925 951 "); 926 int threshold = nodes.Count > 500 ? 10 :0;927 foreach (var kvp in children) {952 int threshold = /* state.nodes.Count > 500 ? 10 : */ 0; 953 foreach (var kvp in state.children) { 928 954 var parent = kvp.Key; 929 955 int parentId; … … 934 960 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 935 961 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 962 hue = 0.0; 936 963 if (parent.actionStatistics.Tries > threshold) 937 sb.Write("{0} [label=\"{1: N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);964 sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue); 938 965 nodeIds.Add(parent, parentId); 939 966 } … … 949 976 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 950 977 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 978 hue = 0.0; 951 979 if (tries > threshold) { 952 sb.Write("{0} [label=\"{1: N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);980 sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue); 953 981 var edgeLabel = child.expr; 954 982 // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
Note: See TracChangeset
for help on using the changeset viewer.