using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.PluginInfrastructure; using HeuristicLab.Random; namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction { [StorableClass] [Item("BoltzmannExplorationSymbolicExpressionConstructionPolicy", "")] public class BoltzmannExplorationSymbolicExpressionConstructionPolicy : SymbolicExpressionConstructionPolicyBase { public double Beta { get { return ((IFixedValueParameter)Parameters["Beta"]).Value.Value; } set { ((IFixedValueParameter)Parameters["Beta"]).Value.Value = value; } } public IStateValueFunction StateValueFunction { get { return ((IValueParameter)Parameters["Quality function"]).Value; } set { ((IValueParameter)Parameters["Quality function"]).Value = value; } } protected BoltzmannExplorationSymbolicExpressionConstructionPolicy(BoltzmannExplorationSymbolicExpressionConstructionPolicy original, Cloner cloner) : base(original, cloner) { } [StorableConstructor] protected BoltzmannExplorationSymbolicExpressionConstructionPolicy(bool deserializing) : base(deserializing) { } public BoltzmannExplorationSymbolicExpressionConstructionPolicy() : base() { Parameters.Add(new FixedValueParameter("Beta", "The weighting factor beta", new DoubleValue(1.0))); Parameters.Add(new ValueParameter("Quality function", "The quality function to use", new TabularAvgStateValueFunction())); } protected sealed override int Select(IReadOnlyList followStates, IRandom random) { var idxs = Enumerable.Range(0, followStates.Count); // windowing var max = followStates.Select(s => StateValueFunction.Value(s)).Max(); var min = followStates.Select(s => StateValueFunction.Value(s)).Min(); double range = max - min; if (range.IsAlmost(0.0)) return idxs.SampleRandom(random); var w = from s in followStates select Math.Exp(Beta * (StateValueFunction.Value(s) - min) / range); return idxs.SampleProportional(random, 1, w).First(); } public sealed override void Update(IEnumerable stateSequence, double quality) { foreach (var state in stateSequence) { StateValueFunction.Update(state, quality); } } protected override object CreateState(ISymbolicExpressionTreeNode root, List actionSequence, ISymbolicExpressionTreeNode parent, int childIdx) { return StateValueFunction.StateFunction.CreateState(root, actionSequence, parent, childIdx); } public override IDeepCloneable Clone(Cloner cloner) { return new BoltzmannExplorationSymbolicExpressionConstructionPolicy(this, cloner); } } }