using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; namespace HeuristicLab.Algorithms.Bandits.Models { public class GaussianMixtureModel : IModel { private double[] componentMeans; private double[] componentVars; private double[] componentProbs; private readonly List allRewards = new List(); private int numComponents; public GaussianMixtureModel(int nComponents = 5) { this.numComponents = nComponents; Reset(); } public double SampleExpectedReward(Random random) { var k = Enumerable.Range(0, numComponents).SampleProportional(random, componentProbs).First(); return alglib.invnormaldistribution(random.NextDouble()) * Math.Sqrt(componentVars[k]) + componentMeans[k]; } public void Update(double reward) { allRewards.Add(reward); throw new NotSupportedException("this does not yet work"); if (allRewards.Count < 1000 && allRewards.Count % 10 == 0) { // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means Reset(); for (int i = 0; i < 20; i++) { var responsibilities = allRewards.Select(r => CalcResponsibility(r)).ToArray(); var sumWeightedRewards = new double[numComponents]; var sumResponsibilities = new double[numComponents]; foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) { for (int k = 0; k < numComponents; k++) { sumWeightedRewards[k] += p.Item2[k] * p.Item1; sumResponsibilities[k] += p.Item2[k]; } } for (int k = 0; k < numComponents; k++) { componentMeans[k] = sumWeightedRewards[k] / sumResponsibilities[k]; } sumWeightedRewards = new double[numComponents]; foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) { for (int k = 0; k < numComponents; k++) { sumWeightedRewards[k] += p.Item2[k] * Math.Pow(p.Item1 - componentMeans[k], 2); } } for (int k = 0; k < numComponents; k++) { componentVars[k] = sumWeightedRewards[k] / sumResponsibilities[k]; componentProbs[k] = sumResponsibilities[k] / sumResponsibilities.Sum(); } } } } private double[] CalcResponsibility(double r) { var res = new double[numComponents]; for (int k = 0; k < numComponents; k++) { componentVars[k] = Math.Max(componentVars[k], 0.001); res[k] = componentProbs[k] * alglib.normaldistribution((r - componentMeans[k]) / Math.Sqrt(componentVars[k])); res[k] = Math.Max(res[k], 0.0001); } var sum = res.Sum(); for (int k = 0; k < numComponents; k++) { res[k] /= sum; } return res; } public void Disable() { Array.Clear(componentMeans, 0, numComponents); for (int i = 0; i < numComponents; i++) componentVars[i] = 0.0; } public object Clone() { return new GaussianMixtureModel(numComponents); } public void Reset() { var rand = new Random(); this.componentProbs = Enumerable.Range(0, numComponents).Select((_) => rand.NextDouble()).ToArray(); var sum = componentProbs.Sum(); for (int i = 0; i < componentProbs.Length; i++) componentProbs[i] /= sum; this.componentMeans = Enumerable.Range(0, numComponents).Select((_) => Rand.RandNormal(rand)).ToArray(); this.componentVars = Enumerable.Range(0, numComponents).Select((_) => 0.01).ToArray(); } public void PrintStats() { throw new NotImplementedException(); } } }