Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/28/17 19:56:51 (6 years ago)
Author:
gkronber
Message:

#2796 refactoring to simplify the code

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r15437 r15438  
    2626using System.Linq;
    2727using System.Text;
    28 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    2928using HeuristicLab.Core;
    3029using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     
    9998      internal readonly Tree tree;
    10099      internal readonly Func<byte[], int, double> evalFun;
    101       internal readonly IPolicy treePolicy;
    102100      // MCTS might get stuck. Track statistics on the number of effective rollouts
    103101      internal int totalRollouts;
     
    145143      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
    146144        int constOptIterations, double lambda,
    147         IPolicy treePolicy = null,
    148145        bool collectParetoOptimalModels = false,
    149146        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     
    187184        this.testEvaluator = new ExpressionEvaluator(testY.Length, lowerEstimationLimit, upperEstimationLimit);
    188185
    189         this.automaton = new Automaton(x, new SimpleConstraintHandler(maxVariables), allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
    190         this.treePolicy = treePolicy ?? new EpsilonGreedy();
     186        this.automaton = new Automaton(x, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms, maxVariables);
    191187        this.tree = new Tree() {
    192188          state = automaton.CurrentState,
    193           actionStatistics = treePolicy.CreateActionStatistics(),
    194189          expr = "",
    195190          level = 0
     
    469464    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
    470465      bool scaleVariables = true, int constOptIterations = -1, double lambda = 0.0,
    471       IPolicy policy = null,
    472466      bool collectParameterOptimalModels = false,
    473467      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     
    479473      ) {
    480474      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, lambda,
    481         policy, collectParameterOptimalModels,
     475        collectParameterOptimalModels,
    482476        lowerEstimationLimit, upperEstimationLimit,
    483477        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
     
    499493      var eval = mctsState.evalFun;
    500494      var rand = mctsState.random;
    501       var treePolicy = mctsState.treePolicy;
    502495      double q = 0;
    503496      bool success = false;
     
    505498
    506499        automaton.Reset();
    507         success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q);
     500        success = TryTreeSearchRec2(rand, tree, automaton, eval, mctsState, out q);
    508501        mctsState.totalRollouts++;
    509502      } while (!success && !tree.Done);
     
    517510
    518511    // search forward
    519     private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
     512    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton,
     513      Func<byte[], int, double> eval,
    520514      State state,
    521515      out double q) {
     
    545539          int selectedIdx = 0;
    546540          if (state.children[tree].Count > 1) {
    547             selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);
     541            selectedIdx = SelectInternal(state.children[tree], rand);
    548542          }
    549543
     
    579573              if (!state.nodes.TryGetValue(hc, out child)) {
    580574                child = new Tree() {
    581                   children = null,
    582575                  state = possibleFollowStates[i],
    583                   actionStatistics = treePolicy.CreateActionStatistics(),
    584576                  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
    585577                  level = tree.level + 1
     
    591583                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
    592584                // to all parents
    593                 BackpropagateStatistics(child.actionStatistics, tree, state);
     585                BackpropagateStatistics(tree, state, child.visits);
    594586              } else {
    595587                // prevent cycles
     
    599591            } else {
    600592              child = new Tree() {
    601                 children = null,
    602593                state = possibleFollowStates[i],
    603                 actionStatistics = treePolicy.CreateActionStatistics(),
    604594                expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
    605595                level = tree.level + 1
     
    639629        automaton.GetCode(out code, out nParams);
    640630        q = eval(code, nParams);
    641         // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);
    642631        success = true;
    643         BackpropagateQuality(tree, q, treePolicy, state);
     632        BackpropagateQuality(tree, q, state);
    644633      } else {
    645634        // we got stuck in roll-out (not evaluation necessary!)
    646         // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");
    647635        q = 0.0;
    648636        success = false;
     
    659647    }
    660648
     649    private static int SelectInternal(List<Tree> list, IRandom rand) {
     650      // choose a random node.
     651      Debug.Assert(list.Any(t => !t.Done));
     652
     653      var idx = rand.Next(list.Count);
     654      while(list[idx].Done) { idx = rand.Next(list.Count); }
     655      return idx;
     656    }
     657
    661658    // backpropagate existing statistics to all parents
    662     private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) {
    663       tree.actionStatistics.Add(stats);
     659    private static void BackpropagateStatistics(Tree tree, State state, int numVisits) {
     660      tree.visits += numVisits;
     661
    664662      if (state.parents.ContainsKey(tree)) {
    665663        foreach (var parent in state.parents[tree]) {
    666           BackpropagateStatistics(stats, parent, state);
     664          BackpropagateStatistics(parent, state, numVisits);
    667665        }
    668666      }
     
    676674    }
    677675
    678     private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) {
    679       policy.Update(tree.actionStatistics, q);
     676    private static void BackpropagateQuality(Tree tree, double q, State state) {
     677      tree.visits++;
     678      // TODO: q is ignored for now
    680679
    681680      if (state.parents.ContainsKey(tree)) {
    682681        foreach (var parent in state.parents[tree]) {
    683           BackpropagateQuality(parent, q, policy, state);
     682          BackpropagateQuality(parent, q, state);
    684683        }
    685684      }
     
    718717      }
    719718      return children[selectedChildIdx];
    720     }
    721 
    722     // tree search might fail because of constraints for expressions
    723     // in this case we get stuck we just restart
    724     // see ConstraintHandler.cs for more info
    725     private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
    726       out double q) {
    727       Tree selectedChild = null;
    728       Contract.Assert(tree.state == automaton.CurrentState);
    729       Contract.Assert(!tree.Done);
    730       if (tree.children == null) {
    731         if (automaton.IsFinalState(tree.state)) {
    732           // final state
    733           tree.Done = true;
    734 
    735           // EVALUATE
    736           byte[] code; int nParams;
    737           automaton.GetCode(out code, out nParams);
    738           q = eval(code, nParams);
    739 
    740           treePolicy.Update(tree.actionStatistics, q);
    741           return true; // we reached a final state
    742         } else {
    743           // EXPAND
    744           int[] possibleFollowStates = new int[1000];
    745           int nFs;
    746           automaton.FollowStates(automaton.CurrentState, ref possibleFollowStates, out nFs);
    747           if (nFs == 0) {
    748             // stuck in a dead end (no final state and no allowed follow states)
    749             q = 0;
    750             tree.Done = true;
    751             tree.children = null;
    752             return false;
    753           }
    754           tree.children = new Tree[nFs];
    755           for (int i = 0; i < tree.children.Length; i++)
    756             tree.children[i] = new Tree() {
    757               children = null,
    758               state = possibleFollowStates[i],
    759               actionStatistics = treePolicy.CreateActionStatistics()
    760             };
    761 
    762           selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
    763         }
    764       } else {
    765         // tree.children != null
    766         // UCT selection within tree
    767         int selectedIdx = 0;
    768         if (tree.children.Length > 1) {
    769           selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
    770         }
    771         selectedChild = tree.children[selectedIdx];
    772       }
    773       // make selected step and recurse
    774       automaton.Goto(selectedChild.state);
    775       var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
    776       if (success) {
    777         // only update if successful
    778         treePolicy.Update(tree.actionStatistics, q);
    779       }
    780 
    781       tree.Done = tree.children.All(ch => ch.Done);
    782       if (tree.Done) {
    783         tree.children = null; // cut off the sub-branch if it has been fully explored
    784       }
    785       return success;
    786     }
    787 
    788     private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
    789       // if one of the new children leads to a final state then go there
    790       // otherwise choose a random child
    791       int selectedChildIdx = -1;
    792       // find first final state if there is one
    793       for (int i = 0; i < tree.children.Length; i++) {
    794         if (automaton.IsFinalState(tree.children[i].state)) {
    795           selectedChildIdx = i;
    796           break;
    797         }
    798       }
    799       // no final state -> select a the first child
    800       if (selectedChildIdx == -1) {
    801         selectedChildIdx = 0;
    802       }
    803       return tree.children[selectedChildIdx];
    804     }
     719    }                                           
    805720
    806721    // scales data and extracts values from dataset into arrays
     
    869784      automaton.GetCode(out code, out nParams);
    870785      return Disassembler.CodeToString(code);
    871     }
    872 
    873 
    874     private static string WriteStatistics(Tree tree, State state) {
    875       var sb = new System.IO.StringWriter();
    876       sb.Write("{0}\t{1:N5}\t", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
    877       if (state.children.ContainsKey(tree)) {
    878         foreach (var ch in state.children[tree]) {
    879           sb.Write("{0}\t{1:N5}\t", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
    880         }
    881       }
    882       sb.WriteLine();
    883       return sb.ToString();
    884786    }
    885787
     
    899801
    900802    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
    901       var avgNodeQ = tree.actionStatistics.AverageQuality;
    902       var tries = tree.actionStatistics.Tries;
    903       if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    904       var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    905       hue = 0.0;
    906 
    907       sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
     803      var tries = tree.visits;
     804
     805      sb.AppendFormat("{0} [label=\"{1}\"]; ", parentId, tries).AppendLine();
    908806
    909807      var list = new List<Tuple<int, int, Tree>>();
     
    911809        foreach (var ch in state.children[tree]) {
    912810          nextId++;
    913           avgNodeQ = ch.actionStatistics.AverageQuality;
    914           tries = ch.actionStatistics.Tries;
    915           if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    916           hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    917           hue = 0.0;
    918           sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
    919           sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine();
     811          tries = ch.visits;
     812          sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
     813          sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", parentId, nextId, ch.expr).AppendLine();
    920814          list.Add(Tuple.Create(tries, nextId, ch));
    921815        }
     
    927821            var chch = state.children[ch].First();
    928822            nextId++;
    929             avgNodeQ = chch.actionStatistics.AverageQuality;
    930             tries = chch.actionStatistics.Tries;
    931             if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    932             hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    933             hue = 0.0;
    934             sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
    935             sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine();
     823            tries = chch.visits;
     824            sb.AppendFormat("{0} [label=\"{1}\"]; ", nextId, tries).AppendLine();
     825            sb.AppendFormat("{0} -> {1} [label=\"{2}\"]", chId, nextId, chch.expr).AppendLine();
    936826          }
    937827        }
     
    957847        if (!nodeIds.TryGetValue(parent, out parentId)) {
    958848          parentId = nodeIds.Count + 1;
    959           var avgNodeQ = parent.actionStatistics.AverageQuality;
    960           var tries = parent.actionStatistics.Tries;
    961           if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    962           var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    963           hue = 0.0;
    964           if (parent.actionStatistics.Tries > threshold)
    965             sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
     849          var tries = parent.visits;
     850          if (tries > threshold)
     851            sb.Write("{0} [label=\"{1}\"]; ", parentId, tries);
    966852          nodeIds.Add(parent, parentId);
    967853        }
     
    972858            nodeIds.Add(child, childId);
    973859          }
    974           var avgNodeQ = child.actionStatistics.AverageQuality;
    975           var tries = child.actionStatistics.Tries;
     860          var tries = child.visits;
    976861          if (tries < 1) continue;
    977           if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    978           var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    979           hue = 0.0;
    980862          if (tries > threshold) {
    981             sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
     863            sb.Write("{0} [label=\"{1}\"]; ", childId, tries);
    982864            var edgeLabel = child.expr;
    983865            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
    984             sb.Write("{0} -> {1} [label=\"{3}\"]", parentId, childId, avgNodeQ, edgeLabel);
     866            sb.Write("{0} -> {1} [label=\"{2}\"]", parentId, childId, edgeLabel);
    985867          }
    986868        }
Note: See TracChangeset for help on using the changeset viewer.