Changeset 13651
- Timestamp:
- 03/05/16 08:25:08 (9 years ago)
- Location:
- trunk/sources
- Files:
-
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs
r13650 r13651 26 26 namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression { 27 27 // this is the core class for generating expressions. 28 // the automaton determines which expressions are allowed 28 // it represents a finite state automaton, each state transition can be associated with an action (e.g. to produce code). 29 // the automaton determines the possible structures for expressions. 30 // 31 // to understand this code it is worthwile to generate a graphical visualization of the automaton (see PrintAutomaton). 32 // If the code is compiled in debug mode the automaton produces a Graphviz file into the folder of the application 33 // whenever an instance of the automaton is constructed. 34 // 35 // This class relies on two other classes: 36 // - CodeGenerator to produce code for a stack-based evaluator and 37 // - ConstraintHandler to restrict the allowed set of expressions. 38 // 39 // The ConstraintHandler extends the automaton and adds semantic restrictions for expressions produced by the automaton. 40 // 41 // 29 42 internal class Automaton { 30 43 public const int StateExpr = 1; … … 52 65 public const int StateInvTFStart = 23; 53 66 public const int StateInvTFEnd = 24; 54 private const int FirstDynamicState = 25; 67 public const int FirstDynamicState = 25; 68 // more states for individual variables are created dynamically 55 69 56 70 private const int StartState = StateExpr; … … 221 235 () => { 222 236 codeGenerator.Emit1(OpCodes.LoadConst0); 223 }, 224 "0"); 237 constraintHandler.StartNewTermInPoly(); 238 }, 239 "0, StartTermInPoly"); 225 240 AddTransition(StateLogTEnd, StateLogFactorEnd, 226 241 () => { … … 271 286 () => { 272 287 codeGenerator.Emit1(OpCodes.LoadConst1); 273 }, 274 "c"); 288 constraintHandler.StartNewTermInPoly(); 289 }, 290 "c, StartTermInPoly"); 275 291 AddTransition(StateInvTEnd, StateInvFactorEnd, 276 292 () => { … … 337 353 private readonly int[] followStatesBuf = new int[1000]; 338 354 public void FollowStates(int state, out int[] buf, out int nElements) { 339 // return followStates[state]340 // .Where(s => s < FirstDynamicState || s >= minVarIdx) // for variables we only allow non-decreasing state sequences341 // // the following states imply an additional variable being added to the expression342 // // F, Sum, Prod343 // .Where(s => (s != StateF && s != StateSum && s != StateProd) || variablesRemaining > 0);344 345 355 // for loop instead of where iterator 346 356 var fs = followStates[state]; 347 357 int j = 0; 348 //Console.Write(stateNames[CurrentState] + " allowed: ");349 358 for (int i = 0; i < fs.Count; i++) { 350 359 var s = fs[i]; 351 360 if (constraintHandler.IsAllowedFollowState(state, s)) { 352 //Console.Write(s + " ");353 361 followStatesBuf[j++] = s; 354 362 } 355 363 } 356 //Console.WriteLine();357 364 buf = followStatesBuf; 358 365 nElements = j; … … 361 368 362 369 public void Goto(int targetState) { 363 //Console.WriteLine("->{0}", stateNames[targetState]);364 // Contract.Assert(FollowStates(CurrentState).Contains(targetState));365 366 370 if (actions[CurrentState, targetState] != null) 367 371 actions[CurrentState, targetState].ForEach(a => a()); // execute all actions … … 370 374 371 375 public bool IsFinalState(int s) { 372 return s == StateExprEnd ;376 return s == StateExprEnd && !constraintHandler.IsInvalidExpression; 373 377 } 374 378 -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/ConstraintHandler.cs
r13645 r13651 2 2 /* HeuristicLab 3 3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 * and the BEACON Center for the Study of Evolution in Action.5 4 * 6 5 * This file is part of HeuristicLab. … … 21 20 #endregion 22 21 23 22 using System; 23 using System.Collections.Generic; 24 24 using System.Diagnostics.Contracts; 25 using System.Linq; 25 26 26 27 namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression { 27 28 28 // more states for individual variables are created dynamically 29 // This class restricts the set of allowed transitions of the automaton to prevent exploration of duplicate expressions. 30 // It would be possible to implement this class in such a way that the search never visits a duplicate expression. However, 31 // it seems very intricate to detect this robustly and in all cases while generating an expression because 32 // some for of lookahead is necessary. 33 // Instead the constraint handler only catches the obvious duplicates directly, but does not guarantee that the search always produces a valid expression. 34 // The ratio of the number of unsuccessful searches, that need backtracking should be tracked in the MCTS alg (MctsSymbolicRegressionStatic) 35 36 // All changes to this class should be tested through unit tests. It is important that the ConstraintHandler is not too restrictive. 37 38 // the constraints are derived from a canonical form for expressions. 39 // overall we can enforce a limited number of variable references 40 // 41 // an expression is a sum of terms t_1 ... t_n where terms are ordered according to a relation t_i (<=)_term t_j for each pair t_i, t_j and i <= j 42 // a term is a product of factors where factors are ordered according to relation f_i (<=)_factor f_j for each pair f_i,f_j and i <= j 43 44 // we want to enforce lower-order terms before higher-order terms in expressions (based on number of variable references) 45 // factors can have different types (variable, exp, log, inverse) 46 47 // (<=)_term [IsSmallerOrEqualTerm(t_i, t_j)] 48 // 1. NumberOfVarRefs(t_i) < NumberOfVarRefs(t_j) --> true enforce terms with non-decreasing number of var refs 49 // 2. NumberOfVarRefs(t_i) > NumberOfVarRefs(t_j) --> false 50 // 3. NumFactors(t_i) > NumFactors(t_j) --> true enforce terms with non-increasing number of factors 51 // 4. NumFactors(t_i) < NumFactors(t_j) --> false 52 // 5. for all k factors: Factor(k, t_i) (<=)_factor Factor(k, t_j) --> true // factors must be non-decreasing 53 // 6. all factors are (=)_factor --> true 54 // 7. else false 55 56 // (<=)_factor [IsSmallerOrEqualFactor(f_i, f_j)] 57 // 1. FactorType(t_i) < FactorType(t_j) --> true enforce terms with non-decreasing factor type (var < exp < log < inv) 58 // 2. FactorType(t_i) > FactorType(t_j) --> false 59 // 3. Compare the two factors specifically 60 // - variables: varIdx_i <= varIdx_j (only one var reference) 61 // - exp: number of variable references and then varIdx_i <= varIdx_j for each position 62 // - log: number of variable references and ... 63 // - inv: number of variable references and ... 64 // 65 66 // for log and inverse factors we allow all polynomials as argument 67 // a polynomial is a sum of terms t_1 ... t_n where terms are ordered according to a relation t_i (<=)_poly t_j for each pair t_i, t_j and i <= j 68 69 // (<=)_poly [IsSmallerOrEqualPoly(t_i, t_j)] 70 // 1. NumberOfVarRefs(t_i) < NumberOfVarRefs(t_j) --> true // enforce non-decreasing number of var refs 71 // 2. NumberOfVarRefs(t_i) > NumberOfVarRefs(t_j) --> false // enforce non-decreasing number of var refs 72 // 3. for all k variables: VarIdx(k,t_i) > VarIdx(k, t_j) --> false // enforce non-decreasing variable idx 73 74 75 // we store the following to make comparsions: 76 // - prevTerm (complete & containing all factors) 77 // - curTerm (incomplete & containing all completed factors) 78 // - curFactor (incomplete) 29 79 internal class ConstraintHandler { 30 80 private int nVars; 31 81 private readonly int maxVariables; 32 33 public int prevTermFirstVariableState; 34 public int curTermFirstVariableState; 35 public int prevTermFirstFactorType; 36 public int curTermFirstFactorType; 37 public int prevFactorType; 38 public int curFactorType; 39 public int prevFactorFirstVariableState; 40 public int curFactorFirstVariableState; 41 public int prevVariableRef; 82 private bool invalidExpression; 83 84 public bool IsInvalidExpression { 85 get { return invalidExpression; } 86 } 87 88 89 private TermInformation prevTerm; 90 private TermInformation curTerm; 91 private FactorInformation curFactor; 92 93 94 private class TermInformation { 95 public int numVarReferences { get { return factors.Sum(f => f.numVarReferences); } } 96 public List<FactorInformation> factors = new List<FactorInformation>(); 97 } 98 99 private class FactorInformation { 100 public int numVarReferences = 0; 101 public int factorType; // use the state number to represent types 102 103 // for variable factors 104 public int variableState = -1; 105 106 // for exp factors 107 public List<int> expVariableStates = new List<int>(); 108 109 // for log and inv factors 110 public List<List<int>> polyVariableStates = new List<List<int>>(); 111 } 42 112 43 113 … … 46 116 } 47 117 48 // 1) an expression is a sum of terms t_1 ... t_n 49 // FirstFactorType(t_i) <= FirstFactorType(t_j) for each pair t_i, t_j where i < j 50 // FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair t_i, t_j where i < j and FirstFactorType(t_i) = FirstFactorType(t_j) 51 // 2) a term is a product of factors, each factor is either a variable factor, an exp factor, a log factor or an inverse factor 52 // FactorType(f_i) <= FactorType(f_j) for each pair of factors f_i, f_j and i < j 53 // FirstVarReference(f_i) <= FirstVarReference(f_j) for each pair of factors f_i, f_j and i < j and FactorType(f_i) = FactorType(f_j) 54 // 3) a variable factor is a product of variable references v1...vn 55 // VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j 56 // (IMPLICIT) FirstVarReference(t) <= VarIdx(v_i) for each variable reference v_i in term t 57 // 4) an exponential factor is the exponential of a product of variables v1...vn 58 // VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j 59 // (IMPLICIT) FirstVarReference(t) <= VarIdx(v_i) for each variable reference v_i in term t 60 // 5) a log factor is a sum of terms t_i where each term is a product of variables 61 // FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair of terms t_i, t_j and i < j 62 // for each term t: VarIdx(v_i) <= VarIdx(v_j) for each pair of variable references v_i, v_j and i < j in t 118 // the order relations for terms and factors 119 120 private static int CompareTerms(TermInformation a, TermInformation b) { 121 if (a.numVarReferences < b.numVarReferences) return -1; 122 if (a.numVarReferences > b.numVarReferences) return 1; 123 124 if (a.factors.Count > b.factors.Count) return -1; // terms with more factors should be ordered first 125 if (a.factors.Count < b.factors.Count) return +1; 126 127 var aFactors = a.factors.GetEnumerator(); 128 var bFactors = b.factors.GetEnumerator(); 129 while (aFactors.MoveNext() & bFactors.MoveNext()) { 130 var c = CompareFactors(aFactors.Current, bFactors.Current); 131 if (c < 0) return -1; 132 if (c > 0) return 1; 133 } 134 // all factors are the same => terms are the same 135 return 0; 136 } 137 138 private static int CompareFactors(FactorInformation a, FactorInformation b) { 139 if (a.factorType < b.factorType) return -1; 140 if (a.factorType > b.factorType) return +1; 141 // same factor types 142 if (a.factorType == Automaton.StateVariableFactorStart) { 143 return a.variableState.CompareTo(b.variableState); 144 } else if (a.factorType == Automaton.StateExpFactorStart) { 145 return CompareStateLists(a.expVariableStates, b.expVariableStates); 146 } else { 147 if (a.numVarReferences < b.numVarReferences) return -1; 148 if (a.numVarReferences > b.numVarReferences) return +1; 149 if (a.polyVariableStates.Count > b.polyVariableStates.Count) return -1; // more terms in the poly should be ordered first 150 if (a.polyVariableStates.Count < b.polyVariableStates.Count) return +1; 151 // log and inv 152 var aTerms = a.polyVariableStates.GetEnumerator(); 153 var bTerms = b.polyVariableStates.GetEnumerator(); 154 while (aTerms.MoveNext() & bTerms.MoveNext()) { 155 var c = CompareStateLists(aTerms.Current, bTerms.Current); 156 if (c != 0) return c; 157 } 158 return 0; // all terms in the polynomial are the same 159 } 160 } 161 162 private static int CompareStateLists(List<int> a, List<int> b) { 163 if (a.Count < b.Count) return -1; 164 if (a.Count > b.Count) return +1; 165 for (int i = 0; i < a.Count; i++) { 166 if (a[i] < b[i]) return -1; 167 if (a[i] > b[i]) return +1; 168 } 169 return 0; // all states are the same 170 } 171 172 173 private bool IsNewTermAllowed() { 174 // next term must have at least as many variable references as the previous term 175 return prevTerm == null || nVars + prevTerm.numVarReferences <= maxVariables; 176 } 177 178 private bool IsNewFactorAllowed() { 179 // next factor must have a larger or equal type compared to the previous factor. 180 // if the types are the same it must have at least as many variable references. 181 // so if the prevFactor is any other than invFactor (last possible type) then we only need to be able to add one variable 182 // otherwise we need to be able to add at least as many variables as the previous factor 183 return !curTerm.factors.Any() || 184 (nVars + curTerm.factors.Last().numVarReferences <= maxVariables); 185 } 186 187 private bool IsAllowedAsNextFactorType(int followState) { 188 // IsNewTermAllowed already ensures that we can add a term with enough variable references 189 190 // enforce constraints within terms (compare to prev factor) 191 if (curTerm.factors.Any()) { 192 // enforce non-decreasing factor types 193 if (curTerm.factors.Last().factorType > followState) return false; 194 // when the factor type is the same, starting a new factor is only allowed if we can add at least the number of variables of the prev factor 195 if (curTerm.factors.Last().factorType == followState && nVars + curTerm.factors.Last().numVarReferences > maxVariables) return false; 196 } 197 198 // enforce constraints on terms (compare to prev term) 199 // meaning that we must ensure non-decreasing terms 200 if (prevTerm != null) { 201 // a factor type is only allowed if we can then produce a term that is larger or equal to the prev term 202 // (1) if we the number of variable references still remaining is larger than the number of variable references in the prev term 203 // then it is always possible to build a larger term 204 // (2) otherwise we try to build the largest possible term starting from current factors in the term. 205 // 206 207 var numVarRefsRemaining = maxVariables - nVars; 208 Contract.Assert(!curTerm.factors.Any() || curTerm.factors.Last().numVarReferences <= numVarRefsRemaining); 209 210 if (prevTerm.numVarReferences < numVarRefsRemaining) return true; 211 212 // variable factors must be handled differently because they can only contain one variable reference 213 if (followState == Automaton.StateVariableFactorStart) { 214 // append the variable factor and the maximum possible state from the previous factor to create a larger factor 215 var varF = CreateLargestPossibleFactor(Automaton.StateVariableFactorStart, 1); 216 var maxF = CreateLargestPossibleFactor(prevTerm.factors.Max(f => f.factorType), numVarRefsRemaining - 1); 217 var origFactorCount = curTerm.factors.Count; 218 // add this factor to the current term 219 curTerm.factors.Add(varF); 220 curTerm.factors.Add(maxF); 221 var c = CompareTerms(prevTerm, curTerm); 222 // restore term 223 curTerm.factors.RemoveRange(origFactorCount, 2); 224 // if the prev term is still larger then this followstate is not allowed 225 if (c > 0) { 226 return false; 227 } 228 } else { 229 var newF = CreateLargestPossibleFactor(followState, numVarRefsRemaining); 230 231 var origFactorCount = curTerm.factors.Count; 232 // add this factor to the current term 233 curTerm.factors.Add(newF); 234 var c = CompareTerms(prevTerm, curTerm); 235 // restore term 236 curTerm.factors.RemoveAt(origFactorCount); 237 // if the prev term is still larger then this followstate is not allowed 238 if (c > 0) { 239 return false; 240 } 241 } 242 } 243 return true; 244 } 245 246 // largest possible factor of the given kind 247 private FactorInformation CreateLargestPossibleFactor(int factorType, int numVarRefs) { 248 var newF = new FactorInformation(); 249 newF.factorType = factorType; 250 if (factorType == Automaton.StateVariableFactorStart) { 251 newF.variableState = int.MaxValue; 252 newF.numVarReferences = 1; 253 } else if (factorType == Automaton.StateExpFactorStart) { 254 for (int i = 0; i < numVarRefs; i++) 255 newF.expVariableStates.Add(int.MaxValue); 256 newF.numVarReferences = numVarRefs; 257 } else if (factorType == Automaton.StateInvFactorStart || factorType == Automaton.StateLogFactorStart) { 258 for (int i = 0; i < numVarRefs; i++) { 259 newF.polyVariableStates.Add(new List<int>()); 260 newF.polyVariableStates[i].Add(int.MaxValue); 261 } 262 newF.numVarReferences = numVarRefs; 263 } 264 return newF; 265 } 266 267 private bool IsAllowedAsNextVariableFactor(int variableState) { 268 Contract.Assert(variableState >= Automaton.FirstDynamicState); 269 return !curTerm.factors.Any() || curTerm.factors.Last().variableState <= variableState; 270 } 271 272 private bool IsAllowedAsNextInExp(int variableState) { 273 Contract.Assert(variableState >= Automaton.FirstDynamicState); 274 if (curFactor.expVariableStates.Any() && curFactor.expVariableStates.Last() > variableState) return false; 275 if (curTerm.factors.Any()) { 276 // try and compare with prev factor 277 curFactor.numVarReferences++; 278 curFactor.expVariableStates.Add(variableState); 279 var c = CompareFactors(curTerm.factors.Last(), curFactor); 280 curFactor.numVarReferences--; 281 curFactor.expVariableStates.RemoveAt(curFactor.expVariableStates.Count - 1); 282 return c <= 0; 283 } 284 return true; 285 } 286 287 private bool IsNewTermAllowedInPoly() { 288 return nVars + curFactor.polyVariableStates.Last().Count() <= maxVariables; 289 } 290 291 private bool IsAllowedAsNextInPoly(int variableState) { 292 Contract.Assert(variableState >= Automaton.FirstDynamicState); 293 return !curFactor.polyVariableStates.Any() || 294 !curFactor.polyVariableStates.Last().Any() || 295 curFactor.polyVariableStates.Last().Last() <= variableState; 296 } 297 private bool IsTermCompleteInPoly() { 298 var nTerms = curFactor.polyVariableStates.Count; 299 return nTerms == 1 || 300 curFactor.polyVariableStates[nTerms - 2].Count <= curFactor.polyVariableStates[nTerms - 1].Count; 301 302 } 303 private bool IsCompleteExp() { 304 return !curTerm.factors.Any() || CompareFactors(curTerm.factors.Last(), curFactor) <= 0; 305 } 306 63 307 public bool IsAllowedFollowState(int currentState, int followState) { 64 // the following states are always allowed 308 // an invalid action was taken earlier on => nothing can be done anymore 309 if (invalidExpression) return false; 310 // states that have no alternative are always allowed 311 // some ending states are only allowed if enough variables have been used in the term 65 312 if ( 66 followState == Automaton.StateVariableFactorEnd || 67 followState == Automaton.StateExpFEnd || 68 followState == Automaton.StateExpFactorEnd || 69 followState == Automaton.StateLogTFEnd || 70 followState == Automaton.StateLogTEnd || 71 followState == Automaton.StateLogFactorEnd || 72 followState == Automaton.StateInvTFEnd || 73 followState == Automaton.StateInvTEnd || 74 followState == Automaton.StateInvFactorEnd || 75 followState == Automaton.StateFactorEnd || 76 followState == Automaton.StateTermEnd || 77 followState == Automaton.StateExprEnd 313 currentState == Automaton.StateTermStart || // no alternative 314 currentState == Automaton.StateExpFactorStart || 315 currentState == Automaton.StateLogFactorStart || 316 currentState == Automaton.StateInvFactorStart || 317 followState == Automaton.StateVariableFactorEnd || // no alternative 318 followState == Automaton.StateExpFEnd || // no alternative 319 followState == Automaton.StateLogTFEnd || // no alternative 320 followState == Automaton.StateInvTFEnd || // no alternative 321 followState == Automaton.StateFactorEnd || // always allowed because no alternative 322 followState == Automaton.StateExprEnd // we could also constrain the minimum number of terms here 78 323 ) return true; 79 324 80 325 81 // all other states are only allowed if we can add more variables 82 if (nVars >= maxVariables) return false; 83 84 // the following states are always allowed when we can add more variables 326 // starting a new term is only allowed if we can add a term with at least the number of variables of the prev term 327 if (followState == Automaton.StateTermStart && !IsNewTermAllowed()) return false; 328 if (followState == Automaton.StateFactorStart && !IsNewFactorAllowed()) return false; 329 if (currentState == Automaton.StateFactorStart && !IsAllowedAsNextFactorType(followState)) return false; 330 if (followState == Automaton.StateTermEnd && prevTerm != null && CompareTerms(prevTerm, curTerm) > 0) return false; 331 332 // all of these states add at least one variable 85 333 if ( 86 followState == Automaton.StateTermStart || 87 followState == Automaton.StateFactorStart || 88 followState == Automaton.StateExpFStart || 89 followState == Automaton.StateLogTStart || 90 followState == Automaton.StateLogTFStart || 91 followState == Automaton.StateInvTStart || 92 followState == Automaton.StateInvTFStart 93 ) return true; 94 95 // enforce non-decreasing factor types 96 if (currentState == Automaton.StateFactorStart) { 97 if (curFactorType < 0) { 98 // FirstFactorType(t_i) <= FirstFactorType(t_j) for each pair t_i, t_j where i < j 99 return prevTermFirstFactorType <= followState; 100 } else { 101 // FactorType(f_i) <= FactorType(f_j) for each pair of factors f_i, f_j and i < j 102 return curFactorType <= followState; 103 } 104 } 105 // enforce non-decreasing variables references in variable and exp factors 106 if (currentState == Automaton.StateVariableFactorStart || currentState == Automaton.StateExpFStart || currentState == Automaton.StateLogTFStart || currentState == Automaton.StateInvTFStart) { 107 if (prevVariableRef > followState) return false; // never allow decreasing variables 108 if (prevFactorType < 0) { 109 // FirstVarReference(t_i) <= FirstVarReference(t_j) for each pair t_i, t_j where i < j 110 return prevTermFirstVariableState <= followState; 111 } else if (prevFactorType == curFactorType) { 112 // (FirstVarReference(f_i) <= FirstVarReference(f_j) for each pair of factors f_i, f_j and i < j and FactorType(f_i) = FactorType(f_j) 113 return prevFactorFirstVariableState <= followState; 114 } 115 } 116 117 118 return true; 334 followState == Automaton.StateVariableFactorStart || 335 followState == Automaton.StateExpFactorStart || followState == Automaton.StateExpFStart || 336 followState == Automaton.StateLogFactorStart || followState == Automaton.StateLogTStart || 337 followState == Automaton.StateLogTFStart || 338 followState == Automaton.StateInvFactorStart || followState == Automaton.StateInvTStart || 339 followState == Automaton.StateInvTFStart) { 340 if (nVars + 1 > maxVariables) return false; 341 } 342 343 if (currentState == Automaton.StateVariableFactorStart && !IsAllowedAsNextVariableFactor(followState)) return false; 344 else if (currentState == Automaton.StateExpFStart && !IsAllowedAsNextInExp(followState)) return false; 345 else if (followState == Automaton.StateLogTStart && !IsNewTermAllowedInPoly()) return false; 346 else if (currentState == Automaton.StateLogTFStart && !IsAllowedAsNextInPoly(followState)) return false; 347 else if (followState == Automaton.StateInvTStart && !IsNewTermAllowedInPoly()) return false; 348 else if (currentState == Automaton.StateInvTFStart && !IsAllowedAsNextInPoly(followState)) return false; 349 // finishing an exponential factor is only allowed when the number of variable references is large enough 350 else if (followState == Automaton.StateExpFactorEnd && !IsCompleteExp()) return false; 351 // finishing a polynomial (in log or inv) is only allowed when the number of variable references is large enough 352 else if (followState == Automaton.StateInvTEnd && !IsTermCompleteInPoly()) return false; 353 else if (followState == Automaton.StateLogTEnd && !IsTermCompleteInPoly()) return false; 354 355 else if (nVars > maxVariables) return false; 356 else return true; 119 357 } 120 358 … … 122 360 public void Reset() { 123 361 nVars = 0; 124 125 126 prevTermFirstVariableState = -1; 127 curTermFirstVariableState = -1; 128 prevTermFirstFactorType = -1; 129 curTermFirstFactorType = -1; 130 prevVariableRef = -1; 131 prevFactorType = -1; 132 curFactorType = -1; 133 curFactorFirstVariableState = -1; 134 prevFactorFirstVariableState = -1; 362 prevTerm = null; 363 curTerm = null; 364 curFactor = null; 365 invalidExpression = false; 135 366 } 136 367 137 368 public void StartTerm() { 138 // reset factor type. in each term we can start with each type of factor 139 prevTermFirstVariableState = curTermFirstVariableState; 140 curTermFirstVariableState = -1; 141 142 prevTermFirstFactorType = curTermFirstFactorType; 143 curTermFirstFactorType = -1; 144 145 146 prevFactorType = -1; 147 curFactorType = -1; 148 149 curFactorFirstVariableState = -1; 150 prevFactorFirstVariableState = -1; 369 curTerm = new TermInformation(); 151 370 } 152 371 153 372 public void StartFactor(int state) { 154 prevFactorType = curFactorType; 155 curFactorType = -1; 156 157 prevFactorFirstVariableState = curFactorFirstVariableState; 158 curFactorFirstVariableState = -1; 159 160 161 // store the first factor type 162 if (curTermFirstFactorType < 0) { 163 curTermFirstFactorType = state; 164 } 165 curFactorType = state; 166 167 // reset variable references. in each factor we can start with each variable reference 168 prevVariableRef = -1; 373 curFactor = new FactorInformation(); 374 curFactor.factorType = state; 169 375 } 170 376 171 377 172 378 public void AddVarToCurrentFactor(int state) { 173 174 Contract.Assert(prevVariableRef <= state); 175 176 // store the first variable reference for each factor 177 if (curFactorFirstVariableState < 0) { 178 curFactorFirstVariableState = state; 179 180 // store the first variable reference for each term 181 if (curTermFirstVariableState < 0) { 182 curTermFirstVariableState = state; 183 } 184 } 185 prevVariableRef = state; 379 Contract.Assert(Automaton.FirstDynamicState <= state); 380 Contract.Assert(curTerm != null); 381 Contract.Assert(curFactor != null); 186 382 187 383 nVars++; 384 curFactor.numVarReferences++; 385 386 if (curFactor.factorType == Automaton.StateVariableFactorStart) { 387 Contract.Assert(curFactor.variableState < 0); // not set before 388 curFactor.variableState = state; 389 } else if (curFactor.factorType == Automaton.StateExpFactorStart) { 390 curFactor.expVariableStates.Add(state); 391 } else if (curFactor.factorType == Automaton.StateLogFactorStart || 392 curFactor.factorType == Automaton.StateInvFactorStart) { 393 curFactor.polyVariableStates.Last().Add(state); 394 } else throw new InvalidProgramException(); 395 } 396 397 public void StartNewTermInPoly() { 398 curFactor.polyVariableStates.Add(new List<int>()); 188 399 } 189 400 190 401 public void EndFactor() { 191 Contract.Assert(prevFactorFirstVariableState <= curFactorFirstVariableState); 192 Contract.Assert(prevFactorType <= curFactorType); 402 // enforce non-decreasing factors 403 if (curTerm.factors.Any() && CompareFactors(curTerm.factors.Last(), curFactor) > 0) 404 invalidExpression = true; 405 curTerm.factors.Add(curFactor); 406 curFactor = null; 193 407 } 194 408 195 409 public void EndTerm() { 196 197 Contract.Assert(prevFactorType <= curFactorType); 198 Contract.Assert(prevTermFirstVariableState <= curTermFirstVariableState); 410 // enforce non-decreasing terms (TODO: equal terms should not be allowed) 411 if (prevTerm != null && CompareTerms(prevTerm, curTerm) > 0) 412 invalidExpression = true; 413 prevTerm = curTerm; 414 curTerm = null; 199 415 } 200 416 } -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Disassembler.cs
r13650 r13651 23 23 24 24 namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression { 25 #if DEBUG26 25 internal class Disassembler { 27 26 public static string CodeToString(byte[] code, double[] consts) { … … 51 50 } 52 51 } 53 #endif54 52 } -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/ExpressionEvaluator.cs
r13650 r13651 28 28 internal class ExpressionEvaluator { 29 29 // manages it's own vector buffers 30 private readonly List<double[]> vectorBuffers = new List<double[]>(); 31 private readonly List<double[]> scalarBuffers = new List<double[]>(); // scalars are vectors of length 1 (to allow mixing scalars and vectors on the same stack) 30 private readonly double[][] vectorBuffers; 31 private readonly double[][] scalarBuffers; // scalars are vectors of length 1 (to allow mixing scalars and vectors on the same stack) 32 private int lastVecBufIdx; 33 private int lastScalarBufIdx; 32 34 33 35 34 36 private double[] GetVectorBuffer() { 35 var v = vectorBuffers[vectorBuffers.Count - 1]; 36 vectorBuffers.RemoveAt(vectorBuffers.Count - 1); 37 return v; 37 return vectorBuffers[--lastVecBufIdx]; 38 38 } 39 39 private double[] GetScalarBuffer() { 40 var v = scalarBuffers[scalarBuffers.Count - 1]; 41 scalarBuffers.RemoveAt(scalarBuffers.Count - 1); 42 return v; 40 return scalarBuffers[--lastScalarBufIdx]; 43 41 } 44 42 45 43 private void ReleaseBuffer(double[] buf) { 46 (buf.Length == 1 ? scalarBuffers : vectorBuffers).Add(buf); 44 if (buf.Length == 1) { 45 scalarBuffers[lastScalarBufIdx++] = buf; 46 } else { 47 vectorBuffers[lastVecBufIdx++] = buf; 48 } 47 49 } 48 50 … … 73 75 74 76 // preallocate buffers 77 vectorBuffers = new double[MaxStackSize * (1 + MaxParams)][]; 78 scalarBuffers = new double[MaxStackSize * (1 + MaxParams)][]; 75 79 for (int i = 0; i < MaxStackSize; i++) { 76 80 ReleaseBuffer(new double[vLen]); … … 94 98 short arg; 95 99 // checked at the end to make sure we do not leak buffers 96 int initialScalarCount = scalarBuffers.Count;97 int initialVectorCount = vectorBuffers.Count;100 int initialScalarCount = lastScalarBufIdx; 101 int initialVectorCount = lastVecBufIdx; 98 102 99 103 while (true) { … … 179 183 180 184 var f = 1.0 / (maxFx * consts[curParamIdx]); 181 // adjust c so that maxFx*c = 1 185 // adjust c so that maxFx*c = 1 TODO: this is not ideal as enforce positive argument to exp() 182 186 consts[curParamIdx] *= f; 183 187 … … 211 215 } 212 216 ReleaseBuffer(r); 213 Contract.Assert( vectorBuffers.Count== initialVectorCount);214 Contract.Assert( scalarBuffers.Count== initialScalarCount);217 Contract.Assert(lastVecBufIdx == initialVectorCount); 218 Contract.Assert(lastScalarBufIdx == initialScalarCount); 215 219 return; 216 220 } … … 232 236 233 237 // checked at the end to make sure we do not leak buffers 234 int initialScalarCount = scalarBuffers.Count;235 int initialVectorCount = vectorBuffers.Count;238 int initialScalarCount = lastScalarBufIdx; 239 int initialVectorCount = lastVecBufIdx; 236 240 237 241 while (true) { … … 400 404 } 401 405 402 Contract.Assert( vectorBuffers.Count== initialVectorCount);403 Contract.Assert( scalarBuffers.Count== initialScalarCount);406 Contract.Assert(lastVecBufIdx == initialVectorCount); 407 Contract.Assert(lastScalarBufIdx == initialScalarCount); 404 408 return; // break loop 405 409 } -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs
r13650 r13651 221 221 Results.Add(new Result("Average quality", avgQuality)); 222 222 223 var totalRollouts = new IntValue(); 224 Results.Add(new Result("Total rollouts", totalRollouts)); 225 var effRollouts = new IntValue(); 226 Results.Add(new Result("Effective rollouts", effRollouts)); 227 var funcEvals = new IntValue(); 228 Results.Add(new Result("Function evaluations", funcEvals)); 229 var gradEvals = new IntValue(); 230 Results.Add(new Result("Gradient evaluations", gradEvals)); 231 232 223 233 // same as in SymbolicRegressionSingleObjectiveProblem 224 234 var y = Problem.ProblemData.Dataset.GetDoubleValues(Problem.ProblemData.TargetVariable, … … 266 276 curBestQ = 0.0; 267 277 278 funcEvals.Value = state.FuncEvaluations; 279 gradEvals.Value = state.GradEvaluations; 280 effRollouts.Value = state.EffectiveRollouts; 281 totalRollouts.Value = state.TotalRollouts; 282 268 283 table.Rows["Best quality"].Values.Add(bestQuality.Value); 269 284 table.Rows["Current best quality"].Values.Add(curQuality.Value); … … 280 295 avgQuality.Value = sumQ / n; 281 296 297 funcEvals.Value = state.FuncEvaluations; 298 gradEvals.Value = state.GradEvaluations; 299 effRollouts.Value = state.EffectiveRollouts; 300 totalRollouts.Value = state.TotalRollouts; 301 282 302 table.Rows["Best quality"].Values.Add(bestQuality.Value); 283 303 table.Rows["Current best quality"].Values.Add(curQuality.Value); 284 304 table.Rows["Average quality"].Values.Add(avgQuality.Value); 285 305 iterations.Value = iterations.Value + n; 306 286 307 } 287 308 … … 289 310 Results.Add(new Result("Best solution quality (train)", new DoubleValue(state.BestSolutionTrainingQuality))); 290 311 Results.Add(new Result("Best solution quality (test)", new DoubleValue(state.BestSolutionTestQuality))); 312 291 313 292 314 // produce solution -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r13650 r13651 44 44 double BestSolutionTrainingQuality { get; } 45 45 double BestSolutionTestQuality { get; } 46 int TotalRollouts { get; } 47 int EffectiveRollouts { get; } 48 int FuncEvaluations { get; } 49 int GradEvaluations { get; } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations 50 // TODO other stats on LM optimizer might be interesting here 46 51 } 47 52 … … 57 62 internal readonly List<Tree> bestChildrenBuf; 58 63 internal readonly Func<byte[], int, double> evalFun; 64 // MCTS might get stuck. Track statistics on the number of effective rollouts 65 internal int totalRollouts; 66 internal int effectiveRollouts; 59 67 60 68 … … 77 85 private int bestNParams; 78 86 private double[] bestConsts; 87 88 // stats 89 private int funcEvaluations; 90 private int gradEvaluations; 79 91 80 92 // buffers … … 173 185 } 174 186 } 187 188 public int TotalRollouts { get { return totalRollouts; } } 189 public int EffectiveRollouts { get { return effectiveRollouts; } } 190 public int FuncEvaluations { get { return funcEvaluations; } } 191 public int GradEvaluations { get { return gradEvaluations; } } // number of gradient evaluations (* num parameters) to get a value representative of the effort comparable to the number of function evaluations 192 175 193 #endregion 176 194 … … 204 222 Array.Copy(ones, constsBuf, nParams); 205 223 evaluator.Exec(code, x, constsBuf, predBuf, adjustOffsetForLogAndExp: true); 224 funcEvaluations++; 206 225 207 226 // calc opt scaling (alpha*f(x) + beta) … … 220 239 // optimize constants using the starting point calculated above 221 240 OptimizeConstsLm(code, constsBuf, nParams, 0.0, nIters: constOptIterations); 241 222 242 evaluator.Exec(code, x, constsBuf, predBuf); 243 funcEvaluations++; 244 223 245 rsq = RSq(y, predBuf); 224 246 optConsts = constsBuf; … … 237 259 238 260 private void OptimizeConstsLm(byte[] code, double[] consts, int nParams, double epsF = 0.0, int nIters = 100) { 239 double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt 261 double[] optConsts = new double[nParams]; // allocate a smaller buffer for constants opt (TODO perf?) 240 262 Array.Copy(consts, optConsts, nParams); 241 263 … … 247 269 alglib.minlmoptimize(state, Func, FuncAndJacobian, null, code); 248 270 alglib.minlmresults(state, out optConsts, out rep); 271 funcEvaluations += rep.nfunc; 272 gradEvaluations += rep.njac * nParams; 249 273 250 274 if (rep.terminationtype < 0) throw new ArgumentException("lm failed: termination type = " + rep.terminationtype); … … 311 335 var rand = mctsState.random; 312 336 double c = mctsState.c; 313 314 automaton.Reset(); 315 return TreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf); 316 } 317 318 private static double TreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf) { 337 double q = 0; 338 bool success = false; 339 do { 340 automaton.Reset(); 341 success = TryTreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf, out q); 342 mctsState.totalRollouts++; 343 } while (!success && !tree.done); 344 mctsState.effectiveRollouts++; 345 return q; 346 } 347 348 // tree search might fail because of constraints for expressions 349 // in this case we get stuck we just restart 350 // see ConstraintHandler.cs for more info 351 private static bool TryTreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf, 352 out double q) { 319 353 Tree selectedChild = null; 320 double q;321 354 Contract.Assert(tree.state == automaton.CurrentState); 322 355 Contract.Assert(!tree.done); … … 332 365 tree.visits++; 333 366 tree.sumQuality += q; 334 return q;367 return true; // we reached a final state 335 368 } else { 336 369 // EXPAND … … 338 371 int nFs; 339 372 automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs); 340 373 if (nFs == 0) { 374 // stuck in a dead end (no final state and no allowed follow states) 375 q = 0; 376 tree.done = true; 377 tree.children = null; 378 return false; 379 } 341 380 tree.children = new Tree[nFs]; 342 for (int i = 0; i < tree.children.Length; i++) tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 }; 381 for (int i = 0; i < tree.children.Length; i++) 382 tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 }; 343 383 344 384 selectedChild = SelectFinalOrRandom(automaton, tree, rand); … … 351 391 // make selected step and recurse 352 392 automaton.Goto(selectedChild.state); 353 q = TreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf); 354 355 tree.sumQuality += q; 356 tree.visits++; 393 var success = TryTreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf, out q); 394 if (success) { 395 // only update if successful 396 tree.sumQuality += q; 397 tree.visits++; 398 } 357 399 358 400 // tree.done = tree.children.All(ch => ch.done); 359 401 tree.done = true; for (int i = 0; i < tree.children.Length && tree.done; i++) tree.done = tree.children[i].done; 360 402 if (tree.done) { 361 tree.children = null; // cut of the sub-branch if it has been fully explored403 tree.children = null; // cut off the sub-branch if it has been fully explored 362 404 // TODO: update all qualities and visits to remove the information gained from this whole branch 363 405 } 364 return q;406 return success; 365 407 } 366 408 -
trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/MctsSymbolicRegressionTest.cs
r13648 r13651 1 1 using System; 2 using System.Diagnostics.Contracts; 2 3 using System.Linq; 3 4 using System.Threading; … … 31 32 { 32 33 // possible solutions with max two variable references: 34 // TODO: equal terms should not be allowed (see ConstraintHandler) 33 35 // x 34 36 // log(x) … … 40 42 // x * exp(x) 41 43 // x * 1/x 42 // x + x 44 // x + x ? 43 45 // x + log(x) 44 46 // x + exp(x) … … 48 50 // log(x) * exp(x) 49 51 // log(x) * 1/x 50 // log(x) + log(x) 51 // log(x) + exp(x) 52 // log(x) + log(x) ? 53 // log(x) + exp(x) ? 52 54 // log(x) + 1/x 53 55 // -- 6 54 56 // exp(x) * exp(x) 55 57 // exp(x) * 1/x 56 // exp(x) + exp(x) 58 // exp(x) + exp(x) ? 57 59 // exp(x) + 1/x 58 60 // -- 4 59 61 // 1/x * 1/x 60 // 1/x + 1/x 62 // 1/x + 1/x ? 61 63 // -- 2 62 // log(x+x) 64 // log(x+x) ? 63 65 // log(x*x) 64 66 // exp(x*x) 65 // 1/(x+x) 67 // 1/(x+x) ? 66 68 // 1/(x*x) 67 69 // -- 5 70 71 68 72 TestMctsNumberOfSolutions(regProblem, 2, 29); 69 73 } … … 75 79 // -- 2 76 80 // x * x 77 // x + x 81 // x + x ? 78 82 // x * exp(x) 79 83 // x + exp(x) 80 84 // exp(x) * exp(x) 81 // exp(x) + exp(x) 85 // exp(x) + exp(x) ? 82 86 // exp(x*x) 83 87 // -- 7 84 88 // x * x * x 85 // x + x * x 86 // x * x + x !! 87 // x + x + x 89 // x + x * x 90 // x + x + x ? 88 91 // x * x * exp(x) 89 // x + x * exp(x) 90 // x * x + exp(x) 91 // x + x + exp(x) 92 // x * exp(x) + x !! 93 // x * exp(x) + exp(x) 94 // x + exp(x) * exp(x) 95 // x + exp(x) + exp(x) 92 // x + x * exp(x) 93 // x + x + exp(x) ? 94 // exp(x) + x*x 95 // exp(x) + x*exp(x) 96 // x + exp(x) * exp(x) 97 // x + exp(x) + exp(x) ? 96 98 // x * exp(x) * exp(x) 97 99 // x * exp(x*x) 98 100 // x + exp(x*x) 99 // -- 1 5101 // -- 13 100 102 101 103 // exp(x) * exp(x) * exp(x) 102 // exp(x) + exp(x) * exp(x) 103 // exp(x) * exp(x) + exp(x) !! 104 // exp(x) + exp(x) + exp(x) 105 // -- 4 104 // exp(x) + exp(x) * exp(x) 105 // exp(x) + exp(x) + exp(x) ? 106 // -- 3 106 107 107 108 // exp(x) * exp(x*x) 108 109 // exp(x) + exp(x*x) 109 // exp(x*x) * exp(x) !! 110 // exp(x*x) + exp(x) !! 111 // -- 4 110 // -- 2 112 111 // exp(x*x*x) 113 112 // -- 1 114 115 TestMctsNumberOfSolutions(regProblem, 3, 2 + 7 + 15 + 4 + 4 + 1, allowLog: false, allowInv: false); 113 TestMctsNumberOfSolutions(regProblem, 3, 2 + 7 + 13 + 3 + 2 + 1, allowLog: false, allowInv: false); 114 } 115 { 116 // possible solutions with max 4 variable references: 117 // without exp, log and inv 118 // x 119 // x*x 120 // x+x ? 121 // x*x*x 122 // x+x*x 123 // x+x+x ? 124 // x*x*x*x 125 // x+x*x*x 126 // x*x+x*x ? 127 // x+x+x*x ? 128 // x+x+x+x ? 129 130 TestMctsNumberOfSolutions(regProblem, 4, 11, allowLog: false, allowInv: false, allowExp: false); 131 } 132 { 133 // possible solutions with max 5 variable references: 134 // without exp, log and inv 135 // x 136 // xx 137 // x+x ? 138 // xxx 139 // x+xx 140 // x+x+x ? 141 // xxxx 142 // x+xxx 143 // xx+xx ? 144 // x+x+xx ? 145 // x+x+x+x ? 146 // xxxxx 147 // x+xxxx 148 // xx+xxx 149 // x+x+xxx ? 150 // x+xx+xx ? 151 // x+x+x+xx ? 152 // x+x+x+x+x ? 153 TestMctsNumberOfSolutions(regProblem, 5, 18, allowLog: false, allowInv: false, allowExp: false); 116 154 } 117 155 } … … 236 274 #endregion 237 275 276 238 277 #region Nguyen 239 278 [TestMethod] … … 485 524 486 525 private void TestMctsNumberOfSolutions(IRegressionProblemData problemData, int maxNumberOfVariables, int expectedNumberOfSolutions, 526 bool allowProd = true, 487 527 bool allowExp = true, 488 528 bool allowLog = true, … … 494 534 regProblem.ProblemDataParameter.Value = problemData; 495 535 #region Algorithm Configuration 536 537 mctsSymbReg.SetSeedRandomly = false; 538 mctsSymbReg.Seed = 1234; 496 539 mctsSymbReg.Problem = regProblem; 497 540 mctsSymbReg.Iterations = int.MaxValue; // stopping when all solutions have been enumerated 498 541 mctsSymbReg.MaxSize = maxNumberOfVariables; 499 542 mctsSymbReg.C = 1000; // essentially breath first seach 543 mctsSymbReg.AllowedFactors.SetItemCheckedState(mctsSymbReg.AllowedFactors.Single(s => s.Value.StartsWith("prod")), allowProd); 500 544 mctsSymbReg.AllowedFactors.SetItemCheckedState(mctsSymbReg.AllowedFactors.Single(s => s.Value.Contains("exp")), allowExp); 501 545 mctsSymbReg.AllowedFactors.SetItemCheckedState(mctsSymbReg.AllowedFactors.Single(s => s.Value.Contains("log")), allowLog);
Note: See TracChangeset
for help on using the changeset viewer.