Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianMixtureModel.cs @ 11744

Last change on this file since 11744 was 11744, checked in by gkronber, 9 years ago

#2283 worked on TD, and models for MCTS

File size: 1.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8
9namespace HeuristicLab.Algorithms.Bandits.Models {
10  public class GaussianMixtureModel : IModel {
11    private readonly double[] componentMeans;
12    private readonly double[] componentVars;
13    private readonly double[] componentProbs;
14
15    private int numComponents;
16
17    public GaussianMixtureModel(int nComponents = 5) {
18      this.numComponents = nComponents;
19      this.componentProbs = new double[nComponents];
20      this.componentMeans = new double[nComponents];
21      this.componentVars = new double[nComponents];
22    }
23
24
25    public double SampleExpectedReward(Random random) {
26      var k = Enumerable.Range(0, numComponents).SampleProportional(random, componentProbs).First();
27      return alglib.invnormaldistribution(random.NextDouble()) * Math.Sqrt(componentVars[k]) + componentMeans[k];
28    }
29
30    public void Update(double reward) {
31      // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
32      throw new NotImplementedException();
33    }
34
35    public void Disable() {
36      Array.Clear(componentMeans, 0, numComponents);
37      for (int i = 0; i < numComponents; i++)
38        componentVars[i] = 0.0;
39    }
40
41    public object Clone() {
42      return new GaussianMixtureModel(numComponents);
43    }
44
45    public void Reset() {
46      Array.Clear(componentMeans, 0, numComponents);
47      Array.Clear(componentVars, 0, numComponents);
48      Array.Clear(componentProbs, 0, numComponents);
49    }
50
51    public void PrintStats() {
52      throw new NotImplementedException();
53    }
54  }
55}
Note: See TracBrowser for help on using the repository browser.