source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/UcbTuned.cs @ 13659

Last change on this file since 13659 was 13659, checked in by gkronber, 3 years ago

#2581: added source files for policies

File size: 3.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics.Contracts;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Parameters;
11
12namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
13  [Item("UcbTuned Policy", "UcbTuned is similar to Ucb but tracks empirical variance. Use parameter c to balance between exploitation and exploration")]
14  internal class UcbTuned : PolicyBase {
15    private class ActionStatistics : IActionStatistics {
16      public double SumQuality { get; set; }
17      public double SumSqrQuality { get; set; }
18      public double AverageQuality { get { return SumQuality / Tries; } }
19      public double QualityVariance { get { return SumSqrQuality / Tries - AverageQuality * AverageQuality; } }
20      public int Tries { get; set; }
21      public bool Done { get; set; }
22    }
23    private List<int> buf = new List<int>();
24
25    public IFixedValueParameter<DoubleValue> CParameter {
26      get { return (IFixedValueParameter<DoubleValue>)Parameters["C"]; }
27    }
28
29    public double C {
30      get { return CParameter.Value.Value; }
31      set { CParameter.Value.Value = value; }
32    }
33
34    private UcbTuned(UcbTuned original, Cloner cloner)
35      : base(original, cloner) {
36    }
37    public UcbTuned()
38      : base() {
39      Parameters.Add(new FixedValueParameter<DoubleValue>("C", "Parameter to balance between exploration and exploitation 0 <= c < 100", new DoubleValue(Math.Sqrt(2))));
40    }
41
42    public override IDeepCloneable Clone(Cloner cloner) {
43      return new UcbTuned(this, cloner);
44    }
45
46    public override int Select(IEnumerable<IActionStatistics> actions, IRandom random) {
47      return Select(actions, random, C, buf);
48    }
49
50    public override void Update(IActionStatistics action, double q) {
51      var a = action as ActionStatistics;
52      a.SumQuality += q;
53      a.SumSqrQuality += q * q;
54      a.Tries++;
55    }
56
57    public override IActionStatistics CreateActionStatistics() {
58      return new ActionStatistics();
59    }
60
61    private static int Select(IEnumerable<IActionStatistics> actions, IRandom rand, double c, IList<int> buf) {
62      // determine total tries of still active actions
63      int totalTries = 0;
64      buf.Clear();
65      int aIdx = -1;
66      foreach (var a in actions) {
67        ++aIdx;
68        if (a.Done) continue;
69        if (a.Tries == 0) buf.Add(aIdx);
70        else totalTries += a.Tries;
71      }
72      // if there are unvisited actions select a random action
73      if (buf.Any()) {
74        return buf[rand.Next(buf.Count)];
75      }
76      Contract.Assert(totalTries > 0);
77      double logTotalTries = Math.Log(totalTries);
78      var bestQ = double.NegativeInfinity;
79      aIdx = -1;
80      foreach (var a in actions.Cast<ActionStatistics>()) {
81        ++aIdx;
82        if (a.Done) continue;
83        var varianceBound = a.QualityVariance + Math.Sqrt(2.0 * logTotalTries / a.Tries);
84        if (varianceBound > 0.25) varianceBound = 0.25;
85        var actionQ = a.AverageQuality + c * Math.Sqrt(logTotalTries / a.Tries * varianceBound);
86        if (actionQ > bestQ) {
87          buf.Clear();
88          buf.Add(aIdx);
89          bestQ = actionQ;
90        } else if (actionQ >= bestQ) {
91          buf.Add(aIdx);
92        }
93      }
94      return buf[rand.Next(buf.Count)];
95    }
96  }
97}
Note: See TracBrowser for help on using the repository browser.