using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.PluginInfrastructure; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction { [StorableClass] [Item("BoltzmannExplorationSymbolicExpressionConstructionPolicy", "")] public class BoltzmannExplorationSymbolicExpressionConstructionPolicy : SymbolicExpressionConstructionPolicyBase { public double Beta { get { return ((IFixedValueParameter)Parameters["Beta"]).Value.Value; } set { ((IFixedValueParameter)Parameters["Beta"]).Value.Value = value; } } public ITabularQualityFunction QualityFunction { get { return ((IValueParameter)Parameters["Quality function"]).Value; } set { ((IValueParameter)Parameters["Quality function"]).Value = value; } } protected BoltzmannExplorationSymbolicExpressionConstructionPolicy(BoltzmannExplorationSymbolicExpressionConstructionPolicy original, Cloner cloner) : base(original, cloner) { } [StorableConstructor] protected BoltzmannExplorationSymbolicExpressionConstructionPolicy(bool deserializing) : base(deserializing) { } public BoltzmannExplorationSymbolicExpressionConstructionPolicy() : base() { Parameters.Add(new FixedValueParameter("Beta", "The weighting factor beta", new DoubleValue(1.0))); Parameters.Add(new ValueParameter("Quality function", "The quality function to use", new TabularAvgQualityFunction())); } protected sealed override int Select(object state, IEnumerable actions, IRandom random) { // find best action var bestActions = new List(); var bestQuality = double.NegativeInfinity; if (actions.Any(a => QualityFunction.Tries(state, a) == 0)) { return actions.Where(a => QualityFunction.Tries(state, a) == 0).SampleRandom(random, 1).First(); } // windowing var max = actions.Select(a => QualityFunction.Q(state, a)).Max(); var min = actions.Select(a => QualityFunction.Q(state, a)).Min(); double range = max - min; if (range.IsAlmost(0.0)) return actions.SampleRandom(random, 1).First(); var w = from a in actions select Math.Exp(Beta * (QualityFunction.Q(state, a) - min) / range); return actions.SampleProportional(random, 1, w).First(); } public sealed override void Update(IEnumerable> stateActionSequence, double quality) { foreach (var t in stateActionSequence) { var state = t.Item1; var action = t.Item2; QualityFunction.Update(state, action, quality); } } protected override object CreateState(ISymbolicExpressionTreeNode root, List actions, ISymbolicExpressionTreeNode parent, int childIdx) { return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx); } public override IDeepCloneable Clone(HeuristicLab.Common.Cloner cloner) { return new BoltzmannExplorationSymbolicExpressionConstructionPolicy(this, cloner); } } }