source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/UcbTuned.cs @ 15060

Last change on this file since 15060 was 15060, checked in by gkronber, 2 years ago

#2581: merged r13645,r13648,r13650,r13651,r13652,r13654,r13657,r13658,r13659,r13661,r13662,r13669,r13708,r14142 from trunk to stable (to be deleted in the next commit)

File size: 3.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics.Contracts;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Parameters;
11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
12
13namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
14  [StorableClass]
15  [Item("UcbTuned Policy", "UcbTuned is similar to Ucb but tracks empirical variance. Use parameter c to balance between exploitation and exploration")]
16  public class UcbTuned : PolicyBase {
17    private class ActionStatistics : IActionStatistics {
18      public double SumQuality { get; set; }
19      public double SumSqrQuality { get; set; }
20      public double AverageQuality { get { return SumQuality / Tries; } }
21      public double QualityVariance { get { return SumSqrQuality / Tries - AverageQuality * AverageQuality; } }
22      public int Tries { get; set; }
23      public bool Done { get; set; }
24    }
25    private List<int> buf = new List<int>();
26
27    public IFixedValueParameter<DoubleValue> CParameter {
28      get { return (IFixedValueParameter<DoubleValue>)Parameters["C"]; }
29    }
30
31    public double C {
32      get { return CParameter.Value.Value; }
33      set { CParameter.Value.Value = value; }
34    }
35
36    [StorableConstructor]
37    protected UcbTuned(bool deserializing) : base(deserializing) { }
38    protected UcbTuned(UcbTuned original, Cloner cloner)
39      : base(original, cloner) {
40    }
41    public UcbTuned()
42      : base() {
43      Parameters.Add(new FixedValueParameter<DoubleValue>("C", "Parameter to balance between exploration and exploitation 0 <= c < 100", new DoubleValue(Math.Sqrt(2))));
44    }
45
46    public override IDeepCloneable Clone(Cloner cloner) {
47      return new UcbTuned(this, cloner);
48    }
49
50    public override int Select(IEnumerable<IActionStatistics> actions, IRandom random) {
51      return Select(actions, random, C, buf);
52    }
53
54    public override void Update(IActionStatistics action, double q) {
55      var a = action as ActionStatistics;
56      a.SumQuality += q;
57      a.SumSqrQuality += q * q;
58      a.Tries++;
59    }
60
61    public override IActionStatistics CreateActionStatistics() {
62      return new ActionStatistics();
63    }
64
65    private static int Select(IEnumerable<IActionStatistics> actions, IRandom rand, double c, IList<int> buf) {
66      // determine total tries of still active actions
67      int totalTries = 0;
68      buf.Clear();
69      int aIdx = -1;
70      foreach (var a in actions) {
71        ++aIdx;
72        if (a.Done) continue;
73        if (a.Tries == 0) buf.Add(aIdx);
74        else totalTries += a.Tries;
75      }
76      // if there are unvisited actions select a random action
77      if (buf.Any()) {
78        return buf[rand.Next(buf.Count)];
79      }
80      Contract.Assert(totalTries > 0);
81      double logTotalTries = Math.Log(totalTries);
82      var bestQ = double.NegativeInfinity;
83      aIdx = -1;
84      foreach (var a in actions.Cast<ActionStatistics>()) {
85        ++aIdx;
86        if (a.Done) continue;
87        var varianceBound = a.QualityVariance + Math.Sqrt(2.0 * logTotalTries / a.Tries);
88        if (varianceBound > 0.25) varianceBound = 0.25;
89        var actionQ = a.AverageQuality + c * Math.Sqrt(logTotalTries / a.Tries * varianceBound);
90        if (actionQ > bestQ) {
91          buf.Clear();
92          buf.Add(aIdx);
93          bestQ = actionQ;
94        } else if (actionQ >= bestQ) {
95          buf.Add(aIdx);
96        }
97      }
98      return buf[rand.Next(buf.Count)];
99    }
100  }
101}
Note: See TracBrowser for help on using the repository browser.