Line | |
---|
1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Diagnostics;
|
---|
4 | using System.Linq;
|
---|
5 | using System.Text;
|
---|
6 | using System.Threading.Tasks;
|
---|
7 | using HeuristicLab.Common;
|
---|
8 |
|
---|
9 | namespace HeuristicLab.Algorithms.Bandits.Models {
|
---|
10 | public class BernoulliModel : IModel {
|
---|
11 |
|
---|
12 | private int success;
|
---|
13 | private int failure;
|
---|
14 |
|
---|
15 | // parameters of beta prior distribution
|
---|
16 | private readonly double alpha;
|
---|
17 | private readonly double beta;
|
---|
18 |
|
---|
19 | public BernoulliModel(double alpha = 1.0, double beta = 1.0) {
|
---|
20 | this.alpha = alpha;
|
---|
21 | this.beta = beta;
|
---|
22 | }
|
---|
23 |
|
---|
24 | public double SampleExpectedReward(Random random) {
|
---|
25 | // sample bernoulli mean from beta prior
|
---|
26 | return Rand.BetaRand(random, success + alpha, failure + beta);
|
---|
27 | }
|
---|
28 |
|
---|
29 | public void Update(double reward) {
|
---|
30 | // Debug.Assert(reward.IsAlmost(1.0) || reward.IsAlmost(0.0));
|
---|
31 | if (reward > 0) {
|
---|
32 | success++;
|
---|
33 | } else {
|
---|
34 | failure++;
|
---|
35 | }
|
---|
36 | }
|
---|
37 |
|
---|
38 | public void Reset() {
|
---|
39 | success = 0;
|
---|
40 | failure = 0;
|
---|
41 | }
|
---|
42 |
|
---|
43 | public void PrintStats() {
|
---|
44 | Console.Write("{0:F2} ", success / (double)failure);
|
---|
45 | }
|
---|
46 |
|
---|
47 | public object Clone() {
|
---|
48 | return new BernoulliModel() { failure = this.failure, success = this.success };
|
---|
49 | }
|
---|
50 |
|
---|
51 | public override string ToString() {
|
---|
52 | return string.Format("Bernoulli with Beta prior: mu={0:F2}", (success + alpha) / (success + alpha + failure + beta));
|
---|
53 | }
|
---|
54 | }
|
---|
55 | }
|
---|
Note: See
TracBrowser
for help on using the repository browser.