- Timestamp:
- 03/07/16 14:50:02 (9 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj
r13653 r13658 263 263 <Compile Include="MctsSymbolicRegression\MctsSymbolicRegressionStatic.cs" /> 264 264 <Compile Include="MctsSymbolicRegression\OpCodes.cs" /> 265 <Compile Include="MctsSymbolicRegression\Policies\EpsGreedy.cs" /> 266 <Compile Include="MctsSymbolicRegression\Policies\UcbTuned.cs" /> 267 <Compile Include="MctsSymbolicRegression\Policies\IActionStatistics.cs" /> 268 <Compile Include="MctsSymbolicRegression\Policies\IPolicy.cs" /> 269 <Compile Include="MctsSymbolicRegression\Policies\PolicyBase.cs" /> 270 <Compile Include="MctsSymbolicRegression\Policies\Ucb.cs" /> 265 271 <Compile Include="MctsSymbolicRegression\SymbolicExpressionGenerator.cs" /> 266 272 <Compile Include="MctsSymbolicRegression\Tree.cs" /> -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs
r13652 r13658 24 24 using System.Runtime.CompilerServices; 25 25 using System.Threading; 26 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies; 26 27 using HeuristicLab.Analysis; 27 28 using HeuristicLab.Common; … … 52 53 private const string AllowedFactorsParameterName = "Allowed factors"; 53 54 private const string ConstantOptimizationIterationsParameterName = "Iterations (constant optimization)"; 54 private const string CParameterName = "C";55 private const string PolicyParameterName = "Policy"; 55 56 private const string SeedParameterName = "Seed"; 56 57 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; … … 79 80 get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; } 80 81 } 81 public I FixedValueParameter<DoubleValue> CParameter {82 get { return (I FixedValueParameter<DoubleValue>)Parameters[CParameterName]; }82 public IValueParameter<IPolicy> PolicyParameter { 83 get { return (IValueParameter<IPolicy>)Parameters[PolicyParameterName]; } 83 84 } 84 85 public IFixedValueParameter<DoubleValue> PunishmentFactorParameter { … … 119 120 set { MaxVariableReferencesParameter.Value.Value = value; } 120 121 } 121 public double C { 122 get { return CParameter.Value.Value; } 123 set { CParameter.Value.Value = value; } 124 } 125 122 public IPolicy Policy { 123 get { return PolicyParameter.Value; } 124 set { PolicyParameter.Value = value; } 125 } 126 126 public double PunishmentFactor { 127 127 get { return PunishmentFactorParameter.Value.Value; } … … 173 173 Parameters.Add(new FixedValueParameter<IntValue>(MaxVariablesParameterName, 174 174 "Maximal number of variables references in the symbolic regression models (multiple usages of the same variable are counted)", new IntValue(5))); 175 Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName, 176 "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0))); 175 // Parameters.Add(new FixedValueParameter<DoubleValue>(CParameterName, 176 // "Balancing parameter in UCT formula (0 < c < 1000). Small values: greedy search. Large values: enumeration. Default: 1.0", new DoubleValue(1.0))); 177 Parameters.Add(new ValueParameter<IPolicy>(PolicyParameterName, 178 "The policy to use for selecting nodes in MCTS (e.g. Ucb)", new Ucb())); 179 PolicyParameter.Hidden = true; 177 180 Parameters.Add(new ValueParameter<ICheckedItemList<StringValue>>(AllowedFactorsParameterName, 178 181 "Choose which expressions are allowed as factors in the model.", defaultFactorsList)); … … 244 247 var problemData = (IRegressionProblemData)Problem.ProblemData.Clone(); 245 248 if (!AllowedFactors.CheckedItems.Any()) throw new ArgumentException("At least on type of factor must be allowed"); 246 var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, C, ScaleVariables, ConstantOptimizationIterations, 249 var state = MctsSymbolicRegressionStatic.CreateState(problemData, (uint)Seed, MaxVariableReferences, ScaleVariables, ConstantOptimizationIterations, 250 Policy, 247 251 lowerLimit, upperLimit, 248 252 allowProdOfVars: AllowedFactors.CheckedItems.Any(s => s.Value.Value == VariableProductFactorName), -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs
r13657 r13658 24 24 using System.Diagnostics.Contracts; 25 25 using System.Linq; 26 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies; 26 27 using HeuristicLab.Common; 27 28 using HeuristicLab.Core; … … 58 59 internal readonly Automaton automaton; 59 60 internal IRandom random { get; private set; } 60 internal readonly double c;61 61 internal readonly Tree tree; 62 internal readonly List<Tree> bestChildrenBuf;63 62 internal readonly Func<byte[], int, double> evalFun; 63 internal readonly IPolicy treePolicy; 64 64 // MCTS might get stuck. Track statistics on the number of effective rollouts 65 65 internal int totalRollouts; … … 96 96 private readonly double[][] gradBuf; 97 97 98 public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, double c, bool scaleVariables, int constOptIterations, 98 public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables, int constOptIterations, 99 IPolicy treePolicy = null, 99 100 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, 100 101 bool allowProdOfVars = true, … … 105 106 106 107 this.problemData = problemData; 107 this.c = c;108 108 this.constOptIterations = constOptIterations; 109 109 this.evalFun = this.Eval; … … 134 134 135 135 this.automaton = new Automaton(x, maxVariables, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); 136 this.tree = new Tree() { state = automaton.CurrentState }; 136 this.treePolicy = treePolicy ?? new Ucb(); 137 this.tree = new Tree() { state = automaton.CurrentState, actionStatistics = treePolicy.CreateActionStatistics() }; 137 138 138 139 // reset best solution … … 146 147 this.ones = Enumerable.Repeat(1.0, MaxParams).ToArray(); 147 148 constsBuf = new double[MaxParams]; 148 this.bestChildrenBuf = new List<Tree>(2 * x.Length); // the number of follow states in the automaton is O(number of variables) 2 * number of variables should be sufficient (capacity is increased if necessary anyway)149 149 this.predBuf = new double[y.Length]; 150 150 this.testPredBuf = new double[testY.Length]; … … 154 154 155 155 #region IState inferface 156 public bool Done { get { return tree != null && tree. done; } }156 public bool Done { get { return tree != null && tree.Done; } } 157 157 158 158 public double BestSolutionTrainingQuality { … … 302 302 } 303 303 304 public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, double c = 1.0, 305 bool scaleVariables = true, int constOptIterations = 0, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, 304 public static IState CreateState(IRegressionProblemData problemData, uint randSeed, int maxVariables = 3, 305 bool scaleVariables = true, int constOptIterations = 0, 306 IPolicy policy = null, 307 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, 306 308 bool allowProdOfVars = true, 307 309 bool allowExp = true, … … 310 312 bool allowMultipleTerms = false 311 313 ) { 312 return new State(problemData, randSeed, maxVariables, c, scaleVariables, constOptIterations, 314 return new State(problemData, randSeed, maxVariables, scaleVariables, constOptIterations, 315 policy, 313 316 lowerEstimationLimit, upperEstimationLimit, 314 317 allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms); … … 329 332 var tree = mctsState.tree; 330 333 var eval = mctsState.evalFun; 331 var bestChildrenBuf = mctsState.bestChildrenBuf;332 334 var rand = mctsState.random; 333 double c = mctsState.c;335 var treePolicy = mctsState.treePolicy; 334 336 double q = 0; 335 double deltaQ = 0;336 double deltaSqrQ = 0;337 int deltaVisits = 0;338 337 bool success = false; 339 338 do { 340 339 automaton.Reset(); 341 success = TryTreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf, out q, out deltaQ, out deltaSqrQ, out deltaVisits);340 success = TryTreeSearchRec(rand, tree, automaton, eval, treePolicy, out q); 342 341 mctsState.totalRollouts++; 343 } while (!success && !tree. done);342 } while (!success && !tree.Done); 344 343 mctsState.effectiveRollouts++; 345 344 return q; … … 349 348 // in this case we get stuck we just restart 350 349 // see ConstraintHandler.cs for more info 351 private static bool TryTreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf, 352 out double q, // quality of the expression 353 out double deltaQ, out double deltaSqrQ, out int deltaVisits // the updates for total quality and number of visits (can be negative if branches have been fully explored) 354 ) { 350 private static bool TryTreeSearchRec(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy, 351 out double q) { 355 352 Tree selectedChild = null; 356 353 Contract.Assert(tree.state == automaton.CurrentState); 357 Contract.Assert(!tree. done);354 Contract.Assert(!tree.Done); 358 355 if (tree.children == null) { 359 356 if (automaton.IsFinalState(tree.state)) { 360 357 // final state 361 tree. done = true;358 tree.Done = true; 362 359 363 360 // EVALUATE … … 365 362 automaton.GetCode(out code, out nParams); 366 363 q = eval(code, nParams); 367 tree.visits += 1; 368 tree.sumQuality += q; 369 tree.sumSqrQuality += q * q; 370 deltaQ = q; 371 deltaVisits = 1; 372 deltaSqrQ = q * q; 364 365 treePolicy.Update(tree.actionStatistics, q); 373 366 return true; // we reached a final state 374 367 } else { … … 380 373 // stuck in a dead end (no final state and no allowed follow states) 381 374 q = 0; 382 deltaQ = 0; 383 deltaSqrQ = 0.0; 384 deltaVisits = 0; 385 tree.done = true; 375 tree.Done = true; 386 376 tree.children = null; 387 tree.visits = 1;388 377 return false; 389 378 } 390 379 tree.children = new Tree[nFs]; 391 380 for (int i = 0; i < tree.children.Length; i++) 392 tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0};381 tree.children[i] = new Tree() { children = null, state = possibleFollowStates[i], actionStatistics = treePolicy.CreateActionStatistics() }; 393 382 394 383 selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0]; … … 397 386 // tree.children != null 398 387 // UCT selection within tree 399 selectedChild = tree.children.Length > 1 ? SelectUctTuned(tree, rand, c, bestChildrenBuf) : tree.children[0]; 388 int selectedIdx = 0; 389 if (tree.children.Length > 1) { 390 selectedIdx = treePolicy.Select(tree.children.Select(ch => ch.actionStatistics), rand); 391 } 392 selectedChild = tree.children[selectedIdx]; 400 393 } 401 394 // make selected step and recurse 402 395 automaton.Goto(selectedChild.state); 403 var success = TryTreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf, 404 out q, out deltaQ, out deltaSqrQ, out deltaVisits); 396 var success = TryTreeSearchRec(rand, selectedChild, automaton, eval, treePolicy, out q); 405 397 if (success) { 406 398 // only update if successful 407 tree.sumQuality += deltaQ; 408 tree.sumSqrQuality += deltaSqrQ; 409 tree.visits += deltaVisits; 410 } 411 412 if (tree.children.All(ch => ch.done)) { 413 tree.done = true; 414 // update parent nodes to remove information from this branch 415 if (tree.children.Length > 1) { 416 deltaQ = -(tree.sumQuality - deltaQ); 417 deltaSqrQ = -(tree.sumSqrQuality - deltaSqrQ); 418 deltaVisits = -(tree.visits - deltaVisits); 419 } 399 treePolicy.Update(tree.actionStatistics, q); 400 } 401 402 tree.Done = tree.children.All(ch => ch.Done); 403 if (tree.Done) { 420 404 tree.children = null; // cut off the sub-branch if it has been fully explored 421 405 } 422 406 return success; 423 }424 425 private static Tree SelectUct(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {426 // determine total tries of still active children427 int totalTries = 0;428 bestChildrenBuf.Clear();429 for (int i = 0; i < tree.children.Length; i++) {430 var ch = tree.children[i];431 if (ch.done) continue;432 if (ch.visits == 0) bestChildrenBuf.Add(ch);433 else totalTries += tree.children[i].visits;434 }435 // if there are unvisited children select a random child436 if (bestChildrenBuf.Any()) {437 return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];438 }439 Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done440 double logTotalTries = Math.Log(totalTries);441 var bestQ = double.NegativeInfinity;442 for (int i = 0; i < tree.children.Length; i++) {443 var ch = tree.children[i];444 if (ch.done) continue;445 var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits);446 if (childQ > bestQ) {447 bestChildrenBuf.Clear();448 bestChildrenBuf.Add(ch);449 bestQ = childQ;450 } else if (childQ >= bestQ) {451 bestChildrenBuf.Add(ch);452 }453 }454 return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];455 }456 457 private static Tree SelectUctTuned(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {458 // determine total tries of still active children459 int totalTries = 0;460 bestChildrenBuf.Clear();461 for (int i = 0; i < tree.children.Length; i++) {462 var ch = tree.children[i];463 if (ch.done) continue;464 if (ch.visits == 0) bestChildrenBuf.Add(ch);465 else totalTries += tree.children[i].visits;466 }467 // if there are unvisited children select a random child468 if (bestChildrenBuf.Any()) {469 return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];470 }471 Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done472 double logTotalTries = Math.Log(totalTries);473 var bestQ = double.NegativeInfinity;474 for (int i = 0; i < tree.children.Length; i++) {475 var ch = tree.children[i];476 if (ch.done) continue;477 var varianceBound = ch.QualityVariance + Math.Sqrt(2.0 * logTotalTries / ch.visits);478 if (varianceBound > 0.25) varianceBound = 0.25;479 var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits * varianceBound);480 if (childQ > bestQ) {481 bestChildrenBuf.Clear();482 bestChildrenBuf.Add(ch);483 bestQ = childQ;484 } else if (childQ >= bestQ) {485 bestChildrenBuf.Add(ch);486 }487 }488 return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];489 407 } 490 408 -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Tree.cs
r13657 r13658 20 20 #endregion 21 21 22 using HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies; 23 22 24 namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression { 23 25 // represents tree nodes for the search tree in MCTS 24 26 internal class Tree { 25 27 public int state; 26 public int visits; 27 public double sumQuality; 28 public double sumSqrQuality; // for variance 29 public double AverageQuality { get { return sumQuality / (double)visits; } } 30 public double QualityVariance { get { return sumSqrQuality / (double)visits - AverageQuality * AverageQuality; } } 31 public bool done; 28 public bool Done { 29 get { return actionStatistics.Done; } 30 set { actionStatistics.Done = value; } 31 } 32 public IActionStatistics actionStatistics; 32 33 public Tree[] children; 33 34 }
Note: See TracChangeset
for help on using the changeset viewer.