Free cookie consent management tool by TermsFeed Policy Generator

source: branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs @ 15426

Last change on this file since 15426 was 15425, checked in by gkronber, 7 years ago

#2796 made several changes for debugging

File size: 3.7 KB
RevLine 
[13659]1using System;
2using System.Collections.Generic;
[15414]3using System.Diagnostics;
[13659]4using System.Diagnostics.Contracts;
5using System.Linq;
6using System.Text;
7using System.Threading.Tasks;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
11using HeuristicLab.Parameters;
[13669]12using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[13659]13
14namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression.Policies {
[13669]15  [StorableClass]
[13659]16  [Item("Ucb Policy", "Ucb with parameter c to balance between exploitation and exploration")]
[13661]17  public class Ucb : PolicyBase {
[13659]18    private class ActionStatistics : IActionStatistics {
19      public double SumQuality { get; set; }
20      public double AverageQuality { get { return SumQuality / Tries; } }
[15425]21      public double BestQuality { get; internal set; }
[13659]22      public int Tries { get; set; }
23      public bool Done { get; set; }
[15410]24      public void Add(IActionStatistics other) {
25        var o = other as ActionStatistics;
26        if (o == null) throw new ArgumentException();
27        this.Tries += o.Tries;
28        this.SumQuality += o.SumQuality;
[15425]29        this.BestQuality = Math.Max(this.BestQuality, other.BestQuality);
[15410]30      }
[13659]31    }
32    private List<int> buf = new List<int>();
33
34    public IFixedValueParameter<DoubleValue> CParameter {
35      get { return (IFixedValueParameter<DoubleValue>)Parameters["C"]; }
36    }
37
38    public double C {
39      get { return CParameter.Value.Value; }
40      set { CParameter.Value.Value = value; }
41    }
42
[13669]43    [StorableConstructor]
44    protected Ucb(bool deserializing) : base(deserializing) { }
[13662]45    protected Ucb(Ucb original, Cloner cloner)
[13659]46      : base(original, cloner) {
47    }
48    public Ucb()
49      : base() {
50      Parameters.Add(new FixedValueParameter<DoubleValue>("C", "Parameter to balance between exploration and exploitation 0 <= c < 100", new DoubleValue(Math.Sqrt(2))));
51    }
52
53    public override IDeepCloneable Clone(Cloner cloner) {
54      return new Ucb(this, cloner);
55    }
56
57    public override int Select(IEnumerable<IActionStatistics> actions, IRandom random) {
58      return Select(actions, random, C, buf);
59    }
60
61    public override void Update(IActionStatistics action, double q) {
62      var a = action as ActionStatistics;
63      a.SumQuality += q;
[15425]64      a.BestQuality = Math.Max(a.BestQuality, q);
[13659]65      a.Tries++;
66    }
67
68    public override IActionStatistics CreateActionStatistics() {
69      return new ActionStatistics();
70    }
71
72    private static int Select(IEnumerable<IActionStatistics> actions, IRandom rand, double c, IList<int> buf) {
73      // determine total tries of still active actions
74      int totalTries = 0;
75      buf.Clear();
76      int aIdx = -1;
77      foreach (var a in actions) {
78        ++aIdx;
79        if (a.Done) continue;
80        if (a.Tries == 0) buf.Add(aIdx);
81        else totalTries += a.Tries;
82      }
[15416]83      // if there are unvisited actions select a random unvisited action
[13659]84      if (buf.Any()) {
85        return buf[rand.Next(buf.Count)];
86      }
[15425]87
88      Debug.Assert(actions.All(a => a.Done || a.Tries > 0));
[15416]89                     
[15414]90      Debug.Assert(totalTries > 0);
[13659]91      double logTotalTries = Math.Log(totalTries);
92      var bestQ = double.NegativeInfinity;
93      aIdx = -1;
94      foreach (var a in actions) {
95        ++aIdx;
96        if (a.Done) continue;
97        var actionQ = a.AverageQuality + c * Math.Sqrt(logTotalTries / a.Tries);
98        if (actionQ > bestQ) {
99          buf.Clear();
100          buf.Add(aIdx);
101          bestQ = actionQ;
102        } else if (actionQ >= bestQ) {
103          buf.Add(aIdx);
104        }
105      }
106      return buf[rand.Next(buf.Count)];
107    }
108
109  }
110}
Note: See TracBrowser for help on using the repository browser.