using System; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.Contracts; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies { [StorableClass] [Item("Ucb Policy", "Ucb with parameter c to balance between exploitation and exploration")] public class Ucb : PolicyBase { private class ActionStatistics : IActionStatistics { public double SumQuality { get; set; } public double AverageQuality { get { return SumQuality / Tries; } } public double BestQuality { get; internal set; } public int Tries { get; set; } public bool Done { get; set; } public void Add(IActionStatistics other) { var o = other as ActionStatistics; if (o == null) throw new ArgumentException(); this.Tries += o.Tries; this.SumQuality += o.SumQuality; this.BestQuality = Math.Max(this.BestQuality, other.BestQuality); } } private List buf = new List(); public IFixedValueParameter CParameter { get { return (IFixedValueParameter)Parameters["C"]; } } public double C { get { return CParameter.Value.Value; } set { CParameter.Value.Value = value; } } [StorableConstructor] protected Ucb(bool deserializing) : base(deserializing) { } protected Ucb(Ucb original, Cloner cloner) : base(original, cloner) { } public Ucb() : base() { Parameters.Add(new FixedValueParameter("C", "Parameter to balance between exploration and exploitation 0 <= c < 100", new DoubleValue(Math.Sqrt(2)))); } public override IDeepCloneable Clone(Cloner cloner) { return new Ucb(this, cloner); } public override int Select(IEnumerable actions, IRandom random) { return Select(actions, random, C, buf); } public override void Update(IActionStatistics action, double q) { var a = action as ActionStatistics; a.SumQuality += q; a.BestQuality = Math.Max(a.BestQuality, q); a.Tries++; } public override IActionStatistics CreateActionStatistics() { return new ActionStatistics(); } private static int Select(IEnumerable actions, IRandom rand, double c, IList buf) { // determine total tries of still active actions int totalTries = 0; buf.Clear(); int aIdx = -1; foreach (var a in actions) { ++aIdx; if (a.Done) continue; if (a.Tries == 0) buf.Add(aIdx); else totalTries += a.Tries; } // if there are unvisited actions select a random unvisited action if (buf.Any()) { return buf[rand.Next(buf.Count)]; } Debug.Assert(actions.All(a => a.Done || a.Tries > 0)); Debug.Assert(totalTries > 0); double logTotalTries = Math.Log(totalTries); var bestQ = double.NegativeInfinity; aIdx = -1; foreach (var a in actions) { ++aIdx; if (a.Done) continue; var actionQ = a.AverageQuality + c * Math.Sqrt(logTotalTries / a.Tries); if (actionQ > bestQ) { buf.Clear(); buf.Add(aIdx); bestQ = actionQ; } else if (actionQ >= bestQ) { buf.Add(aIdx); } } return buf[rand.Next(buf.Count)]; } } }