source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs @ 13659

Last change on this file since 13659 was 13659, checked in by gkronber, 3 years ago

#2581: added source files for policies

File size: 3.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics.Contracts;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Parameters;
11
12namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
13  [Item("Ucb Policy", "Ucb with parameter c to balance between exploitation and exploration")]
14  internal class Ucb : PolicyBase {
15    private class ActionStatistics : IActionStatistics {
16      public double SumQuality { get; set; }
17      public double AverageQuality { get { return SumQuality / Tries; } }
18      public int Tries { get; set; }
19      public bool Done { get; set; }
20    }
21    private List<int> buf = new List<int>();
22
23    public IFixedValueParameter<DoubleValue> CParameter {
24      get { return (IFixedValueParameter<DoubleValue>)Parameters["C"]; }
25    }
26
27    public double C {
28      get { return CParameter.Value.Value; }
29      set { CParameter.Value.Value = value; }
30    }
31
32    private Ucb(Ucb original, Cloner cloner)
33      : base(original, cloner) {
34    }
35    public Ucb()
36      : base() {
37      Parameters.Add(new FixedValueParameter<DoubleValue>("C", "Parameter to balance between exploration and exploitation 0 <= c < 100", new DoubleValue(Math.Sqrt(2))));
38    }
39
40    public override IDeepCloneable Clone(Cloner cloner) {
41      return new Ucb(this, cloner);
42    }
43
44    public override int Select(IEnumerable<IActionStatistics> actions, IRandom random) {
45      return Select(actions, random, C, buf);
46    }
47
48    public override void Update(IActionStatistics action, double q) {
49      var a = action as ActionStatistics;
50      a.SumQuality += q;
51      a.Tries++;
52    }
53
54    public override IActionStatistics CreateActionStatistics() {
55      return new ActionStatistics();
56    }
57
58    private static int Select(IEnumerable<IActionStatistics> actions, IRandom rand, double c, IList<int> buf) {
59      // determine total tries of still active actions
60      int totalTries = 0;
61      buf.Clear();
62      int aIdx = -1;
63      foreach (var a in actions) {
64        ++aIdx;
65        if (a.Done) continue;
66        if (a.Tries == 0) buf.Add(aIdx);
67        else totalTries += a.Tries;
68      }
69      // if there are unvisited actions select a random action
70      if (buf.Any()) {
71        return buf[rand.Next(buf.Count)];
72      }
73      Contract.Assert(totalTries > 0);
74      double logTotalTries = Math.Log(totalTries);
75      var bestQ = double.NegativeInfinity;
76      aIdx = -1;
77      foreach (var a in actions) {
78        ++aIdx;
79        if (a.Done) continue;
80        var actionQ = a.AverageQuality + c * Math.Sqrt(logTotalTries / a.Tries);
81        if (actionQ > bestQ) {
82          buf.Clear();
83          buf.Add(aIdx);
84          bestQ = actionQ;
85        } else if (actionQ >= bestQ) {
86          buf.Add(aIdx);
87        }
88      }
89      return buf[rand.Next(buf.Count)];
90    }
91
92  }
93}
Note: See TracBrowser for help on using the repository browser.