Changeset 15425
- Timestamp:
- 10/20/17 12:38:06 (7 years ago)
- Location:
- branches/MCTS-SymbReg-2796
- Files:
-
- 2 added
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg.csproj
r15416 r15425 98 98 </ItemGroup> 99 99 <ItemGroup> 100 <Compile Include="Heuristics.cs" /> 100 101 <Compile Include="MctsSymbolicRegression\ApproximateDoubleEqualityComparer.cs" /> 101 102 <Compile Include="MctsSymbolicRegression\IConstraintHandler.cs" /> -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r15420 r15425 56 56 // 57 57 58 // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should 59 // weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a 60 // branch and less in the expected value. (--> Review "Extreme Bandit" literature again) 58 // TODO: The samples of x1*... or x2*... do not give any information about the relevance of the interaction term x1*x2 in general! 59 // --> E.g. if x1, x2 ~ N(0, 1) or U(-1, 1) this is trivial to show 60 // --> Therefore, looking at rollout statistics for arm selection is useless in the general case! 61 // --> It is necessary to rely on other features for the arm selection. 62 // --> TODO: Which heuristics can we apply? 61 63 // TODO: Solve Poly-10 62 64 // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved? … … 72 74 // TODO: improve memory usage 73 75 // TODO: support empty test partition 76 // TODO: the algorithm should be invariant to linear transformations of the space (y = f(x') = f( Ax ) ) for invertible transformations A --> unit tests 74 77 #region static API 75 78 … … 143 146 // calculate for each level the number of alternatives the average 'inequality' of tries and 'inequality' of quality over the alternatives for each trie 144 147 // inequality can be calculated using the Gini coefficient 145 internal readonly double[] giniCoeffs = new double[100]; 146 148 internal readonly double[] pathGiniCoeffs = new double[100]; 149 internal readonly double[] pathQs = new double[100]; 150 internal readonly double[] levelBestQ = new double[100]; 151 // internal readonly double[] levelMaxTries = new double[100]; 152 internal readonly double[] pathBestQ = new double[100]; // as long as pathBestQs = levelBestQs we are following the correct path 153 internal readonly string[] levelBestAction = new string[100]; 154 internal readonly string[] curAction = new string[100]; 155 internal readonly double[] pathSelectedQ = new double[100]; 147 156 148 157 public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, … … 448 457 #if DEBUG 449 458 internal void ClearStats() { 450 for (int i = 0; i < giniCoeffs.Length; i++) giniCoeffs[i] = -1; 451 } 452 internal void WriteStats() { 453 Console.WriteLine(string.Join("\t", giniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x)))); 454 } 459 for (int i = 0; i < pathGiniCoeffs.Length; i++) pathGiniCoeffs[i] = -1; 460 for (int i = 0; i < pathQs.Length; i++) pathGiniCoeffs[i] = -99; 461 for (int i = 0; i < pathBestQ.Length; i++) pathBestQ[i] = -99; 462 for (int i = 0; i < pathSelectedQ.Length; i++) pathSelectedQ[i] = -99; 463 } 464 internal void WriteGiniStats() { 465 Console.WriteLine(string.Join("\t", pathGiniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x)))); 466 } 467 internal void WriteQs() { 468 // Console.WriteLine(string.Join("\t", pathQs.TakeWhile(x => x >= -100).Select(x => string.Format("{0:N3}", x)))); 469 var sb = new StringBuilder(); 470 // length 471 int i = 0; 472 while (i < pathBestQ.Length && pathBestQ[i] > -99 && pathBestQ[i] == levelBestQ[i]) { 473 i++; 474 } 475 sb.AppendFormat("{0,-3}",i); 476 477 i = 0; 478 // sb.AppendFormat("{0:N3}", levelBestQ[0]); 479 while (i < pathSelectedQ.Length && pathSelectedQ[i] > -99) { 480 sb.AppendFormat("\t{0:N3}", pathSelectedQ[i]); 481 i++; 482 } 483 Console.WriteLine(sb.ToString()); 484 sb.Clear(); 485 i = 0; 486 // sb.AppendFormat("{0:N3}", levelBestQ[0]); 487 while (i < pathBestQ.Length && pathBestQ[i] > -99) { 488 sb.AppendFormat("\t{0:N3}", pathBestQ[i]); 489 i++; 490 } 491 Console.WriteLine(sb.ToString()); 492 sb.Clear(); 493 i = 0; 494 while (i < pathBestQ.Length && pathBestQ[i] > -99) { 495 sb.AppendFormat("\t{0:N3}", levelBestQ[i]); 496 i++; 497 } 498 Console.WriteLine(sb.ToString()); 499 500 sb.Clear(); 501 i = 0; 502 while (i < pathBestQ.Length && pathBestQ[i] > -99) { 503 sb.AppendFormat("\t{0,-5}", (curAction[i] != null && curAction[i].Length > 5) ? curAction[i].Substring(0, 5) : curAction[i]); 504 i++; 505 } 506 Console.WriteLine(sb.ToString()); 507 sb.Clear(); 508 i = 0; 509 while (i < pathBestQ.Length && pathBestQ[i] > -99) { 510 sb.AppendFormat("\t{0,-5}", (levelBestAction[i] != null && levelBestAction[i].Length > 5) ? levelBestAction[i].Substring(0, 5) : levelBestAction[i]); 511 i++; 512 } 513 Console.WriteLine(sb.ToString()); 514 515 Console.WriteLine(); 516 } 517 455 518 456 519 #endif … … 525 588 526 589 #if DEBUG 527 mctsState.WriteStats(); 590 // mctsState.WriteGiniStats(); 591 Console.WriteLine(ExprStr(automaton)); 592 mctsState.WriteQs(); 593 // Console.WriteLine(WriteStatistics(tree, mctsState)); 594 528 595 #endif 529 596 //if (mctsState.effectiveRollouts % 100 == 1) { … … 564 631 selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand); 565 632 } 566 633 567 634 // STATS 568 state.giniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality)); 635 state.pathGiniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality)); 636 state.pathQs[tree.level] = tree.actionStatistics.AverageQuality; 569 637 570 638 tree = state.children[tree][selectedIdx]; … … 576 644 int nFs; 577 645 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 578 while (automaton.CurrentState != tree.state && nFs == 1 && 579 !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 580 automaton.Goto(possibleFollowStates[0]); 581 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 582 } 646 // TODO! 647 // while (possibleFollowStates[0] != tree.state && nFs == 1 && 648 // !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 649 // automaton.Goto(possibleFollowStates[0]); 650 // automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 651 // } 583 652 Debug.Assert(possibleFollowStates.Contains(tree.state)); 584 653 automaton.Goto(tree.state); … … 589 658 string actionString = ""; 590 659 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 591 while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 592 actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]); 593 // no alternatives -> just go to the next state 594 automaton.Goto(possibleFollowStates[0]); 595 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 596 } 660 // TODO 661 // while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 662 // actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]); 663 // // no alternatives -> just go to the next state 664 // automaton.Goto(possibleFollowStates[0]); 665 // automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 666 // } 597 667 if (nFs == 0) { 598 668 // stuck in a dead end (no final state and no allowed follow states) … … 672 742 q = TransformQuality(q); 673 743 success = true; 744 BackpropagateQuality(tree, q, treePolicy, state); 674 745 } else { 675 746 // we got stuck in roll-out (not evaluation necessary!) … … 682 753 // Update statistics 683 754 // Set branch to done if all children are done. 684 BackpropagateQuality(tree, q, treePolicy, state); 755 BackpropagateDone(tree, state); 756 BackpropagateDebugStats(tree, q, state); 757 685 758 686 759 return success; … … 703 776 private static double TransformQuality(double q) { 704 777 // no transformation 705 //return q;778 return q; 706 779 707 780 // EXPERIMENTAL! … … 738 811 739 812 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) { 740 if (q > 0) policy.Update(tree.actionStatistics, q); 813 policy.Update(tree.actionStatistics, q); 814 815 if (state.parents.ContainsKey(tree)) { 816 foreach (var parent in state.parents[tree]) { 817 BackpropagateQuality(parent, q, policy, state); 818 } 819 } 820 } 821 822 private static void BackpropagateDone(Tree tree, State state) { 741 823 if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) { 742 824 tree.Done = true; … … 746 828 if (state.parents.ContainsKey(tree)) { 747 829 foreach (var parent in state.parents[tree]) { 748 BackpropagateQuality(parent, q, policy, state); 749 } 830 BackpropagateDone(parent, state); 831 } 832 } 833 } 834 835 private static void BackpropagateDebugStats(Tree tree, double q, State state) { 836 if (state.parents.ContainsKey(tree)) { 837 foreach (var parent in state.parents[tree]) { 838 BackpropagateDebugStats(parent, q, state); 839 } 840 } 841 842 state.pathSelectedQ[tree.level] = tree.actionStatistics.AverageQuality; 843 state.pathBestQ[tree.level] = tree.actionStatistics.BestQuality; 844 state.curAction[tree.level] = tree.expr; 845 if (state.levelBestQ[tree.level] < tree.actionStatistics.BestQuality) { 846 state.levelBestQ[tree.level] = tree.actionStatistics.BestQuality; 847 state.levelBestAction[tree.level] = tree.expr; 750 848 } 751 849 } … … 917 1015 private static string WriteStatistics(Tree tree, State state) { 918 1016 var sb = new System.IO.StringWriter(); 919 sb.Write Line("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);1017 sb.Write("{0}\t{1:N5}\t", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality); 920 1018 if (state.children.ContainsKey(tree)) { 921 1019 foreach (var ch in state.children[tree]) { 922 sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality); 923 } 924 } 1020 sb.Write("{0}\t{1:N5}\t", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality); 1021 } 1022 } 1023 sb.WriteLine(); 925 1024 return sb.ToString(); 926 1025 } -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/EpsGreedy.cs
r15410 r15425 18 18 public double SumQuality { get; set; } 19 19 public double AverageQuality { get { return SumQuality / Tries; } } 20 public double BestQuality { get; internal set; } 20 21 public int Tries { get; set; } 21 22 public bool Done { get; set; } … … 26 27 this.Tries += o.Tries; 27 28 this.SumQuality += o.SumQuality; 29 this.BestQuality = Math.Max(this.BestQuality, other.BestQuality); 28 30 } 29 31 } … … 60 62 var a = action as ActionStatistics; 61 63 a.SumQuality += q; 64 a.BestQuality = Math.Max(a.BestQuality, q); 62 65 a.Tries++; 63 66 } -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/IActionStatistics.cs
r15410 r15425 8 8 public interface IActionStatistics { 9 9 double AverageQuality { get; } 10 double BestQuality { get; } 10 11 int Tries { get; } 11 12 bool Done { get; set; } -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs
r15416 r15425 19 19 public double SumQuality { get; set; } 20 20 public double AverageQuality { get { return SumQuality / Tries; } } 21 public double BestQuality { get; internal set; } 21 22 public int Tries { get; set; } 22 23 public bool Done { get; set; } … … 26 27 this.Tries += o.Tries; 27 28 this.SumQuality += o.SumQuality; 29 this.BestQuality = Math.Max(this.BestQuality, other.BestQuality); 28 30 } 29 31 } … … 60 62 var a = action as ActionStatistics; 61 63 a.SumQuality += q; 64 a.BestQuality = Math.Max(a.BestQuality, q); 62 65 a.Tries++; 63 66 } … … 82 85 return buf[rand.Next(buf.Count)]; 83 86 } 87 88 Debug.Assert(actions.All(a => a.Done || a.Tries > 0)); 84 89 85 90 Debug.Assert(totalTries > 0); -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/UcbTuned.cs
r15410 r15425 19 19 public double SumSqrQuality { get; set; } 20 20 public double AverageQuality { get { return SumQuality / Tries; } } 21 public double BestQuality { get; internal set; } 21 22 public double QualityVariance { get { return SumSqrQuality / Tries - AverageQuality * AverageQuality; } } 22 23 public int Tries { get; set; } … … 29 30 this.SumQuality += o.SumQuality; 30 31 this.SumSqrQuality += o.SumSqrQuality; 32 this.BestQuality = Math.Max(this.BestQuality, other.BestQuality); 31 33 } 32 34 } … … 64 66 a.SumQuality += q; 65 67 a.SumSqrQuality += q * q; 68 a.BestQuality = Math.Max(a.BestQuality, q); 66 69 a.Tries++; 67 70 } -
branches/MCTS-SymbReg-2796/Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/MctsSymbolicRegressionTest.cs
r15420 r15425 4 4 using System.Threading; 5 5 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies; 6 using HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg; 6 7 using HeuristicLab.Data; 7 8 using HeuristicLab.Optimization; … … 14 15 [TestClass()] 15 16 public class MctsSymbolicRegressionTest { 17 #region heuristics 18 [TestMethod] 19 [TestCategory("Algorithms.DataAnalysis")] 20 [TestProperty("Time", "short")] 21 public void TestHeuristics() { 22 { 23 // a, b ~ U(0, 1) should be trivial 24 var nRand = new MersenneTwister(1234); 25 26 int n = 10000; // large sample so that we can use the thresholds below 27 var a = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 28 var b = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 29 var x = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 30 var y = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 31 32 var z = a.Zip(b, (ai, bi) => ai * bi).ToArray(); 33 34 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, z) > 0.05); // should be detected as relevant 35 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, x, z) > 0.05); // a and b > 0 so these should be detected as well 36 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, y, z) > 0.05); 37 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, x, z) > 0.05); 38 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, y, z) > 0.05); 39 Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z) < 0.05); 40 } 41 { 42 // a, b ~ U(1000, 2000) also trivial 43 var nRand = new UniformDistributedRandom(new MersenneTwister(1234), 1000, 2000); 44 45 int n = 10000; // large sample so that we can use the thresholds below 46 var a = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 47 var b = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 48 var x = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 49 var y = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 50 51 var z = a.Zip(b, (ai, bi) => ai * bi).ToArray(); 52 53 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, z) > 0.05); // should be detected as relevant 54 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, x, z) > 0.05); 55 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, y, z) > 0.05); 56 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, x, z) > 0.05); 57 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, y, z) > 0.05); 58 Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z) < 0.05); 59 } 60 { 61 // a, b ~ U(-1, 1) 62 var nRand = new UniformDistributedRandom(new MersenneTwister(1234), -1, 1); 63 64 int n = 10000; // large sample so that we can use the thresholds below 65 var a = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 66 var b = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 67 var x = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 68 var y = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 69 70 var z = a.Zip(b, (ai, bi) => ai * bi).ToArray(); 71 72 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, z) > 0.05); // should be detected as relevant 73 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, x, z) < 0.05); 74 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, y, z) < 0.05); 75 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, x, z) < 0.05); 76 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, y, z) < 0.05); 77 Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z) < 0.05); 78 } 79 { 80 // a, b ~ N(0, 1) 81 var nRand = new NormalDistributedRandom(new MersenneTwister(1234), 0, 1); 82 83 int n = 10000; // large sample so that we can use the thresholds below 84 var a = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 85 var b = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 86 var x = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 87 var y = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 88 89 var z = a.Zip(b, (ai, bi) => ai * bi).ToArray(); 90 91 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, z) > 0.05); // should be detected as relevant 92 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, x, z) < 0.05); 93 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, y, z) < 0.05); 94 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, x, z) < 0.05); 95 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, y, z) < 0.05); 96 Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z) < 0.05); 97 } 98 { 99 // a ~ N(100, 1), b ~ N(-100, 1) 100 var nRand = new NormalDistributedRandom(new MersenneTwister(1234), 0, 1); 101 var aRand = new NormalDistributedRandom(new MersenneTwister(1234), 100, 1); 102 var bRand = new NormalDistributedRandom(new MersenneTwister(1234), -100, 1); 103 104 int n = 10000; // large sample so that we can use the thresholds below 105 var a = Enumerable.Range(0, n).Select(_ => aRand.NextDouble()).ToArray(); 106 var b = Enumerable.Range(0, n).Select(_ => bRand.NextDouble()).ToArray(); 107 var x = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 108 var y = Enumerable.Range(0, n).Select(_ => nRand.NextDouble()).ToArray(); 109 110 var z = a.Zip(b, (ai, bi) => ai * bi).ToArray(); 111 112 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, z) > 0.05); // should be detected as relevant 113 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, x, z) > 0.05); // a > 0 114 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, y, z) > 0.05); 115 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, x, z) > 0.05); // b < 0 116 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, y, z) > 0.05); 117 Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z) < 0.05); 118 } 119 } 120 #endregion 121 122 16 123 #region expression hashing 17 124 [TestMethod] … … 790 897 } 791 898 792 var ds = new Dataset(new string[] { "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "y" }, 899 var ds = new Dataset(new string[] { "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "y" }, 793 900 new[] { x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, ys }); 794 901 795 902 796 var problemData = new RegressionProblemData(ds, new string[] { "a", "b", "c", "d", "e", "f", "g", "h", "i", "j" }, "y");903 var problemData = new RegressionProblemData(ds, new string[] { "a", "b", "c", "d", "e", "f", "g", "h", "i", "j" }, "y"); 797 904 798 905 problemData.TrainingPartition.Start = 0; … … 1118 1225 1119 1226 // UCB tuned 1120 //var ucbTuned = new UcbTuned();1121 // ucbTuned.C = 1.5;1122 //mctsSymbReg.Policy = ucbTuned;1227 var ucbTuned = new UcbTuned(); 1228 ucbTuned.C = 1; 1229 mctsSymbReg.Policy = ucbTuned; 1123 1230 1124 1231
Note: See TracChangeset
for help on using the changeset viewer.