Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/05/16 08:25:08 (8 years ago)
Author:
gkronber
Message:

#2581:

  • added unit tests for the number of different expressions
  • fixed problems in Automaton and constraintHandler that lead to duplicate expressions
  • added possibility for MCTS to handle dead-ends in the search tree (when it is not possible to construct a valid new expression)
  • added statistics on function and gradient evaluations
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r13650 r13651  
    4444      double BestSolutionTrainingQuality { get; }
    4545      double BestSolutionTestQuality { get; }
     46      int TotalRollouts { get; }
     47      int EffectiveRollouts { get; }
     48      int FuncEvaluations { get; }
     49      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
     50      // TODO other stats on LM optimizer might be interesting here
    4651    }
    4752
     
    5762      internal readonly List<Tree> bestChildrenBuf;
    5863      internal readonly Func<byte[], int, double> evalFun;
     64      // MCTS might get stuck. Track statistics on the number of effective rollouts
     65      internal int totalRollouts;
     66      internal int effectiveRollouts;
    5967
    6068
     
    7785      private int bestNParams;
    7886      private double[] bestConsts;
     87
     88      // stats
     89      private int funcEvaluations;
     90      private int gradEvaluations;
    7991
    8092      // buffers
     
    173185        }
    174186      }
     187
     188      public int TotalRollouts { get { return totalRollouts; } }
     189      public int EffectiveRollouts { get { return effectiveRollouts; } }
     190      public int FuncEvaluations { get { return funcEvaluations; } }
     191      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
     192
    175193      #endregion
    176194
     
    204222        Array.Copy(ones, constsBuf, nParams);
    205223        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
     224        funcEvaluations++;
    206225
    207226        // calc opt scaling (alpha*f(x) + beta)
     
    220239          // optimize constants using the starting point calculated above
    221240          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
     241
    222242          evaluator.Exec(code, x, constsBuf, predBuf);
     243          funcEvaluations++;
     244
    223245          rsq = RSq(y, predBuf);
    224246          optConsts = constsBuf;
     
    237259
    238260      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
    239         double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt
     261        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
    240262        Array.Copy(consts, optConsts, nParams);
    241263
     
    247269        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
    248270        alglib.minlmresults(state, out optConsts, out rep);
     271        funcEvaluations += rep.nfunc;
     272        gradEvaluations += rep.njac * nParams;
    249273
    250274        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
     
    311335      var rand = mctsState.random;
    312336      double c = mctsState.c;
    313 
    314       automaton.Reset();
    315       return TreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf);
    316     }
    317 
    318     private static double TreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf) {
     337      double q = 0;
     338      bool success = false;
     339      do {
     340        automaton.Reset();
     341        success = TryTreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf, out q);
     342        mctsState.totalRollouts++;
     343      } while (!success && !tree.done);
     344      mctsState.effectiveRollouts++;
     345      return q;
     346    }
     347
     348    // tree search might fail because of constraints for expressions
     349    // in this case we get stuck we just restart
     350    // see ConstraintHandler.cs for more info
     351    private static bool TryTreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf,
     352      out double q) {
    319353      Tree selectedChild = null;
    320       double q;
    321354      Contract.Assert(tree.state == automaton.CurrentState);
    322355      Contract.Assert(!tree.done);
     
    332365          tree.visits++;
    333366          tree.sumQuality += q;
    334           return q;
     367          return true; // we reached a final state
    335368        } else {
    336369          // EXPAND
     
    338371          int nFs;
    339372          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
    340 
     373          if (nFs == 0) {
     374            // stuck in a dead end (no final state and no allowed follow states)
     375            q = 0;
     376            tree.done = true;
     377            tree.children = null;
     378            return false;
     379          }
    341380          tree.children = new Tree[nFs];
    342           for (int i = 0; i < tree.children.Length; i++) tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
     381          for (int i = 0; i < tree.children.Length; i++)
     382            tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
    343383
    344384          selectedChild = SelectFinalOrRandom(automaton, tree, rand);
     
    351391      // make selected step and recurse
    352392      automaton.Goto(selectedChild.state);
    353       q = TreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf);
    354 
    355       tree.sumQuality += q;
    356       tree.visits++;
     393      var success = TryTreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf, out q);
     394      if (success) {
     395        // only update if successful
     396        tree.sumQuality += q;
     397        tree.visits++;
     398      }
    357399
    358400      // tree.done = tree.children.All(ch => ch.done);
    359401      tree.done = true; for (int i = 0; i < tree.children.Length && tree.done; i++) tree.done = tree.children[i].done;
    360402      if (tree.done) {
    361         tree.children = null; // cut of the sub-branch if it has been fully explored
     403        tree.children = null; // cut off the sub-branch if it has been fully explored
    362404        // TODO: update all qualities and visits to remove the information gained from this whole branch
    363405      }
    364       return q;
     406      return success;
    365407    }
    366408
Note: See TracChangeset for help on using the changeset viewer.