Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/12/15 21:23:01 (9 years ago)
Author:
gkronber
Message:

#2283: implemented test problems for MCTS

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianMixtureModel.cs

    r11744 r11747  
    99namespace HeuristicLab.Algorithms.Bandits.Models {
    1010  public class GaussianMixtureModel : IModel {
    11     private readonly double[] componentMeans;
    12     private readonly double[] componentVars;
    13     private readonly double[] componentProbs;
     11    private double[] componentMeans;
     12    private double[] componentVars;
     13    private double[] componentProbs;
     14    private readonly List<double> allRewards = new List<double>();
    1415
    1516    private int numComponents;
     
    1718    public GaussianMixtureModel(int nComponents = 5) {
    1819      this.numComponents = nComponents;
    19       this.componentProbs = new double[nComponents];
    20       this.componentMeans = new double[nComponents];
    21       this.componentVars = new double[nComponents];
     20
     21      Reset();
    2222    }
    2323
     
    2929
    3030    public void Update(double reward) {
    31       // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
    32       throw new NotImplementedException();
     31      allRewards.Add(reward);
     32      throw new NotSupportedException("this does not yet work");
     33      if (allRewards.Count < 1000 && allRewards.Count % 10 == 0) {
     34        // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
     35        Reset();
     36        for (int i = 0; i < 20; i++) {
     37          var responsibilities = allRewards.Select(r => CalcResponsibility(r)).ToArray();
     38
     39
     40          var sumWeightedRewards = new double[numComponents];
     41          var sumResponsibilities = new double[numComponents];
     42          foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) {
     43            for (int k = 0; k < numComponents; k++) {
     44              sumWeightedRewards[k] += p.Item2[k] * p.Item1;
     45              sumResponsibilities[k] += p.Item2[k];
     46            }
     47          }
     48          for (int k = 0; k < numComponents; k++) {
     49            componentMeans[k] = sumWeightedRewards[k] / sumResponsibilities[k];
     50          }
     51
     52          sumWeightedRewards = new double[numComponents];
     53          foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) {
     54            for (int k = 0; k < numComponents; k++) {
     55              sumWeightedRewards[k] += p.Item2[k] * Math.Pow(p.Item1 - componentMeans[k], 2);
     56            }
     57          }
     58          for (int k = 0; k < numComponents; k++) {
     59            componentVars[k] = sumWeightedRewards[k] / sumResponsibilities[k];
     60            componentProbs[k] = sumResponsibilities[k] / sumResponsibilities.Sum();
     61          }
     62        }
     63      }
     64    }
     65
     66    private double[] CalcResponsibility(double r) {
     67      var res = new double[numComponents];
     68      for (int k = 0; k < numComponents; k++) {
     69        componentVars[k] = Math.Max(componentVars[k], 0.001);
     70        res[k] = componentProbs[k] * alglib.normaldistribution((r - componentMeans[k]) / Math.Sqrt(componentVars[k]));
     71        res[k] = Math.Max(res[k], 0.0001);
     72      }
     73      var sum = res.Sum();
     74      for (int k = 0; k < numComponents; k++) {
     75        res[k] /= sum;
     76      }
     77      return res;
    3378    }
    3479
     
    4489
    4590    public void Reset() {
    46       Array.Clear(componentMeans, 0, numComponents);
    47       Array.Clear(componentVars, 0, numComponents);
    48       Array.Clear(componentProbs, 0, numComponents);
     91      var rand = new Random();
     92      this.componentProbs = Enumerable.Range(0, numComponents).Select((_) => rand.NextDouble()).ToArray();
     93      var sum = componentProbs.Sum();
     94      for (int i = 0; i < componentProbs.Length; i++) componentProbs[i] /= sum;
     95      this.componentMeans = Enumerable.Range(0, numComponents).Select((_) => Rand.RandNormal(rand)).ToArray();
     96      this.componentVars = Enumerable.Range(0, numComponents).Select((_) => 0.01).ToArray();
    4997    }
    5098
Note: See TracChangeset for help on using the changeset viewer.