Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianMixtureModel.cs @ 11747

Last change on this file since 11747 was 11747, checked in by gkronber, 9 years ago

#2283: implemented test problems for MCTS

File size: 3.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8
9namespace HeuristicLab.Algorithms.Bandits.Models {
10  public class GaussianMixtureModel : IModel {
11    private double[] componentMeans;
12    private double[] componentVars;
13    private double[] componentProbs;
14    private readonly List<double> allRewards = new List<double>();
15
16    private int numComponents;
17
18    public GaussianMixtureModel(int nComponents = 5) {
19      this.numComponents = nComponents;
20
21      Reset();
22    }
23
24
25    public double SampleExpectedReward(Random random) {
26      var k = Enumerable.Range(0, numComponents).SampleProportional(random, componentProbs).First();
27      return alglib.invnormaldistribution(random.NextDouble()) * Math.Sqrt(componentVars[k]) + componentMeans[k];
28    }
29
30    public void Update(double reward) {
31      allRewards.Add(reward);
32      throw new NotSupportedException("this does not yet work");
33      if (allRewards.Count < 1000 && allRewards.Count % 10 == 0) {
34        // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
35        Reset();
36        for (int i = 0; i < 20; i++) {
37          var responsibilities = allRewards.Select(r => CalcResponsibility(r)).ToArray();
38
39
40          var sumWeightedRewards = new double[numComponents];
41          var sumResponsibilities = new double[numComponents];
42          foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) {
43            for (int k = 0; k < numComponents; k++) {
44              sumWeightedRewards[k] += p.Item2[k] * p.Item1;
45              sumResponsibilities[k] += p.Item2[k];
46            }
47          }
48          for (int k = 0; k < numComponents; k++) {
49            componentMeans[k] = sumWeightedRewards[k] / sumResponsibilities[k];
50          }
51
52          sumWeightedRewards = new double[numComponents];
53          foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) {
54            for (int k = 0; k < numComponents; k++) {
55              sumWeightedRewards[k] += p.Item2[k] * Math.Pow(p.Item1 - componentMeans[k], 2);
56            }
57          }
58          for (int k = 0; k < numComponents; k++) {
59            componentVars[k] = sumWeightedRewards[k] / sumResponsibilities[k];
60            componentProbs[k] = sumResponsibilities[k] / sumResponsibilities.Sum();
61          }
62        }
63      }
64    }
65
66    private double[] CalcResponsibility(double r) {
67      var res = new double[numComponents];
68      for (int k = 0; k < numComponents; k++) {
69        componentVars[k] = Math.Max(componentVars[k], 0.001);
70        res[k] = componentProbs[k] * alglib.normaldistribution((r - componentMeans[k]) / Math.Sqrt(componentVars[k]));
71        res[k] = Math.Max(res[k], 0.0001);
72      }
73      var sum = res.Sum();
74      for (int k = 0; k < numComponents; k++) {
75        res[k] /= sum;
76      }
77      return res;
78    }
79
80    public void Disable() {
81      Array.Clear(componentMeans, 0, numComponents);
82      for (int i = 0; i < numComponents; i++)
83        componentVars[i] = 0.0;
84    }
85
86    public object Clone() {
87      return new GaussianMixtureModel(numComponents);
88    }
89
90    public void Reset() {
91      var rand = new Random();
92      this.componentProbs = Enumerable.Range(0, numComponents).Select((_) => rand.NextDouble()).ToArray();
93      var sum = componentProbs.Sum();
94      for (int i = 0; i < componentProbs.Length; i++) componentProbs[i] /= sum;
95      this.componentMeans = Enumerable.Range(0, numComponents).Select((_) => Rand.RandNormal(rand)).ToArray();
96      this.componentVars = Enumerable.Range(0, numComponents).Select((_) => 0.01).ToArray();
97    }
98
99    public void PrintStats() {
100      throw new NotImplementedException();
101    }
102  }
103}
Note: See TracBrowser for help on using the repository browser.