- Timestamp:
- 10/20/17 12:38:06 (7 years ago)
- Location:
- branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4
- Files:
-
- 2 added
- 6 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg.csproj
r15416 r15425 98 98 </ItemGroup> 99 99 <ItemGroup> 100 <Compile Include="Heuristics.cs" /> 100 101 <Compile Include="MctsSymbolicRegression\ApproximateDoubleEqualityComparer.cs" /> 101 102 <Compile Include="MctsSymbolicRegression\IConstraintHandler.cs" /> -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r15420 r15425 56 56 // 57 57 58 // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should 59 // weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a 60 // branch and less in the expected value. (--> Review "Extreme Bandit" literature again) 58 // TODO: The samples of x1*... or x2*... do not give any information about the relevance of the interaction term x1*x2 in general! 59 // --> E.g. if x1, x2 ~ N(0, 1) or U(-1, 1) this is trivial to show 60 // --> Therefore, looking at rollout statistics for arm selection is useless in the general case! 61 // --> It is necessary to rely on other features for the arm selection. 62 // --> TODO: Which heuristics can we apply? 61 63 // TODO: Solve Poly-10 62 64 // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved? … … 72 74 // TODO: improve memory usage 73 75 // TODO: support empty test partition 76 // TODO: the algorithm should be invariant to linear transformations of the space (y = f(x') = f( Ax ) ) for invertible transformations A --> unit tests 74 77 #region static API 75 78 … … 143 146 // calculate for each level the number of alternatives the average 'inequality' of tries and 'inequality' of quality over the alternatives for each trie 144 147 // inequality can be calculated using the Gini coefficient 145 internal readonly double[] giniCoeffs = new double[100]; 146 148 internal readonly double[] pathGiniCoeffs = new double[100]; 149 internal readonly double[] pathQs = new double[100]; 150 internal readonly double[] levelBestQ = new double[100]; 151 // internal readonly double[] levelMaxTries = new double[100]; 152 internal readonly double[] pathBestQ = new double[100]; // as long as pathBestQs = levelBestQs we are following the correct path 153 internal readonly string[] levelBestAction = new string[100]; 154 internal readonly string[] curAction = new string[100]; 155 internal readonly double[] pathSelectedQ = new double[100]; 147 156 148 157 public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, … … 448 457 #if DEBUG 449 458 internal void ClearStats() { 450 for (int i = 0; i < giniCoeffs.Length; i++) giniCoeffs[i] = -1; 451 } 452 internal void WriteStats() { 453 Console.WriteLine(string.Join("\t", giniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x)))); 454 } 459 for (int i = 0; i < pathGiniCoeffs.Length; i++) pathGiniCoeffs[i] = -1; 460 for (int i = 0; i < pathQs.Length; i++) pathGiniCoeffs[i] = -99; 461 for (int i = 0; i < pathBestQ.Length; i++) pathBestQ[i] = -99; 462 for (int i = 0; i < pathSelectedQ.Length; i++) pathSelectedQ[i] = -99; 463 } 464 internal void WriteGiniStats() { 465 Console.WriteLine(string.Join("\t", pathGiniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x)))); 466 } 467 internal void WriteQs() { 468 // Console.WriteLine(string.Join("\t", pathQs.TakeWhile(x => x >= -100).Select(x => string.Format("{0:N3}", x)))); 469 var sb = new StringBuilder(); 470 // length 471 int i = 0; 472 while (i < pathBestQ.Length && pathBestQ[i] > -99 && pathBestQ[i] == levelBestQ[i]) { 473 i++; 474 } 475 sb.AppendFormat("{0,-3}",i); 476 477 i = 0; 478 // sb.AppendFormat("{0:N3}", levelBestQ[0]); 479 while (i < pathSelectedQ.Length && pathSelectedQ[i] > -99) { 480 sb.AppendFormat("\t{0:N3}", pathSelectedQ[i]); 481 i++; 482 } 483 Console.WriteLine(sb.ToString()); 484 sb.Clear(); 485 i = 0; 486 // sb.AppendFormat("{0:N3}", levelBestQ[0]); 487 while (i < pathBestQ.Length && pathBestQ[i] > -99) { 488 sb.AppendFormat("\t{0:N3}", pathBestQ[i]); 489 i++; 490 } 491 Console.WriteLine(sb.ToString()); 492 sb.Clear(); 493 i = 0; 494 while (i < pathBestQ.Length && pathBestQ[i] > -99) { 495 sb.AppendFormat("\t{0:N3}", levelBestQ[i]); 496 i++; 497 } 498 Console.WriteLine(sb.ToString()); 499 500 sb.Clear(); 501 i = 0; 502 while (i < pathBestQ.Length && pathBestQ[i] > -99) { 503 sb.AppendFormat("\t{0,-5}", (curAction[i] != null && curAction[i].Length > 5) ? curAction[i].Substring(0, 5) : curAction[i]); 504 i++; 505 } 506 Console.WriteLine(sb.ToString()); 507 sb.Clear(); 508 i = 0; 509 while (i < pathBestQ.Length && pathBestQ[i] > -99) { 510 sb.AppendFormat("\t{0,-5}", (levelBestAction[i] != null && levelBestAction[i].Length > 5) ? levelBestAction[i].Substring(0, 5) : levelBestAction[i]); 511 i++; 512 } 513 Console.WriteLine(sb.ToString()); 514 515 Console.WriteLine(); 516 } 517 455 518 456 519 #endif … … 525 588 526 589 #if DEBUG 527 mctsState.WriteStats(); 590 // mctsState.WriteGiniStats(); 591 Console.WriteLine(ExprStr(automaton)); 592 mctsState.WriteQs(); 593 // Console.WriteLine(WriteStatistics(tree, mctsState)); 594 528 595 #endif 529 596 //if (mctsState.effectiveRollouts % 100 == 1) { … … 564 631 selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand); 565 632 } 566 633 567 634 // STATS 568 state.giniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality)); 635 state.pathGiniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality)); 636 state.pathQs[tree.level] = tree.actionStatistics.AverageQuality; 569 637 570 638 tree = state.children[tree][selectedIdx]; … … 576 644 int nFs; 577 645 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 578 while (automaton.CurrentState != tree.state && nFs == 1 && 579 !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 580 automaton.Goto(possibleFollowStates[0]); 581 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 582 } 646 // TODO! 647 // while (possibleFollowStates[0] != tree.state && nFs == 1 && 648 // !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 649 // automaton.Goto(possibleFollowStates[0]); 650 // automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 651 // } 583 652 Debug.Assert(possibleFollowStates.Contains(tree.state)); 584 653 automaton.Goto(tree.state); … … 589 658 string actionString = ""; 590 659 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 591 while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 592 actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]); 593 // no alternatives -> just go to the next state 594 automaton.Goto(possibleFollowStates[0]); 595 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 596 } 660 // TODO 661 // while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 662 // actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]); 663 // // no alternatives -> just go to the next state 664 // automaton.Goto(possibleFollowStates[0]); 665 // automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 666 // } 597 667 if (nFs == 0) { 598 668 // stuck in a dead end (no final state and no allowed follow states) … … 672 742 q = TransformQuality(q); 673 743 success = true; 744 BackpropagateQuality(tree, q, treePolicy, state); 674 745 } else { 675 746 // we got stuck in roll-out (not evaluation necessary!) … … 682 753 // Update statistics 683 754 // Set branch to done if all children are done. 684 BackpropagateQuality(tree, q, treePolicy, state); 755 BackpropagateDone(tree, state); 756 BackpropagateDebugStats(tree, q, state); 757 685 758 686 759 return success; … … 703 776 private static double TransformQuality(double q) { 704 777 // no transformation 705 //return q;778 return q; 706 779 707 780 // EXPERIMENTAL! … … 738 811 739 812 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) { 740 if (q > 0) policy.Update(tree.actionStatistics, q); 813 policy.Update(tree.actionStatistics, q); 814 815 if (state.parents.ContainsKey(tree)) { 816 foreach (var parent in state.parents[tree]) { 817 BackpropagateQuality(parent, q, policy, state); 818 } 819 } 820 } 821 822 private static void BackpropagateDone(Tree tree, State state) { 741 823 if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) { 742 824 tree.Done = true; … … 746 828 if (state.parents.ContainsKey(tree)) { 747 829 foreach (var parent in state.parents[tree]) { 748 BackpropagateQuality(parent, q, policy, state); 749 } 830 BackpropagateDone(parent, state); 831 } 832 } 833 } 834 835 private static void BackpropagateDebugStats(Tree tree, double q, State state) { 836 if (state.parents.ContainsKey(tree)) { 837 foreach (var parent in state.parents[tree]) { 838 BackpropagateDebugStats(parent, q, state); 839 } 840 } 841 842 state.pathSelectedQ[tree.level] = tree.actionStatistics.AverageQuality; 843 state.pathBestQ[tree.level] = tree.actionStatistics.BestQuality; 844 state.curAction[tree.level] = tree.expr; 845 if (state.levelBestQ[tree.level] < tree.actionStatistics.BestQuality) { 846 state.levelBestQ[tree.level] = tree.actionStatistics.BestQuality; 847 state.levelBestAction[tree.level] = tree.expr; 750 848 } 751 849 } … … 917 1015 private static string WriteStatistics(Tree tree, State state) { 918 1016 var sb = new System.IO.StringWriter(); 919 sb.Write Line("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);1017 sb.Write("{0}\t{1:N5}\t", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality); 920 1018 if (state.children.ContainsKey(tree)) { 921 1019 foreach (var ch in state.children[tree]) { 922 sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality); 923 } 924 } 1020 sb.Write("{0}\t{1:N5}\t", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality); 1021 } 1022 } 1023 sb.WriteLine(); 925 1024 return sb.ToString(); 926 1025 } -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/EpsGreedy.cs
r15410 r15425 18 18 public double SumQuality { get; set; } 19 19 public double AverageQuality { get { return SumQuality / Tries; } } 20 public double BestQuality { get; internal set; } 20 21 public int Tries { get; set; } 21 22 public bool Done { get; set; } … … 26 27 this.Tries += o.Tries; 27 28 this.SumQuality += o.SumQuality; 29 this.BestQuality = Math.Max(this.BestQuality, other.BestQuality); 28 30 } 29 31 } … … 60 62 var a = action as ActionStatistics; 61 63 a.SumQuality += q; 64 a.BestQuality = Math.Max(a.BestQuality, q); 62 65 a.Tries++; 63 66 } -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/IActionStatistics.cs
r15410 r15425 8 8 public interface IActionStatistics { 9 9 double AverageQuality { get; } 10 double BestQuality { get; } 10 11 int Tries { get; } 11 12 bool Done { get; set; } -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs
r15416 r15425 19 19 public double SumQuality { get; set; } 20 20 public double AverageQuality { get { return SumQuality / Tries; } } 21 public double BestQuality { get; internal set; } 21 22 public int Tries { get; set; } 22 23 public bool Done { get; set; } … … 26 27 this.Tries += o.Tries; 27 28 this.SumQuality += o.SumQuality; 29 this.BestQuality = Math.Max(this.BestQuality, other.BestQuality); 28 30 } 29 31 } … … 60 62 var a = action as ActionStatistics; 61 63 a.SumQuality += q; 64 a.BestQuality = Math.Max(a.BestQuality, q); 62 65 a.Tries++; 63 66 } … … 82 85 return buf[rand.Next(buf.Count)]; 83 86 } 87 88 Debug.Assert(actions.All(a => a.Done || a.Tries > 0)); 84 89 85 90 Debug.Assert(totalTries > 0); -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/UcbTuned.cs
r15410 r15425 19 19 public double SumSqrQuality { get; set; } 20 20 public double AverageQuality { get { return SumQuality / Tries; } } 21 public double BestQuality { get; internal set; } 21 22 public double QualityVariance { get { return SumSqrQuality / Tries - AverageQuality * AverageQuality; } } 22 23 public int Tries { get; set; } … … 29 30 this.SumQuality += o.SumQuality; 30 31 this.SumSqrQuality += o.SumSqrQuality; 32 this.BestQuality = Math.Max(this.BestQuality, other.BestQuality); 31 33 } 32 34 } … … 64 66 a.SumQuality += q; 65 67 a.SumSqrQuality += q * q; 68 a.BestQuality = Math.Max(a.BestQuality, q); 66 69 a.Tries++; 67 70 }
Note: See TracChangeset
for help on using the changeset viewer.