Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/09/17 18:53:34 (7 years ago)
Author:
gkronber
Message:

#2796 worked on MCTS

Location:
branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression
Files:
1 added
7 edited

Legend:

Unmodified
Added
Removed
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs

    r15410 r15414  
    398398        v == StateLogTFEnd ||
    399399        v == StateInvTFEnd ||
    400         v == StateExpFEnd;
     400        v == StateExpFEnd
     401        ;
    401402    }
    402403
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Disassembler.cs

    r15410 r15414  
    5858        switch (op) {
    5959          case (byte)OpCodes.Add: sb.Append(" + "); break;
    60           case (byte)OpCodes.Mul: sb.Append(" * "); break;
    61           case (byte)OpCodes.LoadConst1: sb.Append(" 1 "); break;
    62           case (byte)OpCodes.LoadConst0: sb.Append(" 0 "); break;
    63           case (byte)OpCodes.LoadParamN: sb.AppendFormat(" c "); break;
     60          case (byte)OpCodes.Mul: sb.Append(""); break;
     61          case (byte)OpCodes.LoadConst1: break;
     62          case (byte)OpCodes.LoadConst0: break;
     63          case (byte)OpCodes.LoadParamN: break;
    6464          case (byte)OpCodes.LoadVar: {
    6565              short arg = (short)((code[pc] << 8) | code[pc + 1]);
    6666              pc += 2;
    67               sb.AppendFormat(" var{0} ", arg); break;
     67              sb.AppendFormat("{0}", (char)('a'+arg)); break;
    6868            }
    6969          case (byte)OpCodes.Exp: sb.Append(" exp "); break;
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/ExprHash.cs

    r15404 r15414  
    2020#endregion
    2121using System;
     22using System.Collections.Generic;
    2223using System.Diagnostics.Contracts;
    2324using HeuristicLab.Random;
     25using System.Linq;
    2426
    2527namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
     
    4042  // We only need to identify equivalent structures. The numeric constants are irrelevant.
    4143  // Therefore, equivalent structures with different numeric constants map to the same hash code.
    42  
     44
    4345  public static class ExprHash {
    4446    const int MaxStackSize = 100;
    4547    const int MaxVariables = 1000;
    46     private static double[] varSymbValues; 
     48    private static double[] varSymbValues;
    4749    static ExprHash() {
    4850      var rand = new MersenneTwister();
    4951      varSymbValues = new double[MaxVariables];
    50       for(int i = 0;i < MaxVariables;i++) {
     52      for (int i = 0; i < MaxVariables; i++) {
    5153        varSymbValues[i] = rand.NextDouble();
    5254      }
     
    5456
    5557    public static ulong GetHash(byte[] code, int nParams) {
    56       var bits = (ulong)BitConverter.DoubleToInt64Bits(Exec(code, nParams));
     58      var bits = (ulong)BitConverter.DoubleToInt64Bits(Eval(code, nParams));
    5759      // clear last five bits (insignificant?)
    5860      bits = bits & 0xFFFFFFFFFFFFFFE0;
     
    6062    }
    6163
    62     private static double Exec(byte[] code, int nParams) {
    63       var stack = new double[MaxStackSize];
     64    private static double Eval(byte[] code, int nParams) {
     65      // The hash code calculation already preserves commutativity, associativity and distributivity of operations.
     66      // However, we also need to hash c1*x1 + c2*x1 to the same value as c3*x1.
     67      // Similarly for x1*x2 + x1*x2 or log(x1) + log(x1)!
     68
     69      // Calculate sums lazily. Keep all terms and only when the actual sum is necessary remove duplicate terms and calculate sum
     70      // think about speed later (TODO)
     71
     72      var stack = new ISet<double>[MaxStackSize];
     73      var terms = new HashSet<double>(new ApproximateDoubleEqualityComparer()); // the set of arguments for the current operator (+, *)
    6474      int topOfStack = -1;
    6575      int pc = 0;
     
    7383          case OpCodes.LoadConst0: {
    7484              ++topOfStack;
    75               stack[topOfStack] = 0.0;
     85              stack[topOfStack] = new HashSet<double>( new[] { 0.0 });
     86
     87              // terms.Add(0.0); // ignore numeric constants in expr-hash
     88
    7689              break;
    7790            }
    7891          case OpCodes.LoadConst1: {
    7992              ++topOfStack;
    80               stack[topOfStack] = 1.0;
     93              stack[topOfStack] = new HashSet<double>(new[] { 1.0 });
     94
     95              // args.Add(1.0); ignore numeric constants in expr-hash
     96
    8197              break;
    8298            }
    8399          case OpCodes.LoadParamN: {
    84100              ++topOfStack;
    85               stack[topOfStack] = 1.0;
     101              stack[topOfStack] =  new HashSet<double>(new[] { 1.0 });
    86102              break;
    87103            }
    88104          case OpCodes.LoadVar: {
    89105              ++topOfStack;
    90               stack[topOfStack] = varSymbValues[arg];
     106              stack[topOfStack] = new HashSet<double>(new[] { varSymbValues[arg] });
     107
     108              // args.Add(varSymbValues[arg]);
     109
    91110              break;
    92111            }
    93112          case OpCodes.Add: {
    94               var t1 = stack[topOfStack - 1];
    95               var t2 = stack[topOfStack];
     113              // take arguments from stack and put both terms into the set of terms
     114              // for every other operation we need to evaluate the sum of terms first and put it onto the stack (lazy eval of sums)
     115
     116              stack[topOfStack - 1].UnionWith(stack[topOfStack]);
    96117              topOfStack--;
    97               stack[topOfStack] = t1 + t2;
     118
     119              // stack[topOfStack] = t1 + t2; (later)
    98120              break;
    99121            }
     
    102124              var t2 = stack[topOfStack];
    103125              topOfStack--;
    104               stack[topOfStack] = t1 * t2;
     126              stack[topOfStack] = new HashSet<double>(new double[] { t1.Sum() * t2.Sum() });
    105127              break;
    106128            }
    107129          case OpCodes.Log: {
    108130              var v1 = stack[topOfStack];
    109               stack[topOfStack] = Math.Log(v1);
     131              stack[topOfStack] = new HashSet<double>(new double[] { Math.Log( v1.Sum()) });
    110132              break;
    111133            }
    112134          case OpCodes.Exp: {
    113135              var v1 = stack[topOfStack];
    114               stack[topOfStack] = Math.Exp(v1);
     136              stack[topOfStack] = new HashSet<double>(new double[] { Math.Exp(v1.Sum()) });
    115137              break;
    116138            }
    117139          case OpCodes.Inv: {
    118140              var v1 = stack[topOfStack];
    119               stack[topOfStack] = 1.0 / v1;
     141              stack[topOfStack] = new HashSet<double>(new double[] { 1.0 / v1.Sum() });
    120142              break;
    121143            }
    122144          case OpCodes.Exit:
    123145            Contract.Assert(topOfStack == 0);
    124             return stack[topOfStack];
     146            return stack[topOfStack].Sum();
    125147        }
    126148      }
     149    }
     150
     151    private static void EvalTerms(HashSet<double> terms, double[] stack, ref int topOfStack) {
     152      ++topOfStack;
     153      stack[topOfStack] = terms.Sum();
     154      terms.Clear();
    127155    }
    128156
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r15410 r15414  
    2525using System.Diagnostics.Contracts;
    2626using System.Linq;
     27using System.Text;
    2728using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    2829using HeuristicLab.Core;
     
    5859    //       weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a
    5960    //       branch and less in the expected value. (--> Review "Extreme Bandit" literature again)
    60     // TODO: Constraint handling is too restrictive!  E.g. for Poly-10, if MCTS identifies the term x3*x4 first it is
    61     //       not possible to add the term x1*x2 later on. The same is true for individual terms after x2 it is not
    62     //       possible to multiply x1. It is easy to get stuck. Why do we actually need the current way of constraint handling?
    63     //       It would probably be easier to use some kind of hashing to identify equivalent expressions in the tree.
    64     // TODO: State unification (using hashing) is partially done. The hashcode calculation should be improved to also detect that
    65     //       c*x1 + c*x1*x1 + c*x1 is the same as c*x1 + c*x1*x1
     61    // TODO: Solve Poly-10
    6662    // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
     63    // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
    6764    // TODO: check if transformation of y is correct and works (Obj 2)
    6865    // TODO: The algorithm is not invariant to location and scale of variables.
     
    7168    // TODO: support e(-x) and possibly (1/-x) (Obj 1)
    7269    // TODO: is it OK to initialize all constants to 1 (Obj 2)?
     70    // TODO: improve memory usage
    7371    #region static API
    7472
     
    179177        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
    180178
    181         this.automaton = new Automaton(x, new SimpleConstraintHandler(100), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
     179        this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
    182180        this.treePolicy = treePolicy ?? new Ucb();
    183181        this.tree = new Tree() {
    184182          state = automaton.CurrentState,
    185183          actionStatistics = treePolicy.CreateActionStatistics(),
    186           expr = ""
     184          expr = "",
     185          level = 0
    187186        };
    188187
     
    497496      mctsState.effectiveRollouts++;
    498497
    499       if (mctsState.effectiveRollouts % 10 == 1) Console.WriteLine(WriteTree(tree));
     498      if (mctsState.effectiveRollouts % 10 == 1) {
     499        //Console.WriteLine(WriteTree(tree));
     500        //Console.WriteLine(TraceTree(tree));
     501      }
    500502      return q;
    501503    }
     
    520522      // We do this until we reach a complete expression (final state)
    521523
    522       // Loops in the graph are possible! (Problem?)
     524      // Loops in the graph are prevented by checking that the level of a child must be larger than the level of the parent
    523525      // Sub-graphs which have been completely searched are marked as done.
    524526      // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
     
    526528      while (!automaton.IsFinalState(automaton.CurrentState)) {
    527529        if (children.ContainsKey(tree)) {
     530          if (children[tree].All(ch => ch.Done)) {
     531            tree.Done = true;
     532            break;
     533          }
    528534          // ROLLOUT INSIDE TREE
    529535          // UCT selection within tree
     
    540546          int nFs;
    541547          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
    542           while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0])) {
     548          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
    543549            automaton.Goto(possibleFollowStates[0]);
    544550            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
     
    551557          int nFs;
    552558          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
    553           while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0])) {
     559          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
    554560            // no alternatives -> just go to the next state
    555561            automaton.Goto(possibleFollowStates[0]);
     
    565571          for (int i = 0; i < nFs; i++) {
    566572            Tree child = null;
    567             // for selected states we introduce state unification (detection of equivalent states)
     573            // for selected states (EvalStates) we introduce state unification (detection of equivalent states)
    568574            if (automaton.IsEvalState(possibleFollowStates[i])) {
    569575              var hc = Hashcode(automaton);
     
    573579                  state = possibleFollowStates[i],
    574580                  actionStatistics = treePolicy.CreateActionStatistics(),
    575                   expr = ExprStr(automaton)
     581                  expr = string.Empty, // ExprStr(automaton),
     582                  level = tree.level + 1
    576583                };
    577584                nodes.Add(hc, child);
    578               } else {
     585              }
     586              // only allow forward edges (don't add the child if we would go back in the graph)
     587              else if (child.level > tree.level) {
    579588                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
    580589                // to all parents
    581590                BackpropagateStatistics(child.actionStatistics, tree);
     591              } else {
     592                // prevent cycles
     593                Debug.Assert(child.level <= tree.level);
     594                child = null;
    582595              }
    583596            } else {
     
    586599                state = possibleFollowStates[i],
    587600                actionStatistics = treePolicy.CreateActionStatistics(),
    588                 expr = ExprStr(automaton)
     601                expr = string.Empty, // ExprStr(automaton),
     602                level = tree.level + 1
    589603              };
    590604            }
    591             newChildren.Add(child);
     605            if (child != null)
     606              newChildren.Add(child);
     607          }
     608
     609          if (!newChildren.Any()) {
     610            // stuck in a dead end (no final state and no allowed follow states)
     611            tree.Done = true;
     612            break;
    592613          }
    593614
     
    599620          }
    600621
     622
    601623          // follow one of the children
    602624          tree = SelectFinalOrRandom2(automaton, tree, rand);
     
    610632      if (automaton.IsFinalState(automaton.CurrentState)) {
    611633        tree.Done = true;
     634        tree.expr = ExprStr(automaton);
    612635        byte[] code; int nParams;
    613636        automaton.GetCode(out code, out nParams);
     
    636659      // EXPERIMENTAL!
    637660      // optimal result: q = 1 -> return huge value
    638       if (q >= 1.0) return 1E16;
    639       // return number of 9s in R²
    640       return -Math.Log10(1 - q);
     661      // if (q >= 1.0) return 1E16;
     662      // // return number of 9s in R²
     663      // return -Math.Log10(1 - q);
    641664    }
    642665
     
    852875      return sb.ToString();
    853876    }
     877
     878    private static string TraceTree(Tree tree) {
     879      var sb = new StringBuilder();
     880      sb.Append(
     881@"digraph {
     882  ratio = fill;
     883  node [style=filled];
     884");
     885      int nodeId = 0;
     886
     887      TraceTreeRec(tree, 0, sb, ref nodeId);
     888      sb.Append("}");
     889      return sb.ToString();
     890    }
     891
     892    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId) {
     893      var avgNodeQ = tree.actionStatistics.AverageQuality;
     894      var tries = tree.actionStatistics.Tries;
     895      if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
     896      var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
     897
     898      sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
     899
     900      var list = new List<Tuple<int, int, Tree>>();
     901      if (children.ContainsKey(tree)) {
     902        foreach (var ch in children[tree]) {
     903          nextId++;
     904          avgNodeQ = ch.actionStatistics.AverageQuality;
     905          tries = ch.actionStatistics.Tries;
     906          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
     907          hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
     908          sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
     909          sb.AppendFormat("{0} -> {1}", parentId, nextId, avgNodeQ).AppendLine();
     910          list.Add(Tuple.Create(tries, nextId, ch));
     911        }
     912        foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
     913          TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId);
     914        }
     915      }
     916    }
     917
    854918    private static string WriteTree(Tree tree) {
    855919      var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
     
    860924  node [style=filled];
    861925");
    862       foreach(var kvp in children) {
     926      int threshold = nodes.Count > 500 ? 10 : 0;
     927      foreach (var kvp in children) {
    863928        var parent = kvp.Key;
    864929        int parentId;
    865         if(!nodeIds.TryGetValue(parent, out parentId)) {
     930        if (!nodeIds.TryGetValue(parent, out parentId)) {
    866931          parentId = nodeIds.Count + 1;
    867           var avgNodeQ = parent.actionStatistics.AverageQuality; 
     932          var avgNodeQ = parent.actionStatistics.AverageQuality;
    868933          var tries = parent.actionStatistics.Tries;
    869934          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    870           var hue = (1 - avgNodeQ) / 255.0 * 240.0; // 0 equals red, 240 equals blue
    871           sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
     935          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
     936          if (parent.actionStatistics.Tries > threshold)
     937            sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
    872938          nodeIds.Add(parent, parentId);
    873939        }
    874         foreach(var child in kvp.Value) {
     940        foreach (var child in kvp.Value) {
    875941          int childId;
    876           if(!nodeIds.TryGetValue(child, out childId)) {
     942          if (!nodeIds.TryGetValue(child, out childId)) {
    877943            childId = nodeIds.Count + 1;
    878944            nodeIds.Add(child, childId);
     
    882948          if (tries < 1) continue;
    883949          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    884           var hue = (1 - avgNodeQ) / 255.0 * 240.0; // 0 equals red, 240 equals blue
    885           sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
    886           var edgeLabel = child.expr;
    887           if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
    888           sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
     950          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
     951          if (tries > threshold) {
     952            sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
     953            var edgeLabel = child.expr;
     954            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
     955            sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
     956          }
    889957        }
    890958      }
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs

    r15410 r15414  
    11using System;
    22using System.Collections.Generic;
     3using System.Diagnostics;
    34using System.Diagnostics.Contracts;
    45using System.Linq;
     
    8182        return buf[rand.Next(buf.Count)];
    8283      }
    83       Contract.Assert(totalTries > 0);
     84      Debug.Assert(totalTries > 0);
    8485      double logTotalTries = Math.Log(totalTries);
    8586      var bestQ = double.NegativeInfinity;
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/SimpleConstraintHandler.cs

    r15410 r15414  
    3131
    3232    public bool IsAllowedFollowState(int currentState, int followState) {
    33       return numVariables < maxVariables || (
    34         followState != Automaton.StateExpFactorStart &&
    35         followState != Automaton.StateFactorStart &&
    36         followState != Automaton.StateTermStart &&
    37         followState != Automaton.StateLogTStart &&
    38         followState != Automaton.StateLogTFStart &&
    39         followState != Automaton.StateInvTStart &&
    40         followState != Automaton.StateInvTFStart);
     33      return numVariables < maxVariables ||
     34        // going to the final state is always allowed (smaller states are closer to the final state)
     35        currentState > followState;
     36        // (
     37        // followState != Automaton.StateExpFactorStart &&
     38        // followState != Automaton.StateFactorStart &&
     39        // followState != Automaton.StateTermStart &&
     40        // followState != Automaton.StateLogTStart &&
     41        // followState != Automaton.StateLogTFStart &&
     42        // followState != Automaton.StateInvTStart &&
     43        // followState != Automaton.StateInvTFStart);
    4144    }
    4245
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Tree.cs

    r15410 r15414  
    2626  internal class Tree {
    2727    public int state;
     28    public int level;
    2829    public string expr;
    2930    public bool Done {
Note: See TracChangeset for help on using the changeset viewer.