Changeset 15438
- Timestamp:
- 10/28/17 19:56:51 (7 years ago)
- Location:
- branches/MCTS-SymbReg-2796
- Files:
-
- 5 deleted
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg.csproj
r15437 r15438 100 100 <Compile Include="Heuristics.cs" /> 101 101 <Compile Include="MctsSymbolicRegression\ApproximateDoubleEqualityComparer.cs" /> 102 <Compile Include="MctsSymbolicRegression\IConstraintHandler.cs" />103 102 <Compile Include="MctsSymbolicRegression\Automaton.cs" /> 104 103 <Compile Include="MctsSymbolicRegression\CodeGenerator.cs" /> … … 108 107 <Compile Include="MctsSymbolicRegression\MctsSymbolicRegressionStatic.cs" /> 109 108 <Compile Include="MctsSymbolicRegression\OpCodes.cs" /> 110 <Compile Include="MctsSymbolicRegression\Policies\EpsGreedy.cs" />111 <Compile Include="MctsSymbolicRegression\Policies\IActionStatistics.cs" />112 <Compile Include="MctsSymbolicRegression\Policies\IPolicy.cs" />113 <Compile Include="MctsSymbolicRegression\Policies\PolicyBase.cs" />114 109 <Compile Include="MctsSymbolicRegression\ExprHash.cs" /> 115 <Compile Include="MctsSymbolicRegression\EmptyConstraintHandler.cs" />116 <Compile Include="MctsSymbolicRegression\SimpleConstraintHandler.cs" />117 110 <Compile Include="MctsSymbolicRegression\SymbolicExpressionGenerator.cs" /> 118 111 <Compile Include="MctsSymbolicRegression\Tree.cs" /> … … 122 115 <ItemGroup> 123 116 <None Include="Plugin.cs.frame" /> 117 </ItemGroup> 118 <ItemGroup> 119 <Folder Include="MctsSymbolicRegression\Policies\" /> 124 120 </ItemGroup> 125 121 <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" /> -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/Heuristics.cs
r15437 r15438 22 22 public static class Heuristics { 23 23 public static double CorrelationForInteraction(double[] a, double[] b, double[] c, double[] target) { 24 return 0.0;25 }26 public static double CorrelationForInteraction(double[] a, double[] b, double[] z) {27 //28 24 var am = a.Average(); 29 25 var bm = b.Average(); 26 var cm = c.Average(); 30 27 var p1 = Enumerable.Range(0, a.Length).Where(i => a[i] < am); 31 28 var p2 = Enumerable.Range(0, a.Length).Where(i => a[i] > am); 32 29 var p3 = Enumerable.Range(0, a.Length).Where(i => b[i] < bm); 33 30 var p4 = Enumerable.Range(0, a.Length).Where(i => b[i] > bm); 31 var p5 = Enumerable.Range(0, a.Length).Where(i => c[i] < cm); 32 var p6 = Enumerable.Range(0, a.Length).Where(i => c[i] > cm); 33 34 return 1.0 / (p1.Count() + p2.Count() + p3.Count() + p4.Count() + p5.Count() + p6.Count()) * 35 ( 36 p1.Count() * CorrelationForInteraction(b, c, target, p1) + 37 p2.Count() * CorrelationForInteraction(b, c, target, p2) + 38 p3.Count() * CorrelationForInteraction(a, c, target, p3) + 39 p4.Count() * CorrelationForInteraction(a, c, target, p3) + 40 p5.Count() * CorrelationForInteraction(a, b, target, p5) + 41 p6.Count() * CorrelationForInteraction(a, b, target, p6) 42 ); 43 } 44 public static double CorrelationForInteraction(double[] a, double[] b, double[] z) { 45 return CorrelationForInteraction(a, b, z, Enumerable.Range(0, a.Length)); 46 } 47 public static double CorrelationForInteraction(double[] a, double[] b, double[] z, IEnumerable<int> idx) { 48 // 49 var am = a.Average(); 50 var bm = b.Average(); 51 var p1 = idx.Where(i => a[i] < am); 52 var p2 = idx.Where(i => a[i] > am); 53 var p3 = idx.Where(i => b[i] < bm); 54 var p4 = idx.Where(i => b[i] > bm); 34 55 35 56 return 1.0 / (p1.Count() + p2.Count() + p3.Count() + p4.Count()) * -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs
r15437 r15438 118 118 private List<string>[,] actionStrings; // just for printing 119 119 private readonly CodeGenerator codeGenerator; 120 private IConstraintHandler constraintHandler; 121 122 public Automaton(double[][] vars, IConstraintHandler constraintHandler, 120 private int numVarRefs; 121 private int maximumNumberOfVariables; 122 123 public Automaton(double[][] vars, 123 124 bool allowProdOfVars = true, 124 125 bool allowExp = true, 125 126 bool allowLog = true, 126 127 bool allowInv = true, 127 bool allowMultipleTerms = false) { 128 bool allowMultipleTerms = false, 129 int maxNumberOfVariables = 5) { 128 130 int nVars = vars.Length; 131 this.maximumNumberOfVariables = maxNumberOfVariables; 129 132 codeGenerator = new CodeGenerator(); 130 this.constraintHandler = constraintHandler;131 133 BuildAutomaton(nVars, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); 132 134 … … 165 167 codeGenerator.Reset(); 166 168 codeGenerator.Emit1(OpCodes.LoadConst0); 167 constraintHandler.Reset();169 numVarRefs = 0; 168 170 }, "0"); 169 171 AddTransition(StateTermEnd, StateExprEnd, () => { … … 180 182 () => { 181 183 codeGenerator.Emit1(OpCodes.LoadParamN); 182 constraintHandler.StartTerm();183 184 }, 184 185 "c"); … … 186 187 () => { 187 188 codeGenerator.Emit1(OpCodes.Mul); 188 constraintHandler.EndTerm();189 189 }, 190 190 "*"); … … 198 198 if (allowProdOfVars) 199 199 AddTransition(StateFactorStart, StateVariableFactorStart, () => { 200 constraintHandler.StartFactor(StateVariableFactorStart);201 200 }, ""); 202 201 if (allowExp) 203 202 AddTransition(StateFactorStart, StateExpFactorStart, () => { 204 constraintHandler.StartFactor(StateExpFactorStart);205 203 }, ""); 206 204 if (allowLog) 207 205 AddTransition(StateFactorStart, StateLogFactorStart, () => { 208 constraintHandler.StartFactor(StateLogFactorStart);209 206 }, ""); 210 207 if (allowInv) 211 208 AddTransition(StateFactorStart, StateInvFactorStart, () => { 212 constraintHandler.StartFactor(StateInvFactorStart);213 209 }, ""); 214 AddTransition(StateVariableFactorEnd, StateFactorEnd , () => { constraintHandler.EndFactor(); }, "");215 AddTransition(StateExpFactorEnd, StateFactorEnd , () => { constraintHandler.EndFactor(); }, "");216 AddTransition(StateLogFactorEnd, StateFactorEnd , () => { constraintHandler.EndFactor(); }, "");217 AddTransition(StateInvFactorEnd, StateFactorEnd , () => { constraintHandler.EndFactor(); }, "");210 AddTransition(StateVariableFactorEnd, StateFactorEnd); 211 AddTransition(StateExpFactorEnd, StateFactorEnd); 212 AddTransition(StateLogFactorEnd, StateFactorEnd); 213 AddTransition(StateInvFactorEnd, StateFactorEnd); 218 214 219 215 // VarFact -> var_1 ... var_n … … 227 223 () => { 228 224 codeGenerator.Emit2(OpCodes.LoadVar, varIdx); 229 constraintHandler.AddVarToCurrentFactor(varState);225 numVarRefs++; 230 226 }, 231 227 "var_" + varIdx + ""); … … 259 255 () => { 260 256 codeGenerator.Emit2(OpCodes.LoadVar, varIdx); 261 constraintHandler.AddVarToCurrentFactor(varState);257 numVarRefs++; 262 258 }, 263 259 "var_" + varIdx + ""); … … 274 270 () => { 275 271 codeGenerator.Emit1(OpCodes.LoadConst0); 276 constraintHandler.StartNewTermInPoly();277 272 }, 278 273 "0"); … … 314 309 () => { 315 310 codeGenerator.Emit2(OpCodes.LoadVar, varIdx); 316 constraintHandler.AddVarToCurrentFactor(varState);311 numVarRefs++; 317 312 }, 318 313 "var_" + varIdx + ""); … … 325 320 () => { 326 321 codeGenerator.Emit1(OpCodes.LoadConst1); 327 constraintHandler.StartNewTermInPoly();328 322 }, 329 323 "c"); … … 363 357 () => { 364 358 codeGenerator.Emit2(OpCodes.LoadVar, varIdx); 365 constraintHandler.AddVarToCurrentFactor(varState);359 numVarRefs++; 366 360 }, 367 361 "var_" + varIdx + ""); … … 401 395 for (int i = 0; i < fs.Count; i++) { 402 396 var s = fs[i]; 403 if ( constraintHandler.IsAllowedFollowState(state, s)) {397 if (IsAllowedFollowState(state, s)) { 404 398 buf[j++] = s; 405 399 } … … 408 402 } 409 403 404 private bool IsAllowedFollowState(int state, int nextState) { 405 // any state is allowed if we have not reached the max number of variable references 406 // otherwise we can only go towards the final state (smaller state numbers) 407 if (numVarRefs < maximumNumberOfVariables) return true; 408 else return state > nextState; 409 } 410 410 411 411 public void Goto(int targetState) { … … 417 417 418 418 public bool IsFinalState(int s) { 419 return s == StateExprEnd && !constraintHandler.IsInvalidExpression;419 return s == StateExprEnd && numVarRefs <= maximumNumberOfVariables; 420 420 } 421 421 … … 434 434 // After that state of the automaton is restored to the current state. 435 435 public void GetCode(out byte[] code, out int nParams) { 436 IConstraintHandler storedConstraintHandler = null;437 436 int storedState = CurrentState; 438 437 int storedPC = codeGenerator.ProgramCounter; … … 440 439 if (!IsFinalState(CurrentState)) { 441 440 // save state and code, 442 // constraints are ignored while completing the expression443 storedConstraintHandler = constraintHandler;444 constraintHandler = new EmptyConstraintHandler();445 441 storedState = CurrentState; 446 442 storedPC = codeGenerator.ProgramCounter; … … 457 453 458 454 // restore 459 if (storedConstraintHandler != null) { 460 constraintHandler = storedConstraintHandler; 461 CurrentState = storedState; 462 codeGenerator.ProgramCounter = storedPC; 463 } 455 codeGenerator.ProgramCounter = storedPC; 456 CurrentState = storedState; 464 457 } 465 458 … … 467 460 CurrentState = StartState; 468 461 codeGenerator.Reset(); 469 constraintHandler.Reset();462 numVarRefs = 0; 470 463 } 471 464 -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs
r15437 r15438 23 23 using System.Linq; 24 24 using System.Threading; 25 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;26 25 using HeuristicLab.Analysis; 27 26 using HeuristicLab.Common; … … 75 74 public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter { 76 75 get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; } 77 } 78 public IValueParameter<IPolicy> PolicyParameter { 79 get { return (IValueParameter<IPolicy>)Parameters[PolicyParameterName]; } 80 } 76 } 81 77 public IFixedValueParameter<DoubleValue> PunishmentFactorParameter { 82 78 get { return (IFixedValueParameter<DoubleValue>)Parameters[PunishmentFactorParameterName]; } … … 121 117 get { return MaxVariableReferencesParameter.Value.Value; } 122 118 set { MaxVariableReferencesParameter.Value.Value = value; } 123 }124 public IPolicy Policy {125 get { return PolicyParameter.Value; }126 set { PolicyParameter.Value = value; }127 119 } 128 120 public double PunishmentFactor { … … 183 175 Parameters.Add(new FixedValueParameter<IntValue>(MaxVariablesParameterName, 184 176 "Maximal number of variables references in the symbolic regression models (multiple usages of the same variable are counted)", new IntValue(5))); 185 // Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName,186 // "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0)));187 Parameters.Add(new ValueParameter<IPolicy>(PolicyParameterName,188 "The policy to use for selecting nodes in MCTS", new EpsilonGreedy()));189 PolicyParameter.Hidden = true;190 177 Parameters.Add(new ValueParameter<ICheckedItemList<StringValue>>(AllowedFactorsParameterName, 191 178 "Choose which expressions are allowed as factors in the model.", defaultFactorsList)); … … 275 262 var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, ScaleVariables, 276 263 ConstantOptimizationIterations, Lambda, 277 Policy,collectPareto,264 collectPareto, 278 265 lowerLimit, upperLimit, 279 266 allowProdOfVars: AllowedFactors.CheckedItems.Any(s => s.Value.Value == VariableProductFactorName), -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r15437 r15438 26 26 using System.Linq; 27 27 using System.Text; 28 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;29 28 using HeuristicLab.Core; 30 29 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; … … 99 98 internal readonly Tree tree; 100 99 internal readonly Func<byte[], int, double> evalFun; 101 internal readonly IPolicy treePolicy;102 100 // MCTS might get stuck. Track statistics on the number of effective rollouts 103 101 internal int totalRollouts; … … 145 143 public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, 146 144 int constOptIterations, double lambda, 147 IPolicy treePolicy = null,148 145 bool collectParetoOptimalModels = false, 149 146 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, … … 187 184 this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit); 188 185 189 this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); 190 this.treePolicy = treePolicy ?? new EpsilonGreedy(); 186 this.automaton = new Automaton(x, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms, maxVariables); 191 187 this.tree = new Tree() { 192 188 state = automaton.CurrentState, 193 actionStatistics = treePolicy.CreateActionStatistics(),194 189 expr = "", 195 190 level = 0 … … 469 464 public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, 470 465 bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0, 471 IPolicy policy = null,472 466 bool collectParameterOptimalModels = false, 473 467 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, … … 479 473 ) { 480 474 return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda, 481 policy,collectParameterOptimalModels,475 collectParameterOptimalModels, 482 476 lowerEstimationLimit, upperEstimationLimit, 483 477 allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); … … 499 493 var eval = mctsState.evalFun; 500 494 var rand = mctsState.random; 501 var treePolicy = mctsState.treePolicy;502 495 double q = 0; 503 496 bool success = false; … … 505 498 506 499 automaton.Reset(); 507 success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy,mctsState, out q);500 success = TryTreeSearchRec2(rand, tree, automaton, eval, mctsState, out q); 508 501 mctsState.totalRollouts++; 509 502 } while (!success && !tree.Done); … … 517 510 518 511 // search forward 519 private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy, 512 private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, 513 Func<byte[], int, double> eval, 520 514 State state, 521 515 out double q) { … … 545 539 int selectedIdx = 0; 546 540 if (state.children[tree].Count > 1) { 547 selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);541 selectedIdx = SelectInternal(state.children[tree], rand); 548 542 } 549 543 … … 579 573 if (!state.nodes.TryGetValue(hc, out child)) { 580 574 child = new Tree() { 581 children = null,582 575 state = possibleFollowStates[i], 583 actionStatistics = treePolicy.CreateActionStatistics(),584 576 expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), 585 577 level = tree.level + 1 … … 591 583 // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link 592 584 // to all parents 593 BackpropagateStatistics( child.actionStatistics, tree, state);585 BackpropagateStatistics(tree, state, child.visits); 594 586 } else { 595 587 // prevent cycles … … 599 591 } else { 600 592 child = new Tree() { 601 children = null,602 593 state = possibleFollowStates[i], 603 actionStatistics = treePolicy.CreateActionStatistics(),604 594 expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]), 605 595 level = tree.level + 1 … … 639 629 automaton.GetCode(out code, out nParams); 640 630 q = eval(code, nParams); 641 // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);642 631 success = true; 643 BackpropagateQuality(tree, q, treePolicy, state);632 BackpropagateQuality(tree, q, state); 644 633 } else { 645 634 // we got stuck in roll-out (not evaluation necessary!) 646 // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");647 635 q = 0.0; 648 636 success = false; … … 659 647 } 660 648 649 private static int SelectInternal(List<Tree> list, IRandom rand) { 650 // choose a random node. 651 Debug.Assert(list.Any(t => !t.Done)); 652 653 var idx = rand.Next(list.Count); 654 while(list[idx].Done) { idx = rand.Next(list.Count); } 655 return idx; 656 } 657 661 658 // backpropagate existing statistics to all parents 662 private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) { 663 tree.actionStatistics.Add(stats); 659 private static void BackpropagateStatistics(Tree tree, State state, int numVisits) { 660 tree.visits += numVisits; 661 664 662 if (state.parents.ContainsKey(tree)) { 665 663 foreach (var parent in state.parents[tree]) { 666 BackpropagateStatistics( stats, parent, state);664 BackpropagateStatistics(parent, state, numVisits); 667 665 } 668 666 } … … 676 674 } 677 675 678 private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) { 679 policy.Update(tree.actionStatistics, q); 676 private static void BackpropagateQuality(Tree tree, double q, State state) { 677 tree.visits++; 678 // TODO: q is ignored for now 680 679 681 680 if (state.parents.ContainsKey(tree)) { 682 681 foreach (var parent in state.parents[tree]) { 683 BackpropagateQuality(parent, q, policy,state);682 BackpropagateQuality(parent, q, state); 684 683 } 685 684 } … … 718 717 } 719 718 return children[selectedChildIdx]; 720 } 721 722 // tree search might fail because of constraints for expressions 723 // in this case we get stuck we just restart 724 // see ConstraintHandler.cs for more info 725 private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy, 726 out double q) { 727 Tree selectedChild = null; 728 Contract.Assert(tree.state == automaton.CurrentState); 729 Contract.Assert(!tree.Done); 730 if (tree.children == null) { 731 if (automaton.IsFinalState(tree.state)) { 732 // final state 733 tree.Done = true; 734 735 // EVALUATE 736 byte[] code; int nParams; 737 automaton.GetCode(out code, out nParams); 738 q = eval(code, nParams); 739 740 treePolicy.Update(tree.actionStatistics, q); 741 return true; // we reached a final state 742 } else { 743 // EXPAND 744 int[] possibleFollowStates = new int[1000]; 745 int nFs; 746 automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs); 747 if (nFs == 0) { 748 // stuck in a dead end (no final state and no allowed follow states) 749 q = 0; 750 tree.Done = true; 751 tree.children = null; 752 return false; 753 } 754 tree.children = new Tree[nFs]; 755 for (int i = 0; i < tree.children.Length; i++) 756 tree.children[i] = new Tree() { 757 children = null, 758 state = possibleFollowStates[i], 759 actionStatistics = treePolicy.CreateActionStatistics() 760 }; 761 762 selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0]; 763 } 764 } else { 765 // tree.children != null 766 // UCT selection within tree 767 int selectedIdx = 0; 768 if (tree.children.Length > 1) { 769 selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand); 770 } 771 selectedChild = tree.children[selectedIdx]; 772 } 773 // make selected step and recurse 774 automaton.Goto(selectedChild.state); 775 var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q); 776 if (success) { 777 // only update if successful 778 treePolicy.Update(tree.actionStatistics, q); 779 } 780 781 tree.Done = tree.children.All(ch => ch.Done); 782 if (tree.Done) { 783 tree.children = null; // cut off the sub-branch if it has been fully explored 784 } 785 return success; 786 } 787 788 private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) { 789 // if one of the new children leads to a final state then go there 790 // otherwise choose a random child 791 int selectedChildIdx = -1; 792 // find first final state if there is one 793 for (int i = 0; i < tree.children.Length; i++) { 794 if (automaton.IsFinalState(tree.children[i].state)) { 795 selectedChildIdx = i; 796 break; 797 } 798 } 799 // no final state -> select a the first child 800 if (selectedChildIdx == -1) { 801 selectedChildIdx = 0; 802 } 803 return tree.children[selectedChildIdx]; 804 } 719 } 805 720 806 721 // scales data and extracts values from dataset into arrays … … 869 784 automaton.GetCode(out code, out nParams); 870 785 return Disassembler.CodeToString(code); 871 }872 873 874 private static string WriteStatistics(Tree tree, State state) {875 var sb = new System.IO.StringWriter();876 sb.Write("{0}\t{1:N5}\t", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);877 if (state.children.ContainsKey(tree)) {878 foreach (var ch in state.children[tree]) {879 sb.Write("{0}\t{1:N5}\t", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);880 }881 }882 sb.WriteLine();883 return sb.ToString();884 786 } 885 787 … … 899 801 900 802 private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) { 901 var avgNodeQ = tree.actionStatistics.AverageQuality; 902 var tries = tree.actionStatistics.Tries; 903 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 904 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 905 hue = 0.0; 906 907 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine(); 803 var tries = tree.visits; 804 805 sb.AppendFormat("{0} [label=\"{1}\"]; ", parentId, tries).AppendLine(); 908 806 909 807 var list = new List<Tuple<int, int, Tree>>(); … … 911 809 foreach (var ch in state.children[tree]) { 912 810 nextId++; 913 avgNodeQ = ch.actionStatistics.AverageQuality; 914 tries = ch.actionStatistics.Tries; 915 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 916 hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 917 hue = 0.0; 918 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 919 sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine(); 811 tries = ch.visits; 812 sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine(); 813 sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", parentId, nextId, ch.expr).AppendLine(); 920 814 list.Add(Tuple.Create(tries, nextId, ch)); 921 815 } … … 927 821 var chch = state.children[ch].First(); 928 822 nextId++; 929 avgNodeQ = chch.actionStatistics.AverageQuality; 930 tries = chch.actionStatistics.Tries; 931 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 932 hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 933 hue = 0.0; 934 sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine(); 935 sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine(); 823 tries = chch.visits; 824 sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine(); 825 sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", chId, nextId, chch.expr).AppendLine(); 936 826 } 937 827 } … … 957 847 if (!nodeIds.TryGetValue(parent, out parentId)) { 958 848 parentId = nodeIds.Count + 1; 959 var avgNodeQ = parent.actionStatistics.AverageQuality; 960 var tries = parent.actionStatistics.Tries; 961 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0; 962 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue 963 hue = 0.0; 964 if (parent.actionStatistics.Tries > threshold) 965 sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue); 849 var tries = parent.visits; 850 if (tries > threshold) 851 sb.Write("{0} [label=\"{1}\"]; ", parentId, tries); 966 852 nodeIds.Add(parent, parentId); 967 853 } … … 972 858 nodeIds.Add(child, childId); 973 859 } 974 var avgNodeQ = child.actionStatistics.AverageQuality; 975 var tries = child.actionStatistics.Tries; 860 var tries = child.visits; 976 861 if (tries < 1) continue; 977 if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;978 var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue979 hue = 0.0;980 862 if (tries > threshold) { 981 sb.Write("{0} [label=\"{1 :E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);863 sb.Write("{0} [label=\"{1}\"]; ", childId, tries); 982 864 var edgeLabel = child.expr; 983 865 // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, ""); 984 sb.Write("{0} -> {1} [label=\"{ 3}\"]", parentId, childId, avgNodeQ, edgeLabel);866 sb.Write("{0} -> {1} [label=\"{2}\"]", parentId, childId, edgeLabel); 985 867 } 986 868 } -
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Tree.cs
r15414 r15438 19 19 */ 20 20 #endregion 21 22 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies; 23 21 24 22 namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression { 25 // represents tree nodes for the search tree in MCTS26 23 internal class Tree { 27 24 public int state; 28 25 public int level; 29 26 public string expr; 30 public bool Done { 31 get { return actionStatistics.Done; } 32 set { actionStatistics.Done = value; } 33 } 34 public IActionStatistics actionStatistics; 35 public Tree[] children; 27 public bool Done { get; set; } 28 public int visits; 29 // { 30 // get { return actionStatistics.Done; } 31 // set { actionStatistics.Done = value; } 32 // } 33 // public IActionStatistics actionStatistics; 34 // public Tree[] children; 36 35 } 37 36 } -
branches/MCTS-SymbReg-2796/Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/MctsSymbolicRegressionTest.cs
r15437 r15438 3 3 using System.Linq; 4 4 using System.Threading; 5 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;6 5 using HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg; 7 6 using HeuristicLab.Data; … … 260 259 } 261 260 262 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, c, z) > 0.05);263 Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z, z) < 0.05);261 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, c, t) > 0.05); 262 Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z, t) < 0.05); 264 263 265 264 /* we might see correlations when only using one of the two relevant factors. … … 271 270 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, y, z) < 0.05); 272 271 */ 273 Console.WriteLine("a,b: {0:N3}\tx,y: {1:N3}\ta,x: {2:N3}\tb,x: {3:N3}\ta,y: {4:N3}\tb,y: {5:N3}\tcov(a,b): {6:N3}", 274 Heuristics.CorrelationForInteraction(a, b, z), 275 Heuristics.CorrelationForInteraction(x, y, z), 276 Heuristics.CorrelationForInteraction(a, x, z), 277 Heuristics.CorrelationForInteraction(b, x, z), 278 Heuristics.CorrelationForInteraction(a, y, z), 279 Heuristics.CorrelationForInteraction(b, y, z), 280 alglib.cov2(a, b) 272 Console.WriteLine("a,b,c: {0:N3}\tx,y,z: {1:N3}\ta,b,x: {2:N3}\tb,c,x: {3:N3}", 273 Heuristics.CorrelationForInteraction(a, b, c, t), 274 Heuristics.CorrelationForInteraction(x, y, z, t), 275 Heuristics.CorrelationForInteraction(a, b, x, t), 276 Heuristics.CorrelationForInteraction(b, c, x, t) 281 277 ); 282 278 } 279 } 280 } 281 282 [TestMethod] 283 [TestCategory("Algorithms.DataAnalysis")] 284 [TestProperty("Time", "short")] 285 public void TestPoly10Interactions() { 286 { 287 alglib.hqrndstate randState; 288 alglib.hqrndseed(1234, 31415, out randState); 289 290 int N = 25000; // large sample size to make sure the test thresholds hold 291 double[] a = new double[N]; 292 double[] b = new double[N]; 293 double[] c = new double[N]; 294 double[] d = new double[N]; 295 double[] e = new double[N]; 296 double[] f = new double[N]; 297 double[] g = new double[N]; 298 double[] h = new double[N]; 299 double[] i = new double[N]; 300 double[] j = new double[N]; 301 double[] y = new double[N]; 302 303 for(int k=0;k<N;k++) { 304 a[k] = alglib.hqrnduniformr(randState) * 2 - 1; 305 b[k] = alglib.hqrnduniformr(randState) * 2 - 1; 306 c[k] = alglib.hqrnduniformr(randState) * 2 - 1; 307 d[k] = alglib.hqrnduniformr(randState) * 2 - 1; 308 e[k] = alglib.hqrnduniformr(randState) * 2 - 1; 309 f[k] = alglib.hqrnduniformr(randState) * 2 - 1; 310 g[k] = alglib.hqrnduniformr(randState) * 2 - 1; 311 h[k] = alglib.hqrnduniformr(randState) * 2 - 1; 312 i[k] = alglib.hqrnduniformr(randState) * 2 - 1; 313 j[k] = alglib.hqrnduniformr(randState) * 2 - 1; 314 y[k] = a[k] * b[k] + c[k] * d[k] + e[k] * f[k] + a[k] * g[k] * i[k] + c[k] * f[k] * j[k]; 315 } 316 317 var x = new[] { a, b, c, d, e, f, g, h, i, j }; 318 var all2Combinations = HeuristicLab.Common.EnumerableExtensions.Combinations(new[] {1,2,3,4,5,6,7,8,9,10}, 2); 319 320 var resultList = new List<Tuple<string, double>>(); 321 foreach(var entry in all2Combinations) { 322 var aIdx = entry.First(); 323 var bIdx = entry.Skip(1).First(); 324 resultList.Add(Tuple.Create(aIdx + " " + bIdx, Heuristics.CorrelationForInteraction(x[aIdx - 1], x[bIdx - 1], y))); 325 } 326 327 foreach(var entry in resultList.OrderByDescending(t => t.Item2)) { 328 Console.WriteLine("{0} {1:N3}", entry.Item1, entry.Item2); 329 } 330 331 var all3Combinations = HeuristicLab.Common.EnumerableExtensions.Combinations(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, 3); 332 333 resultList = new List<Tuple<string, double>>(); 334 foreach (var entry in all3Combinations) { 335 var aIdx = entry.First(); 336 var bIdx = entry.Skip(1).First(); 337 var cIdx = entry.Skip(2).First(); 338 resultList.Add(Tuple.Create(aIdx + " " + bIdx + " " + cIdx, Heuristics.CorrelationForInteraction(x[aIdx - 1], x[bIdx - 1], x[cIdx - 1], y))); 339 } 340 341 // Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10 342 343 foreach (var entry in resultList.OrderByDescending(t => t.Item2)) { 344 Console.WriteLine("{0} {1:N3}", entry.Item1, entry.Item2); 345 } 346 347 348 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, y) > 0.01); 349 Assert.IsTrue(Heuristics.CorrelationForInteraction(b, a, y) > 0.01); 350 Assert.IsTrue(Heuristics.CorrelationForInteraction(c, d, y) > 0.01); 351 Assert.IsTrue(Heuristics.CorrelationForInteraction(d, c, y) > 0.01); 352 Assert.IsTrue(Heuristics.CorrelationForInteraction(e, f, y) > 0.01); 353 Assert.IsTrue(Heuristics.CorrelationForInteraction(f, e, y) > 0.01); 354 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, g, i, y) > 0.01); 355 Assert.IsTrue(Heuristics.CorrelationForInteraction(a, i, g, y) > 0.01); 356 Assert.IsTrue(Heuristics.CorrelationForInteraction(g, a, i, y) > 0.01); 357 Assert.IsTrue(Heuristics.CorrelationForInteraction(g, i, a, y) > 0.01); 358 Assert.IsTrue(Heuristics.CorrelationForInteraction(i, g, a, y) > 0.01); 359 Assert.IsTrue(Heuristics.CorrelationForInteraction(i, a, g, y) > 0.01); 360 361 Assert.IsTrue(Heuristics.CorrelationForInteraction(c, f, j, y) > 0.01); 362 Assert.IsTrue(Heuristics.CorrelationForInteraction(c, j, f, y) > 0.01); 363 Assert.IsTrue(Heuristics.CorrelationForInteraction(f, c, j, y) > 0.01); 364 Assert.IsTrue(Heuristics.CorrelationForInteraction(f, j, c, y) > 0.01); 365 Assert.IsTrue(Heuristics.CorrelationForInteraction(j, c, f, y) > 0.01); 366 Assert.IsTrue(Heuristics.CorrelationForInteraction(j, f, c, y) > 0.01); 283 367 } 284 368 }
Note: See TracChangeset
for help on using the changeset viewer.