Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/25/17 15:59:39 (8 years ago)
Author:
gkronber
Message:

#2581: merged r13645,r13648,r13650,r13651,r13652,r13654,r13657,r13658,r13659,r13661,r13662,r13669,r13708,r14142 from trunk to stable (to be deleted in the next commit)

Location:
stable
Files:
17 edited
2 copied

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r14116 r15060  
    253253    <Compile Include="Linear\MultinomialLogitClassificationSolution.cs" />
    254254    <Compile Include="Linear\MultinomialLogitModel.cs" />
     255    <Compile Include="MctsSymbolicRegression\Automaton.cs" />
     256    <Compile Include="MctsSymbolicRegression\CodeGenerator.cs" />
     257    <Compile Include="MctsSymbolicRegression\ConstraintHandler.cs" />
     258    <Compile Include="MctsSymbolicRegression\Disassembler.cs" />
     259    <Compile Include="MctsSymbolicRegression\ExpressionEvaluator.cs" />
     260    <Compile Include="MctsSymbolicRegression\MctsSymbolicRegressionAlgorithm.cs" />
     261    <Compile Include="MctsSymbolicRegression\MctsSymbolicRegressionStatic.cs" />
     262    <Compile Include="MctsSymbolicRegression\OpCodes.cs" />
     263    <Compile Include="MctsSymbolicRegression\Policies\EpsGreedy.cs" />
     264    <Compile Include="MctsSymbolicRegression\Policies\UcbTuned.cs" />
     265    <Compile Include="MctsSymbolicRegression\Policies\IActionStatistics.cs" />
     266    <Compile Include="MctsSymbolicRegression\Policies\IPolicy.cs" />
     267    <Compile Include="MctsSymbolicRegression\Policies\PolicyBase.cs" />
     268    <Compile Include="MctsSymbolicRegression\Policies\Ucb.cs" />
     269    <Compile Include="MctsSymbolicRegression\SymbolicExpressionGenerator.cs" />
     270    <Compile Include="MctsSymbolicRegression\Tree.cs" />
    255271    <Compile Include="Nca\Initialization\INcaInitializer.cs" />
    256272    <Compile Include="Nca\Initialization\LdaInitializer.cs" />
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    2726namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    2827  // this is the core class for generating expressions.
    29   // the automaton determines which expressions are allowed
     28  // it represents a finite state automaton, each state transition can be associated with an action (e.g. to produce code).
     29  // the automaton determines the possible structures for expressions.
     30  //
     31  // to understand this code it is worthwile to generate a graphical visualization of the automaton (see PrintAutomaton).
     32  // If the code is compiled in debug mode the automaton produces a Graphviz file into the folder of the application
     33  // whenever an instance of the automaton is constructed.
     34  //
     35  // This class relies on two other classes:
     36  // - CodeGenerator to produce code for a stack-based evaluator and
     37  // - ConstraintHandler to restrict the allowed set of expressions.
     38  //
     39  // The ConstraintHandler extends the automaton and adds semantic restrictions for expressions produced by the automaton.
     40  //
     41  //
    3042  internal class Automaton {
    3143    public const int StateExpr = 1;
     
    5365    public const int StateInvTFStart = 23;
    5466    public const int StateInvTFEnd = 24;
    55     private const int FirstDynamicState = 25;
     67    public const int FirstDynamicState = 25;
     68    // more states for individual variables are created dynamically
    5669
    5770    private const int StartState = StateExpr;
     
    222235        () => {
    223236          codeGenerator.Emit1(OpCodes.LoadConst0);
    224         },
    225         "0");
     237          constraintHandler.StartNewTermInPoly();
     238        },
     239        "0, StartTermInPoly");
    226240      AddTransition(StateLogTEnd, StateLogFactorEnd,
    227241        () => {
     
    272286        () => {
    273287          codeGenerator.Emit1(OpCodes.LoadConst1);
    274         },
    275         "c");
     288          constraintHandler.StartNewTermInPoly();
     289        },
     290        "c, StartTermInPoly");
    276291      AddTransition(StateInvTEnd, StateInvFactorEnd,
    277292        () => {
     
    338353    private readonly int[] followStatesBuf = new int[1000];
    339354    public void FollowStates(int state, out int[] buf, out int nElements) {
    340       // return followStates[state]
    341       //   .Where(s => s < FirstDynamicState || s >= minVarIdx) // for variables we only allow non-decreasing state sequences
    342       //   // the following states imply an additional variable being added to the expression
    343       //   // F, Sum, Prod
    344       //   .Where(s => (s != StateF && s != StateSum && s != StateProd) || variablesRemaining > 0);
    345 
    346355      // for loop instead of where iterator
    347356      var fs = followStates[state];
    348357      int j = 0;
    349       //Console.Write(stateNames[CurrentState] + " allowed: ");
    350358      for (int i = 0; i < fs.Count; i++) {
    351359        var s = fs[i];
    352360        if (constraintHandler.IsAllowedFollowState(state, s)) {
    353           //Console.Write(s + " ");
    354361          followStatesBuf[j++] = s;
    355362        }
    356363      }
    357       //Console.WriteLine();
    358364      buf = followStatesBuf;
    359365      nElements = j;
     
    362368
    363369    public void Goto(int targetState) {
    364       //Console.WriteLine("->{0}", stateNames[targetState]);
    365       // Contract.Assert(FollowStates(CurrentState).Contains(targetState));
    366 
    367370      if (actions[CurrentState, targetState] != null)
    368371        actions[CurrentState, targetState].ForEach(a => a()); // execute all actions
     
    371374
    372375    public bool IsFinalState(int s) {
    373       return s == StateExprEnd;
     376      return s == StateExprEnd && !constraintHandler.IsInvalidExpression;
    374377    }
    375378
     
    389392        writer.WriteLine("digraph {");
    390393        // writer.WriteLine("rankdir=LR");
    391         int[] fs;
    392         int nFs;
    393394        for (int s = StartState; s < stateNames.Count; s++) {
    394395          for (int i = 0; i < followStates[s].Count; i++) {
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/CodeGenerator.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/ConstraintHandler.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    2120#endregion
    2221
    23 
     22using System;
     23using System.Collections.Generic;
    2424using System.Diagnostics.Contracts;
     25using System.Linq;
    2526
    2627namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    2728
    28   // more states for individual variables are created dynamically
     29  // This class restricts the set of allowed transitions of the automaton to prevent exploration of duplicate expressions.
     30  // It would be possible to implement this class in such a way that the search never visits a duplicate expression. However,
     31  // it seems very intricate to detect this robustly and in all cases while generating an expression because
     32  // some for of lookahead is necessary.
     33  // Instead the constraint handler only catches the obvious duplicates directly, but does not guarantee that the search always produces a valid expression.
     34  // The ratio of the number of unsuccessful searches, that need backtracking should be tracked in the MCTS alg (MctsSymbolicRegressionStatic)
     35
     36  // All changes to this class should be tested through unit tests. It is important that the ConstraintHandler is not too restrictive.
     37
     38  // the constraints are derived from a canonical form for expressions.
     39  // overall we can enforce a limited number of variable references
     40  //
     41  // an expression is a sum of terms t_1 ... t_n where terms are ordered according to a relation t_i (<=)_term t_j for each pair t_i, t_j and i <= j
     42  // a term is a product of factors where factors are ordered according to relation f_i (<=)_factor f_j for each pair f_i,f_j and i <= j
     43
     44  // we want to enforce lower-order terms before higher-order terms in expressions (based on number of variable references)
     45  // factors can have different types (variable, exp, log, inverse)
     46
     47  // (<=)_term  [IsSmallerOrEqualTerm(t_i, t_j)]
     48  //   1.  NumberOfVarRefs(t_i) < NumberOfVarRefs(t_j)  --> true           enforce terms with non-decreasing number of var refs
     49  //   2.  NumberOfVarRefs(t_i) > NumberOfVarRefs(t_j)  --> false
     50  //   3.  NumFactors(t_i) > NumFactors(t_j)            --> true           enforce terms with non-increasing number of factors
     51  //   4.  NumFactors(t_i) < NumFactors(t_j)            --> false
     52  //   5.  for all k factors: Factor(k, t_i) (<=)_factor  Factor(k, t_j) --> true // factors must be non-decreasing
     53  //   6.  all factors are (=)_factor                   --> true
     54  //   7.  else false
     55
     56  // (<=)_factor  [IsSmallerOrEqualFactor(f_i, f_j)]
     57  //   1.  FactorType(t_i) < FactorType(t_j)  --> true           enforce terms with non-decreasing factor type (var < exp < log < inv)
     58  //   2.  FactorType(t_i) > FactorType(t_j)  --> false
     59  //   3.  Compare the two factors specifically
     60  //     - variables: varIdx_i <= varIdx_j (only one var reference)
     61  //     - exp: number of variable references and then varIdx_i <= varIdx_j for each position
     62  //     - log: number of variable references and ...
     63  //     - inv: number of variable references and ...
     64  //
     65
     66  // for log and inverse factors we allow all polynomials as argument
     67  // a polynomial is a sum of terms t_1 ... t_n where terms are ordered according to a relation t_i (<=)_poly t_j for each pair t_i, t_j and i <= j
     68
     69  // (<=)_poly  [IsSmallerOrEqualPoly(t_i, t_j)]
     70  //  1. NumberOfVarRefs(t_i) < NumberOfVarRefs(t_j)         --> true // enforce non-decreasing number of var refs
     71  //  2. NumberOfVarRefs(t_i) > NumberOfVarRefs(t_j)         --> false // enforce non-decreasing number of var refs
     72  //  3. for all k variables: VarIdx(k,t_i) > VarIdx(k, t_j) --> false // enforce non-decreasing variable idx
     73
     74
     75  // we store the following to make comparsions:
     76  // - prevTerm (complete & containing all factors)
     77  // - curTerm  (incomplete & containing all completed factors)
     78  // - curFactor (incomplete)
    2979  internal class ConstraintHandler {
    3080    private int nVars;
    3181    private readonly int maxVariables;
    32 
    33     public int prevTermFirstVariableState;
    34     public int curTermFirstVariableState;
    35     public int prevTermFirstFactorType;
    36     public int curTermFirstFactorType;
    37     public int prevFactorType;
    38     public int curFactorType;
    39     public int prevFactorFirstVariableState;
    40     public int curFactorFirstVariableState;
    41     public int prevVariableRef;
     82    private bool invalidExpression;
     83
     84    public bool IsInvalidExpression {
     85      get { return invalidExpression; }
     86    }
     87
     88
     89    private TermInformation prevTerm;
     90    private TermInformation curTerm;
     91    private FactorInformation curFactor;
     92
     93
     94    private class TermInformation {
     95      public int numVarReferences { get { return factors.Sum(f => f.numVarReferences); } }
     96      public List<FactorInformation> factors = new List<FactorInformation>();
     97    }
     98
     99    private class FactorInformation {
     100      public int numVarReferences = 0;
     101      public int factorType; // use the state number to represent types
     102
     103      // for variable factors
     104      public int variableState = -1;
     105
     106      // for exp  factors
     107      public List<int> expVariableStates = new List<int>();
     108
     109      // for log and inv factors
     110      public List<List<int>> polyVariableStates = new List<List<int>>();
     111    }
    42112
    43113
     
    46116    }
    47117
    48     // 1) an expression is a sum of terms t_1 ... t_n
    49     //    FirstFactorType(t_i) <= FirstFactorType(t_j) for each pair t_i, t_j where i < j
    50     //    FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair t_i, t_j where i < j and FirstFactorType(t_i) = FirstFactorType(t_j)
    51     // 2) a term is a product of factors, each factor is either a variable factor, an exp factor, a log factor or an inverse factor
    52     //    FactorType(f_i) <= FactorType(f_j) for each pair of factors f_i, f_j and i < j
    53     //    FirstVarReference(f_i) <= FirstVarReference(f_j) for each pair of factors f_i, f_j and i < j and FactorType(f_i) = FactorType(f_j)
    54     // 3) a variable factor is a product of variable references v1...vn
    55     //    VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j
    56     //    (IMPLICIT) FirstVarReference(t) <= VarIdx(v_i) for each variable reference v_i in term t
    57     // 4) an exponential factor is the exponential of a product of variables v1...vn
    58     //    VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j
    59     //    (IMPLICIT) FirstVarReference(t) <= VarIdx(v_i) for each variable reference v_i in term t
    60     // 5) a log factor is a sum of terms t_i where each term is a product of variables
    61     //    FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair of terms t_i, t_j and i < j
    62     //    for each term t: VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j in t
     118    // the order relations for terms and factors
     119
     120    private static int CompareTerms(TermInformation a, TermInformation b) {
     121      if (a.numVarReferences < b.numVarReferences) return -1;
     122      if (a.numVarReferences > b.numVarReferences) return 1;
     123
     124      if (a.factors.Count > b.factors.Count) return -1;  // terms with more factors should be ordered first
     125      if (a.factors.Count < b.factors.Count) return +1;
     126
     127      var aFactors = a.factors.GetEnumerator();
     128      var bFactors = b.factors.GetEnumerator();
     129      while (aFactors.MoveNext() & bFactors.MoveNext()) {
     130        var c = CompareFactors(aFactors.Current, bFactors.Current);
     131        if (c < 0) return -1;
     132        if (c > 0) return 1;
     133      }
     134      // all factors are the same => terms are the same
     135      return 0;
     136    }
     137
     138    private static int CompareFactors(FactorInformation a, FactorInformation b) {
     139      if (a.factorType < b.factorType) return -1;
     140      if (a.factorType > b.factorType) return +1;
     141      // same factor types
     142      if (a.factorType == Automaton.StateVariableFactorStart) {
     143        return a.variableState.CompareTo(b.variableState);
     144      } else if (a.factorType == Automaton.StateExpFactorStart) {
     145        return CompareStateLists(a.expVariableStates, b.expVariableStates);
     146      } else {
     147        if (a.numVarReferences < b.numVarReferences) return -1;
     148        if (a.numVarReferences > b.numVarReferences) return +1;
     149        if (a.polyVariableStates.Count > b.polyVariableStates.Count) return -1; // more terms in the poly should be ordered first
     150        if (a.polyVariableStates.Count < b.polyVariableStates.Count) return +1;
     151        // log and inv
     152        var aTerms = a.polyVariableStates.GetEnumerator();
     153        var bTerms = b.polyVariableStates.GetEnumerator();
     154        while (aTerms.MoveNext() & bTerms.MoveNext()) {
     155          var c = CompareStateLists(aTerms.Current, bTerms.Current);
     156          if (c != 0) return c;
     157        }
     158        return 0; // all terms in the polynomial are the same
     159      }
     160    }
     161
     162    private static int CompareStateLists(List<int> a, List<int> b) {
     163      if (a.Count < b.Count) return -1;
     164      if (a.Count > b.Count) return +1;
     165      for (int i = 0; i < a.Count; i++) {
     166        if (a[i] < b[i]) return -1;
     167        if (a[i] > b[i]) return +1;
     168      }
     169      return 0; // all states are the same
     170    }
     171
     172
     173    private bool IsNewTermAllowed() {
     174      // next term must have at least as many variable references as the previous term
     175      return prevTerm == null || nVars + prevTerm.numVarReferences <= maxVariables;
     176    }
     177
     178    private bool IsNewFactorAllowed() {
     179      // next factor must have a larger or equal type compared to the previous factor.
     180      // if the types are the same it must have at least as many variable references.
     181      // so if the prevFactor is any other than invFactor (last possible type) then we only need to be able to add one variable
     182      // otherwise we need to be able to add at least as many variables as the previous factor
     183      return !curTerm.factors.Any() ||
     184             (nVars + curTerm.factors.Last().numVarReferences <= maxVariables);
     185    }
     186
     187    private bool IsAllowedAsNextFactorType(int followState) {
     188      // IsNewTermAllowed already ensures that we can add a term with enough variable references
     189
     190      // enforce constraints within terms (compare to prev factor)
     191      if (curTerm.factors.Any()) {
     192        // enforce non-decreasing factor types
     193        if (curTerm.factors.Last().factorType > followState) return false;
     194        // when the factor type is the same, starting a new factor is only allowed if we can add at least the number of variables of the prev factor
     195        if (curTerm.factors.Last().factorType == followState && nVars + curTerm.factors.Last().numVarReferences > maxVariables) return false;
     196      }
     197
     198      // enforce constraints on terms (compare to prev term)
     199      // meaning that we must ensure non-decreasing terms
     200      if (prevTerm != null) {
     201        // a factor type is only allowed if we can then produce a term that is larger or equal to the prev term
     202        // (1) if we the number of variable references still remaining is larger than the number of variable references in the prev term
     203        //     then it is always possible to build a larger term
     204        // (2) otherwise we try to build the largest possible term starting from current factors in the term.
     205        //     
     206
     207        var numVarRefsRemaining = maxVariables - nVars;
     208        Contract.Assert(!curTerm.factors.Any() || curTerm.factors.Last().numVarReferences <= numVarRefsRemaining);
     209
     210        if (prevTerm.numVarReferences < numVarRefsRemaining) return true;
     211
     212        // variable factors must be handled differently because they can only contain one variable reference
     213        if (followState == Automaton.StateVariableFactorStart) {
     214          // append the variable factor and the maximum possible state from the previous factor to create a larger factor
     215          var varF = CreateLargestPossibleFactor(Automaton.StateVariableFactorStart, 1);
     216          var maxF = CreateLargestPossibleFactor(prevTerm.factors.Max(f => f.factorType), numVarRefsRemaining - 1);
     217          var origFactorCount = curTerm.factors.Count;
     218          // add this factor to the current term
     219          curTerm.factors.Add(varF);
     220          curTerm.factors.Add(maxF);
     221          var c = CompareTerms(prevTerm, curTerm);
     222          // restore term
     223          curTerm.factors.RemoveRange(origFactorCount, 2);
     224          // if the prev term is still larger then this followstate is not allowed
     225          if (c > 0) {
     226            return false;
     227          }
     228        } else {
     229          var newF = CreateLargestPossibleFactor(followState, numVarRefsRemaining);
     230
     231          var origFactorCount = curTerm.factors.Count;
     232          // add this factor to the current term
     233          curTerm.factors.Add(newF);
     234          var c = CompareTerms(prevTerm, curTerm);
     235          // restore term
     236          curTerm.factors.RemoveAt(origFactorCount);
     237          // if the prev term is still larger then this followstate is not allowed
     238          if (c > 0) {
     239            return false;
     240          }
     241        }
     242      }
     243      return true;
     244    }
     245
     246    // largest possible factor of the given kind
     247    private FactorInformation CreateLargestPossibleFactor(int factorType, int numVarRefs) {
     248      var newF = new FactorInformation();
     249      newF.factorType = factorType;
     250      if (factorType == Automaton.StateVariableFactorStart) {
     251        newF.variableState = int.MaxValue;
     252        newF.numVarReferences = 1;
     253      } else if (factorType == Automaton.StateExpFactorStart) {
     254        for (int i = 0; i < numVarRefs; i++)
     255          newF.expVariableStates.Add(int.MaxValue);
     256        newF.numVarReferences = numVarRefs;
     257      } else if (factorType == Automaton.StateInvFactorStart || factorType == Automaton.StateLogFactorStart) {
     258        for (int i = 0; i < numVarRefs; i++) {
     259          newF.polyVariableStates.Add(new List<int>());
     260          newF.polyVariableStates[i].Add(int.MaxValue);
     261        }
     262        newF.numVarReferences = numVarRefs;
     263      }
     264      return newF;
     265    }
     266
     267    private bool IsAllowedAsNextVariableFactor(int variableState) {
     268      Contract.Assert(variableState >= Automaton.FirstDynamicState);
     269      return !curTerm.factors.Any() || curTerm.factors.Last().variableState <= variableState;
     270    }
     271
     272    private bool IsAllowedAsNextInExp(int variableState) {
     273      Contract.Assert(variableState >= Automaton.FirstDynamicState);
     274      if (curFactor.expVariableStates.Any() && curFactor.expVariableStates.Last() > variableState) return false;
     275      if (curTerm.factors.Any()) {
     276        // try and compare with prev factor     
     277        curFactor.numVarReferences++;
     278        curFactor.expVariableStates.Add(variableState);
     279        var c = CompareFactors(curTerm.factors.Last(), curFactor);
     280        curFactor.numVarReferences--;
     281        curFactor.expVariableStates.RemoveAt(curFactor.expVariableStates.Count - 1);
     282        return c <= 0;
     283      }
     284      return true;
     285    }
     286
     287    private bool IsNewTermAllowedInPoly() {
     288      return nVars + curFactor.polyVariableStates.Last().Count() <= maxVariables;
     289    }
     290
     291    private bool IsAllowedAsNextInPoly(int variableState) {
     292      Contract.Assert(variableState >= Automaton.FirstDynamicState);
     293      return !curFactor.polyVariableStates.Any() ||
     294             !curFactor.polyVariableStates.Last().Any() ||
     295              curFactor.polyVariableStates.Last().Last() <= variableState;
     296    }
     297    private bool IsTermCompleteInPoly() {
     298      var nTerms = curFactor.polyVariableStates.Count;
     299      return nTerms == 1 ||
     300             curFactor.polyVariableStates[nTerms - 2].Count <= curFactor.polyVariableStates[nTerms - 1].Count;
     301
     302    }
     303    private bool IsCompleteExp() {
     304      return !curTerm.factors.Any() || CompareFactors(curTerm.factors.Last(), curFactor) <= 0;
     305    }
     306
    63307    public bool IsAllowedFollowState(int currentState, int followState) {
    64       // the following states are always allowed
     308      // an invalid action was taken earlier on => nothing can be done anymore
     309      if (invalidExpression) return false;
     310      // states that have no alternative are always allowed
     311      // some ending states are only allowed if enough variables have been used in the term
    65312      if (
    66         followState == Automaton.StateVariableFactorEnd ||
    67         followState == Automaton.StateExpFEnd ||
    68         followState == Automaton.StateExpFactorEnd ||
    69         followState == Automaton.StateLogTFEnd ||
    70         followState == Automaton.StateLogTEnd ||
    71         followState == Automaton.StateLogFactorEnd ||
    72         followState == Automaton.StateInvTFEnd ||
    73         followState == Automaton.StateInvTEnd ||
    74         followState == Automaton.StateInvFactorEnd ||
    75         followState == Automaton.StateFactorEnd ||
    76         followState == Automaton.StateTermEnd ||
    77         followState == Automaton.StateExprEnd
     313        currentState == Automaton.StateTermStart ||           // no alternative
     314        currentState == Automaton.StateExpFactorStart ||
     315        currentState == Automaton.StateLogFactorStart ||
     316        currentState == Automaton.StateInvFactorStart ||
     317        followState == Automaton.StateVariableFactorEnd ||    // no alternative
     318        followState == Automaton.StateExpFEnd ||              // no alternative
     319        followState == Automaton.StateLogTFEnd ||             // no alternative
     320        followState == Automaton.StateInvTFEnd ||             // no alternative
     321        followState == Automaton.StateFactorEnd ||            // always allowed because no alternative
     322        followState == Automaton.StateExprEnd                 // we could also constrain the minimum number of terms here
    78323      ) return true;
    79324
    80325
    81       // all other states are only allowed if we can add more variables
    82       if (nVars >= maxVariables) return false;
    83 
    84       // the following states are always allowed when we can add more variables
     326      // starting a new term is only allowed if we can add a term with at least the number of variables of the prev term
     327      if (followState == Automaton.StateTermStart && !IsNewTermAllowed()) return false;
     328      if (followState == Automaton.StateFactorStart && !IsNewFactorAllowed()) return false;
     329      if (currentState == Automaton.StateFactorStart && !IsAllowedAsNextFactorType(followState)) return false;
     330      if (followState == Automaton.StateTermEnd && prevTerm != null && CompareTerms(prevTerm, curTerm) > 0) return false;
     331
     332      // all of these states add at least one variable
    85333      if (
    86         followState == Automaton.StateTermStart ||
    87         followState == Automaton.StateFactorStart ||
    88         followState == Automaton.StateExpFStart ||
    89         followState == Automaton.StateLogTStart ||
    90         followState == Automaton.StateLogTFStart ||
    91         followState == Automaton.StateInvTStart ||
    92         followState == Automaton.StateInvTFStart
    93         ) return true;
    94 
    95       // enforce non-decreasing factor types
    96       if (currentState == Automaton.StateFactorStart) {
    97         if (curFactorType < 0) {
    98           //    FirstFactorType(t_i) <= FirstFactorType(t_j) for each pair t_i, t_j where i < j
    99           return prevTermFirstFactorType <= followState;
    100         } else {
    101           // FactorType(f_i) <= FactorType(f_j) for each pair of factors f_i, f_j and i < j
    102           return curFactorType <= followState;
    103         }
    104       }
    105       // enforce non-decreasing variables references in variable and exp factors
    106       if (currentState == Automaton.StateVariableFactorStart || currentState == Automaton.StateExpFStart || currentState == Automaton.StateLogTFStart || currentState == Automaton.StateInvTFStart) {
    107         if (prevVariableRef > followState) return false; // never allow decreasing variables
    108         if (prevFactorType < 0) {
    109           // FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair t_i, t_j where i < j
    110           return prevTermFirstVariableState <= followState;
    111         } else if (prevFactorType == curFactorType) {
    112           // (FirstVarReference(f_i) <= FirstVarReference(f_j) for each pair of factors f_i, f_j and i < j and FactorType(f_i) = FactorType(f_j)
    113           return prevFactorFirstVariableState <= followState;
    114         }
    115       }
    116 
    117 
    118       return true;
     334          followState == Automaton.StateVariableFactorStart ||
     335          followState == Automaton.StateExpFactorStart || followState == Automaton.StateExpFStart ||
     336          followState == Automaton.StateLogFactorStart || followState == Automaton.StateLogTStart ||
     337          followState == Automaton.StateLogTFStart ||
     338          followState == Automaton.StateInvFactorStart || followState == Automaton.StateInvTStart ||
     339          followState == Automaton.StateInvTFStart) {
     340        if (nVars + 1 > maxVariables) return false;
     341      }
     342
     343      if (currentState == Automaton.StateVariableFactorStart && !IsAllowedAsNextVariableFactor(followState)) return false;
     344      else if (currentState == Automaton.StateExpFStart && !IsAllowedAsNextInExp(followState)) return false;
     345      else if (followState == Automaton.StateLogTStart && !IsNewTermAllowedInPoly()) return false;
     346      else if (currentState == Automaton.StateLogTFStart && !IsAllowedAsNextInPoly(followState)) return false;
     347      else if (followState == Automaton.StateInvTStart && !IsNewTermAllowedInPoly()) return false;
     348      else if (currentState == Automaton.StateInvTFStart && !IsAllowedAsNextInPoly(followState)) return false;
     349      // finishing an exponential factor is only allowed when the number of variable references is large enough
     350      else if (followState == Automaton.StateExpFactorEnd && !IsCompleteExp()) return false;
     351      // finishing a polynomial (in log or inv) is only allowed when the number of variable references is large enough
     352      else if (followState == Automaton.StateInvTEnd && !IsTermCompleteInPoly()) return false;
     353      else if (followState == Automaton.StateLogTEnd && !IsTermCompleteInPoly()) return false;
     354
     355      else if (nVars > maxVariables) return false;
     356      else return true;
    119357    }
    120358
     
    122360    public void Reset() {
    123361      nVars = 0;
    124 
    125 
    126       prevTermFirstVariableState = -1;
    127       curTermFirstVariableState = -1;
    128       prevTermFirstFactorType = -1;
    129       curTermFirstFactorType = -1;
    130       prevVariableRef = -1;
    131       prevFactorType = -1;
    132       curFactorType = -1;
    133       curFactorFirstVariableState = -1;
    134       prevFactorFirstVariableState = -1;
     362      prevTerm = null;
     363      curTerm = null;
     364      curFactor = null;
     365      invalidExpression = false;
    135366    }
    136367
    137368    public void StartTerm() {
    138       // reset factor type. in each term we can start with each type of factor
    139       prevTermFirstVariableState = curTermFirstVariableState;
    140       curTermFirstVariableState = -1;
    141 
    142       prevTermFirstFactorType = curTermFirstFactorType;
    143       curTermFirstFactorType = -1;
    144 
    145 
    146       prevFactorType = -1;
    147       curFactorType = -1;
    148 
    149       curFactorFirstVariableState = -1;
    150       prevFactorFirstVariableState = -1;
     369      curTerm = new TermInformation();
    151370    }
    152371
    153372    public void StartFactor(int state) {
    154       prevFactorType = curFactorType;
    155       curFactorType = -1;
    156 
    157       prevFactorFirstVariableState = curFactorFirstVariableState;
    158       curFactorFirstVariableState = -1;
    159 
    160 
    161       // store the first factor type
    162       if (curTermFirstFactorType < 0) {
    163         curTermFirstFactorType = state;
    164       }
    165       curFactorType = state;
    166 
    167       // reset variable references. in each factor we can start with each variable reference
    168       prevVariableRef = -1;
     373      curFactor = new FactorInformation();
     374      curFactor.factorType = state;
    169375    }
    170376
    171377
    172378    public void AddVarToCurrentFactor(int state) {
    173 
    174       Contract.Assert(prevVariableRef <= state);
    175 
    176       // store the first variable reference for each factor
    177       if (curFactorFirstVariableState < 0) {
    178         curFactorFirstVariableState = state;
    179 
    180         // store the first variable reference for each term
    181         if (curTermFirstVariableState < 0) {
    182           curTermFirstVariableState = state;
    183         }
    184       }
    185       prevVariableRef = state;
     379      Contract.Assert(Automaton.FirstDynamicState <= state);
     380      Contract.Assert(curTerm != null);
     381      Contract.Assert(curFactor != null);
    186382
    187383      nVars++;
     384      curFactor.numVarReferences++;
     385
     386      if (curFactor.factorType == Automaton.StateVariableFactorStart) {
     387        Contract.Assert(curFactor.variableState < 0); // not set before
     388        curFactor.variableState = state;
     389      } else if (curFactor.factorType == Automaton.StateExpFactorStart) {
     390        curFactor.expVariableStates.Add(state);
     391      } else if (curFactor.factorType == Automaton.StateLogFactorStart ||
     392                 curFactor.factorType == Automaton.StateInvFactorStart) {
     393        curFactor.polyVariableStates.Last().Add(state);
     394      } else throw new InvalidProgramException();
     395    }
     396
     397    public void StartNewTermInPoly() {
     398      curFactor.polyVariableStates.Add(new List<int>());
    188399    }
    189400
    190401    public void EndFactor() {
    191       Contract.Assert(prevFactorFirstVariableState <= curFactorFirstVariableState);
    192       Contract.Assert(prevFactorType <= curFactorType);
     402      // enforce non-decreasing factors
     403      if (curTerm.factors.Any() && CompareFactors(curTerm.factors.Last(), curFactor) > 0)
     404        invalidExpression = true;
     405      curTerm.factors.Add(curFactor);
     406      curFactor = null;
    193407    }
    194408
    195409    public void EndTerm() {
    196 
    197       Contract.Assert(prevFactorType <= curFactorType);
    198       Contract.Assert(prevTermFirstVariableState <= curTermFirstVariableState);
     410      // enforce non-decreasing terms (TODO: equal terms should not be allowed)
     411      if (prevTerm != null && CompareTerms(prevTerm, curTerm) > 0)
     412        invalidExpression = true;
     413      prevTerm = curTerm;
     414      curTerm = null;
    199415    }
    200416  }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Disassembler.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    2423
    2524namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    26 #if DEBUG
    2725  internal class Disassembler {
    2826    public static string CodeToString(byte[] code, double[] consts) {
     
    4038          case (byte)OpCodes.LoadVar:
    4139          {
    42               short arg = (short)(((short)code[pc] << 8) | (short)code[pc + 1]);
     40              short arg = (short)((code[pc] << 8) | code[pc + 1]);
    4341              pc += 2;
    4442            sb.AppendFormat(" var{0} ", arg); break;
     
    5250    }
    5351  }
    54 #endif
    5552}
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/ExpressionEvaluator.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    2928  internal class ExpressionEvaluator {
    3029    // manages it's own vector buffers
    31     private readonly List<double[]> vectorBuffers = new List<double[]>();
    32     private readonly List<double[]> scalarBuffers = new List<double[]>(); // scalars are vectors of length 1 (to allow mixing scalars and vectors on the same stack)
     30    private readonly double[][] vectorBuffers;
     31    private readonly double[][] scalarBuffers; // scalars are vectors of length 1 (to allow mixing scalars and vectors on the same stack)
     32    private int lastVecBufIdx;
     33    private int lastScalarBufIdx;
    3334
    3435
    3536    private double[] GetVectorBuffer() {
    36       var v = vectorBuffers[vectorBuffers.Count - 1];
    37       vectorBuffers.RemoveAt(vectorBuffers.Count - 1);
    38       return v;
     37      return vectorBuffers[--lastVecBufIdx];
    3938    }
    4039    private double[] GetScalarBuffer() {
    41       var v = scalarBuffers[scalarBuffers.Count - 1];
    42       scalarBuffers.RemoveAt(scalarBuffers.Count - 1);
    43       return v;
     40      return scalarBuffers[--lastScalarBufIdx];
    4441    }
    4542
    4643    private void ReleaseBuffer(double[] buf) {
    47       (buf.Length == 1 ? scalarBuffers : vectorBuffers).Add(buf);
     44      if (buf.Length == 1) {
     45        scalarBuffers[lastScalarBufIdx++] = buf;
     46      } else {
     47        vectorBuffers[lastVecBufIdx++] = buf;
     48      }
    4849    }
    4950
     
    7475
    7576      // preallocate buffers
     77      vectorBuffers = new double[MaxStackSize * (1 + MaxParams)][];
     78      scalarBuffers = new double[MaxStackSize * (1 + MaxParams)][];
    7679      for (int i = 0; i < MaxStackSize; i++) {
    7780        ReleaseBuffer(new double[vLen]);
     
    9598      short arg;
    9699      // checked at the end to make sure we do not leak buffers
    97       int initialScalarCount = scalarBuffers.Count;
    98       int initialVectorCount = vectorBuffers.Count;
     100      int initialScalarCount = lastScalarBufIdx;
     101      int initialVectorCount = lastVecBufIdx;
    99102
    100103      while (true) {
     
    180183
    181184                var f = 1.0 / (maxFx * consts[curParamIdx]);
    182                 // adjust c so that maxFx*c = 1
     185                // adjust c so that maxFx*c = 1 TODO: this is not ideal as it enforces positive arguments to exp()
    183186                consts[curParamIdx] *= f;
    184187
     
    212215            }
    213216            ReleaseBuffer(r);
    214             Contract.Assert(vectorBuffers.Count == initialVectorCount);
    215             Contract.Assert(scalarBuffers.Count == initialScalarCount);
     217            Contract.Assert(lastVecBufIdx == initialVectorCount);
     218            Contract.Assert(lastScalarBufIdx == initialScalarCount);
    216219            return;
    217220        }
     
    233236
    234237      // checked at the end to make sure we do not leak buffers
    235       int initialScalarCount = scalarBuffers.Count;
    236       int initialVectorCount = vectorBuffers.Count;
     238      int initialScalarCount = lastScalarBufIdx;
     239      int initialVectorCount = lastVecBufIdx;
    237240
    238241      while (true) {
     
    401404            }
    402405
    403             Contract.Assert(vectorBuffers.Count == initialVectorCount);
    404             Contract.Assert(scalarBuffers.Count == initialScalarCount);
     406            Contract.Assert(lastVecBufIdx == initialVectorCount);
     407            Contract.Assert(lastScalarBufIdx == initialScalarCount);
    405408            return; // break loop
    406409        }
     
    509512      s = 0;
    510513      if (op == (byte)OpCodes.LoadVar) {
    511         s = (short)(((short)code[pc] << 8) | (short)code[pc + 1]);
     514        s = (short)((code[pc] << 8) | code[pc + 1]);
    512515        pc += 2;
    513516      }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    2524using System.Runtime.CompilerServices;
    2625using System.Threading;
     26using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    2727using HeuristicLab.Analysis;
    2828using HeuristicLab.Common;
     
    5353    private const string AllowedFactorsParameterName = "Allowed factors";
    5454    private const string ConstantOptimizationIterationsParameterName = "Iterations (constant optimization)";
    55     private const string CParameterName = "C";
     55    private const string PolicyParameterName = "Policy";
    5656    private const string SeedParameterName = "Seed";
    5757    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
     
    7171      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
    7272    }
    73     public IFixedValueParameter<IntValue> MaxSizeParameter {
     73    public IFixedValueParameter<IntValue> MaxVariableReferencesParameter {
    7474      get { return (IFixedValueParameter<IntValue>)Parameters[MaxVariablesParameterName]; }
    7575    }
     
    8080      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
    8181    }
    82     public IFixedValueParameter<DoubleValue> CParameter {
    83       get { return (IFixedValueParameter<DoubleValue>)Parameters[CParameterName]; }
     82    public IValueParameter<IPolicy> PolicyParameter {
     83      get { return (IValueParameter<IPolicy>)Parameters[PolicyParameterName]; }
    8484    }
    8585    public IFixedValueParameter<DoubleValue> PunishmentFactorParameter {
     
    116116      set { SetSeedRandomlyParameter.Value.Value = value; }
    117117    }
    118     public int MaxSize {
    119       get { return MaxSizeParameter.Value.Value; }
    120       set { MaxSizeParameter.Value.Value = value; }
    121     }
    122     public double C {
    123       get { return CParameter.Value.Value; }
    124       set { CParameter.Value.Value = value; }
    125     }
    126 
     118    public int MaxVariableReferences {
     119      get { return MaxVariableReferencesParameter.Value.Value; }
     120      set { MaxVariableReferencesParameter.Value.Value = value; }
     121    }
     122    public IPolicy Policy {
     123      get { return PolicyParameter.Value; }
     124      set { PolicyParameter.Value = value; }
     125    }
    127126    public double PunishmentFactor {
    128127      get { return PunishmentFactorParameter.Value.Value; }
     
    174173      Parameters.Add(new FixedValueParameter<IntValue>(MaxVariablesParameterName,
    175174        "Maximal number of variables references in the symbolic regression models (multiple usages of the same variable are counted)", new IntValue(5)));
    176       Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName,
    177         "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0)));
     175      // Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName,
     176      //   "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0)));
     177      Parameters.Add(new ValueParameter<IPolicy>(PolicyParameterName,
     178        "The policy to use for selecting nodes in MCTS (e.g. Ucb)", new Ucb()));
     179      PolicyParameter.Hidden = true;
    178180      Parameters.Add(new ValueParameter<ICheckedItemList<StringValue>>(AllowedFactorsParameterName,
    179181        "Choose which expressions are allowed as factors in the model.", defaultFactorsList));
     
    207209      Results.Add(new Result("Iterations", iterations));
    208210
     211      var bestSolutionIteration = new IntValue(0);
     212      Results.Add(new Result("Best solution iteration", bestSolutionIteration));
     213
    209214      var table = new DataTable("Qualities");
    210215      table.Rows.Add(new DataRow("Best quality"));
     
    221226      var avgQuality = new DoubleValue();
    222227      Results.Add(new Result("Average quality", avgQuality));
     228
     229      var totalRollouts = new IntValue();
     230      Results.Add(new Result("Total rollouts", totalRollouts));
     231      var effRollouts = new IntValue();
     232      Results.Add(new Result("Effective rollouts", effRollouts));
     233      var funcEvals = new IntValue();
     234      Results.Add(new Result("Function evaluations", funcEvals));
     235      var gradEvals = new IntValue();
     236      Results.Add(new Result("Gradient evaluations", gradEvals));
     237
    223238
    224239      // same as in SymbolicRegressionSingleObjectiveProblem
     
    235250      var problemData = (IRegressionProblemData)Problem.ProblemData.Clone();
    236251      if (!AllowedFactors.CheckedItems.Any()) throw new ArgumentException("At least on type of factor must be allowed");
    237       var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxSize, C, ScaleVariables, ConstantOptimizationIterations,
     252      var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, ScaleVariables, ConstantOptimizationIterations,
     253        Policy,
    238254        lowerLimit, upperLimit,
    239255        allowProdOfVars: AllowedFactors.CheckedItems.Any(s => s.Value.Value == VariableProductFactorName),
     
    248264      double bestQ = 0.0;
    249265      double curBestQ = 0.0;
    250       double q = 0.0;
    251266      int n = 0;
    252267      // Loop until iteration limit reached or canceled.
     
    254269        cancellationToken.ThrowIfCancellationRequested();
    255270
    256         q = MctsSymbolicRegressionStatic.MakeStep(state);
     271        var q = MctsSymbolicRegressionStatic.MakeStep(state);
    257272        sumQ += q; // sum of qs in the last updateinterval iterations
    258273        curBestQ = Math.Max(q, curBestQ); // the best q in the last updateinterval iterations
     
    261276        // iteration results
    262277        if (n == updateInterval) {
     278          if (bestQ > bestQuality.Value) {
     279            bestSolutionIteration.Value = i;
     280          }
    263281          bestQuality.Value = bestQ;
    264282          curQuality.Value = curBestQ;
     
    266284          sumQ = 0.0;
    267285          curBestQ = 0.0;
     286
     287          funcEvals.Value = state.FuncEvaluations;
     288          gradEvals.Value = state.GradEvaluations;
     289          effRollouts.Value = state.EffectiveRollouts;
     290          totalRollouts.Value = state.TotalRollouts;
    268291
    269292          table.Rows["Best quality"].Values.Add(bestQuality.Value);
     
    277300      // final results
    278301      if (n > 0) {
     302        if (bestQ > bestQuality.Value) {
     303          bestSolutionIteration.Value = iterations.Value + n;
     304        }
    279305        bestQuality.Value = bestQ;
    280306        curQuality.Value = curBestQ;
    281307        avgQuality.Value = sumQ / n;
     308
     309        funcEvals.Value = state.FuncEvaluations;
     310        gradEvals.Value = state.GradEvaluations;
     311        effRollouts.Value = state.EffectiveRollouts;
     312        totalRollouts.Value = state.TotalRollouts;
    282313
    283314        table.Rows["Best quality"].Values.Add(bestQuality.Value);
     
    285316        table.Rows["Average quality"].Values.Add(avgQuality.Value);
    286317        iterations.Value = iterations.Value + n;
     318
    287319      }
    288320
     
    290322      Results.Add(new Result("Best solution quality (train)", new DoubleValue(state.BestSolutionTrainingQuality)));
    291323      Results.Add(new Result("Best solution quality (test)", new DoubleValue(state.BestSolutionTestQuality)));
     324
    292325
    293326      // produce solution
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    2524using System.Diagnostics.Contracts;
    2625using System.Linq;
     26using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
    2727using HeuristicLab.Common;
    2828using HeuristicLab.Core;
     
    4545      double BestSolutionTrainingQuality { get; }
    4646      double BestSolutionTestQuality { get; }
     47      int TotalRollouts { get; }
     48      int EffectiveRollouts { get; }
     49      int FuncEvaluations { get; }
     50      int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
     51      // TODO other stats on LM optimizer might be interesting here
    4752    }
    4853
     
    5459      internal readonly Automaton automaton;
    5560      internal IRandom random { get; private set; }
    56       internal readonly double c;
    5761      internal readonly Tree tree;
    58       internal readonly List<Tree> bestChildrenBuf;
    5962      internal readonly Func<byte[], int, double> evalFun;
     63      internal readonly IPolicy treePolicy;
     64      // MCTS might get stuck. Track statistics on the number of effective rollouts
     65      internal int totalRollouts;
     66      internal int effectiveRollouts;
    6067
    6168
     
    7986      private double[] bestConsts;
    8087
     88      // stats
     89      private int funcEvaluations;
     90      private int gradEvaluations;
     91
    8192      // buffers
    8293      private readonly double[] ones; // vector of ones (as default params)
     
    8596      private readonly double[][] gradBuf;
    8697
    87       public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, double c, bool scaleVariables, int constOptIterations,
     98      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, int constOptIterations,
     99        IPolicy treePolicy = null,
    88100        double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
    89101        bool allowProdOfVars = true,
     
    94106
    95107        this.problemData = problemData;
    96         this.c = c;
    97108        this.constOptIterations = constOptIterations;
    98109        this.evalFun = this.Eval;
     
    123134
    124135        this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
    125         this.tree = new Tree() { state = automaton.CurrentState };
     136        this.treePolicy = treePolicy ?? new Ucb();
     137        this.tree = new Tree() { state = automaton.CurrentState, actionStatistics = treePolicy.CreateActionStatistics() };
    126138
    127139        // reset best solution
     
    135147        this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray();
    136148        constsBuf = new double[MaxParams];
    137         this.bestChildrenBuf = new List<Tree>(2 * x.Length); // the number of follow states in the automaton is O(number of variables) 2 * number of variables should be sufficient (capacity is increased if necessary anyway)
    138149        this.predBuf = new double[y.Length];
    139150        this.testPredBuf = new double[testY.Length];
     
    143154
    144155      #region IState inferface
    145       public bool Done { get { return tree != null && tree.done; } }
     156      public bool Done { get { return tree != null && tree.Done; } }
    146157
    147158      public double BestSolutionTrainingQuality {
     
    164175          var treeGen = new SymbolicExpressionTreeGenerator(problemData.AllowedInputVariables.ToArray());
    165176          var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
    166           var simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier();
    167177
    168178          var t = new SymbolicExpressionTree(treeGen.Exec(bestCode, bestConsts, bestNParams, scalingFactor, scalingOffset));
    169           var simpleT = simplifier.Simplify(t);
    170           var model = new SymbolicRegressionModel(simpleT, interpreter, lowerEstimationLimit, upperEstimationLimit);
     179          var model = new SymbolicRegressionModel(t, interpreter, lowerEstimationLimit, upperEstimationLimit);
    171180
    172181          // model has already been scaled linearly in Eval
     
    174183        }
    175184      }
     185
     186      public int TotalRollouts { get { return totalRollouts; } }
     187      public int EffectiveRollouts { get { return effectiveRollouts; } }
     188      public int FuncEvaluations { get { return funcEvaluations; } }
     189      public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations
     190
    176191      #endregion
    177192
     
    205220        Array.Copy(ones, constsBuf, nParams);
    206221        evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true);
     222        funcEvaluations++;
    207223
    208224        // calc opt scaling (alpha*f(x) + beta)
     
    221237          // optimize constants using the starting point calculated above
    222238          OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations);
     239
    223240          evaluator.Exec(code, x, constsBuf, predBuf);
     241          funcEvaluations++;
     242
    224243          rsq = RSq(y, predBuf);
    225244          optConsts = constsBuf;
     
    238257
    239258      private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) {
    240         double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt
     259        double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?)
    241260        Array.Copy(consts, optConsts, nParams);
    242261
     
    248267        alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code);
    249268        alglib.minlmresults(state, out optConsts, out rep);
     269        funcEvaluations += rep.nfunc;
     270        gradEvaluations += rep.njac * nParams;
    250271
    251272        if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype);
     
    258279
    259280      private void Func(double[] arg, double[] fi, object obj) {
    260         // 0.5 * MSE and gradient
    261281        var code = (byte[])obj;
    262282        evaluator.Exec(code, x, arg, predBuf); // gradients are nParams x vLen
     
    282302    }
    283303
    284     public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, double c = 1.0,
    285       bool scaleVariables = true, int constOptIterations = 0, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     304    public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3,
     305      bool scaleVariables = true, int constOptIterations = 0,
     306      IPolicy policy = null,
     307      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
    286308      bool allowProdOfVars = true,
    287309      bool allowExp = true,
     
    290312      bool allowMultipleTerms = false
    291313      ) {
    292       return new State(problemData, randSeed, maxVariables, c, scaleVariables, constOptIterations,
     314      return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations,
     315        policy,
    293316        lowerEstimationLimit, upperEstimationLimit,
    294317        allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
     
    309332      var tree = mctsState.tree;
    310333      var eval = mctsState.evalFun;
    311       var bestChildrenBuf = mctsState.bestChildrenBuf;
    312334      var rand = mctsState.random;
    313       double c = mctsState.c;
    314 
    315       automaton.Reset();
    316       return TreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf);
    317     }
    318 
    319     private static double TreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf) {
     335      var treePolicy = mctsState.treePolicy;
     336      double q = 0;
     337      bool success = false;
     338      do {
     339        automaton.Reset();
     340        success = TryTreeSearchRec(rand, tree, automaton, eval, treePolicy, out q);
     341        mctsState.totalRollouts++;
     342      } while (!success && !tree.Done);
     343      mctsState.effectiveRollouts++;
     344      return q;
     345    }
     346
     347    // tree search might fail because of constraints for expressions
     348    // in this case we get stuck we just restart
     349    // see ConstraintHandler.cs for more info
     350    private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
     351      out double q) {
    320352      Tree selectedChild = null;
    321       double q;
    322353      Contract.Assert(tree.state == automaton.CurrentState);
    323       Contract.Assert(!tree.done);
     354      Contract.Assert(!tree.Done);
    324355      if (tree.children == null) {
    325356        if (automaton.IsFinalState(tree.state)) {
    326357          // final state
    327           tree.done = true;
     358          tree.Done = true;
    328359
    329360          // EVALUATE
     
    331362          automaton.GetCode(out code, out nParams);
    332363          q = eval(code, nParams);
    333           tree.visits++;
    334           tree.sumQuality += q;
    335           return q;
     364
     365          treePolicy.Update(tree.actionStatistics, q);
     366          return true; // we reached a final state
    336367        } else {
    337368          // EXPAND
     
    339370          int nFs;
    340371          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
    341 
     372          if (nFs == 0) {
     373            // stuck in a dead end (no final state and no allowed follow states)
     374            q = 0;
     375            tree.Done = true;
     376            tree.children = null;
     377            return false;
     378          }
    342379          tree.children = new Tree[nFs];
    343           for (int i = 0; i < tree.children.Length; i++) tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
    344 
    345           selectedChild = SelectFinalOrRandom(automaton, tree, rand);
     380          for (int i = 0; i < tree.children.Length; i++)
     381            tree.children[i] = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics() };
     382
     383          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
    346384        }
    347385      } else {
    348386        // tree.children != null
    349387        // UCT selection within tree
    350         selectedChild = SelectUct(tree, rand, c, bestChildrenBuf);
     388        int selectedIdx = 0;
     389        if (tree.children.Length > 1) {
     390          selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand);
     391        }
     392        selectedChild = tree.children[selectedIdx];
    351393      }
    352394      // make selected step and recurse
    353395      automaton.Goto(selectedChild.state);
    354       q = TreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf);
    355 
    356       tree.sumQuality += q;
    357       tree.visits++;
    358 
    359       // tree.done = tree.children.All(ch => ch.done);
    360       tree.done = true; for (int i = 0; i < tree.children.Length && tree.done; i++) tree.done = tree.children[i].done;
    361       if (tree.done) {
    362         tree.children = null; // cut of the sub-branch if it has been fully explored
    363         // TODO: update all qualities and visits to remove the information gained from this whole branch
    364       }
    365       return q;
    366     }
    367 
    368     private static Tree SelectUct(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {
    369       // determine total tries of still active children
    370       int totalTries = 0;
    371       bestChildrenBuf.Clear();
    372       for (int i = 0; i < tree.children.Length; i++) {
    373         var ch = tree.children[i];
    374         if (ch.done) continue;
    375         if (ch.visits == 0) bestChildrenBuf.Add(ch);
    376         else totalTries += tree.children[i].visits;
    377       }
    378       // if there are unvisited children select a random child
    379       if (bestChildrenBuf.Any()) {
    380         return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
    381       }
    382       Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done
    383       double logTotalTries = Math.Log(totalTries);
    384       var bestQ = double.NegativeInfinity;
    385       for (int i = 0; i < tree.children.Length; i++) {
    386         var ch = tree.children[i];
    387         if (ch.done) continue;
    388         var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits);
    389         if (childQ > bestQ) {
    390           bestChildrenBuf.Clear();
    391           bestChildrenBuf.Add(ch);
    392           bestQ = childQ;
    393         } else if (childQ >= bestQ) {
    394           bestChildrenBuf.Add(ch);
    395         }
    396       }
    397       return bestChildrenBuf.Count > 0 ? bestChildrenBuf[rand.Next(bestChildrenBuf.Count)] : bestChildrenBuf[0];
     396      var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q);
     397      if (success) {
     398        // only update if successful
     399        treePolicy.Update(tree.actionStatistics, q);
     400      }
     401
     402      tree.Done = tree.children.All(ch => ch.Done);
     403      if (tree.Done) {
     404        tree.children = null; // cut off the sub-branch if it has been fully explored
     405      }
     406      return success;
    398407    }
    399408
     
    409418        }
    410419      }
    411       // no final state -> select a random child
     420      // no final state -> select a the first child
    412421      if (selectedChildIdx == -1) {
    413         selectedChildIdx = rand.Next(tree.children.Length);
     422        selectedChildIdx = 0;
    414423      }
    415424      return tree.children[selectedChildIdx];
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/OpCodes.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/EpsGreedy.cs

    r13659 r15060  
    99using HeuristicLab.Data;
    1010using HeuristicLab.Parameters;
     11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    1112
    1213namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
     14  [StorableClass]
    1315  [Item("EpsilonGreedy", "Epsilon greedy policy with parameter eps to balance between exploitation and exploration")]
    14   internal class EpsilonGreedy : PolicyBase {
     16  public class EpsilonGreedy : PolicyBase {
    1517    private class ActionStatistics : IActionStatistics {
    1618      public double SumQuality { get; set; }
     
    3032    }
    3133
    32     private EpsilonGreedy(EpsilonGreedy original, Cloner cloner)
     34    [StorableConstructor]
     35    protected EpsilonGreedy(bool deserializing) : base(deserializing) { }
     36    protected EpsilonGreedy(EpsilonGreedy original, Cloner cloner)
    3337      : base(original, cloner) {
    3438    }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/PolicyBase.cs

    r13659 r15060  
    1111namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
    1212  [StorableClass]
    13   internal abstract class PolicyBase : Item, IParameterizedItem, IPolicy {
     13  public abstract class PolicyBase : Item, IParameterizedItem, IPolicy {
    1414    [Storable]
    1515    public IKeyedItemCollection<string, IParameter> Parameters { get; private set; }
    1616
    1717    [StorableConstructor]
    18     private PolicyBase(bool deserializing) : base(deserializing) { }
     18    protected PolicyBase(bool deserializing) : base(deserializing) { }
    1919    protected PolicyBase(PolicyBase original, Cloner cloner)
    2020      : base(original, cloner) {
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs

    r13659 r15060  
    99using HeuristicLab.Data;
    1010using HeuristicLab.Parameters;
     11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    1112
    1213namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
     14  [StorableClass]
    1315  [Item("Ucb Policy", "Ucb with parameter c to balance between exploitation and exploration")]
    14   internal class Ucb : PolicyBase {
     16  public class Ucb : PolicyBase {
    1517    private class ActionStatistics : IActionStatistics {
    1618      public double SumQuality { get; set; }
     
    3032    }
    3133
    32     private Ucb(Ucb original, Cloner cloner)
     34    [StorableConstructor]
     35    protected Ucb(bool deserializing) : base(deserializing) { }
     36    protected Ucb(Ucb original, Cloner cloner)
    3337      : base(original, cloner) {
    3438    }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/UcbTuned.cs

    r13659 r15060  
    99using HeuristicLab.Data;
    1010using HeuristicLab.Parameters;
     11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    1112
    1213namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
     14  [StorableClass]
    1315  [Item("UcbTuned Policy", "UcbTuned is similar to Ucb but tracks empirical variance. Use parameter c to balance between exploitation and exploration")]
    14   internal class UcbTuned : PolicyBase {
     16  public class UcbTuned : PolicyBase {
    1517    private class ActionStatistics : IActionStatistics {
    1618      public double SumQuality { get; set; }
     
    3234    }
    3335
    34     private UcbTuned(UcbTuned original, Cloner cloner)
     36    [StorableConstructor]
     37    protected UcbTuned(bool deserializing) : base(deserializing) { }
     38    protected UcbTuned(UcbTuned original, Cloner cloner)
    3539      : base(original, cloner) {
    3640    }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/SymbolicExpressionGenerator.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    114113            break;
    115114          case OpCodes.Add: {
    116               var t1 = stack[topOfStack];
    117               var t2 = stack[topOfStack - 1];
     115              var t1 = stack[topOfStack - 1];
     116              var t2 = stack[topOfStack];
    118117              topOfStack--;
    119               if (t2.Symbol is Addition) {
    120                 t2.AddSubtree(t1);
     118              if (t1.Symbol is Addition) {
     119                t1.AddSubtree(t2);
    121120              } else {
    122121                var addNode = addSy.CreateTreeNode();
     
    128127            }
    129128          case OpCodes.Mul: {
    130               var t1 = stack[topOfStack];
    131               var t2 = stack[topOfStack - 1];
     129              var t1 = stack[topOfStack - 1];
     130              var t2 = stack[topOfStack];
    132131              topOfStack--;
    133               if (t2.Symbol is Multiplication) {
    134                 t2.AddSubtree(t1);
     132              if (t1.Symbol is Multiplication) {
     133                t1.AddSubtree(t2);
    135134              } else {
    136135                var mulNode = mulSy.CreateTreeNode();
     
    177176      s = 0;
    178177      if (op == OpCodes.LoadVar) {
    179         s = (short)(((short)code[pc] << 8) | (short)code[pc + 1]);
     178        s = (short)((code[pc] << 8) | code[pc + 1]);
    180179        pc += 2;
    181180      }
  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Tree.cs

    r13645 r15060  
    22/* HeuristicLab
    33 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    4  * and the BEACON Center for the Study of Evolution in Action.
    54 *
    65 * This file is part of HeuristicLab.
     
    2120#endregion
    2221
     22using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies;
     23
    2324namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
    2425  // represents tree nodes for the search tree in MCTS
    2526  internal class Tree {
    2627    public int state;
    27     public int visits;
    28     public double sumQuality;
    29     public double AverageQuality { get { return sumQuality / (double)visits; } }
    30     public bool done;
     28    public bool Done {
     29      get { return actionStatistics.Done; }
     30      set { actionStatistics.Done = value; }
     31    }
     32    public IActionStatistics actionStatistics;
    3133    public Tree[] children;
    3234  }
Note: See TracChangeset for help on using the changeset viewer.