using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.PluginInfrastructure; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction { [StorableClass] [Item("UcbTunedSymbolicExpressionConstructionPolicy", "Also uses an estimate of the variance")] public class UcbTunedSymbolicExpressionConstructionPolicy : SymbolicExpressionConstructionPolicyBase { public double R { get { return ((IFixedValueParameter)Parameters["R"]).Value.Value; } set { ((IFixedValueParameter)Parameters["R"]).Value.Value = value; } } public ITabularQualityFunction QualityFunction { get { return ((IValueParameter)Parameters["Quality function"]).Value; } set { ((IValueParameter)Parameters["Quality function"]).Value = value; } } protected UcbTunedSymbolicExpressionConstructionPolicy(UcbTunedSymbolicExpressionConstructionPolicy original, Cloner cloner) : base(original, cloner) { } [StorableConstructor] protected UcbTunedSymbolicExpressionConstructionPolicy(bool deserializing) : base(deserializing) { } public UcbTunedSymbolicExpressionConstructionPolicy() : base() { Parameters.Add(new FixedValueParameter("R", "The weighting factor for the confidence bound (should be scaled based on the range or the fitness values)", new DoubleValue(1.0))); Parameters.Add(new ValueParameter("Quality function", "The quality function to use", new TabularAvgQualityFunction())); } protected sealed override int Select(object state, IEnumerable actions, IRandom random) { // find best action var bestActions = new List(); var bestQuality = double.NegativeInfinity; int totalTries = actions.Sum(a => QualityFunction.Tries(state, a)); foreach (var a in actions) { double quality; if (QualityFunction.Tries(state, a) == 0) { quality = double.PositiveInfinity; } else { double v = QualityFunction.QVariance(state, a) + Math.Sqrt(2 * Math.Log(totalTries) / QualityFunction.Tries(state, a)); quality = QualityFunction.Q(state, a) + R * Math.Sqrt(Math.Log(totalTries) / QualityFunction.Tries(state, a) * v); } if (quality >= bestQuality) { if (quality > bestQuality) { bestActions.Clear(); bestQuality = quality; } bestActions.Add(a); } } return bestActions.SampleRandom(random, 1).First(); } public sealed override void Update(IEnumerable> stateActionSequence, double quality) { foreach (var t in stateActionSequence) { var state = t.Item1; var action = t.Item2; QualityFunction.Update(state, action, quality); } } protected override object CreateState(ISymbolicExpressionTreeNode root, List actions, ISymbolicExpressionTreeNode parent, int childIdx) { return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx); } public override IDeepCloneable Clone(HeuristicLab.Common.Cloner cloner) { return new UcbTunedSymbolicExpressionConstructionPolicy(this, cloner); } } }