Changeset 15416 for branches/MCTS-SymbReg-2796
- Timestamp:
- 10/11/17 08:15:19 (7 years ago)
- Location:
- branches/MCTS-SymbReg-2796
- Files:
-
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg.csproj
r15414 r15416 20 20 <DebugType>full</DebugType> 21 21 <Optimize>false</Optimize> 22 <OutputPath> bin\</OutputPath>22 <OutputPath>..\..\..\..\trunk\sources\bin\</OutputPath> 23 23 <DefineConstants>DEBUG;TRACE</DefineConstants> 24 24 <ErrorReport>prompt</ErrorReport> -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs
r15414 r15416 139 139 codeGenerator.Emit1(OpCodes.LoadConst0); 140 140 constraintHandler.Reset(); 141 }, "0 , Reset");141 }, "0"); 142 142 AddTransition(StateTermEnd, StateExprEnd, () => { 143 143 codeGenerator.Emit1(OpCodes.Add); … … 155 155 constraintHandler.StartTerm(); 156 156 }, 157 "c , StartTerm");157 "c"); 158 158 AddTransition(StateFactorEnd, StateTermEnd, 159 159 () => { … … 161 161 constraintHandler.EndTerm(); 162 162 }, 163 "* , EndTerm");163 "*"); 164 164 165 165 AddTransition(StateFactorEnd, StateFactorStart, … … 172 172 AddTransition(StateFactorStart, StateVariableFactorStart, () => { 173 173 constraintHandler.StartFactor(StateVariableFactorStart); 174 }, " StartFactor");174 }, ""); 175 175 if (allowExp) 176 176 AddTransition(StateFactorStart, StateExpFactorStart, () => { 177 177 constraintHandler.StartFactor(StateExpFactorStart); 178 }, " StartFactor");178 }, ""); 179 179 if (allowLog) 180 180 AddTransition(StateFactorStart, StateLogFactorStart, () => { 181 181 constraintHandler.StartFactor(StateLogFactorStart); 182 }, " StartFactor");182 }, ""); 183 183 if (allowInv) 184 184 AddTransition(StateFactorStart, StateInvFactorStart, () => { 185 185 constraintHandler.StartFactor(StateInvFactorStart); 186 }, " StartFactor");187 AddTransition(StateVariableFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, " EndFactor");188 AddTransition(StateExpFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, " EndFactor");189 AddTransition(StateLogFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, " EndFactor");190 AddTransition(StateInvFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, " EndFactor");186 }, ""); 187 AddTransition(StateVariableFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, ""); 188 AddTransition(StateExpFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, ""); 189 AddTransition(StateLogFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, ""); 190 AddTransition(StateInvFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, ""); 191 191 192 192 // VarFact -> var_1 ... var_n … … 202 202 constraintHandler.AddVarToCurrentFactor(varState); 203 203 }, 204 "var_" + varIdx + " , AddVar");204 "var_" + varIdx + ""); 205 205 AddTransition(curDynVarState, StateVariableFactorEnd); 206 206 curDynVarState++; … … 234 234 constraintHandler.AddVarToCurrentFactor(varState); 235 235 }, 236 "var_" + varIdx + " , AddVar");236 "var_" + varIdx + ""); 237 237 AddTransition(curDynVarState, StateExpFEnd, 238 238 () => { … … 249 249 constraintHandler.StartNewTermInPoly(); 250 250 }, 251 "0 , StartTermInPoly");251 "0"); 252 252 AddTransition(StateLogTEnd, StateLogFactorEnd, 253 253 () => { … … 289 289 constraintHandler.AddVarToCurrentFactor(varState); 290 290 }, 291 "var_" + varIdx + " , AddVar");291 "var_" + varIdx + ""); 292 292 AddTransition(curDynVarState, StateLogTFEnd); 293 293 curDynVarState++; … … 300 300 constraintHandler.StartNewTermInPoly(); 301 301 }, 302 "c , StartTermInPoly");302 "c"); 303 303 AddTransition(StateInvTEnd, StateInvFactorEnd, 304 304 () => { … … 338 338 constraintHandler.AddVarToCurrentFactor(varState); 339 339 }, 340 "var_" + varIdx + " , AddVar");340 "var_" + varIdx + ""); 341 341 AddTransition(curDynVarState, StateInvTFEnd); 342 342 curDynVarState++; … … 442 442 codeGenerator.Reset(); 443 443 constraintHandler.Reset(); 444 } 445 446 internal string GetActionString(int fromState, int toState) { 447 return actionStrings[fromState,toState] != null ? string.Join(" , ", actionStrings[fromState, toState]) : ""; 444 448 } 445 449 -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs
r15403 r15416 304 304 if (bestQ > bestQuality.Value) { 305 305 bestSolutionIteration.Value = i; 306 if (state.BestSolutionTrainingQuality > 0.99999) break; 306 307 } 307 308 bestQuality.Value = bestQ; … … 329 330 330 331 // final results (assumes that at least one iteration was calculated) 331 if (n > 0) { 332 if (n > 0) { 332 333 if (bestQ > bestQuality.Value) { 333 334 bestSolutionIteration.Value = iterations.Value + n; -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r15414 r15416 61 61 // TODO: Solve Poly-10 62 62 // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved? 63 // TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)? 64 // TODO: check if we can use a quality measure with range [-1..1] in policies 63 65 // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants 64 66 // TODO: check if transformation of y is correct and works (Obj 2) … … 117 119 private readonly ExpressionEvaluator evaluator, testEvaluator; 118 120 121 internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>(); 122 internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>(); 123 internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>(); 124 119 125 // values for best solution 120 private double bestR Sq;126 private double bestR; 121 127 private byte[] bestCode; 122 128 private int bestNParams; … … 187 193 188 194 // reset best solution 189 this.bestR Sq= 0;195 this.bestR = 0; 190 196 // code for default solution (constant model) 191 197 this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit }; … … 208 214 get { 209 215 evaluator.Exec(bestCode, x, bestConsts, predBuf); 210 return R Sq(y, predBuf);216 return Rho(y, predBuf); 211 217 } 212 218 } … … 215 221 get { 216 222 testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf); 217 return R Sq(testY, testPredBuf);223 return Rho(testY, testPredBuf); 218 224 } 219 225 } … … 248 254 249 255 // single objective best 250 if (q > bestR Sq) {251 bestR Sq= q;256 if (q > bestR) { 257 bestR = q; 252 258 bestNParams = nParams; 253 259 this.bestCode = new byte[code.Length]; … … 266 272 } 267 273 268 private void Eval(byte[] code, int nParams, out double r sq, out double[] optConsts) {274 private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) { 269 275 // we make a first pass to determine a valid starting configuration for all constants 270 276 // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator) … … 280 286 // if we don't need to optimize parameters then we are done 281 287 // changing scale and offset does not influence r² 282 r sq = RSq(y, predBuf);288 rho = Rho(y, predBuf); 283 289 optConsts = constsBuf; 284 290 } else { … … 289 295 funcEvaluations++; 290 296 291 r sq = RSq(y, predBuf);297 rho = Rho(y, predBuf); 292 298 optConsts = constsBuf; 293 299 } … … 297 303 298 304 #region helpers 299 private static double R Sq(IEnumerable<double> x, IEnumerable<double> y) {305 private static double Rho(IEnumerable<double> x, IEnumerable<double> y) { 300 306 OnlineCalculatorError error; 301 307 double r = OnlinePearsonsRCalculator.Calculate(x, y, out error); 302 return error == OnlineCalculatorError.None ? r * r: 0.0;308 return error == OnlineCalculatorError.None ? r : 0.0; 303 309 } 304 310 … … 491 497 do { 492 498 automaton.Reset(); 493 success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, out q);499 success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q); 494 500 mctsState.totalRollouts++; 495 501 } while (!success && !tree.Done); 496 502 mctsState.effectiveRollouts++; 497 503 498 if (mctsState.effectiveRollouts % 10 == 1) {499 // Console.WriteLine(WriteTree(tree));500 // Console.WriteLine(TraceTree(tree));501 }504 //if (mctsState.effectiveRollouts % 100 == 1) { 505 // Console.WriteLine(WriteTree(tree, mctsState)); 506 // Console.WriteLine(TraceTree(tree, mctsState)); 507 //} 502 508 return q; 503 509 } 504 505 private static Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();506 private static Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();507 private static Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();508 509 510 510 511 511 512 // search forward 512 513 private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy, 514 State state, 513 515 out double q) { 514 516 // ROLLOUT AND EXPANSION … … 527 529 528 530 while (!automaton.IsFinalState(automaton.CurrentState)) { 529 if ( children.ContainsKey(tree)) {530 if ( children[tree].All(ch => ch.Done)) {531 if (state.children.ContainsKey(tree)) { 532 if (state.children[tree].All(ch => ch.Done)) { 531 533 tree.Done = true; 532 534 break; … … 535 537 // UCT selection within tree 536 538 int selectedIdx = 0; 537 if ( children[tree].Count > 1) {538 selectedIdx = treePolicy.Select( children[tree].Select(ch => ch.actionStatistics), rand);539 } 540 tree = children[tree][selectedIdx];539 if (state.children[tree].Count > 1) { 540 selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand); 541 } 542 tree = state.children[tree][selectedIdx]; 541 543 542 544 // move the automaton forward until reaching the state … … 556 558 int[] possibleFollowStates; 557 559 int nFs; 560 string actionString = ""; 558 561 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 559 562 while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) { 563 actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]); 560 564 // no alternatives -> just go to the next state 561 565 automaton.Goto(possibleFollowStates[0]); … … 568 572 } 569 573 var newChildren = new List<Tree>(nFs); 570 children.Add(tree, newChildren);574 state.children.Add(tree, newChildren); 571 575 for (int i = 0; i < nFs; i++) { 572 576 Tree child = null; … … 574 578 if (automaton.IsEvalState(possibleFollowStates[i])) { 575 579 var hc = Hashcode(automaton); 576 if (! nodes.TryGetValue(hc, out child)) {580 if (!state.nodes.TryGetValue(hc, out child)) { 577 581 child = new Tree() { 578 582 children = null, 579 583 state = possibleFollowStates[i], 580 584 actionStatistics = treePolicy.CreateActionStatistics(), 581 expr = string.Empty, // ExprStr(automaton),585 expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), 582 586 level = tree.level + 1 583 587 }; 584 nodes.Add(hc, child);588 state.nodes.Add(hc, child); 585 589 } 586 590 // only allow forward edges (don't add the child if we would go back in the graph) 587 else if (child.level > tree.level) {591 else if (child.level > tree.level) { 588 592 // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link 589 593 // to all parents 590 BackpropagateStatistics(child.actionStatistics, tree );594 BackpropagateStatistics(child.actionStatistics, tree, state); 591 595 } else { 592 596 // prevent cycles 593 597 Debug.Assert(child.level <= tree.level); 594 598 child = null; 595 } 599 } 596 600 } else { 597 601 child = new Tree() { … … 599 603 state = possibleFollowStates[i], 600 604 actionStatistics = treePolicy.CreateActionStatistics(), 601 expr = string.Empty, // ExprStr(automaton),605 expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), 602 606 level = tree.level + 1 603 607 }; … … 614 618 615 619 foreach (var ch in newChildren) { 616 if (! parents.ContainsKey(ch)) {617 parents.Add(ch, new List<Tree>());620 if (!state.parents.ContainsKey(ch)) { 621 state.parents.Add(ch, new List<Tree>()); 618 622 } 619 parents[ch].Add(tree);620 } 621 622 623 // follow one of the children 624 tree = Select FinalOrRandom2(automaton, tree, rand);623 state.parents[ch].Add(tree); 624 } 625 626 627 // follow one of the children 628 tree = SelectStateLeadingToFinal(automaton, tree, rand, state); 625 629 automaton.Goto(tree.state); 626 630 } … … 636 640 automaton.GetCode(out code, out nParams); 637 641 q = eval(code, nParams); 642 // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr); 638 643 q = TransformQuality(q); 639 644 success = true; 640 645 } else { 641 646 // we got stuck in roll-out (not evaluation necessary!) 647 // Console.WriteLine("\t" + ExprStr(automaton) + " STOP"); 642 648 q = 0.0; 643 649 success = false; … … 647 653 // Update statistics 648 654 // Set branch to done if all children are done. 649 BackpropagateQuality(tree, q, treePolicy );655 BackpropagateQuality(tree, q, treePolicy, state); 650 656 651 657 return success; … … 655 661 private static double TransformQuality(double q) { 656 662 // no transformation 657 return q;663 // return q; 658 664 659 665 // EXPERIMENTAL! 666 667 // Fisher transformation 668 // (assumes q is Correl(pred, target) 669 670 q = Math.Min(q, 0.99999999); 671 q = Math.Max(q, -0.99999999); 672 return 0.5 * Math.Log((1 + q) / (1 - q)); 673 660 674 // optimal result: q = 1 -> return huge value 661 675 // if (q >= 1.0) return 1E16; … … 665 679 666 680 // backpropagate existing statistics to all parents 667 private static void BackpropagateStatistics(IActionStatistics stats, Tree tree ) {681 private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) { 668 682 tree.actionStatistics.Add(stats); 669 if ( parents.ContainsKey(tree)) {670 foreach (var parent in parents[tree]) {671 BackpropagateStatistics(stats, parent );683 if (state.parents.ContainsKey(tree)) { 684 foreach (var parent in state.parents[tree]) { 685 BackpropagateStatistics(stats, parent, state); 672 686 } 673 687 } … … 681 695 } 682 696 683 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy ) {697 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) { 684 698 if (q > 0) policy.Update(tree.actionStatistics, q); 685 if ( children.ContainsKey(tree) &&children[tree].All(ch => ch.Done)) {699 if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) { 686 700 tree.Done = true; 687 701 // children[tree] = null; keep all nodes 688 702 } 689 703 690 if (parents.ContainsKey(tree)) { 691 foreach (var parent in parents[tree]) { 692 BackpropagateQuality(parent, q, policy); 693 } 694 } 695 } 696 697 private static Tree SelectFinalOrRandom2(Automaton automaton, Tree tree, IRandom rand) { 698 // if one of the new children leads to a final state then go there 699 // otherwise choose a random child 700 int selectedChildIdx = -1; 701 // find first final state if there is one 702 var children = MctsSymbolicRegressionStatic.children[tree]; 703 for (int i = 0; i < children.Count; i++) { 704 if (automaton.IsFinalState(children[i].state)) { 704 if (state.parents.ContainsKey(tree)) { 705 foreach (var parent in state.parents[tree]) { 706 BackpropagateQuality(parent, q, policy, state); 707 } 708 } 709 } 710 711 private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) { 712 // find the child with the smallest state value (smaller values are closer to the final state) 713 int selectedChildIdx = 0; 714 var children = state.children[tree]; 715 Tree minChild = children.First(); 716 for (int i = 1; i < children.Count; i++) { 717 if(children[i].state < minChild.state) 705 718 selectedChildIdx = i; 706 break;707 }708 }709 // no final state -> select the first child710 if (selectedChildIdx == -1) {711 selectedChildIdx = 0;712 719 } 713 720 return children[selectedChildIdx]; … … 865 872 } 866 873 867 private static string WriteStatistics(Tree tree ) {874 private static string WriteStatistics(Tree tree, State state) { 868 875 var sb = new System.IO.StringWriter(); 869 876 sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality); 870 if ( children.ContainsKey(tree)) {871 foreach (var ch in children[tree]) {877 if (state.children.ContainsKey(tree)) { 878 foreach (var ch in state.children[tree]) { 872 879 sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality); 873 880 } … … 876 883 } 877 884 878 private static string TraceTree(Tree tree ) {885 private static string TraceTree(Tree tree, State state) { 879 886 var sb = new StringBuilder(); 880 887 sb.Append( … … 885 892 int nodeId = 0; 886 893 887 TraceTreeRec(tree, 0, sb, ref nodeId );894 TraceTreeRec(tree, 0, sb, ref nodeId, state); 888 895 sb.Append("}"); 889 896 return sb.ToString(); 890 897 } 891 898 892 private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId ) {899 private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) { 893 900 var avgNodeQ = tree.actionStatistics.AverageQuality; 894 901 var tries = tree.actionStatistics.Tries; 895 902 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 896 903 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 897 898 sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine(); 904 hue = 0.0; 905 906 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine(); 899 907 900 908 var list = new List<Tuple<int, int, Tree>>(); 901 if ( children.ContainsKey(tree)) {902 foreach (var ch in children[tree]) {909 if (state.children.ContainsKey(tree)) { 910 foreach (var ch in state.children[tree]) { 903 911 nextId++; 904 912 avgNodeQ = ch.actionStatistics.AverageQuality; … … 906 914 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 907 915 hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 908 sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 909 sb.AppendFormat("{0} -> {1}", parentId, nextId, avgNodeQ).AppendLine(); 916 hue = 0.0; 917 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 918 sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine(); 910 919 list.Add(Tuple.Create(tries, nextId, ch)); 911 920 } 921 922 foreach(var tup in list) { 923 var ch = tup.Item3; 924 var chId = tup.Item2; 925 if(state.children.ContainsKey(ch) && state.children[ch].Count == 1) { 926 var chch = state.children[ch].First(); 927 nextId++; 928 avgNodeQ = chch.actionStatistics.AverageQuality; 929 tries = chch.actionStatistics.Tries; 930 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 931 hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 932 hue = 0.0; 933 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 934 sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine(); 935 } 936 } 937 912 938 foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) { 913 TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId );914 } 915 } 916 } 917 918 private static string WriteTree(Tree tree ) {939 TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state); 940 } 941 } 942 } 943 944 private static string WriteTree(Tree tree, State state) { 919 945 var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture); 920 946 var nodeIds = new Dictionary<Tree, int>(); … … 924 950 node [style=filled]; 925 951 "); 926 int threshold = nodes.Count > 500 ? 10 :0;927 foreach (var kvp in children) {952 int threshold = /* state.nodes.Count > 500 ? 10 : */ 0; 953 foreach (var kvp in state.children) { 928 954 var parent = kvp.Key; 929 955 int parentId; … … 934 960 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 935 961 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 962 hue = 0.0; 936 963 if (parent.actionStatistics.Tries > threshold) 937 sb.Write("{0} [label=\"{1: N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);964 sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue); 938 965 nodeIds.Add(parent, parentId); 939 966 } … … 949 976 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 950 977 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 978 hue = 0.0; 951 979 if (tries > threshold) { 952 sb.Write("{0} [label=\"{1: N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);980 sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue); 953 981 var edgeLabel = child.expr; 954 982 // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, ""); -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs
r15414 r15416 78 78 else totalTries += a.Tries; 79 79 } 80 // if there are unvisited actions select a random action80 // if there are unvisited actions select a random unvisited action 81 81 if (buf.Any()) { 82 82 return buf[rand.Next(buf.Count)]; 83 83 } 84 84 85 Debug.Assert(totalTries > 0); 85 86 double logTotalTries = Math.Log(totalTries); -
branches/MCTS-SymbReg-2796/Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/MctsSymbolicRegressionTest.cs
r15414 r15416 7 7 using HeuristicLab.Problems.DataAnalysis; 8 8 using HeuristicLab.Problems.Instances.DataAnalysis; 9 using HeuristicLab.Random; 9 10 using Microsoft.VisualStudio.TestTools.UnitTesting; 10 11 … … 550 551 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenInstanceProvider(seed: 1234); 551 552 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("F7 "))); 552 TestMctsWithoutConstants(regProblem, nVarRefs: 10, iterations: 1000000, allowExp: false, allowLog: true, allowInv: false); 553 } 553 TestMctsWithoutConstants(regProblem, nVarRefs: 10, iterations: 100000, allowExp: false, allowLog: true, allowInv: false); 554 } 555 556 [TestMethod] 557 [TestCategory("Algorithms.DataAnalysis")] 558 [TestProperty("Time", "short")] 559 public void MctsSymbReg_NoConstants_Poly10_Part1() { 560 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234); 561 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10"))); 562 563 // Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10 564 // Y' = X1*X2 + X3*X4 + X5*X6 565 // simplify problem by changing target 566 var ds = ((Dataset)regProblem.Dataset).ToModifiable(); 567 var ys = ds.GetDoubleValues("Y").ToArray(); 568 var x1 = ds.GetDoubleValues("X1").ToArray(); 569 var x2 = ds.GetDoubleValues("X2").ToArray(); 570 var x3 = ds.GetDoubleValues("X3").ToArray(); 571 var x4 = ds.GetDoubleValues("X4").ToArray(); 572 var x5 = ds.GetDoubleValues("X5").ToArray(); 573 var x6 = ds.GetDoubleValues("X6").ToArray(); 574 var x7 = ds.GetDoubleValues("X7").ToArray(); 575 var x8 = ds.GetDoubleValues("X8").ToArray(); 576 var x9 = ds.GetDoubleValues("X9").ToArray(); 577 var x10 = ds.GetDoubleValues("X10").ToArray(); 578 for (int i = 0; i < ys.Length; i++) { 579 ys[i] -= x1[i] * x7[i] * x9[i]; 580 ys[i] -= x3[i] * x6[i] * x10[i]; 581 } 582 ds.ReplaceVariable("Y", ys.ToList()); 583 584 var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable); 585 586 587 TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false); 588 } 589 590 [TestMethod] 591 [TestCategory("Algorithms.DataAnalysis")] 592 [TestProperty("Time", "short")] 593 public void MctsSymbReg_NoConstants_Poly10_Part2() { 594 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234); 595 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10"))); 596 597 // Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10 598 // Y' = X1*X7*X9 + X3*X6*X10 599 // simplify problem by changing target 600 var ds = ((Dataset)regProblem.Dataset).ToModifiable(); 601 var ys = ds.GetDoubleValues("Y").ToArray(); 602 var x1 = ds.GetDoubleValues("X1").ToArray(); 603 var x2 = ds.GetDoubleValues("X2").ToArray(); 604 var x3 = ds.GetDoubleValues("X3").ToArray(); 605 var x4 = ds.GetDoubleValues("X4").ToArray(); 606 var x5 = ds.GetDoubleValues("X5").ToArray(); 607 var x6 = ds.GetDoubleValues("X6").ToArray(); 608 var x7 = ds.GetDoubleValues("X7").ToArray(); 609 var x8 = ds.GetDoubleValues("X8").ToArray(); 610 var x9 = ds.GetDoubleValues("X9").ToArray(); 611 var x10 = ds.GetDoubleValues("X10").ToArray(); 612 for (int i = 0; i < ys.Length; i++) { 613 ys[i] -= x1[i] * x2[i]; 614 ys[i] -= x3[i] * x4[i]; 615 ys[i] -= x5[i] * x6[i]; 616 } 617 ds.ReplaceVariable("Y", ys.ToList()); 618 619 var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable); 620 621 622 TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false); 623 } 624 625 [TestMethod] 626 [TestCategory("Algorithms.DataAnalysis")] 627 [TestProperty("Time", "short")] 628 public void MctsSymbReg_NoConstants_Poly10_Part3() { 629 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234); 630 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10"))); 631 632 // Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10 633 // Y' = X1*X2 + X1*X7*X9 634 // simplify problem by changing target 635 var ds = ((Dataset)regProblem.Dataset).ToModifiable(); 636 var ys = ds.GetDoubleValues("Y").ToArray(); 637 var x1 = ds.GetDoubleValues("X1").ToArray(); 638 var x2 = ds.GetDoubleValues("X2").ToArray(); 639 var x3 = ds.GetDoubleValues("X3").ToArray(); 640 var x4 = ds.GetDoubleValues("X4").ToArray(); 641 var x5 = ds.GetDoubleValues("X5").ToArray(); 642 var x6 = ds.GetDoubleValues("X6").ToArray(); 643 var x7 = ds.GetDoubleValues("X7").ToArray(); 644 var x8 = ds.GetDoubleValues("X8").ToArray(); 645 var x9 = ds.GetDoubleValues("X9").ToArray(); 646 var x10 = ds.GetDoubleValues("X10").ToArray(); 647 for (int i = 0; i < ys.Length; i++) { 648 ys[i] -= x3[i] * x4[i]; 649 ys[i] -= x5[i] * x6[i]; 650 ys[i] -= x3[i] * x6[i] * x10[i]; 651 } 652 ds.ReplaceVariable("Y", ys.ToList()); 653 654 var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable); 655 656 657 TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false); 658 } 659 660 [TestMethod] 661 [TestCategory("Algorithms.DataAnalysis")] 662 [TestProperty("Time", "short")] 663 public void MctsSymbReg_NoConstants_Poly10_Part4() { 664 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234); 665 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10"))); 666 667 // Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10 668 // Y' = X3*X4 + X5*X6 + X3*X6*X10 669 // simplify problem by changing target 670 var ds = ((Dataset)regProblem.Dataset).ToModifiable(); 671 var ys = ds.GetDoubleValues("Y").ToArray(); 672 var x1 = ds.GetDoubleValues("X1").ToArray(); 673 var x2 = ds.GetDoubleValues("X2").ToArray(); 674 var x3 = ds.GetDoubleValues("X3").ToArray(); 675 var x4 = ds.GetDoubleValues("X4").ToArray(); 676 var x5 = ds.GetDoubleValues("X5").ToArray(); 677 var x6 = ds.GetDoubleValues("X6").ToArray(); 678 var x7 = ds.GetDoubleValues("X7").ToArray(); 679 var x8 = ds.GetDoubleValues("X8").ToArray(); 680 var x9 = ds.GetDoubleValues("X9").ToArray(); 681 var x10 = ds.GetDoubleValues("X10").ToArray(); 682 for (int i = 0; i < ys.Length; i++) { 683 ys[i] -= x1[i] * x2[i]; 684 ys[i] -= x1[i] * x7[i] * x9[i]; 685 } 686 ds.ReplaceVariable("Y", ys.ToList()); 687 var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable); 688 689 690 TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false); 691 } 692 693 [TestMethod] 694 [TestCategory("Algorithms.DataAnalysis")] 695 [TestProperty("Time", "short")] 696 public void MctsSymbReg_NoConstants_Poly10_Part5() { 697 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234); 698 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10"))); 699 700 // Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10 701 // Y' = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 702 // simplify problem by changing target 703 var ds = ((Dataset)regProblem.Dataset).ToModifiable(); 704 var ys = ds.GetDoubleValues("Y").ToArray(); 705 var x1 = ds.GetDoubleValues("X1").ToArray(); 706 var x2 = ds.GetDoubleValues("X2").ToArray(); 707 var x3 = ds.GetDoubleValues("X3").ToArray(); 708 var x4 = ds.GetDoubleValues("X4").ToArray(); 709 var x5 = ds.GetDoubleValues("X5").ToArray(); 710 var x6 = ds.GetDoubleValues("X6").ToArray(); 711 var x7 = ds.GetDoubleValues("X7").ToArray(); 712 var x8 = ds.GetDoubleValues("X8").ToArray(); 713 var x9 = ds.GetDoubleValues("X9").ToArray(); 714 var x10 = ds.GetDoubleValues("X10").ToArray(); 715 for (int i = 0; i < ys.Length; i++) { 716 ys[i] -= x3[i] * x6[i] * x10[i]; 717 } 718 ds.ReplaceVariable("Y", ys.ToList()); 719 var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable); 720 721 722 TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false); 723 } 724 725 [TestMethod] 726 [TestCategory("Algorithms.DataAnalysis")] 727 [TestProperty("Time", "short")] 728 public void MctsSymbReg_NoConstants_Poly10_Part6() { 729 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234); 730 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10"))); 731 732 // Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10 733 // Y' = X1*X2 + X3*X4 + X5*X6 + X3*X6*X10 734 // simplify problem by changing target 735 var ds = ((Dataset)regProblem.Dataset).ToModifiable(); 736 var ys = ds.GetDoubleValues("Y").ToArray(); 737 var x1 = ds.GetDoubleValues("X1").ToArray(); 738 var x2 = ds.GetDoubleValues("X2").ToArray(); 739 var x3 = ds.GetDoubleValues("X3").ToArray(); 740 var x4 = ds.GetDoubleValues("X4").ToArray(); 741 var x5 = ds.GetDoubleValues("X5").ToArray(); 742 var x6 = ds.GetDoubleValues("X6").ToArray(); 743 var x7 = ds.GetDoubleValues("X7").ToArray(); 744 var x8 = ds.GetDoubleValues("X8").ToArray(); 745 var x9 = ds.GetDoubleValues("X9").ToArray(); 746 var x10 = ds.GetDoubleValues("X10").ToArray(); 747 for (int i = 0; i < ys.Length; i++) { 748 ys[i] -= x1[i] * x7[i] * x9[i]; 749 } 750 ds.ReplaceVariable("Y", ys.ToList()); 751 var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable); 752 753 754 TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 9, iterations: 100000, allowExp: false, allowLog: false, allowInv: false); 755 } 756 554 757 555 758 [TestMethod] … … 559 762 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234); 560 763 var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10"))); 561 TestMctsWithoutConstants(regProblem, nVarRefs: 15, iterations: 1000000, allowExp: false, allowLog: false, allowInv: false); 764 TestMctsWithoutConstants(regProblem, nVarRefs: 15, iterations: 200000, allowExp: false, allowLog: false, allowInv: false); 765 } 766 767 [TestMethod] 768 [TestCategory("Algorithms.DataAnalysis")] 769 [TestProperty("Time", "short")] 770 public void MctsSymbReg_NoConstants_TwoVars() { 771 772 // y = x1 + x2 + x1*x2 + x1*x2*x2 + x1*x1*x2 773 var rand = new FastRandom(1234); 774 var x1 = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList(); 775 var x2 = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList(); 776 var ys = x1.Zip(x2, (x1i, x2i) => x1i + x2i + x1i * x2i + x1i * x2i * x2i + x1i * x1i * x2i).ToList(); 777 778 var ds = new Dataset(new string[] { "a", "b", "y" }, new[] { x1, x2, ys }); 779 780 var problemData = new RegressionProblemData(ds, new string[] { "a", "b" }, "y"); 781 782 783 TestMctsWithoutConstants(problemData, nVarRefs: 10, iterations: 10000, allowExp: false, allowLog: false, allowInv: false); 784 } 785 786 [TestMethod] 787 [TestCategory("Algorithms.DataAnalysis")] 788 [TestProperty("Time", "short")] 789 public void MctsSymbReg_NoConstants_Misleading() { 790 791 // y = a + baaaaa (the effect of the second term should be very small) 792 // the alg will quickly find that a has big effect and will search below a 793 // since we prevent a + a... the algorithm must find the correct expression via a + b... 794 // however b has a small effect so the branch might not be identified as relevant 795 796 var rand = new FastRandom(1234); 797 var @as = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList(); 798 var bs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList(); 799 var cs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() *1.0e-3).ToList(); 800 var ds = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() ).ToList(); 801 var es = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() ).ToList(); 802 var ys = new double[@as.Count]; 803 for(int i=0;i<ys.Length;i++) 804 ys[i] = @as[i] + bs[i] + @as[i]*bs[i]*cs[i]; 805 806 var dataset = new Dataset(new string[] { "a", "b", "c", "d", "e", "y" }, new[] { @as, bs, cs, ds, es, ys.ToList() }); 807 808 var problemData = new RegressionProblemData(dataset, new string[] { "a", "b","c","d","e" }, "y"); 809 810 811 TestMctsWithoutConstants(problemData, nVarRefs: 10, iterations: 10000, allowExp: false, allowLog: false, allowInv: false); 562 812 } 563 813 #endregion … … 822 1072 mctsSymbReg.ConstantOptimizationIterations = -1; 823 1073 1074 // random policy 1075 // var epsPolicy = new EpsilonGreedy(); 1076 // epsPolicy.Eps = 1.0; 1077 // mctsSymbReg.Policy = epsPolicy; 1078 1079 // UCB tuned 1080 // var ucbTuned = new UcbTuned(); 1081 // ucbTuned.C = 1.5; 1082 // mctsSymbReg.Policy = ucbTuned; 1083 1084 824 1085 #endregion 825 1086 RunAlgorithm(mctsSymbReg); -
branches/MCTS-SymbReg-2796/Tests/Test.csproj
r15403 r15416 63 63 <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.Instances.DataAnalysis-3.3.dll</HintPath> 64 64 </Reference> 65 <Reference Include="HeuristicLab.Random-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL"> 66 <SpecificVersion>False</SpecificVersion> 67 <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Random-3.3.dll</HintPath> 68 </Reference> 65 69 <Reference Include="System" /> 66 70 </ItemGroup>
Note: See TracChangeset
for help on using the changeset viewer.