1 | using System;
|
---|
2 | using System.Collections;
|
---|
3 | using System.Collections.Generic;
|
---|
4 | using System.Linq;
|
---|
5 | using HeuristicLab.Common;
|
---|
6 | using HeuristicLab.Core;
|
---|
7 | using HeuristicLab.Parameters;
|
---|
8 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
9 |
|
---|
10 | namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction {
|
---|
11 | [StorableClass]
|
---|
12 | public class SparseLinearApproximateStateValueFunction : ParameterizedNamedItem, IStateValueFunction, IStatefulItem {
|
---|
13 |
|
---|
14 | private readonly Dictionary<string, double> w = new Dictionary<string, double>();
|
---|
15 |
|
---|
16 | public IStateFunction StateFunction {
|
---|
17 | get {
|
---|
18 | return ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value;
|
---|
19 | }
|
---|
20 | private set { ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value = value; }
|
---|
21 | }
|
---|
22 |
|
---|
23 | public SparseLinearApproximateStateValueFunction()
|
---|
24 | : base() {
|
---|
25 | Parameters.Add(new ValueParameter<IStateFunction>("Feature function", "The function that generates features for value approximation"));
|
---|
26 | }
|
---|
27 |
|
---|
28 | public double Value(object state) {
|
---|
29 | var features = state as IEnumerable<Tuple<string, double>>;
|
---|
30 | if (features == null) throw new InvalidProgramException();
|
---|
31 |
|
---|
32 | return features.Select(t => t.Item2 * GetWeight(t.Item1)).Sum();
|
---|
33 | }
|
---|
34 |
|
---|
35 | private double GetWeight(string featureId) {
|
---|
36 | double w;
|
---|
37 | this.w.TryGetValue(featureId, out w);
|
---|
38 | return w;
|
---|
39 | }
|
---|
40 |
|
---|
41 | public void Update(object state, double observedQuality) {
|
---|
42 | var features = state as IEnumerable<Tuple<string, double>>;
|
---|
43 | if (features == null) throw new InvalidProgramException();
|
---|
44 |
|
---|
45 | var delta = observedQuality - Value(state);
|
---|
46 | const double alpha = 0.01;
|
---|
47 |
|
---|
48 | foreach (var t in features) {
|
---|
49 | var featureId = t.Item1;
|
---|
50 | var w = t.Item2;
|
---|
51 |
|
---|
52 | this.w[featureId] = GetWeight(featureId) + alpha * delta;
|
---|
53 | }
|
---|
54 | }
|
---|
55 |
|
---|
56 | #region item
|
---|
57 | [StorableConstructor]
|
---|
58 | protected SparseLinearApproximateStateValueFunction(bool deserializing) : base(deserializing) { }
|
---|
59 | protected SparseLinearApproximateStateValueFunction(SparseLinearApproximateStateValueFunction original, Cloner cloner)
|
---|
60 | : base(original, cloner) {
|
---|
61 | w = new Dictionary<string, double>(original.w);
|
---|
62 | }
|
---|
63 |
|
---|
64 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
65 | return new SparseLinearApproximateStateValueFunction(this, cloner);
|
---|
66 | }
|
---|
67 | #endregion
|
---|
68 |
|
---|
69 | public void InitializeState() {
|
---|
70 | w.Clear();
|
---|
71 | }
|
---|
72 |
|
---|
73 | public void ClearState() {
|
---|
74 | w.Clear();
|
---|
75 | }
|
---|
76 | }
|
---|
77 | }
|
---|