Changeset 15420 for branches/MCTS-SymbReg-2796
- Timestamp:
- 10/17/17 09:56:30 (7 years ago)
- Location:
- branches/MCTS-SymbReg-2796
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs
r15416 r15420 385 385 386 386 public void Goto(int targetState) { 387 Debug.Assert(followStates[CurrentState].Contains(targetState)); 387 388 if (actions[CurrentState, targetState] != null) 388 389 actions[CurrentState, targetState].ForEach(a => a()); // execute all actions -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r15416 r15420 71 71 // TODO: is it OK to initialize all constants to 1 (Obj 2)? 72 72 // TODO: improve memory usage 73 // TODO: support empty test partition 73 74 #region static API 74 75 … … 138 139 private readonly double[] predBuf, testPredBuf; 139 140 private readonly double[][] gradBuf; 141 142 // debugging stats 143 // calculate for each level the number of alternatives the average 'inequality' of tries and 'inequality' of quality over the alternatives for each trie 144 // inequality can be calculated using the Gini coefficient 145 internal readonly double[] giniCoeffs = new double[100]; 146 140 147 141 148 public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, … … 436 443 } 437 444 } 445 438 446 #endregion 447 448 #if DEBUG 449 internal void ClearStats() { 450 for (int i = 0; i < giniCoeffs.Length; i++) giniCoeffs[i] = -1; 451 } 452 internal void WriteStats() { 453 Console.WriteLine(string.Join("\t", giniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x)))); 454 } 455 456 #endif 457 439 458 } 440 459 … … 496 515 bool success = false; 497 516 do { 517 #if DEBUG 518 mctsState.ClearStats(); 519 #endif 498 520 automaton.Reset(); 499 521 success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q); … … 502 524 mctsState.effectiveRollouts++; 503 525 526 #if DEBUG 527 mctsState.WriteStats(); 528 #endif 504 529 //if (mctsState.effectiveRollouts % 100 == 1) { 505 506 530 // Console.WriteLine(WriteTree(tree, mctsState)); 531 // Console.WriteLine(TraceTree(tree, mctsState)); 507 532 //} 508 533 return q; 509 534 } 510 511 535 512 536 // search forward … … 540 564 selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand); 541 565 } 566 567 // STATS 568 state.giniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality)); 569 542 570 tree = state.children[tree][selectedIdx]; 543 571 … … 548 576 int nFs; 549 577 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 550 while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 578 while (automaton.CurrentState != tree.state && nFs == 1 && 579 !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 551 580 automaton.Goto(possibleFollowStates[0]); 552 581 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); … … 658 687 } 659 688 689 private static double InequalityCoefficient(IEnumerable<double> xs) { 690 var arr = xs.ToArray(); 691 var sad = 0.0; 692 var sum = 0.0; 693 694 for(int i=0;i<arr.Length;i++) { 695 for(int j=0;j<arr.Length;j++) { 696 sad += Math.Abs(arr[i] - arr[j]); 697 sum += arr[j]; 698 } 699 } 700 return 0.5 * sad / sum; 701 } 660 702 661 703 private static double TransformQuality(double q) { … … 871 913 return Disassembler.CodeToString(code); 872 914 } 915 873 916 874 917 private static string WriteStatistics(Tree tree, State state) { -
branches/MCTS-SymbReg-2796/Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/MctsSymbolicRegressionTest.cs
r15416 r15420 1 1 using System; 2 using System.Collections.Generic; 2 3 using System.Linq; 3 4 using System.Threading; … … 758 759 [TestMethod] 759 760 [TestCategory("Algorithms.DataAnalysis")] 760 [TestProperty("Time", " short")]761 public void MctsSymbReg_NoConstants_Poly10 () {761 [TestProperty("Time", "long")] 762 public void MctsSymbReg_NoConstants_Poly10_250rows() { 762 763 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234); 763 764 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10"))); 765 regProblem.TrainingPartition.Start = 0; 766 regProblem.TrainingPartition.End = regProblem.Dataset.Rows; 767 regProblem.TestPartition.Start = 0; 768 regProblem.TestPartition.End = 2; 764 769 TestMctsWithoutConstants(regProblem, nVarRefs: 15, iterations: 200000, allowExp: false, allowLog: false, allowInv: false); 770 } 771 [TestMethod] 772 [TestCategory("Algorithms.DataAnalysis")] 773 [TestProperty("Time", "long")] 774 public void MctsSymbReg_NoConstants_Poly10_10000rows() { 775 // as poly-10 but more rows 776 var rand = new FastRandom(1234); 777 var x1 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 778 var x2 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 779 var x3 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 780 var x4 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 781 var x5 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 782 var x6 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 783 var x7 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 784 var x8 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 785 var x9 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 786 var x10 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList(); 787 var ys = new List<double>(); 788 for (int i = 0; i < x1.Count; i++) { 789 ys.Add(x1[i] * x2[i] + x3[i] * x4[i] + x5[i] * x6[i] + x1[i] * x7[i] * x9[i] + x3[i] * x6[i] * x10[i]); 790 } 791 792 var ds = new Dataset(new string[] { "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "y" }, 793 new[] { x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, ys }); 794 795 796 var problemData = new RegressionProblemData(ds, new string[] { "a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}, "y"); 797 798 problemData.TrainingPartition.Start = 0; 799 problemData.TrainingPartition.End = problemData.Dataset.Rows; 800 problemData.TestPartition.Start = 0; 801 problemData.TestPartition.End = 2; // must not be empty 802 803 804 TestMctsWithoutConstants(problemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false); 765 805 } 766 806 … … 797 837 var @as = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList(); 798 838 var bs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList(); 799 var cs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() * 1.0e-3).ToList();800 var ds = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() 801 var es = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() 839 var cs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() * 1.0e-3).ToList(); 840 var ds = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList(); 841 var es = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList(); 802 842 var ys = new double[@as.Count]; 803 for (int i=0;i<ys.Length;i++)804 ys[i] = @as[i] + bs[i] + @as[i] *bs[i]*cs[i];843 for (int i = 0; i < ys.Length; i++) 844 ys[i] = @as[i] + bs[i] + @as[i] * bs[i] * cs[i]; 805 845 806 846 var dataset = new Dataset(new string[] { "a", "b", "c", "d", "e", "y" }, new[] { @as, bs, cs, ds, es, ys.ToList() }); 807 847 808 var problemData = new RegressionProblemData(dataset, new string[] { "a", "b", "c","d","e" }, "y");848 var problemData = new RegressionProblemData(dataset, new string[] { "a", "b", "c", "d", "e" }, "y"); 809 849 810 850 … … 854 894 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenInstanceProvider(seed: 1234); 855 895 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("F7 "))); 856 TestMcts(regProblem, maxVariableReferences: 5, allowExp: false, allowLog: true, allowInv: false);896 TestMcts(regProblem, maxVariableReferences: 5, allowExp: false, allowLog: true, allowInv: false); 857 897 } 858 898 [TestMethod] … … 895 935 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenInstanceProvider(seed: 1234); 896 936 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("F11 "))); 897 TestMcts(regProblem, maxVariableReferences: 5, allowExp: true, allowLog: true, allowInv: false); 937 TestMcts(regProblem, maxVariableReferences: 5, allowExp: true, allowLog: true, allowInv: false); 898 938 } 899 939 [TestMethod]
Note: See TracChangeset
for help on using the changeset viewer.