Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/06/17 17:52:36 (7 years ago)
Author:
gkronber
Message:

#2796 worked on MCTS (removing constraint handling and introducing state unification instead)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r15404 r15410  
    2222using System;
    2323using System.Collections.Generic;
     24using System.Diagnostics;
    2425using System.Diagnostics.Contracts;
    2526using System.Linq;
     
    5455    //   
    5556
     57    // TODO: Taking averages of R² values is probably not ideal as an improvement of R² from 0.99 to 0.999 should
     58    //       weight more than an improvement from 0.98 to 0.99. Also, we are more interested in the best value of a
     59    //       branch and less in the expected value. (--> Review "Extreme Bandit" literature again)
    5660    // TODO: Constraint handling is too restrictive!  E.g. for Poly-10, if MCTS identifies the term x3*x4 first it is
    5761    //       not possible to add the term x1*x2 later on. The same is true for individual terms after x2 it is not
    5862    //       possible to multiply x1. It is easy to get stuck. Why do we actually need the current way of constraint handling?
    5963    //       It would probably be easier to use some kind of hashing to identify equivalent expressions in the tree.
     64    // TODO: State unification (using hashing) is partially done. The hashcode calculation should be improved to also detect that
     65    //       c*x1 + c*x1*x1 + c*x1 is the same as c*x1 + c*x1*x1
     66    // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
    6067    // TODO: check if transformation of y is correct and works (Obj 2)
    6168    // TODO: The algorithm is not invariant to location and scale of variables.
     
    172179        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
    173180
    174         this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
     181        this.automaton = new Automaton(x, new SimpleConstraintHandler(100), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
    175182        this.treePolicy = treePolicy ?? new Ucb();
    176         this.tree = new Tree() { state = automaton.CurrentState, actionStatistics = treePolicy.CreateActionStatistics() };
     183        this.tree = new Tree() {
     184          state = automaton.CurrentState,
     185          actionStatistics = treePolicy.CreateActionStatistics(),
     186          expr = ""
     187        };
    177188
    178189        // reset best solution
     
    481492      do {
    482493        automaton.Reset();
    483         success = TryTreeSearchRec(rand, tree, automaton, eval, treePolicy, out q);
     494        success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, out q);
    484495        mctsState.totalRollouts++;
    485496      } while (!success && !tree.Done);
    486497      mctsState.effectiveRollouts++;
     498
     499      if (mctsState.effectiveRollouts % 10 == 1) Console.WriteLine(WriteTree(tree));
    487500      return q;
     501    }
     502
     503    private static Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
     504    private static Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
     505    private static Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
     506
     507
     508
     509    // search forward
     510    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
     511      out double q) {
     512      // ROLLOUT AND EXPANSION
     513      // We are navigating a graph (states might be reached via different paths) instead of a tree.
     514      // State equivalence is checked through ExprHash (based on the generated code through the path).
     515
     516      // We switch between rollout-mode and expansion mode
     517      // Rollout-mode means we are navigating an existing path through the tree (using a rollout policy, e.g. UCB)
     518      // Expansion mode means we expand the graph, creating new nodes and edges (using an expansion policy, e.g. shortest route to a complete expression)
     519      // In expansion mode we might re-enter the graph and switch back to rollout-mode
     520      // We do this until we reach a complete expression (final state)
     521
     522      // Loops in the graph are possible! (Problem?)
     523      // Sub-graphs which have been completely searched are marked as done.
     524      // Roll-out could lead to a state where all follow-states are done. In this case we call the rollout ineffective.
     525
     526      while (!automaton.IsFinalState(automaton.CurrentState)) {
     527        if (children.ContainsKey(tree)) {
     528          // ROLLOUT INSIDE TREE
     529          // UCT selection within tree
     530          int selectedIdx = 0;
     531          if (children[tree].Count > 1) {
     532            selectedIdx = treePolicy.Select(children[tree].Select(ch => ch.actionStatistics), rand);
     533          }
     534          tree = children[tree][selectedIdx];
     535
     536          // move the automaton forward until reaching the state
     537          // all steps where no alternatives are possible are immediately taken
     538          // TODO: simplification of the automaton
     539          int[] possibleFollowStates;
     540          int nFs;
     541          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
     542          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0])) {
     543            automaton.Goto(possibleFollowStates[0]);
     544            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
     545          }
     546          Debug.Assert(possibleFollowStates.Contains(tree.state));
     547          automaton.Goto(tree.state);
     548        } else {
     549          // EXPAND
     550          int[] possibleFollowStates;
     551          int nFs;
     552          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
     553          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0])) {
     554            // no alternatives -> just go to the next state
     555            automaton.Goto(possibleFollowStates[0]);
     556            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
     557          }
     558          if (nFs == 0) {
     559            // stuck in a dead end (no final state and no allowed follow states)
     560            tree.Done = true;
     561            break;
     562          }
     563          var newChildren = new List<Tree>(nFs);
     564          children.Add(tree, newChildren);
     565          for (int i = 0; i < nFs; i++) {
     566            Tree child = null;
     567            // for selected states we introduce state unification (detection of equivalent states)
     568            if (automaton.IsEvalState(possibleFollowStates[i])) {
     569              var hc = Hashcode(automaton);
     570              if (!nodes.TryGetValue(hc, out child)) {
     571                child = new Tree() {
     572                  children = null,
     573                  state = possibleFollowStates[i],
     574                  actionStatistics = treePolicy.CreateActionStatistics(),
     575                  expr = ExprStr(automaton)
     576                };
     577                nodes.Add(hc, child);
     578              } else {
     579                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
     580                // to all parents
     581                BackpropagateStatistics(child.actionStatistics, tree);
     582              }
     583            } else {
     584              child = new Tree() {
     585                children = null,
     586                state = possibleFollowStates[i],
     587                actionStatistics = treePolicy.CreateActionStatistics(),
     588                expr = ExprStr(automaton)
     589              };
     590            }
     591            newChildren.Add(child);
     592          }
     593
     594          foreach (var ch in newChildren) {
     595            if (!parents.ContainsKey(ch)) {
     596              parents.Add(ch, new List<Tree>());
     597            }
     598            parents[ch].Add(tree);
     599          }
     600
     601          // follow one of the children
     602          tree = SelectFinalOrRandom2(automaton, tree, rand);
     603          automaton.Goto(tree.state);
     604        }
     605      }
     606
     607      bool success;
     608
     609      // EVALUATE TREE
     610      if (automaton.IsFinalState(automaton.CurrentState)) {
     611        tree.Done = true;
     612        byte[] code; int nParams;
     613        automaton.GetCode(out code, out nParams);
     614        q = eval(code, nParams);
     615        q = TransformQuality(q);
     616        success = true;
     617      } else {
     618        // we got stuck in roll-out (not evaluation necessary!)
     619        q = 0.0;
     620        success = false;
     621      }
     622
     623      // RECURSIVELY BACKPROPAGATE RESULTS TO ALL PARENTS
     624      // Update statistics
     625      // Set branch to done if all children are done.
     626      BackpropagateQuality(tree, q, treePolicy);
     627
     628      return success;
     629    }
     630
     631
     632    private static double TransformQuality(double q) {
     633      // no transformation
     634      return q;
     635
     636      // EXPERIMENTAL!
     637      // optimal result: q = 1 -> return huge value
     638      if (q >= 1.0) return 1E16;
     639      // return number of 9s in R²
     640      return -Math.Log10(1 - q);
     641    }
     642
     643    // backpropagate existing statistics to all parents
     644    private static void BackpropagateStatistics(IActionStatistics stats, Tree tree) {
     645      tree.actionStatistics.Add(stats);
     646      if (parents.ContainsKey(tree)) {
     647        foreach (var parent in parents[tree]) {
     648          BackpropagateStatistics(stats, parent);
     649        }
     650      }
     651    }
     652
     653    private static ulong Hashcode(Automaton automaton) {
     654      byte[] code;
     655      int nParams;
     656      automaton.GetCode(out code, out nParams);
     657      return ExprHash.GetHash(code, nParams);
     658    }
     659
     660    private static void BackpropagateQuality(Tree tree, double q, IPolicy policy) {
     661      if (q > 0) policy.Update(tree.actionStatistics, q);
     662      if (children.ContainsKey(tree) && children[tree].All(ch => ch.Done)) {
     663        tree.Done = true;
     664        // children[tree] = null; keep all nodes
     665      }
     666
     667      if (parents.ContainsKey(tree)) {
     668        foreach (var parent in parents[tree]) {
     669          BackpropagateQuality(parent, q, policy);
     670        }
     671      }
     672    }
     673
     674    private static Tree SelectFinalOrRandom2(Automaton automaton, Tree tree, IRandom rand) {
     675      // if one of the new children leads to a final state then go there
     676      // otherwise choose a random child
     677      int selectedChildIdx = -1;
     678      // find first final state if there is one
     679      var children = MctsSymbolicRegressionStatic.children[tree];
     680      for (int i = 0; i < children.Count; i++) {
     681        if (automaton.IsFinalState(children[i].state)) {
     682          selectedChildIdx = i;
     683          break;
     684        }
     685      }
     686      // no final state -> select the first child
     687      if (selectedChildIdx == -1) {
     688        selectedChildIdx = 0;
     689      }
     690      return children[selectedChildIdx];
    488691    }
    489692
     
    522725          tree.children = new Tree[nFs];
    523726          for (int i = 0; i < tree.children.Length; i++)
    524             tree.children[i] = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics() };
     727            tree.children[i] = new Tree() {
     728              children = null,
     729              state = possibleFollowStates[i],
     730              actionStatistics = treePolicy.CreateActionStatistics()
     731            };
    525732
    526733          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
     
    624831      }
    625832    }
     833
     834    // for debugging only
     835
     836
     837    private static string ExprStr(Automaton automaton) {
     838      byte[] code;
     839      int nParams;
     840      automaton.GetCode(out code, out nParams);
     841      return Disassembler.CodeToString(code);
     842    }
     843
     844    private static string WriteStatistics(Tree tree) {
     845      var sb = new System.IO.StringWriter();
     846      sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
     847      if (children.ContainsKey(tree)) {
     848        foreach (var ch in children[tree]) {
     849          sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
     850        }
     851      }
     852      return sb.ToString();
     853    }
     854    private static string WriteTree(Tree tree) {
     855      var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
     856      var nodeIds = new Dictionary<Tree, int>();
     857      sb.Write(
     858@"digraph {
     859  ratio = fill;
     860  node [style=filled];
     861");
     862      foreach(var kvp in children) {
     863        var parent = kvp.Key;
     864        int parentId;
     865        if(!nodeIds.TryGetValue(parent, out parentId)) {
     866          parentId = nodeIds.Count + 1;
     867          var avgNodeQ = parent.actionStatistics.AverageQuality;
     868          var tries = parent.actionStatistics.Tries;
     869          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
     870          var hue = (1 - avgNodeQ) / 255.0 * 240.0; // 0 equals red, 240 equals blue
     871          sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
     872          nodeIds.Add(parent, parentId);
     873        }
     874        foreach(var child in kvp.Value) {
     875          int childId;
     876          if(!nodeIds.TryGetValue(child, out childId)) {
     877            childId = nodeIds.Count + 1;
     878            nodeIds.Add(child, childId);
     879          }
     880          var avgNodeQ = child.actionStatistics.AverageQuality;
     881          var tries = child.actionStatistics.Tries;
     882          if (tries < 1) continue;
     883          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
     884          var hue = (1 - avgNodeQ) / 255.0 * 240.0; // 0 equals red, 240 equals blue
     885          sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
     886          var edgeLabel = child.expr;
     887          if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
     888          sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
     889        }
     890      }
     891
     892      sb.Write("}");
     893      return sb.ToString();
     894    }
    626895  }
    627896}
Note: See TracChangeset for help on using the changeset viewer.