using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; namespace HeuristicLab.Algorithms.Bandits.Models { public class GaussianMixtureModel : IModel { private readonly int numActions; private readonly double[][] meanMean; // mean of mean for each arm and component private readonly double[][] meanVariance; // variance of mean for each arm and component private readonly double[][] componentProb; // parameters of beta prior distribution private int numComponents; private double priorMean; public GaussianMixtureModel(int numActions, double priorMean = 0.5, int nComponents = 5) { this.numActions = numActions; this.numComponents = nComponents; this.priorMean = priorMean; this.meanMean = new double[numActions][]; this.meanVariance = new double[numActions][]; this.componentProb = new double[numActions][]; for (int a = 0; a < numActions; a++) { // TODO: probably need to initizalize this randomly to allow learning meanMean[a] = Enumerable.Repeat(priorMean, nComponents).ToArray(); meanVariance[a] = Enumerable.Repeat(1.0, nComponents).ToArray(); // prior variance of mean variance = 1 componentProb[a] = Enumerable.Repeat(1.0 / nComponents, nComponents).ToArray(); // uniform prior for component probabilities } } public double[] SampleExpectedRewards(Random random) { // sample mean foreach action and component from the prior var exp = new double[numActions]; for (int a = 0; a < numActions; a++) { var sumReward = 0.0; var numSamples = 10000; var sampledComponents = Enumerable.Range(0, numComponents).SampleProportional(random, componentProb[a]).Take(numSamples); foreach (var k in sampledComponents) { sumReward += Rand.RandNormal(random) * Math.Sqrt(meanVariance[a][k]) + meanMean[a][k]; } exp[a] = sumReward / (double)numSamples; } return exp; } public void Update(int action, double reward) { // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means throw new NotImplementedException(); } public void Disable(int action) { Array.Clear(meanMean[action], 0, meanMean[action].Length); Array.Clear(meanVariance[action], 0, meanVariance[action].Length); } public void Reset() { Array.Clear(meanMean, 0, meanMean.Length); Array.Clear(meanVariance, 0, meanVariance.Length); } public void PrintStats() { throw new NotImplementedException(); } } }