using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.PluginInfrastructure; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction { [StorableClass] [Item("UcbSymbolicExpressionConstructionPolicy", "")] public class UcbSymbolicExpressionConstructionPolicy : SymbolicExpressionConstructionPolicyBase { public double R { get { return ((IFixedValueParameter)Parameters["R"]).Value.Value; } set { ((IFixedValueParameter)Parameters["R"]).Value.Value = value; } } public ITabularStateValueFunction StateValueFunction { get { return ((IValueParameter)Parameters["Quality function"]).Value; } set { ((IValueParameter)Parameters["Quality function"]).Value = value; } } protected UcbSymbolicExpressionConstructionPolicy(UcbSymbolicExpressionConstructionPolicy original, Cloner cloner) : base(original, cloner) { } [StorableConstructor] protected UcbSymbolicExpressionConstructionPolicy(bool deserializing) : base(deserializing) { } public UcbSymbolicExpressionConstructionPolicy() : base() { Parameters.Add(new FixedValueParameter("R", "The weighting factor for the confidence bound (should be scaled based on the range or the fitness values)", new DoubleValue(1.0))); Parameters.Add(new ValueParameter("Quality function", "The quality function to use", new TabularAvgStateValueFunction())); } protected sealed override int Select(IReadOnlyList followStates, IRandom random) { var bestFollowStates = new List(); var bestQuality = double.NegativeInfinity; int totalTries = followStates.Sum(s => StateValueFunction.Tries(s)); for (int idx = 0; idx < followStates.Count; idx++) { double quality; var s = followStates[idx]; if (StateValueFunction.Tries(s) == 0) { quality = double.PositiveInfinity; } else { quality = StateValueFunction.Value(s) + R * Math.Sqrt((2 * Math.Log(totalTries)) / StateValueFunction.Tries(s)); } if (quality >= bestQuality) { if (quality > bestQuality) { bestFollowStates.Clear(); bestQuality = quality; } bestFollowStates.Add(idx); } } return bestFollowStates.SampleRandom(random); } public sealed override void Update(IEnumerable stateSequence, double quality) { foreach (var state in stateSequence) { StateValueFunction.Update(state, quality); } } protected override object CreateState(ISymbolicExpressionTreeNode root, List actionSequence, ISymbolicExpressionTreeNode parent, int childIdx) { return StateValueFunction.StateFunction.CreateState(root, actionSequence, parent, childIdx); } public override IDeepCloneable Clone(HeuristicLab.Common.Cloner cloner) { return new UcbSymbolicExpressionConstructionPolicy(this, cloner); } } }