using System; using System.Collections.Generic; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction { [StorableClass] public abstract class TabularStateValueFunctionBase : ParameterizedNamedItem, ITabularStateValueFunction { [Storable] private readonly Dictionary q = new Dictionary(); [Storable] private readonly Dictionary qVariance = new Dictionary(); [Storable] private readonly Dictionary tries = new Dictionary(); public IStateFunction StateFunction { get { return ((IValueParameter)Parameters["State function"]).Value; } set { ((IValueParameter)Parameters["State function"]).Value = value; } } protected TabularStateValueFunctionBase() : base() { Parameters.Add(new ValueParameter("State function", "The function that is used to map partial trees to states", new DefaultStateFunction())); } public int Tries(object state) { var t = 0; if (!tries.TryGetValue(state, out t)) return 0; return t; } public double Value(object state) { // an action that has never been tried has q == infinity double quality; if (!q.TryGetValue(state, out quality)) return double.PositiveInfinity; return quality; } public double ValueVariance(object state) { // an action that has never been tried has qVariance == infinity double var; if (!qVariance.TryGetValue(state, out var)) return double.PositiveInfinity; return var / Tries(state); } public virtual void Update(object state, double observedQuality) { int t; if (!tries.TryGetValue(state, out t)) { t = 0; tries.Add(state, t + 1); q.Add(state, observedQuality); qVariance.Add(state, 0); // naive initialization } else { tries[state] = t + 1; var delta = observedQuality - q[state]; var curMean = CalculateNewQ(state, observedQuality); q[state] = curMean; qVariance[state] = qVariance[state] + delta * (observedQuality - curMean); // iterative calculation of mean } } protected abstract double CalculateNewQ(object state, double observedQuality); #region item [StorableConstructor] protected TabularStateValueFunctionBase(bool deserializing) : base(deserializing) { } protected TabularStateValueFunctionBase(TabularStateValueFunctionBase original, Cloner cloner) : base(original, cloner) { // TODO: these become really large when using this class only from BasicAlgorithms it would not be necessary to clone and reset everything, (pause is not allowed) this.q = new Dictionary(original.q); this.tries = new Dictionary(original.tries); this.qVariance = new Dictionary(original.qVariance); } #endregion //public void InitializeState() { // ClearState(); //} // //public void ClearState() { // q.Clear(); // tries.Clear(); // qVariance.Clear(); //} } }