source: stable/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs @ 15060

Last change on this file since 15060 was 15060, checked in by gkronber, 2 years ago

#2581: merged r13645,r13648,r13650,r13651,r13652,r13654,r13657,r13658,r13659,r13661,r13662,r13669,r13708,r14142 from trunk to stable (to be deleted in the next commit)

File size: 3.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics.Contracts;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Parameters;
11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
12
13namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
14  [StorableClass]
15  [Item("Ucb Policy", "Ucb with parameter c to balance between exploitation and exploration")]
16  public class Ucb : PolicyBase {
17    private class ActionStatistics : IActionStatistics {
18      public double SumQuality { get; set; }
19      public double AverageQuality { get { return SumQuality / Tries; } }
20      public int Tries { get; set; }
21      public bool Done { get; set; }
22    }
23    private List<int> buf = new List<int>();
24
25    public IFixedValueParameter<DoubleValue> CParameter {
26      get { return (IFixedValueParameter<DoubleValue>)Parameters["C"]; }
27    }
28
29    public double C {
30      get { return CParameter.Value.Value; }
31      set { CParameter.Value.Value = value; }
32    }
33
34    [StorableConstructor]
35    protected Ucb(bool deserializing) : base(deserializing) { }
36    protected Ucb(Ucb original, Cloner cloner)
37      : base(original, cloner) {
38    }
39    public Ucb()
40      : base() {
41      Parameters.Add(new FixedValueParameter<DoubleValue>("C", "Parameter to balance between exploration and exploitation 0 <= c < 100", new DoubleValue(Math.Sqrt(2))));
42    }
43
44    public override IDeepCloneable Clone(Cloner cloner) {
45      return new Ucb(this, cloner);
46    }
47
48    public override int Select(IEnumerable<IActionStatistics> actions, IRandom random) {
49      return Select(actions, random, C, buf);
50    }
51
52    public override void Update(IActionStatistics action, double q) {
53      var a = action as ActionStatistics;
54      a.SumQuality += q;
55      a.Tries++;
56    }
57
58    public override IActionStatistics CreateActionStatistics() {
59      return new ActionStatistics();
60    }
61
62    private static int Select(IEnumerable<IActionStatistics> actions, IRandom rand, double c, IList<int> buf) {
63      // determine total tries of still active actions
64      int totalTries = 0;
65      buf.Clear();
66      int aIdx = -1;
67      foreach (var a in actions) {
68        ++aIdx;
69        if (a.Done) continue;
70        if (a.Tries == 0) buf.Add(aIdx);
71        else totalTries += a.Tries;
72      }
73      // if there are unvisited actions select a random action
74      if (buf.Any()) {
75        return buf[rand.Next(buf.Count)];
76      }
77      Contract.Assert(totalTries > 0);
78      double logTotalTries = Math.Log(totalTries);
79      var bestQ = double.NegativeInfinity;
80      aIdx = -1;
81      foreach (var a in actions) {
82        ++aIdx;
83        if (a.Done) continue;
84        var actionQ = a.AverageQuality + c * Math.Sqrt(logTotalTries / a.Tries);
85        if (actionQ > bestQ) {
86          buf.Clear();
87          buf.Add(aIdx);
88          bestQ = actionQ;
89        } else if (actionQ >= bestQ) {
90          buf.Add(aIdx);
91        }
92      }
93      return buf[rand.Next(buf.Count)];
94    }
95
96  }
97}
Note: See TracBrowser for help on using the repository browser.