Changeset 13658


Ignore:
Timestamp:
03/07/16 14:50:02 (4 years ago)
Author:
gkronber
Message:

#2581: extracted policies from MCTS to allow experimentation with different policies for MCTS

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r13653 r13658  
    263263    <Compile Include="MctsSymbolicRegression\MctsSymbolicRegressionStatic.cs" />
    264264    <Compile Include="MctsSymbolicRegression\OpCodes.cs" />
     265    <Compile Include="MctsSymbolicRegression\Policies\EpsGreedy.cs" />
     266    <Compile Include="MctsSymbolicRegression\Policies\UcbTuned.cs" />
     267    <Compile Include="MctsSymbolicRegression\Policies\IActionStatistics.cs" />
     268    <Compile Include="MctsSymbolicRegression\Policies\IPolicy.cs" />
     269    <Compile Include="MctsSymbolicRegression\Policies\PolicyBase.cs" />
     270    <Compile Include="MctsSymbolicRegression\Policies\Ucb.cs" />
    265271    <Compile Include="MctsSymbolicRegression\SymbolicExpressionGenerator.cs" />
    266272    <Compile Include="MctsSymbolicRegression\Tree.cs" />
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs

    r13652 r13658  
    2424using System.Runtime.CompilerServices;
    2525using System.Threading;
     26using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    2627using HeuristicLab.Analysis;
    2728using HeuristicLab.Common;
     
    5253    private const string AllowedFactorsParameterName = "Allowed factors";
    5354    private const string ConstantOptimizationIterationsParameterName = "Iterations (constant optimization)";
    54     private const string CParameterName = "C";
     55    private const string PolicyParameterName = "Policy";
    5556    private const string SeedParameterName = "Seed";
    5657    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     
    7980      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
    8081    }
    81     public IFixedValueParameter<DoubleValue> CParameter {
    82       get { return (IFixedValueParameter<DoubleValue>)Parameters[CParameterName]; }
     82    public IValueParameter<IPolicy> PolicyParameter {
     83      get { return (IValueParameter<IPolicy>)Parameters[PolicyParameterName]; }
    8384    }
    8485    public IFixedValueParameter<DoubleValue> PunishmentFactorParameter {
     
    119120      set { MaxVariableReferencesParameter.Value.Value = value; }
    120121    }
    121     public double C {
    122       get { return CParameter.Value.Value; }
    123       set { CParameter.Value.Value = value; }
    124     }
    125 
     122    public IPolicy Policy {
     123      get { return PolicyParameter.Value; }
     124      set { PolicyParameter.Value = value; }
     125    }
    126126    public double PunishmentFactor {
    127127      get { return PunishmentFactorParameter.Value.Value; }
     
    173173      Parameters.Add(new FixedValueParameter<IntValue>(MaxVariablesParameterName,
    174174        "Maximal number of variables references in the symbolic regression models (multiple usages of the same variable are counted)", new IntValue(5)));
    175       Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName,
    176         "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0)));
     175      // Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName,
     176      //   "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0)));
     177      Parameters.Add(new ValueParameter<IPolicy>(PolicyParameterName,
     178        "The policy to use for selecting nodes in MCTS (e.g. Ucb)", new Ucb()));
     179      PolicyParameter.Hidden = true;
    177180      Parameters.Add(new ValueParameter<ICheckedItemList<StringValue>>(AllowedFactorsParameterName,
    178181        "Choose which expressions are allowed as factors in the model.", defaultFactorsList));
     
    244247      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    245248      if (!AllowedFactors.CheckedItems.Any()) throw new ArgumentException("At least on type of factor must be allowed");
    246       var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, C, ScaleVariables, ConstantOptimizationIterations,
     249      var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, ScaleVariables, ConstantOptimizationIterations,
     250        Policy,
    247251        lowerLimit, upperLimit,
    248252        allowProdOfVars: AllowedFactors.CheckedItems.Any(s => s.Value.Value == VariableProductFactorName),
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r13657 r13658  
    2424using System.Diagnostics.Contracts;
    2525using System.Linq;
     26using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    2627using HeuristicLab.Common;
    2728using HeuristicLab.Core;
     
    5859      internal readonly Automaton automaton;
    5960      internal IRandom random { get; private set; }
    60       internal readonly double c;
    6161      internal readonly Tree tree;
    62       internal readonly List<Tree> bestChildrenBuf;
    6362      internal readonly Func<byte[], int, double> evalFun;
     63      internal readonly IPolicy treePolicy;
    6464      // MCTS might get stuck. Track statistics on the number of effective rollouts
    6565      internal int totalRollouts;
     
    9696      private readonly double[][] gradBuf;
    9797
    98       public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, double c, bool scaleVariables, int constOptIterations,
     98      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, int constOptIterations,
     99        IPolicy treePolicy = null,
    99100        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
    100101        bool allowProdOfVars = true,
     
    105106
    106107        this.problemData = problemData;
    107         this.c = c;
    108108        this.constOptIterations = constOptIterations;
    109109        this.evalFun = this.Eval;
     
    134134
    135135        this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
    136         this.tree = new Tree() { state = automaton.CurrentState };
     136        this.treePolicy = treePolicy ?? new Ucb();
     137        this.tree = new Tree() { state = automaton.CurrentState, actionStatistics = treePolicy.CreateActionStatistics() };
    137138
    138139        // reset best solution
     
    146147        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
    147148        constsBuf = new double[MaxParams];
    148         this.bestChildrenBuf = new List<Tree>(2 * x.Length); // the number of follow states in the automaton is O(number of variables) 2 * number of variables should be sufficient (capacity is increased if necessary anyway)
    149149        this.predBuf = new double[y.Length];
    150150        this.testPredBuf = new double[testY.Length];
     
    154154
    155155      #region IState inferface
    156       public bool Done { get { return tree != null && tree.done; } }
     156      public bool Done { get { return tree != null && tree.Done; } }
    157157
    158158      public double BestSolutionTrainingQuality {
     
    302302    }
    303303
    304     public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, double c = 1.0,
    305       bool scaleVariables = true, int constOptIterations = 0, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     304    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
     305      bool scaleVariables = true, int constOptIterations = 0,
     306      IPolicy policy = null,
     307      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
    306308      bool allowProdOfVars = true,
    307309      bool allowExp = true,
     
    310312      bool allowMultipleTerms = false
    311313      ) {
    312       return new State(problemData, randSeed, maxVariables, c, scaleVariables, constOptIterations,
     314      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations,
     315        policy,
    313316        lowerEstimationLimit, upperEstimationLimit,
    314317        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
     
    329332      var tree = mctsState.tree;
    330333      var eval = mctsState.evalFun;
    331       var bestChildrenBuf = mctsState.bestChildrenBuf;
    332334      var rand = mctsState.random;
    333       double c = mctsState.c;
     335      var treePolicy = mctsState.treePolicy;
    334336      double q = 0;
    335       double deltaQ = 0;
    336       double deltaSqrQ = 0;
    337       int deltaVisits = 0;
    338337      bool success = false;
    339338      do {
    340339        automaton.Reset();
    341         success = TryTreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf, out q, out deltaQ, out deltaSqrQ, out deltaVisits);
     340        success = TryTreeSearchRec(rand, tree, automaton, eval, treePolicy, out q);
    342341        mctsState.totalRollouts++;
    343       } while (!success && !tree.done);
     342      } while (!success && !tree.Done);
    344343      mctsState.effectiveRollouts++;
    345344      return q;
     
    349348    // in this case we get stuck we just restart
    350349    // see ConstraintHandler.cs for more info
    351     private static bool TryTreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf,
    352       out double q, // quality of the expression
    353       out double deltaQ, out double deltaSqrQ, out int deltaVisits // the updates for total quality and number of visits (can be negative if branches have been fully explored)
    354       ) {
     350    private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
     351      out double q) {
    355352      Tree selectedChild = null;
    356353      Contract.Assert(tree.state == automaton.CurrentState);
    357       Contract.Assert(!tree.done);
     354      Contract.Assert(!tree.Done);
    358355      if (tree.children == null) {
    359356        if (automaton.IsFinalState(tree.state)) {
    360357          // final state
    361           tree.done = true;
     358          tree.Done = true;
    362359
    363360          // EVALUATE
     
    365362          automaton.GetCode(out code, out nParams);
    366363          q = eval(code, nParams);
    367           tree.visits += 1;
    368           tree.sumQuality += q;
    369           tree.sumSqrQuality += q * q;
    370           deltaQ = q;
    371           deltaVisits = 1;
    372           deltaSqrQ = q * q;
     364
     365          treePolicy.Update(tree.actionStatistics, q);
    373366          return true; // we reached a final state
    374367        } else {
     
    380373            // stuck in a dead end (no final state and no allowed follow states)
    381374            q = 0;
    382             deltaQ = 0;
    383             deltaSqrQ = 0.0;
    384             deltaVisits = 0;
    385             tree.done = true;
     375            tree.Done = true;
    386376            tree.children = null;
    387             tree.visits = 1;
    388377            return false;
    389378          }
    390379          tree.children = new Tree[nFs];
    391380          for (int i = 0; i < tree.children.Length; i++)
    392             tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
     381            tree.children[i] = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics() };
    393382
    394383          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
     
    397386        // tree.children != null
    398387        // UCT selection within tree
    399         selectedChild = tree.children.Length > 1 ? SelectUctTuned(tree, rand, c, bestChildrenBuf) : tree.children[0];
     388        int selectedIdx = 0;
     389        if (tree.children.Length > 1) {
     390          selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
     391        }
     392        selectedChild = tree.children[selectedIdx];
    400393      }
    401394      // make selected step and recurse
    402395      automaton.Goto(selectedChild.state);
    403       var success = TryTreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf,
    404         out q, out deltaQ, out deltaSqrQ, out deltaVisits);
     396      var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
    405397      if (success) {
    406398        // only update if successful
    407         tree.sumQuality += deltaQ;
    408         tree.sumSqrQuality += deltaSqrQ;
    409         tree.visits += deltaVisits;
    410       }
    411 
    412       if (tree.children.All(ch => ch.done)) {
    413         tree.done = true;
    414         // update parent nodes to remove information from this branch
    415         if (tree.children.Length > 1) {
    416           deltaQ = -(tree.sumQuality - deltaQ);
    417           deltaSqrQ = -(tree.sumSqrQuality - deltaSqrQ);
    418           deltaVisits = -(tree.visits - deltaVisits);
    419         }
     399        treePolicy.Update(tree.actionStatistics, q);
     400      }
     401
     402      tree.Done = tree.children.All(ch => ch.Done);
     403      if (tree.Done) {
    420404        tree.children = null; // cut off the sub-branch if it has been fully explored
    421405      }
    422406      return success;
    423     }
    424 
    425     private static Tree SelectUct(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {
    426       // determine total tries of still active children
    427       int totalTries = 0;
    428       bestChildrenBuf.Clear();
    429       for (int i = 0; i < tree.children.Length; i++) {
    430         var ch = tree.children[i];
    431         if (ch.done) continue;
    432         if (ch.visits == 0) bestChildrenBuf.Add(ch);
    433         else totalTries += tree.children[i].visits;
    434       }
    435       // if there are unvisited children select a random child
    436       if (bestChildrenBuf.Any()) {
    437         return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
    438       }
    439       Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done
    440       double logTotalTries = Math.Log(totalTries);
    441       var bestQ = double.NegativeInfinity;
    442       for (int i = 0; i < tree.children.Length; i++) {
    443         var ch = tree.children[i];
    444         if (ch.done) continue;
    445         var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits);
    446         if (childQ > bestQ) {
    447           bestChildrenBuf.Clear();
    448           bestChildrenBuf.Add(ch);
    449           bestQ = childQ;
    450         } else if (childQ >= bestQ) {
    451           bestChildrenBuf.Add(ch);
    452         }
    453       }
    454       return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
    455     }
    456 
    457     private static Tree SelectUctTuned(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {
    458       // determine total tries of still active children
    459       int totalTries = 0;
    460       bestChildrenBuf.Clear();
    461       for (int i = 0; i < tree.children.Length; i++) {
    462         var ch = tree.children[i];
    463         if (ch.done) continue;
    464         if (ch.visits == 0) bestChildrenBuf.Add(ch);
    465         else totalTries += tree.children[i].visits;
    466       }
    467       // if there are unvisited children select a random child
    468       if (bestChildrenBuf.Any()) {
    469         return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
    470       }
    471       Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done
    472       double logTotalTries = Math.Log(totalTries);
    473       var bestQ = double.NegativeInfinity;
    474       for (int i = 0; i < tree.children.Length; i++) {
    475         var ch = tree.children[i];
    476         if (ch.done) continue;
    477         var varianceBound = ch.QualityVariance + Math.Sqrt(2.0 * logTotalTries / ch.visits);
    478         if (varianceBound > 0.25) varianceBound = 0.25;
    479         var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits * varianceBound);
    480         if (childQ > bestQ) {
    481           bestChildrenBuf.Clear();
    482           bestChildrenBuf.Add(ch);
    483           bestQ = childQ;
    484         } else if (childQ >= bestQ) {
    485           bestChildrenBuf.Add(ch);
    486         }
    487       }
    488       return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
    489407    }
    490408
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Tree.cs

    r13657 r13658  
    2020#endregion
    2121
     22using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
     23
    2224namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    2325  // represents tree nodes for the search tree in MCTS
    2426  internal class Tree {
    2527    public int state;
    26     public int visits;
    27     public double sumQuality;
    28     public double sumSqrQuality; // for variance
    29     public double AverageQuality { get { return sumQuality / (double)visits; } }
    30     public double QualityVariance { get { return sumSqrQuality / (double)visits - AverageQuality * AverageQuality; } }
    31     public bool done;
     28    public bool Done {
     29      get { return actionStatistics.Done; }
     30      set { actionStatistics.Done = value; }
     31    }
     32    public IActionStatistics actionStatistics;
    3233    public Tree[] children;
    3334  }
Note: See TracChangeset for help on using the changeset viewer.