[11730] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Diagnostics;
|
---|
| 4 | using System.Linq;
|
---|
| 5 | using System.Text;
|
---|
| 6 | using System.Threading.Tasks;
|
---|
[11732] | 7 | using HeuristicLab.Common;
|
---|
[11730] | 8 |
|
---|
[11742] | 9 | namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
|
---|
[11730] | 10 | // also called softmax policy
|
---|
[11742] | 11 | public class BoltzmannExplorationPolicy : IBanditPolicy {
|
---|
[11730] | 12 | private readonly double beta;
|
---|
| 13 |
|
---|
[12893] | 14 | public BoltzmannExplorationPolicy(double beta) {
|
---|
[11730] | 15 | if (beta < 0) throw new ArgumentException();
|
---|
| 16 | this.beta = beta;
|
---|
| 17 | }
|
---|
[11742] | 18 | public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) {
|
---|
[11732] | 19 | Debug.Assert(actionInfos.Any());
|
---|
[11730] | 20 |
|
---|
[11742] | 21 | var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
|
---|
[11730] | 22 |
|
---|
[11747] | 23 | // try any of the untries actions randomly
|
---|
[12893] | 24 | if (myActionInfos.Any(aInfo => aInfo.Tries == 0)) {
|
---|
| 25 | return myActionInfos
|
---|
| 26 | .Select((aInfo, idx) => new { aInfo, idx })
|
---|
| 27 | .Where(p => p.aInfo.Tries == 0)
|
---|
| 28 | .SelectRandom(random).idx;
|
---|
| 29 | }
|
---|
| 30 |
|
---|
| 31 | // using ranks
|
---|
| 32 | //var qualities = actionInfos.Select(i => i.MaxReward).ToArray(); // largest reward should have largest rank
|
---|
| 33 | //var ranks = Enumerable.Range(0, myActionInfos.Count()).ToArray();
|
---|
| 34 | //Array.Sort(qualities, ranks);
|
---|
| 35 | //
|
---|
| 36 | //// set same rank for same quality
|
---|
| 37 | ////for (int i = 0; i < ranks.Length - 1; i++) {
|
---|
| 38 | //// if (qualities[i] == qualities[i + 1]) ranks[i + 1] = ranks[i];
|
---|
| 39 | ////}
|
---|
| 40 | ////
|
---|
| 41 | //
|
---|
| 42 | //var rankForAction = new int[myActionInfos.Count()];
|
---|
| 43 | //for (int i = 0; i < rankForAction.Length; i++) {
|
---|
| 44 | // rankForAction[ranks[i]] = i;
|
---|
[11747] | 45 | //}
|
---|
[12893] | 46 | //
|
---|
| 47 | //var w = from idx in Enumerable.Range(0, myActionInfos.Count())
|
---|
| 48 | // select Math.Exp(beta * rankForAction[idx]);
|
---|
[11747] | 49 |
|
---|
[11730] | 50 |
|
---|
[12893] | 51 | // windowing
|
---|
| 52 | var max = actionInfos.Select(i => i.MaxReward).Max();
|
---|
| 53 | var min = actionInfos.Select(i => i.MaxReward).Min();
|
---|
| 54 | double range = max - min;
|
---|
| 55 | var w = from aInfo in actionInfos
|
---|
| 56 | select Math.Exp(beta * (aInfo.MaxReward - min) / range);
|
---|
| 57 |
|
---|
[11799] | 58 | var bestAction = Enumerable.Range(0, myActionInfos.Count()).SampleProportional(random, w);
|
---|
[11732] | 59 | Debug.Assert(bestAction >= 0);
|
---|
| 60 | return bestAction;
|
---|
[11730] | 61 | }
|
---|
| 62 |
|
---|
[11742] | 63 | public IBanditPolicyActionInfo CreateActionInfo() {
|
---|
[11732] | 64 | return new DefaultPolicyActionInfo();
|
---|
[11730] | 65 | }
|
---|
| 66 |
|
---|
| 67 | public override string ToString() {
|
---|
| 68 | return string.Format("BoltzmannExplorationPolicy({0:F2})", beta);
|
---|
| 69 | }
|
---|
| 70 | }
|
---|
| 71 | }
|
---|