Changeset 15438


Ignore:
Timestamp:
10/28/17 19:56:51 (2 years ago)
Author:
gkronber
Message:

#2796 refactoring to simplify the code

Location:
branches/MCTS-SymbReg-2796
Files:
5 deleted
7 edited

Legend:

Unmodified
Added
Removed
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg.csproj

    r15437 r15438  
    100100    <Compile Include="Heuristics.cs" />
    101101    <Compile Include="MctsSymbolicRegression\ApproximateDoubleEqualityComparer.cs" />
    102     <Compile Include="MctsSymbolicRegression\IConstraintHandler.cs" />
    103102    <Compile Include="MctsSymbolicRegression\Automaton.cs" />
    104103    <Compile Include="MctsSymbolicRegression\CodeGenerator.cs" />
     
    108107    <Compile Include="MctsSymbolicRegression\MctsSymbolicRegressionStatic.cs" />
    109108    <Compile Include="MctsSymbolicRegression\OpCodes.cs" />
    110     <Compile Include="MctsSymbolicRegression\Policies\EpsGreedy.cs" />
    111     <Compile Include="MctsSymbolicRegression\Policies\IActionStatistics.cs" />
    112     <Compile Include="MctsSymbolicRegression\Policies\IPolicy.cs" />
    113     <Compile Include="MctsSymbolicRegression\Policies\PolicyBase.cs" />
    114109    <Compile Include="MctsSymbolicRegression\ExprHash.cs" />
    115     <Compile Include="MctsSymbolicRegression\EmptyConstraintHandler.cs" />
    116     <Compile Include="MctsSymbolicRegression\SimpleConstraintHandler.cs" />
    117110    <Compile Include="MctsSymbolicRegression\SymbolicExpressionGenerator.cs" />
    118111    <Compile Include="MctsSymbolicRegression\Tree.cs" />
     
    122115  <ItemGroup>
    123116    <None Include="Plugin.cs.frame" />
     117  </ItemGroup>
     118  <ItemGroup>
     119    <Folder Include="MctsSymbolicRegression\Policies\" />
    124120  </ItemGroup>
    125121  <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/Heuristics.cs

    r15437 r15438  
    2222  public static class Heuristics {
    2323    public static double CorrelationForInteraction(double[] a, double[] b, double[] c, double[] target) {
    24       return 0.0;
    25     }
    26     public static double CorrelationForInteraction(double[] a, double[] b, double[] z) {
    27       //
    2824      var am = a.Average();
    2925      var bm = b.Average();
     26      var cm = c.Average();
    3027      var p1 = Enumerable.Range(0, a.Length).Where(i => a[i] < am);
    3128      var p2 = Enumerable.Range(0, a.Length).Where(i => a[i] > am);
    3229      var p3 = Enumerable.Range(0, a.Length).Where(i => b[i] < bm);
    3330      var p4 = Enumerable.Range(0, a.Length).Where(i => b[i] > bm);
     31      var p5 = Enumerable.Range(0, a.Length).Where(i => c[i] < cm);
     32      var p6 = Enumerable.Range(0, a.Length).Where(i => c[i] > cm);
     33
     34      return 1.0 / (p1.Count() + p2.Count() + p3.Count() + p4.Count() + p5.Count() + p6.Count()) *
     35        (
     36        p1.Count() * CorrelationForInteraction(b, c, target, p1) +
     37        p2.Count() * CorrelationForInteraction(b, c, target, p2) +
     38        p3.Count() * CorrelationForInteraction(a, c, target, p3) +
     39        p4.Count() * CorrelationForInteraction(a, c, target, p3) +
     40        p5.Count() * CorrelationForInteraction(a, b, target, p5) +
     41        p6.Count() * CorrelationForInteraction(a, b, target, p6)
     42      );
     43    }
     44    public static double CorrelationForInteraction(double[] a, double[] b, double[] z) {
     45      return CorrelationForInteraction(a, b, z, Enumerable.Range(0, a.Length));
     46    }
     47    public static double CorrelationForInteraction(double[] a, double[] b, double[] z, IEnumerable<int> idx) {
     48      //
     49      var am = a.Average();
     50      var bm = b.Average();
     51      var p1 = idx.Where(i => a[i] < am);
     52      var p2 = idx.Where(i => a[i] > am);
     53      var p3 = idx.Where(i => b[i] < bm);
     54      var p4 = idx.Where(i => b[i] > bm);
    3455
    3556      return 1.0 / (p1.Count() + p2.Count() + p3.Count() + p4.Count()) *
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs

    r15437 r15438  
    118118    private List<string>[,] actionStrings; // just for printing
    119119    private readonly CodeGenerator codeGenerator;
    120     private IConstraintHandler constraintHandler;
    121 
    122     public Automaton(double[][] vars, IConstraintHandler constraintHandler,
     120    private int numVarRefs;
     121    private int maximumNumberOfVariables;
     122
     123    public Automaton(double[][] vars,
    123124       bool allowProdOfVars = true,
    124125       bool allowExp = true,
    125126       bool allowLog = true,
    126127       bool allowInv = true,
    127        bool allowMultipleTerms = false) {
     128       bool allowMultipleTerms = false,
     129       int maxNumberOfVariables = 5) {
    128130      int nVars = vars.Length;
     131      this.maximumNumberOfVariables = maxNumberOfVariables;
    129132      codeGenerator = new CodeGenerator();
    130       this.constraintHandler = constraintHandler;
    131133      BuildAutomaton(nVars, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
    132134
     
    165167        codeGenerator.Reset();
    166168        codeGenerator.Emit1(OpCodes.LoadConst0);
    167         constraintHandler.Reset();
     169        numVarRefs = 0;
    168170      }, "0");
    169171      AddTransition(StateTermEnd, StateExprEnd, () => {
     
    180182        () => {
    181183          codeGenerator.Emit1(OpCodes.LoadParamN);
    182           constraintHandler.StartTerm();
    183184        },
    184185        "c");
     
    186187        () => {
    187188          codeGenerator.Emit1(OpCodes.Mul);
    188           constraintHandler.EndTerm();
    189189        },
    190190        "*");
     
    198198      if (allowProdOfVars)
    199199        AddTransition(StateFactorStart, StateVariableFactorStart, () => {
    200           constraintHandler.StartFactor(StateVariableFactorStart);
    201200        }, "");
    202201      if (allowExp)
    203202        AddTransition(StateFactorStart, StateExpFactorStart, () => {
    204           constraintHandler.StartFactor(StateExpFactorStart);
    205203        }, "");
    206204      if (allowLog)
    207205        AddTransition(StateFactorStart, StateLogFactorStart, () => {
    208           constraintHandler.StartFactor(StateLogFactorStart);
    209206        }, "");
    210207      if (allowInv)
    211208        AddTransition(StateFactorStart, StateInvFactorStart, () => {
    212           constraintHandler.StartFactor(StateInvFactorStart);
    213209        }, "");
    214       AddTransition(StateVariableFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "");
    215       AddTransition(StateExpFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "");
    216       AddTransition(StateLogFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "");
    217       AddTransition(StateInvFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "");
     210      AddTransition(StateVariableFactorEnd, StateFactorEnd);
     211      AddTransition(StateExpFactorEnd, StateFactorEnd);
     212      AddTransition(StateLogFactorEnd, StateFactorEnd);
     213      AddTransition(StateInvFactorEnd, StateFactorEnd);
    218214
    219215      // VarFact -> var_1 ... var_n
     
    227223          () => {
    228224            codeGenerator.Emit2(OpCodes.LoadVar, varIdx);
    229             constraintHandler.AddVarToCurrentFactor(varState);
     225            numVarRefs++;
    230226          },
    231227          "var_" + varIdx + "");
     
    259255          () => {
    260256            codeGenerator.Emit2(OpCodes.LoadVar, varIdx);
    261             constraintHandler.AddVarToCurrentFactor(varState);
     257            numVarRefs++;
    262258          },
    263259          "var_" + varIdx + "");
     
    274270        () => {
    275271          codeGenerator.Emit1(OpCodes.LoadConst0);
    276           constraintHandler.StartNewTermInPoly();
    277272        },
    278273        "0");
     
    314309          () => {
    315310            codeGenerator.Emit2(OpCodes.LoadVar, varIdx);
    316             constraintHandler.AddVarToCurrentFactor(varState);
     311            numVarRefs++;
    317312          },
    318313          "var_" + varIdx + "");
     
    325320        () => {
    326321          codeGenerator.Emit1(OpCodes.LoadConst1);
    327           constraintHandler.StartNewTermInPoly();
    328322        },
    329323        "c");
     
    363357          () => {
    364358            codeGenerator.Emit2(OpCodes.LoadVar, varIdx);
    365             constraintHandler.AddVarToCurrentFactor(varState);
     359            numVarRefs++;
    366360          },
    367361          "var_" + varIdx + "");
     
    401395      for (int i = 0; i < fs.Count; i++) {
    402396        var s = fs[i];
    403         if (constraintHandler.IsAllowedFollowState(state, s)) {
     397        if (IsAllowedFollowState(state, s)) {
    404398          buf[j++] = s;
    405399        }
     
    408402    }
    409403
     404    private bool IsAllowedFollowState(int state, int nextState) {
     405      // any state is allowed if we have not reached the max number of variable references
     406      // otherwise we can only go towards the final state (smaller state numbers)
     407      if (numVarRefs < maximumNumberOfVariables) return true;
     408      else return state > nextState;
     409    }
    410410
    411411    public void Goto(int targetState) {
     
    417417
    418418    public bool IsFinalState(int s) {
    419       return s == StateExprEnd && !constraintHandler.IsInvalidExpression;
     419      return s == StateExprEnd && numVarRefs <= maximumNumberOfVariables;
    420420    }
    421421
     
    434434    // After that state of the automaton is restored to the current state.
    435435    public void GetCode(out byte[] code, out int nParams) {
    436       IConstraintHandler storedConstraintHandler = null;
    437436      int storedState = CurrentState;
    438437      int storedPC = codeGenerator.ProgramCounter;
     
    440439      if (!IsFinalState(CurrentState)) {
    441440        // save state and code,
    442         // constraints are ignored while completing the expression
    443         storedConstraintHandler = constraintHandler;
    444         constraintHandler = new EmptyConstraintHandler();
    445441        storedState = CurrentState;
    446442        storedPC = codeGenerator.ProgramCounter;
     
    457453
    458454      // restore
    459       if (storedConstraintHandler != null) {
    460         constraintHandler = storedConstraintHandler;
    461         CurrentState = storedState;
    462         codeGenerator.ProgramCounter = storedPC;
    463       }
     455      codeGenerator.ProgramCounter = storedPC;
     456      CurrentState = storedState;
    464457    }
    465458
     
    467460      CurrentState = StartState;
    468461      codeGenerator.Reset();
    469       constraintHandler.Reset();
     462      numVarRefs = 0;
    470463    }
    471464
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs

    r15437 r15438  
    2323using System.Linq;
    2424using System.Threading;
    25 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    2625using HeuristicLab.Analysis;
    2726using HeuristicLab.Common;
     
    7574    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
    7675      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
    77     }
    78     public IValueParameter<IPolicy> PolicyParameter {
    79       get { return (IValueParameter<IPolicy>)Parameters[PolicyParameterName]; }
    80     }
     76    }                 
    8177    public IFixedValueParameter<DoubleValue> PunishmentFactorParameter {
    8278      get { return (IFixedValueParameter<DoubleValue>)Parameters[PunishmentFactorParameterName]; }
     
    121117      get { return MaxVariableReferencesParameter.Value.Value; }
    122118      set { MaxVariableReferencesParameter.Value.Value = value; }
    123     }
    124     public IPolicy Policy {
    125       get { return PolicyParameter.Value; }
    126       set { PolicyParameter.Value = value; }
    127119    }
    128120    public double PunishmentFactor {
     
    183175      Parameters.Add(new FixedValueParameter<IntValue>(MaxVariablesParameterName,
    184176        "Maximal number of variables references in the symbolic regression models (multiple usages of the same variable are counted)", new IntValue(5)));
    185       // Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName,
    186       //   "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0)));
    187       Parameters.Add(new ValueParameter<IPolicy>(PolicyParameterName,
    188         "The policy to use for selecting nodes in MCTS", new EpsilonGreedy()));
    189       PolicyParameter.Hidden = true;
    190177      Parameters.Add(new ValueParameter<ICheckedItemList<StringValue>>(AllowedFactorsParameterName,
    191178        "Choose which expressions are allowed as factors in the model.", defaultFactorsList));
     
    275262      var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, ScaleVariables,
    276263        ConstantOptimizationIterations, Lambda,
    277         Policy, collectPareto,
     264        collectPareto,
    278265        lowerLimit, upperLimit,
    279266        allowProdOfVars: AllowedFactors.CheckedItems.Any(s => s.Value.Value == VariableProductFactorName),
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r15437 r15438  
    2626using System.Linq;
    2727using System.Text;
    28 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    2928using HeuristicLab.Core;
    3029using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     
    9998      internal readonly Tree tree;
    10099      internal readonly Func<byte[], int, double> evalFun;
    101       internal readonly IPolicy treePolicy;
    102100      // MCTS might get stuck. Track statistics on the number of effective rollouts
    103101      internal int totalRollouts;
     
    145143      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
    146144        int constOptIterations, double lambda,
    147         IPolicy treePolicy = null,
    148145        bool collectParetoOptimalModels = false,
    149146        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     
    187184        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
    188185
    189         this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
    190         this.treePolicy = treePolicy ?? new EpsilonGreedy();
     186        this.automaton = new Automaton(x, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms, maxVariables);
    191187        this.tree = new Tree() {
    192188          state = automaton.CurrentState,
    193           actionStatistics = treePolicy.CreateActionStatistics(),
    194189          expr = "",
    195190          level = 0
     
    469464    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
    470465      bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
    471       IPolicy policy = null,
    472466      bool collectParameterOptimalModels = false,
    473467      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     
    479473      ) {
    480474      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
    481         policy, collectParameterOptimalModels,
     475        collectParameterOptimalModels,
    482476        lowerEstimationLimit, upperEstimationLimit,
    483477        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
     
    499493      var eval = mctsState.evalFun;
    500494      var rand = mctsState.random;
    501       var treePolicy = mctsState.treePolicy;
    502495      double q = 0;
    503496      bool success = false;
     
    505498
    506499        automaton.Reset();
    507         success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q);
     500        success = TryTreeSearchRec2(rand, tree, automaton, eval, mctsState, out q);
    508501        mctsState.totalRollouts++;
    509502      } while (!success && !tree.Done);
     
    517510
    518511    // search forward
    519     private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
     512    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton,
     513      Func<byte[], int, double> eval,
    520514      State state,
    521515      out double q) {
     
    545539          int selectedIdx = 0;
    546540          if (state.children[tree].Count > 1) {
    547             selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);
     541            selectedIdx = SelectInternal(state.children[tree], rand);
    548542          }
    549543
     
    579573              if (!state.nodes.TryGetValue(hc, out child)) {
    580574                child = new Tree() {
    581                   children = null,
    582575                  state = possibleFollowStates[i],
    583                   actionStatistics = treePolicy.CreateActionStatistics(),
    584576                  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
    585577                  level = tree.level + 1
     
    591583                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
    592584                // to all parents
    593                 BackpropagateStatistics(child.actionStatistics, tree, state);
     585                BackpropagateStatistics(tree, state, child.visits);
    594586              } else {
    595587                // prevent cycles
     
    599591            } else {
    600592              child = new Tree() {
    601                 children = null,
    602593                state = possibleFollowStates[i],
    603                 actionStatistics = treePolicy.CreateActionStatistics(),
    604594                expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
    605595                level = tree.level + 1
     
    639629        automaton.GetCode(out code, out nParams);
    640630        q = eval(code, nParams);
    641         // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);
    642631        success = true;
    643         BackpropagateQuality(tree, q, treePolicy, state);
     632        BackpropagateQuality(tree, q, state);
    644633      } else {
    645634        // we got stuck in roll-out (not evaluation necessary!)
    646         // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");
    647635        q = 0.0;
    648636        success = false;
     
    659647    }
    660648
     649    private static int SelectInternal(List<Tree> list, IRandom rand) {
     650      // choose a random node.
     651      Debug.Assert(list.Any(t => !t.Done));
     652
     653      var idx = rand.Next(list.Count);
     654      while(list[idx].Done) { idx = rand.Next(list.Count); }
     655      return idx;
     656    }
     657
    661658    // backpropagate existing statistics to all parents
    662     private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) {
    663       tree.actionStatistics.Add(stats);
     659    private static void BackpropagateStatistics(Tree tree, State state, int numVisits) {
     660      tree.visits += numVisits;
     661
    664662      if (state.parents.ContainsKey(tree)) {
    665663        foreach (var parent in state.parents[tree]) {
    666           BackpropagateStatistics(stats, parent, state);
     664          BackpropagateStatistics(parent, state, numVisits);
    667665        }
    668666      }
     
    676674    }
    677675
    678     private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) {
    679       policy.Update(tree.actionStatistics, q);
     676    private static void BackpropagateQuality(Tree tree, double q, State state) {
     677      tree.visits++;
     678      // TODO: q is ignored for now
    680679
    681680      if (state.parents.ContainsKey(tree)) {
    682681        foreach (var parent in state.parents[tree]) {
    683           BackpropagateQuality(parent, q, policy, state);
     682          BackpropagateQuality(parent, q, state);
    684683        }
    685684      }
     
    718717      }
    719718      return children[selectedChildIdx];
    720     }
    721 
    722     // tree search might fail because of constraints for expressions
    723     // in this case we get stuck we just restart
    724     // see ConstraintHandler.cs for more info
    725     private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
    726       out double q) {
    727       Tree selectedChild = null;
    728       Contract.Assert(tree.state == automaton.CurrentState);
    729       Contract.Assert(!tree.Done);
    730       if (tree.children == null) {
    731         if (automaton.IsFinalState(tree.state)) {
    732           // final state
    733           tree.Done = true;
    734 
    735           // EVALUATE
    736           byte[] code; int nParams;
    737           automaton.GetCode(out code, out nParams);
    738           q = eval(code, nParams);
    739 
    740           treePolicy.Update(tree.actionStatistics, q);
    741           return true; // we reached a final state
    742         } else {
    743           // EXPAND
    744           int[] possibleFollowStates = new int[1000];
    745           int nFs;
    746           automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);
    747           if (nFs == 0) {
    748             // stuck in a dead end (no final state and no allowed follow states)
    749             q = 0;
    750             tree.Done = true;
    751             tree.children = null;
    752             return false;
    753           }
    754           tree.children = new Tree[nFs];
    755           for (int i = 0; i < tree.children.Length; i++)
    756             tree.children[i] = new Tree() {
    757               children = null,
    758               state = possibleFollowStates[i],
    759               actionStatistics = treePolicy.CreateActionStatistics()
    760             };
    761 
    762           selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
    763         }
    764       } else {
    765         // tree.children != null
    766         // UCT selection within tree
    767         int selectedIdx = 0;
    768         if (tree.children.Length > 1) {
    769           selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
    770         }
    771         selectedChild = tree.children[selectedIdx];
    772       }
    773       // make selected step and recurse
    774       automaton.Goto(selectedChild.state);
    775       var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
    776       if (success) {
    777         // only update if successful
    778         treePolicy.Update(tree.actionStatistics, q);
    779       }
    780 
    781       tree.Done = tree.children.All(ch => ch.Done);
    782       if (tree.Done) {
    783         tree.children = null; // cut off the sub-branch if it has been fully explored
    784       }
    785       return success;
    786     }
    787 
    788     private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
    789       // if one of the new children leads to a final state then go there
    790       // otherwise choose a random child
    791       int selectedChildIdx = -1;
    792       // find first final state if there is one
    793       for (int i = 0; i < tree.children.Length; i++) {
    794         if (automaton.IsFinalState(tree.children[i].state)) {
    795           selectedChildIdx = i;
    796           break;
    797         }
    798       }
    799       // no final state -> select a the first child
    800       if (selectedChildIdx == -1) {
    801         selectedChildIdx = 0;
    802       }
    803       return tree.children[selectedChildIdx];
    804     }
     719    }                                           
    805720
    806721    // scales data and extracts values from dataset into arrays
     
    869784      automaton.GetCode(out code, out nParams);
    870785      return Disassembler.CodeToString(code);
    871     }
    872 
    873 
    874     private static string WriteStatistics(Tree tree, State state) {
    875       var sb = new System.IO.StringWriter();
    876       sb.Write("{0}\t{1:N5}\t", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
    877       if (state.children.ContainsKey(tree)) {
    878         foreach (var ch in state.children[tree]) {
    879           sb.Write("{0}\t{1:N5}\t", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
    880         }
    881       }
    882       sb.WriteLine();
    883       return sb.ToString();
    884786    }
    885787
     
    899801
    900802    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
    901       var avgNodeQ = tree.actionStatistics.AverageQuality;
    902       var tries = tree.actionStatistics.Tries;
    903       if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    904       var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    905       hue = 0.0;
    906 
    907       sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
     803      var tries = tree.visits;
     804
     805      sb.AppendFormat("{0} [label=\"{1}\"]; ", parentId, tries).AppendLine();
    908806
    909807      var list = new List<Tuple<int, int, Tree>>();
     
    911809        foreach (var ch in state.children[tree]) {
    912810          nextId++;
    913           avgNodeQ = ch.actionStatistics.AverageQuality;
    914           tries = ch.actionStatistics.Tries;
    915           if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    916           hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    917           hue = 0.0;
    918           sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
    919           sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine();
     811          tries = ch.visits;
     812          sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
     813          sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", parentId, nextId, ch.expr).AppendLine();
    920814          list.Add(Tuple.Create(tries, nextId, ch));
    921815        }
     
    927821            var chch = state.children[ch].First();
    928822            nextId++;
    929             avgNodeQ = chch.actionStatistics.AverageQuality;
    930             tries = chch.actionStatistics.Tries;
    931             if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    932             hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    933             hue = 0.0;
    934             sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
    935             sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine();
     823            tries = chch.visits;
     824            sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
     825            sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", chId, nextId, chch.expr).AppendLine();
    936826          }
    937827        }
     
    957847        if (!nodeIds.TryGetValue(parent, out parentId)) {
    958848          parentId = nodeIds.Count + 1;
    959           var avgNodeQ = parent.actionStatistics.AverageQuality;
    960           var tries = parent.actionStatistics.Tries;
    961           if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    962           var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    963           hue = 0.0;
    964           if (parent.actionStatistics.Tries > threshold)
    965             sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
     849          var tries = parent.visits;
     850          if (tries > threshold)
     851            sb.Write("{0} [label=\"{1}\"]; ", parentId, tries);
    966852          nodeIds.Add(parent, parentId);
    967853        }
     
    972858            nodeIds.Add(child, childId);
    973859          }
    974           var avgNodeQ = child.actionStatistics.AverageQuality;
    975           var tries = child.actionStatistics.Tries;
     860          var tries = child.visits;
    976861          if (tries < 1) continue;
    977           if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    978           var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    979           hue = 0.0;
    980862          if (tries > threshold) {
    981             sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
     863            sb.Write("{0} [label=\"{1}\"]; ", childId, tries);
    982864            var edgeLabel = child.expr;
    983865            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
    984             sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
     866            sb.Write("{0} -> {1} [label=\"{2}\"]", parentId, childId, edgeLabel);
    985867          }
    986868        }
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Tree.cs

    r15414 r15438  
    1919 */
    2020#endregion
    21 
    22 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    23 
     21                                                     
    2422namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    25   // represents tree nodes for the search tree in MCTS
    2623  internal class Tree {
    2724    public int state;
    2825    public int level;
    2926    public string expr;
    30     public bool Done {
    31       get { return actionStatistics.Done; }
    32       set { actionStatistics.Done = value; }
    33     }
    34     public IActionStatistics actionStatistics;
    35     public Tree[] children;
     27    public bool Done { get; set; }
     28    public int visits;
     29    //   {
     30    //   get { return actionStatistics.Done; }
     31    //   set { actionStatistics.Done = value; }
     32    // }
     33    // public IActionStatistics actionStatistics;
     34    // public Tree[] children;
    3635  }
    3736}
  • branches/MCTS-SymbReg-2796/Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/MctsSymbolicRegressionTest.cs

    r15437 r15438  
    33using System.Linq;
    44using System.Threading;
    5 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    65using HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg;
    76using HeuristicLab.Data;
     
    260259          }
    261260
    262           Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, c, z) > 0.05);
    263           Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z, z) < 0.05);
     261          Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, c, t) > 0.05);
     262          Assert.IsTrue(Heuristics.CorrelationForInteraction(x, y, z, t) < 0.05);
    264263
    265264          /* we might see correlations when only using one of the two relevant factors.
     
    271270          Assert.IsTrue(Heuristics.CorrelationForInteraction(b, y, z) < 0.05);
    272271          */
    273           Console.WriteLine("a,b: {0:N3}\tx,y: {1:N3}\ta,x: {2:N3}\tb,x: {3:N3}\ta,y: {4:N3}\tb,y: {5:N3}\tcov(a,b): {6:N3}",
    274             Heuristics.CorrelationForInteraction(a, b, z),
    275             Heuristics.CorrelationForInteraction(x, y, z),
    276             Heuristics.CorrelationForInteraction(a, x, z),
    277             Heuristics.CorrelationForInteraction(b, x, z),
    278             Heuristics.CorrelationForInteraction(a, y, z),
    279             Heuristics.CorrelationForInteraction(b, y, z),
    280             alglib.cov2(a, b)
     272          Console.WriteLine("a,b,c: {0:N3}\tx,y,z: {1:N3}\ta,b,x: {2:N3}\tb,c,x: {3:N3}",
     273            Heuristics.CorrelationForInteraction(a, b, c, t),
     274            Heuristics.CorrelationForInteraction(x, y, z, t),
     275            Heuristics.CorrelationForInteraction(a, b, x, t),
     276            Heuristics.CorrelationForInteraction(b, c, x, t)
    281277            );
    282278        }
     279      }
     280    }
     281
     282    [TestMethod]
     283    [TestCategory("Algorithms.DataAnalysis")]
     284    [TestProperty("Time", "short")]
     285    public void TestPoly10Interactions() {
     286      {
     287        alglib.hqrndstate randState;
     288        alglib.hqrndseed(1234, 31415, out randState);
     289
     290        int N = 25000; // large sample size to make sure the test thresholds hold
     291        double[] a = new double[N];
     292        double[] b = new double[N];
     293        double[] c = new double[N];
     294        double[] d = new double[N];
     295        double[] e = new double[N];
     296        double[] f = new double[N];
     297        double[] g = new double[N];
     298        double[] h = new double[N];
     299        double[] i = new double[N];
     300        double[] j = new double[N];
     301        double[] y = new double[N];
     302
     303        for(int k=0;k<N;k++) {
     304          a[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     305          b[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     306          c[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     307          d[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     308          e[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     309          f[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     310          g[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     311          h[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     312          i[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     313          j[k] = alglib.hqrnduniformr(randState) * 2 - 1;
     314          y[k] = a[k] * b[k] + c[k] * d[k] + e[k] * f[k] + a[k] * g[k] * i[k] + c[k] * f[k] * j[k];
     315        }
     316
     317        var x = new[] { a, b, c, d, e, f, g, h, i, j };
     318        var all2Combinations = HeuristicLab.Common.EnumerableExtensions.Combinations(new[] {1,2,3,4,5,6,7,8,9,10}, 2);
     319
     320        var resultList = new List<Tuple<string, double>>();
     321        foreach(var entry in all2Combinations) {
     322          var aIdx = entry.First();
     323          var bIdx = entry.Skip(1).First();
     324          resultList.Add(Tuple.Create(aIdx + " " + bIdx, Heuristics.CorrelationForInteraction(x[aIdx - 1], x[bIdx - 1], y)));
     325        }
     326
     327        foreach(var entry in resultList.OrderByDescending(t => t.Item2)) {
     328          Console.WriteLine("{0} {1:N3}", entry.Item1, entry.Item2);
     329        }
     330
     331        var all3Combinations = HeuristicLab.Common.EnumerableExtensions.Combinations(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, 3);
     332
     333        resultList = new List<Tuple<string, double>>();
     334        foreach (var entry in all3Combinations) {
     335          var aIdx = entry.First();
     336          var bIdx = entry.Skip(1).First();
     337          var cIdx = entry.Skip(2).First();
     338          resultList.Add(Tuple.Create(aIdx + " " + bIdx + " " + cIdx, Heuristics.CorrelationForInteraction(x[aIdx - 1], x[bIdx - 1], x[cIdx - 1], y)));
     339        }
     340
     341        //  Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10
     342
     343        foreach (var entry in resultList.OrderByDescending(t => t.Item2)) {
     344          Console.WriteLine("{0} {1:N3}", entry.Item1, entry.Item2);
     345        }
     346
     347
     348        Assert.IsTrue(Heuristics.CorrelationForInteraction(a, b, y) > 0.01);
     349        Assert.IsTrue(Heuristics.CorrelationForInteraction(b, a, y) > 0.01);
     350        Assert.IsTrue(Heuristics.CorrelationForInteraction(c, d, y) > 0.01);
     351        Assert.IsTrue(Heuristics.CorrelationForInteraction(d, c, y) > 0.01);
     352        Assert.IsTrue(Heuristics.CorrelationForInteraction(e, f, y) > 0.01);
     353        Assert.IsTrue(Heuristics.CorrelationForInteraction(f, e, y) > 0.01);
     354        Assert.IsTrue(Heuristics.CorrelationForInteraction(a, g, i, y) > 0.01);
     355        Assert.IsTrue(Heuristics.CorrelationForInteraction(a, i, g, y) > 0.01);
     356        Assert.IsTrue(Heuristics.CorrelationForInteraction(g, a, i, y) > 0.01);
     357        Assert.IsTrue(Heuristics.CorrelationForInteraction(g, i, a, y) > 0.01);
     358        Assert.IsTrue(Heuristics.CorrelationForInteraction(i, g, a, y) > 0.01);
     359        Assert.IsTrue(Heuristics.CorrelationForInteraction(i, a, g, y) > 0.01);
     360
     361        Assert.IsTrue(Heuristics.CorrelationForInteraction(c, f, j, y) > 0.01);
     362        Assert.IsTrue(Heuristics.CorrelationForInteraction(c, j, f, y) > 0.01);
     363        Assert.IsTrue(Heuristics.CorrelationForInteraction(f, c, j, y) > 0.01);
     364        Assert.IsTrue(Heuristics.CorrelationForInteraction(f, j, c, y) > 0.01);
     365        Assert.IsTrue(Heuristics.CorrelationForInteraction(j, c, f, y) > 0.01);
     366        Assert.IsTrue(Heuristics.CorrelationForInteraction(j, f, c, y) > 0.01);
    283367      }
    284368    }
Note: See TracChangeset for help on using the changeset viewer.