Rev | Line | |
---|
[11730] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Diagnostics;
|
---|
| 4 | using System.Linq;
|
---|
| 5 | using System.Text;
|
---|
| 6 | using System.Threading.Tasks;
|
---|
| 7 | using HeuristicLab.Common;
|
---|
| 8 |
|
---|
| 9 | namespace HeuristicLab.Algorithms.Bandits.Models {
|
---|
| 10 | public class BernoulliModel : IModel {
|
---|
[11742] | 11 |
|
---|
[11732] | 12 | private int success;
|
---|
| 13 | private int failure;
|
---|
[11730] | 14 |
|
---|
| 15 | // parameters of beta prior distribution
|
---|
| 16 | private readonly double alpha;
|
---|
| 17 | private readonly double beta;
|
---|
| 18 |
|
---|
[11732] | 19 | public BernoulliModel(double alpha = 1.0, double beta = 1.0) {
|
---|
[11730] | 20 | this.alpha = alpha;
|
---|
| 21 | this.beta = beta;
|
---|
| 22 | }
|
---|
| 23 |
|
---|
[11851] | 24 | public double Sample(Random random) {
|
---|
[11730] | 25 | // sample bernoulli mean from beta prior
|
---|
[11732] | 26 | return Rand.BetaRand(random, success + alpha, failure + beta);
|
---|
[11730] | 27 | }
|
---|
| 28 |
|
---|
[11732] | 29 | public void Update(double reward) {
|
---|
[11792] | 30 | // Debug.Assert(reward.IsAlmost(1.0) || reward.IsAlmost(0.0));
|
---|
| 31 | if (reward > 0) {
|
---|
[11732] | 32 | success++;
|
---|
[11730] | 33 | } else {
|
---|
[11732] | 34 | failure++;
|
---|
[11730] | 35 | }
|
---|
| 36 | }
|
---|
| 37 |
|
---|
| 38 | public void Reset() {
|
---|
[11732] | 39 | success = 0;
|
---|
| 40 | failure = 0;
|
---|
[11730] | 41 | }
|
---|
| 42 |
|
---|
[11732] | 43 | public object Clone() {
|
---|
| 44 | return new BernoulliModel() { failure = this.failure, success = this.success };
|
---|
| 45 | }
|
---|
[11742] | 46 |
|
---|
| 47 | public override string ToString() {
|
---|
| 48 | return string.Format("Bernoulli with Beta prior: mu={0:F2}", (success + alpha) / (success + alpha + failure + beta));
|
---|
| 49 | }
|
---|
[11730] | 50 | }
|
---|
| 51 | }
|
---|
Note: See
TracBrowser
for help on using the repository browser.