Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianMixtureModel.cs @ 11742

Last change on this file since 11742 was 11730, checked in by gkronber, 9 years ago

#2283: several major extensions for grammatical optimization

File size: 2.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8
9namespace HeuristicLab.Algorithms.Bandits.Models {
10  public class GaussianMixtureModel : IModel {
11    private readonly int numActions;
12    private readonly double[][] meanMean; // mean of mean for each arm and component
13    private readonly double[][] meanVariance; // variance of mean for each arm and component
14    private readonly double[][] componentProb;
15
16    // parameters of beta prior distribution
17    private int numComponents;
18    private double priorMean;
19
20    public GaussianMixtureModel(int numActions, double priorMean = 0.5, int nComponents = 5) {
21      this.numActions = numActions;
22      this.numComponents = nComponents;
23      this.priorMean = priorMean;
24      this.meanMean = new double[numActions][];
25      this.meanVariance = new double[numActions][];
26      this.componentProb = new double[numActions][];
27      for (int a = 0; a < numActions; a++) {
28        // TODO: probably need to initizalize this randomly to allow learning
29        meanMean[a] = Enumerable.Repeat(priorMean, nComponents).ToArray();
30        meanVariance[a] = Enumerable.Repeat(1.0, nComponents).ToArray(); // prior variance of mean variance = 1
31        componentProb[a] = Enumerable.Repeat(1.0 / nComponents, nComponents).ToArray(); // uniform prior for component probabilities
32      }
33    }
34
35
36    public double[] SampleExpectedRewards(Random random) {
37      // sample mean foreach action and component from the prior
38      var exp = new double[numActions];
39      for (int a = 0; a < numActions; a++) {
40        var sumReward = 0.0;
41        var numSamples = 10000;
42        var sampledComponents = Enumerable.Range(0, numComponents).SampleProportional(random, componentProb[a]).Take(numSamples);
43        foreach (var k in sampledComponents) {
44          sumReward += Rand.RandNormal(random) * Math.Sqrt(meanVariance[a][k]) + meanMean[a][k];
45        }
46        exp[a] = sumReward / (double)numSamples;
47      }
48
49      return exp;
50    }
51
52    public void Update(int action, double reward) {
53      // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
54      throw new NotImplementedException();
55    }
56
57    public void Disable(int action) {
58      Array.Clear(meanMean[action], 0, meanMean[action].Length);
59      Array.Clear(meanVariance[action], 0, meanVariance[action].Length);
60    }
61
62    public void Reset() {
63      Array.Clear(meanMean, 0, meanMean.Length);
64      Array.Clear(meanVariance, 0, meanVariance.Length);
65    }
66
67    public void PrintStats() {
68      throw new NotImplementedException();
69    }
70  }
71}
Note: See TracBrowser for help on using the repository browser.