Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/05/16 08:25:08 (9 years ago)
Author:
gkronber
Message:

#2581:

  • added unit tests for the number of different expressions
  • fixed problems in Automaton and constraintHandler that lead to duplicate expressions
  • added possibility for MCTS to handle dead-ends in the search tree (when it is not possible to construct a valid new expression)
  • added statistics on function and gradient evaluations
Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression
Files:
6 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs

    r13650 r13651  
    2626namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    2727  // this is the core class for generating expressions.
    28   // the automaton determines which expressions are allowed
     28  // it represents a finite state automaton, each state transition can be associated with an action (e.g. to produce code).
     29  // the automaton determines the possible structures for expressions.
     30  //
     31  // to understand this code it is worthwile to generate a graphical visualization of the automaton (see PrintAutomaton).
     32  // If the code is compiled in debug mode the automaton produces a Graphviz file into the folder of the application
     33  // whenever an instance of the automaton is constructed.
     34  //
     35  // This class relies on two other classes:
     36  // - CodeGenerator to produce code for a stack-based evaluator and
     37  // - ConstraintHandler to restrict the allowed set of expressions.
     38  //
     39  // The ConstraintHandler extends the automaton and adds semantic restrictions for expressions produced by the automaton.
     40  //
     41  //
    2942  internal class Automaton {
    3043    public const int StateExpr = 1;
     
    5265    public const int StateInvTFStart = 23;
    5366    public const int StateInvTFEnd = 24;
    54     private const int FirstDynamicState = 25;
     67    public const int FirstDynamicState = 25;
     68    // more states for individual variables are created dynamically
    5569
    5670    private const int StartState = StateExpr;
     
    221235        () => {
    222236          codeGenerator.Emit1(OpCodes.LoadConst0);
    223         },
    224         "0");
     237          constraintHandler.StartNewTermInPoly();
     238        },
     239        "0, StartTermInPoly");
    225240      AddTransition(StateLogTEnd, StateLogFactorEnd,
    226241        () => {
     
    271286        () => {
    272287          codeGenerator.Emit1(OpCodes.LoadConst1);
    273         },
    274         "c");
     288          constraintHandler.StartNewTermInPoly();
     289        },
     290        "c, StartTermInPoly");
    275291      AddTransition(StateInvTEnd, StateInvFactorEnd,
    276292        () => {
     
    337353    private readonly int[] followStatesBuf = new int[1000];
    338354    public void FollowStates(int state, out int[] buf, out int nElements) {
    339       // return followStates[state]
    340       //   .Where(s => s < FirstDynamicState || s >= minVarIdx) // for variables we only allow non-decreasing state sequences
    341       //   // the following states imply an additional variable being added to the expression
    342       //   // F, Sum, Prod
    343       //   .Where(s => (s != StateF && s != StateSum && s != StateProd) || variablesRemaining > 0);
    344 
    345355      // for loop instead of where iterator
    346356      var fs = followStates[state];
    347357      int j = 0;
    348       //Console.Write(stateNames[CurrentState] + " allowed: ");
    349358      for (int i = 0; i < fs.Count; i++) {
    350359        var s = fs[i];
    351360        if (constraintHandler.IsAllowedFollowState(state, s)) {
    352           //Console.Write(s + " ");
    353361          followStatesBuf[j++] = s;
    354362        }
    355363      }
    356       //Console.WriteLine();
    357364      buf = followStatesBuf;
    358365      nElements = j;
     
    361368
    362369    public void Goto(int targetState) {
    363       //Console.WriteLine("->{0}", stateNames[targetState]);
    364       // Contract.Assert(FollowStates(CurrentState).Contains(targetState));
    365 
    366370      if (actions[CurrentState, targetState] != null)
    367371        actions[CurrentState, targetState].ForEach(a => a()); // execute all actions
     
    370374
    371375    public bool IsFinalState(int s) {
    372       return s == StateExprEnd;
     376      return s == StateExprEnd && !constraintHandler.IsInvalidExpression;
    373377    }
    374378
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/ConstraintHandler.cs

    r13645 r13651  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    2120#endregion
    2221
    23 
     22using System;
     23using System.Collections.Generic;
    2424using System.Diagnostics.Contracts;
     25using System.Linq;
    2526
    2627namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    2728
    28   // more states for individual variables are created dynamically
     29  // This class restricts the set of allowed transitions of the automaton to prevent exploration of duplicate expressions.
     30  // It would be possible to implement this class in such a way that the search never visits a duplicate expression. However,
     31  // it seems very intricate to detect this robustly and in all cases while generating an expression because
     32  // some for of lookahead is necessary.
     33  // Instead the constraint handler only catches the obvious duplicates directly, but does not guarantee that the search always produces a valid expression.
     34  // The ratio of the number of unsuccessful searches, that need backtracking should be tracked in the MCTS alg (MctsSymbolicRegressionStatic)
     35
     36  // All changes to this class should be tested through unit tests. It is important that the ConstraintHandler is not too restrictive.
     37
     38  // the constraints are derived from a canonical form for expressions.
     39  // overall we can enforce a limited number of variable references
     40  //
     41  // an expression is a sum of terms t_1 ... t_n where terms are ordered according to a relation t_i (<=)_term t_j for each pair t_i, t_j and i <= j
     42  // a term is a product of factors where factors are ordered according to relation f_i (<=)_factor f_j for each pair f_i,f_j and i <= j
     43
     44  // we want to enforce lower-order terms before higher-order terms in expressions (based on number of variable references)
     45  // factors can have different types (variable, exp, log, inverse)
     46
     47  // (<=)_term  [IsSmallerOrEqualTerm(t_i, t_j)]
     48  //   1.  NumberOfVarRefs(t_i) < NumberOfVarRefs(t_j)  --> true           enforce terms with non-decreasing number of var refs
     49  //   2.  NumberOfVarRefs(t_i) > NumberOfVarRefs(t_j)  --> false
     50  //   3.  NumFactors(t_i) > NumFactors(t_j)            --> true           enforce terms with non-increasing number of factors
     51  //   4.  NumFactors(t_i) < NumFactors(t_j)            --> false
     52  //   5.  for all k factors: Factor(k, t_i) (<=)_factor  Factor(k, t_j) --> true // factors must be non-decreasing
     53  //   6.  all factors are (=)_factor                   --> true
     54  //   7.  else false
     55
     56  // (<=)_factor  [IsSmallerOrEqualFactor(f_i, f_j)]
     57  //   1.  FactorType(t_i) < FactorType(t_j)  --> true           enforce terms with non-decreasing factor type (var < exp < log < inv)
     58  //   2.  FactorType(t_i) > FactorType(t_j)  --> false
     59  //   3.  Compare the two factors specifically
     60  //     - variables: varIdx_i <= varIdx_j (only one var reference)
     61  //     - exp: number of variable references and then varIdx_i <= varIdx_j for each position
     62  //     - log: number of variable references and ...
     63  //     - inv: number of variable references and ...
     64  //
     65
     66  // for log and inverse factors we allow all polynomials as argument
     67  // a polynomial is a sum of terms t_1 ... t_n where terms are ordered according to a relation t_i (<=)_poly t_j for each pair t_i, t_j and i <= j
     68
     69  // (<=)_poly  [IsSmallerOrEqualPoly(t_i, t_j)]
     70  //  1. NumberOfVarRefs(t_i) < NumberOfVarRefs(t_j)         --> true // enforce non-decreasing number of var refs
     71  //  2. NumberOfVarRefs(t_i) > NumberOfVarRefs(t_j)         --> false // enforce non-decreasing number of var refs
     72  //  3. for all k variables: VarIdx(k,t_i) > VarIdx(k, t_j) --> false // enforce non-decreasing variable idx
     73
     74
     75  // we store the following to make comparsions:
     76  // - prevTerm (complete & containing all factors)
     77  // - curTerm  (incomplete & containing all completed factors)
     78  // - curFactor (incomplete)
    2979  internal class ConstraintHandler {
    3080    private int nVars;
    3181    private readonly int maxVariables;
    32 
    33     public int prevTermFirstVariableState;
    34     public int curTermFirstVariableState;
    35     public int prevTermFirstFactorType;
    36     public int curTermFirstFactorType;
    37     public int prevFactorType;
    38     public int curFactorType;
    39     public int prevFactorFirstVariableState;
    40     public int curFactorFirstVariableState;
    41     public int prevVariableRef;
     82    private bool invalidExpression;
     83
     84    public bool IsInvalidExpression {
     85      get { return invalidExpression; }
     86    }
     87
     88
     89    private TermInformation prevTerm;
     90    private TermInformation curTerm;
     91    private FactorInformation curFactor;
     92
     93
     94    private class TermInformation {
     95      public int numVarReferences { get { return factors.Sum(f => f.numVarReferences); } }
     96      public List<FactorInformation> factors = new List<FactorInformation>();
     97    }
     98
     99    private class FactorInformation {
     100      public int numVarReferences = 0;
     101      public int factorType; // use the state number to represent types
     102
     103      // for variable factors
     104      public int variableState = -1;
     105
     106      // for exp  factors
     107      public List<int> expVariableStates = new List<int>();
     108
     109      // for log and inv factors
     110      public List<List<int>> polyVariableStates = new List<List<int>>();
     111    }
    42112
    43113
     
    46116    }
    47117
    48     // 1) an expression is a sum of terms t_1 ... t_n
    49     //    FirstFactorType(t_i) <= FirstFactorType(t_j) for each pair t_i, t_j where i < j
    50     //    FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair t_i, t_j where i < j and FirstFactorType(t_i) = FirstFactorType(t_j)
    51     // 2) a term is a product of factors, each factor is either a variable factor, an exp factor, a log factor or an inverse factor
    52     //    FactorType(f_i) <= FactorType(f_j) for each pair of factors f_i, f_j and i < j
    53     //    FirstVarReference(f_i) <= FirstVarReference(f_j) for each pair of factors f_i, f_j and i < j and FactorType(f_i) = FactorType(f_j)
    54     // 3) a variable factor is a product of variable references v1...vn
    55     //    VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j
    56     //    (IMPLICIT) FirstVarReference(t) <= VarIdx(v_i) for each variable reference v_i in term t
    57     // 4) an exponential factor is the exponential of a product of variables v1...vn
    58     //    VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j
    59     //    (IMPLICIT) FirstVarReference(t) <= VarIdx(v_i) for each variable reference v_i in term t
    60     // 5) a log factor is a sum of terms t_i where each term is a product of variables
    61     //    FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair of terms t_i, t_j and i < j
    62     //    for each term t: VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j in t
     118    // the order relations for terms and factors
     119
     120    private static int CompareTerms(TermInformation a, TermInformation b) {
     121      if (a.numVarReferences < b.numVarReferences) return -1;
     122      if (a.numVarReferences > b.numVarReferences) return 1;
     123
     124      if (a.factors.Count > b.factors.Count) return -1;  // terms with more factors should be ordered first
     125      if (a.factors.Count < b.factors.Count) return +1;
     126
     127      var aFactors = a.factors.GetEnumerator();
     128      var bFactors = b.factors.GetEnumerator();
     129      while (aFactors.MoveNext() & bFactors.MoveNext()) {
     130        var c = CompareFactors(aFactors.Current, bFactors.Current);
     131        if (c < 0) return -1;
     132        if (c > 0) return 1;
     133      }
     134      // all factors are the same => terms are the same
     135      return 0;
     136    }
     137
     138    private static int CompareFactors(FactorInformation a, FactorInformation b) {
     139      if (a.factorType < b.factorType) return -1;
     140      if (a.factorType > b.factorType) return +1;
     141      // same factor types
     142      if (a.factorType == Automaton.StateVariableFactorStart) {
     143        return a.variableState.CompareTo(b.variableState);
     144      } else if (a.factorType == Automaton.StateExpFactorStart) {
     145        return CompareStateLists(a.expVariableStates, b.expVariableStates);
     146      } else {
     147        if (a.numVarReferences < b.numVarReferences) return -1;
     148        if (a.numVarReferences > b.numVarReferences) return +1;
     149        if (a.polyVariableStates.Count > b.polyVariableStates.Count) return -1; // more terms in the poly should be ordered first
     150        if (a.polyVariableStates.Count < b.polyVariableStates.Count) return +1;
     151        // log and inv
     152        var aTerms = a.polyVariableStates.GetEnumerator();
     153        var bTerms = b.polyVariableStates.GetEnumerator();
     154        while (aTerms.MoveNext() & bTerms.MoveNext()) {
     155          var c = CompareStateLists(aTerms.Current, bTerms.Current);
     156          if (c != 0) return c;
     157        }
     158        return 0; // all terms in the polynomial are the same
     159      }
     160    }
     161
     162    private static int CompareStateLists(List<int> a, List<int> b) {
     163      if (a.Count < b.Count) return -1;
     164      if (a.Count > b.Count) return +1;
     165      for (int i = 0; i < a.Count; i++) {
     166        if (a[i] < b[i]) return -1;
     167        if (a[i] > b[i]) return +1;
     168      }
     169      return 0; // all states are the same
     170    }
     171
     172
     173    private bool IsNewTermAllowed() {
     174      // next term must have at least as many variable references as the previous term
     175      return prevTerm == null || nVars + prevTerm.numVarReferences <= maxVariables;
     176    }
     177
     178    private bool IsNewFactorAllowed() {
     179      // next factor must have a larger or equal type compared to the previous factor.
     180      // if the types are the same it must have at least as many variable references.
     181      // so if the prevFactor is any other than invFactor (last possible type) then we only need to be able to add one variable
     182      // otherwise we need to be able to add at least as many variables as the previous factor
     183      return !curTerm.factors.Any() ||
     184             (nVars + curTerm.factors.Last().numVarReferences <= maxVariables);
     185    }
     186
     187    private bool IsAllowedAsNextFactorType(int followState) {
     188      // IsNewTermAllowed already ensures that we can add a term with enough variable references
     189
     190      // enforce constraints within terms (compare to prev factor)
     191      if (curTerm.factors.Any()) {
     192        // enforce non-decreasing factor types
     193        if (curTerm.factors.Last().factorType > followState) return false;
     194        // when the factor type is the same, starting a new factor is only allowed if we can add at least the number of variables of the prev factor
     195        if (curTerm.factors.Last().factorType == followState && nVars + curTerm.factors.Last().numVarReferences > maxVariables) return false;
     196      }
     197
     198      // enforce constraints on terms (compare to prev term)
     199      // meaning that we must ensure non-decreasing terms
     200      if (prevTerm != null) {
     201        // a factor type is only allowed if we can then produce a term that is larger or equal to the prev term
     202        // (1) if we the number of variable references still remaining is larger than the number of variable references in the prev term
     203        //     then it is always possible to build a larger term
     204        // (2) otherwise we try to build the largest possible term starting from current factors in the term.
     205        //     
     206
     207        var numVarRefsRemaining = maxVariables - nVars;
     208        Contract.Assert(!curTerm.factors.Any() || curTerm.factors.Last().numVarReferences <= numVarRefsRemaining);
     209
     210        if (prevTerm.numVarReferences < numVarRefsRemaining) return true;
     211
     212        // variable factors must be handled differently because they can only contain one variable reference
     213        if (followState == Automaton.StateVariableFactorStart) {
     214          // append the variable factor and the maximum possible state from the previous factor to create a larger factor
     215          var varF = CreateLargestPossibleFactor(Automaton.StateVariableFactorStart, 1);
     216          var maxF = CreateLargestPossibleFactor(prevTerm.factors.Max(f => f.factorType), numVarRefsRemaining - 1);
     217          var origFactorCount = curTerm.factors.Count;
     218          // add this factor to the current term
     219          curTerm.factors.Add(varF);
     220          curTerm.factors.Add(maxF);
     221          var c = CompareTerms(prevTerm, curTerm);
     222          // restore term
     223          curTerm.factors.RemoveRange(origFactorCount, 2);
     224          // if the prev term is still larger then this followstate is not allowed
     225          if (c > 0) {
     226            return false;
     227          }
     228        } else {
     229          var newF = CreateLargestPossibleFactor(followState, numVarRefsRemaining);
     230
     231          var origFactorCount = curTerm.factors.Count;
     232          // add this factor to the current term
     233          curTerm.factors.Add(newF);
     234          var c = CompareTerms(prevTerm, curTerm);
     235          // restore term
     236          curTerm.factors.RemoveAt(origFactorCount);
     237          // if the prev term is still larger then this followstate is not allowed
     238          if (c > 0) {
     239            return false;
     240          }
     241        }
     242      }
     243      return true;
     244    }
     245
     246    // largest possible factor of the given kind
     247    private FactorInformation CreateLargestPossibleFactor(int factorType, int numVarRefs) {
     248      var newF = new FactorInformation();
     249      newF.factorType = factorType;
     250      if (factorType == Automaton.StateVariableFactorStart) {
     251        newF.variableState = int.MaxValue;
     252        newF.numVarReferences = 1;
     253      } else if (factorType == Automaton.StateExpFactorStart) {
     254        for (int i = 0; i < numVarRefs; i++)
     255          newF.expVariableStates.Add(int.MaxValue);
     256        newF.numVarReferences = numVarRefs;
     257      } else if (factorType == Automaton.StateInvFactorStart || factorType == Automaton.StateLogFactorStart) {
     258        for (int i = 0; i < numVarRefs; i++) {
     259          newF.polyVariableStates.Add(new List<int>());
     260          newF.polyVariableStates[i].Add(int.MaxValue);
     261        }
     262        newF.numVarReferences = numVarRefs;
     263      }
     264      return newF;
     265    }
     266
     267    private bool IsAllowedAsNextVariableFactor(int variableState) {
     268      Contract.Assert(variableState >= Automaton.FirstDynamicState);
     269      return !curTerm.factors.Any() || curTerm.factors.Last().variableState <= variableState;
     270    }
     271
     272    private bool IsAllowedAsNextInExp(int variableState) {
     273      Contract.Assert(variableState >= Automaton.FirstDynamicState);
     274      if (curFactor.expVariableStates.Any() && curFactor.expVariableStates.Last() > variableState) return false;
     275      if (curTerm.factors.Any()) {
     276        // try and compare with prev factor     
     277        curFactor.numVarReferences++;
     278        curFactor.expVariableStates.Add(variableState);
     279        var c = CompareFactors(curTerm.factors.Last(), curFactor);
     280        curFactor.numVarReferences--;
     281        curFactor.expVariableStates.RemoveAt(curFactor.expVariableStates.Count - 1);
     282        return c <= 0;
     283      }
     284      return true;
     285    }
     286
     287    private bool IsNewTermAllowedInPoly() {
     288      return nVars + curFactor.polyVariableStates.Last().Count() <= maxVariables;
     289    }
     290
     291    private bool IsAllowedAsNextInPoly(int variableState) {
     292      Contract.Assert(variableState >= Automaton.FirstDynamicState);
     293      return !curFactor.polyVariableStates.Any() ||
     294             !curFactor.polyVariableStates.Last().Any() ||
     295              curFactor.polyVariableStates.Last().Last() <= variableState;
     296    }
     297    private bool IsTermCompleteInPoly() {
     298      var nTerms = curFactor.polyVariableStates.Count;
     299      return nTerms == 1 ||
     300             curFactor.polyVariableStates[nTerms - 2].Count <= curFactor.polyVariableStates[nTerms - 1].Count;
     301
     302    }
     303    private bool IsCompleteExp() {
     304      return !curTerm.factors.Any() || CompareFactors(curTerm.factors.Last(), curFactor) <= 0;
     305    }
     306
    63307    public bool IsAllowedFollowState(int currentState, int followState) {
    64       // the following states are always allowed
     308      // an invalid action was taken earlier on => nothing can be done anymore
     309      if (invalidExpression) return false;
     310      // states that have no alternative are always allowed
     311      // some ending states are only allowed if enough variables have been used in the term
    65312      if (
    66         followState == Automaton.StateVariableFactorEnd ||
    67         followState == Automaton.StateExpFEnd ||
    68         followState == Automaton.StateExpFactorEnd ||
    69         followState == Automaton.StateLogTFEnd ||
    70         followState == Automaton.StateLogTEnd ||
    71         followState == Automaton.StateLogFactorEnd ||
    72         followState == Automaton.StateInvTFEnd ||
    73         followState == Automaton.StateInvTEnd ||
    74         followState == Automaton.StateInvFactorEnd ||
    75         followState == Automaton.StateFactorEnd ||
    76         followState == Automaton.StateTermEnd ||
    77         followState == Automaton.StateExprEnd
     313        currentState == Automaton.StateTermStart ||           // no alternative
     314        currentState == Automaton.StateExpFactorStart ||
     315        currentState == Automaton.StateLogFactorStart ||
     316        currentState == Automaton.StateInvFactorStart ||
     317        followState == Automaton.StateVariableFactorEnd ||    // no alternative
     318        followState == Automaton.StateExpFEnd ||              // no alternative
     319        followState == Automaton.StateLogTFEnd ||             // no alternative
     320        followState == Automaton.StateInvTFEnd ||             // no alternative
     321        followState == Automaton.StateFactorEnd ||            // always allowed because no alternative
     322        followState == Automaton.StateExprEnd                 // we could also constrain the minimum number of terms here
    78323      ) return true;
    79324
    80325
    81       // all other states are only allowed if we can add more variables
    82       if (nVars >= maxVariables) return false;
    83 
    84       // the following states are always allowed when we can add more variables
     326      // starting a new term is only allowed if we can add a term with at least the number of variables of the prev term
     327      if (followState == Automaton.StateTermStart && !IsNewTermAllowed()) return false;
     328      if (followState == Automaton.StateFactorStart && !IsNewFactorAllowed()) return false;
     329      if (currentState == Automaton.StateFactorStart && !IsAllowedAsNextFactorType(followState)) return false;
     330      if (followState == Automaton.StateTermEnd && prevTerm != null && CompareTerms(prevTerm, curTerm) > 0) return false;
     331
     332      // all of these states add at least one variable
    85333      if (
    86         followState == Automaton.StateTermStart ||
    87         followState == Automaton.StateFactorStart ||
    88         followState == Automaton.StateExpFStart ||
    89         followState == Automaton.StateLogTStart ||
    90         followState == Automaton.StateLogTFStart ||
    91         followState == Automaton.StateInvTStart ||
    92         followState == Automaton.StateInvTFStart
    93         ) return true;
    94 
    95       // enforce non-decreasing factor types
    96       if (currentState == Automaton.StateFactorStart) {
    97         if (curFactorType < 0) {
    98           //    FirstFactorType(t_i) <= FirstFactorType(t_j) for each pair t_i, t_j where i < j
    99           return prevTermFirstFactorType <= followState;
    100         } else {
    101           // FactorType(f_i) <= FactorType(f_j) for each pair of factors f_i, f_j and i < j
    102           return curFactorType <= followState;
    103         }
    104       }
    105       // enforce non-decreasing variables references in variable and exp factors
    106       if (currentState == Automaton.StateVariableFactorStart || currentState == Automaton.StateExpFStart || currentState == Automaton.StateLogTFStart || currentState == Automaton.StateInvTFStart) {
    107         if (prevVariableRef > followState) return false; // never allow decreasing variables
    108         if (prevFactorType < 0) {
    109           // FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair t_i, t_j where i < j
    110           return prevTermFirstVariableState <= followState;
    111         } else if (prevFactorType == curFactorType) {
    112           // (FirstVarReference(f_i) <= FirstVarReference(f_j) for each pair of factors f_i, f_j and i < j and FactorType(f_i) = FactorType(f_j)
    113           return prevFactorFirstVariableState <= followState;
    114         }
    115       }
    116 
    117 
    118       return true;
     334          followState == Automaton.StateVariableFactorStart ||
     335          followState == Automaton.StateExpFactorStart || followState == Automaton.StateExpFStart ||
     336          followState == Automaton.StateLogFactorStart || followState == Automaton.StateLogTStart ||
     337          followState == Automaton.StateLogTFStart ||
     338          followState == Automaton.StateInvFactorStart || followState == Automaton.StateInvTStart ||
     339          followState == Automaton.StateInvTFStart) {
     340        if (nVars + 1 > maxVariables) return false;
     341      }
     342
     343      if (currentState == Automaton.StateVariableFactorStart && !IsAllowedAsNextVariableFactor(followState)) return false;
     344      else if (currentState == Automaton.StateExpFStart && !IsAllowedAsNextInExp(followState)) return false;
     345      else if (followState == Automaton.StateLogTStart && !IsNewTermAllowedInPoly()) return false;
     346      else if (currentState == Automaton.StateLogTFStart && !IsAllowedAsNextInPoly(followState)) return false;
     347      else if (followState == Automaton.StateInvTStart && !IsNewTermAllowedInPoly()) return false;
     348      else if (currentState == Automaton.StateInvTFStart && !IsAllowedAsNextInPoly(followState)) return false;
     349      // finishing an exponential factor is only allowed when the number of variable references is large enough
     350      else if (followState == Automaton.StateExpFactorEnd && !IsCompleteExp()) return false;
     351      // finishing a polynomial (in log or inv) is only allowed when the number of variable references is large enough
     352      else if (followState == Automaton.StateInvTEnd && !IsTermCompleteInPoly()) return false;
     353      else if (followState == Automaton.StateLogTEnd && !IsTermCompleteInPoly()) return false;
     354
     355      else if (nVars > maxVariables) return false;
     356      else return true;
    119357    }
    120358
     
    122360    public void Reset() {
    123361      nVars = 0;
    124 
    125 
    126       prevTermFirstVariableState = -1;
    127       curTermFirstVariableState = -1;
    128       prevTermFirstFactorType = -1;
    129       curTermFirstFactorType = -1;
    130       prevVariableRef = -1;
    131       prevFactorType = -1;
    132       curFactorType = -1;
    133       curFactorFirstVariableState = -1;
    134       prevFactorFirstVariableState = -1;
     362      prevTerm = null;
     363      curTerm = null;
     364      curFactor = null;
     365      invalidExpression = false;
    135366    }
    136367
    137368    public void StartTerm() {
    138       // reset factor type. in each term we can start with each type of factor
    139       prevTermFirstVariableState = curTermFirstVariableState;
    140       curTermFirstVariableState = -1;
    141 
    142       prevTermFirstFactorType = curTermFirstFactorType;
    143       curTermFirstFactorType = -1;
    144 
    145 
    146       prevFactorType = -1;
    147       curFactorType = -1;
    148 
    149       curFactorFirstVariableState = -1;
    150       prevFactorFirstVariableState = -1;
     369      curTerm = new TermInformation();
    151370    }
    152371
    153372    public void StartFactor(int state) {
    154       prevFactorType = curFactorType;
    155       curFactorType = -1;
    156 
    157       prevFactorFirstVariableState = curFactorFirstVariableState;
    158       curFactorFirstVariableState = -1;
    159 
    160 
    161       // store the first factor type
    162       if (curTermFirstFactorType < 0) {
    163         curTermFirstFactorType = state;
    164       }
    165       curFactorType = state;
    166 
    167       // reset variable references. in each factor we can start with each variable reference
    168       prevVariableRef = -1;
     373      curFactor = new FactorInformation();
     374      curFactor.factorType = state;
    169375    }
    170376
    171377
    172378    public void AddVarToCurrentFactor(int state) {
    173 
    174       Contract.Assert(prevVariableRef <= state);
    175 
    176       // store the first variable reference for each factor
    177       if (curFactorFirstVariableState < 0) {
    178         curFactorFirstVariableState = state;
    179 
    180         // store the first variable reference for each term
    181         if (curTermFirstVariableState < 0) {
    182           curTermFirstVariableState = state;
    183         }
    184       }
    185       prevVariableRef = state;
     379      Contract.Assert(Automaton.FirstDynamicState <= state);
     380      Contract.Assert(curTerm != null);
     381      Contract.Assert(curFactor != null);
    186382
    187383      nVars++;
     384      curFactor.numVarReferences++;
     385
     386      if (curFactor.factorType == Automaton.StateVariableFactorStart) {
     387        Contract.Assert(curFactor.variableState < 0); // not set before
     388        curFactor.variableState = state;
     389      } else if (curFactor.factorType == Automaton.StateExpFactorStart) {
     390        curFactor.expVariableStates.Add(state);
     391      } else if (curFactor.factorType == Automaton.StateLogFactorStart ||
     392                 curFactor.factorType == Automaton.StateInvFactorStart) {
     393        curFactor.polyVariableStates.Last().Add(state);
     394      } else throw new InvalidProgramException();
     395    }
     396
     397    public void StartNewTermInPoly() {
     398      curFactor.polyVariableStates.Add(new List<int>());
    188399    }
    189400
    190401    public void EndFactor() {
    191       Contract.Assert(prevFactorFirstVariableState <= curFactorFirstVariableState);
    192       Contract.Assert(prevFactorType <= curFactorType);
     402      // enforce non-decreasing factors
     403      if (curTerm.factors.Any() && CompareFactors(curTerm.factors.Last(), curFactor) > 0)
     404        invalidExpression = true;
     405      curTerm.factors.Add(curFactor);
     406      curFactor = null;
    193407    }
    194408
    195409    public void EndTerm() {
    196 
    197       Contract.Assert(prevFactorType <= curFactorType);
    198       Contract.Assert(prevTermFirstVariableState <= curTermFirstVariableState);
     410      // enforce non-decreasing terms (TODO: equal terms should not be allowed)
     411      if (prevTerm != null && CompareTerms(prevTerm, curTerm) > 0)
     412        invalidExpression = true;
     413      prevTerm = curTerm;
     414      curTerm = null;
    199415    }
    200416  }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Disassembler.cs

    r13650 r13651  
    2323
    2424namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    25 #if DEBUG
    2625  internal class Disassembler {
    2726    public static string CodeToString(byte[] code, double[] consts) {
     
    5150    }
    5251  }
    53 #endif
    5452}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/ExpressionEvaluator.cs

    r13650 r13651  
    2828  internal class ExpressionEvaluator {
    2929    // manages it's own vector buffers
    30     private readonly List<double[]> vectorBuffers = new List<double[]>();
    31     private readonly List<double[]> scalarBuffers = new List<double[]>(); // scalars are vectors of length 1 (to allow mixing scalars and vectors on the same stack)
     30    private readonly double[][] vectorBuffers;
     31    private readonly double[][] scalarBuffers; // scalars are vectors of length 1 (to allow mixing scalars and vectors on the same stack)
     32    private int lastVecBufIdx;
     33    private int lastScalarBufIdx;
    3234
    3335
    3436    private double[] GetVectorBuffer() {
    35       var v = vectorBuffers[vectorBuffers.Count - 1];
    36       vectorBuffers.RemoveAt(vectorBuffers.Count - 1);
    37       return v;
     37      return vectorBuffers[--lastVecBufIdx];
    3838    }
    3939    private double[] GetScalarBuffer() {
    40       var v = scalarBuffers[scalarBuffers.Count - 1];
    41       scalarBuffers.RemoveAt(scalarBuffers.Count - 1);
    42       return v;
     40      return scalarBuffers[--lastScalarBufIdx];
    4341    }
    4442
    4543    private void ReleaseBuffer(double[] buf) {
    46       (buf.Length == 1 ? scalarBuffers : vectorBuffers).Add(buf);
     44      if (buf.Length == 1) {
     45        scalarBuffers[lastScalarBufIdx++] = buf;
     46      } else {
     47        vectorBuffers[lastVecBufIdx++] = buf;
     48      }
    4749    }
    4850
     
    7375
    7476      // preallocate buffers
     77      vectorBuffers = new double[MaxStackSize * (1 + MaxParams)][];
     78      scalarBuffers = new double[MaxStackSize * (1 + MaxParams)][];
    7579      for (int i = 0; i < MaxStackSize; i++) {
    7680        ReleaseBuffer(new double[vLen]);
     
    9498      short arg;
    9599      // checked at the end to make sure we do not leak buffers
    96       int initialScalarCount = scalarBuffers.Count;
    97       int initialVectorCount = vectorBuffers.Count;
     100      int initialScalarCount = lastScalarBufIdx;
     101      int initialVectorCount = lastVecBufIdx;
    98102
    99103      while (true) {
     
    179183
    180184                var f = 1.0 / (maxFx * consts[curParamIdx]);
    181                 // adjust c so that maxFx*c = 1
     185                // adjust c so that maxFx*c = 1 TODO: this is not ideal as enforce positive argument to exp()
    182186                consts[curParamIdx] *= f;
    183187
     
    211215            }
    212216            ReleaseBuffer(r);
    213             Contract.Assert(vectorBuffers.Count == initialVectorCount);
    214             Contract.Assert(scalarBuffers.Count == initialScalarCount);
     217            Contract.Assert(lastVecBufIdx == initialVectorCount);
     218            Contract.Assert(lastScalarBufIdx == initialScalarCount);
    215219            return;
    216220        }
     
    232236
    233237      // checked at the end to make sure we do not leak buffers
    234       int initialScalarCount = scalarBuffers.Count;
    235       int initialVectorCount = vectorBuffers.Count;
     238      int initialScalarCount = lastScalarBufIdx;
     239      int initialVectorCount = lastVecBufIdx;
    236240
    237241      while (true) {
     
    400404            }
    401405
    402             Contract.Assert(vectorBuffers.Count == initialVectorCount);
    403             Contract.Assert(scalarBuffers.Count == initialScalarCount);
     406            Contract.Assert(lastVecBufIdx == initialVectorCount);
     407            Contract.Assert(lastScalarBufIdx == initialScalarCount);
    404408            return; // break loop
    405409        }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs

    r13650 r13651  
    221221      Results.Add(new Result("Average quality", avgQuality));
    222222
     223      var totalRollouts = new IntValue();
     224      Results.Add(new Result("Total rollouts", totalRollouts));
     225      var effRollouts = new IntValue();
     226      Results.Add(new Result("Effective rollouts", effRollouts));
     227      var funcEvals = new IntValue();
     228      Results.Add(new Result("Function evaluations", funcEvals));
     229      var gradEvals = new IntValue();
     230      Results.Add(new Result("Gradient evaluations", gradEvals));
     231
     232
    223233      // same as in SymbolicRegressionSingleObjectiveProblem
    224234      var y = Problem.ProblemData.Dataset.GetDoubleValues(Problem.ProblemData.TargetVariable,
     
    266276          curBestQ = 0.0;
    267277
     278          funcEvals.Value = state.FuncEvaluations;
     279          gradEvals.Value = state.GradEvaluations;
     280          effRollouts.Value = state.EffectiveRollouts;
     281          totalRollouts.Value = state.TotalRollouts;
     282
    268283          table.Rows["Best quality"].Values.Add(bestQuality.Value);
    269284          table.Rows["Current best quality"].Values.Add(curQuality.Value);
     
    280295        avgQuality.Value = sumQ / n;
    281296
     297        funcEvals.Value = state.FuncEvaluations;
     298        gradEvals.Value = state.GradEvaluations;
     299        effRollouts.Value = state.EffectiveRollouts;
     300        totalRollouts.Value = state.TotalRollouts;
     301
    282302        table.Rows["Best quality"].Values.Add(bestQuality.Value);
    283303        table.Rows["Current best quality"].Values.Add(curQuality.Value);
    284304        table.Rows["Average quality"].Values.Add(avgQuality.Value);
    285305        iterations.Value = iterations.Value + n;
     306
    286307      }
    287308
     
    289310      Results.Add(new Result("Best solution quality (train)", new DoubleValue(state.BestSolutionTrainingQuality)));
    290311      Results.Add(new Result("Best solution quality (test)", new DoubleValue(state.BestSolutionTestQuality)));
     312
    291313
    292314      // produce solution
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r13650 r13651  
    4444      double BestSolutionTrainingQuality { get; }
    4545      double BestSolutionTestQuality { get; }
     46      int TotalRollouts { get; }
     47      int EffectiveRollouts { get; }
     48      int FuncEvaluations { get; }
     49      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
     50      // TODO other stats on LM optimizer might be interesting here
    4651    }
    4752
     
    5762      internal readonly List<Tree> bestChildrenBuf;
    5863      internal readonly Func<byte[], int, double> evalFun;
     64      // MCTS might get stuck. Track statistics on the number of effective rollouts
     65      internal int totalRollouts;
     66      internal int effectiveRollouts;
    5967
    6068
     
    7785      private int bestNParams;
    7886      private double[] bestConsts;
     87
     88      // stats
     89      private int funcEvaluations;
     90      private int gradEvaluations;
    7991
    8092      // buffers
     
    173185        }
    174186      }
     187
     188      public int TotalRollouts { get { return totalRollouts; } }
     189      public int EffectiveRollouts { get { return effectiveRollouts; } }
     190      public int FuncEvaluations { get { return funcEvaluations; } }
     191      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
     192
    175193      #endregion
    176194
     
    204222        Array.Copy(ones, constsBuf, nParams);
    205223        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
     224        funcEvaluations++;
    206225
    207226        // calc opt scaling (alpha*f(x) + beta)
     
    220239          // optimize constants using the starting point calculated above
    221240          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
     241
    222242          evaluator.Exec(code, x, constsBuf, predBuf);
     243          funcEvaluations++;
     244
    223245          rsq = RSq(y, predBuf);
    224246          optConsts = constsBuf;
     
    237259
    238260      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
    239         double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt
     261        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
    240262        Array.Copy(consts, optConsts, nParams);
    241263
     
    247269        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
    248270        alglib.minlmresults(state, out optConsts, out rep);
     271        funcEvaluations += rep.nfunc;
     272        gradEvaluations += rep.njac * nParams;
    249273
    250274        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
     
    311335      var rand = mctsState.random;
    312336      double c = mctsState.c;
    313 
    314       automaton.Reset();
    315       return TreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf);
    316     }
    317 
    318     private static double TreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf) {
     337      double q = 0;
     338      bool success = false;
     339      do {
     340        automaton.Reset();
     341        success = TryTreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf, out q);
     342        mctsState.totalRollouts++;
     343      } while (!success && !tree.done);
     344      mctsState.effectiveRollouts++;
     345      return q;
     346    }
     347
     348    // tree search might fail because of constraints for expressions
     349    // in this case we get stuck we just restart
     350    // see ConstraintHandler.cs for more info
     351    private static bool TryTreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf,
     352      out double q) {
    319353      Tree selectedChild = null;
    320       double q;
    321354      Contract.Assert(tree.state == automaton.CurrentState);
    322355      Contract.Assert(!tree.done);
     
    332365          tree.visits++;
    333366          tree.sumQuality += q;
    334           return q;
     367          return true; // we reached a final state
    335368        } else {
    336369          // EXPAND
     
    338371          int nFs;
    339372          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
    340 
     373          if (nFs == 0) {
     374            // stuck in a dead end (no final state and no allowed follow states)
     375            q = 0;
     376            tree.done = true;
     377            tree.children = null;
     378            return false;
     379          }
    341380          tree.children = new Tree[nFs];
    342           for (int i = 0; i < tree.children.Length; i++) tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
     381          for (int i = 0; i < tree.children.Length; i++)
     382            tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
    343383
    344384          selectedChild = SelectFinalOrRandom(automaton, tree, rand);
     
    351391      // make selected step and recurse
    352392      automaton.Goto(selectedChild.state);
    353       q = TreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf);
    354 
    355       tree.sumQuality += q;
    356       tree.visits++;
     393      var success = TryTreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf, out q);
     394      if (success) {
     395        // only update if successful
     396        tree.sumQuality += q;
     397        tree.visits++;
     398      }
    357399
    358400      // tree.done = tree.children.All(ch => ch.done);
    359401      tree.done = true; for (int i = 0; i < tree.children.Length && tree.done; i++) tree.done = tree.children[i].done;
    360402      if (tree.done) {
    361         tree.children = null; // cut of the sub-branch if it has been fully explored
     403        tree.children = null; // cut off the sub-branch if it has been fully explored
    362404        // TODO: update all qualities and visits to remove the information gained from this whole branch
    363405      }
    364       return q;
     406      return success;
    365407    }
    366408
Note: See TracChangeset for help on using the changeset viewer.