1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using HeuristicLab.Algorithms.DataAnalysis;
5 | using HeuristicLab.Common;
6 | using HeuristicLab.Core;
7 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8 | using HeuristicLab.Parameters;
9 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
10 | using HeuristicLab.Problems.DataAnalysis;
11 |
12 | namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction {
13 | [StorableClass]
14 | public class GbtApproximateStateValueFunction : ParameterizedNamedItem, IStateValueFunction {
15 | // encapuslates a feature list as a state (that is equatable)
16 | [StorableClass]
17 | private class FeatureState : IEquatable<FeatureState> {
18 | private SortedList<string, double> features; // readonly
19 |
20 | public FeatureState(IEnumerable<Tuple<string, double>> features) {
21 | this.features = new SortedList<string, double>(features.ToDictionary(t => t.Item1, t => t.Item2));
22 | }
23 |
24 | public IEnumerable<string> GetFeatures() {
25 | return features.Keys;
26 | }
27 | public double GetValue(string featureName) {
28 | double d;
29 | if (!features.TryGetValue(featureName, out d)) return 0.0;
30 | else return d;
31 | }
32 |
33 | public bool Equals(FeatureState other) {
34 | if (other == null) return false;
35 | if (other.features.Count != this.features.Count) return false;
36 | var f0 = this.features.GetEnumerator();
37 | var f1 = other.features.GetEnumerator();
38 | while (f0.MoveNext() & f1.MoveNext()) {
39 | if (f0.Current.Key != f1.Current.Key ||
40 | Math.Abs(f0.Current.Value - f1.Current.Value) > 1E-6) return false;
41 | }
42 | return true;
43 | }
44 |
45 | public override bool Equals(object obj) {
46 | return this.Equals(obj as FeatureState);
47 | }
48 |
49 | public override int GetHashCode() {
50 | // TODO perf
51 | return string.Join("", features.Keys).GetHashCode();
52 | }
53 | }
54 |
55 | // calculates a sparse list of features
56 | [StorableClass]
57 | [Item("FeaturesFunction", "")]
58 | private class FeaturesFunction : Item, IStateFunction {
59 | public FeaturesFunction() {
60 | }
61 |
62 | private FeaturesFunction(FeaturesFunction original, Cloner cloner)
63 | : base(original, cloner) {
64 | }
65 | [StorableConstructor]
66 | protected FeaturesFunction(bool deserializing) : base(deserializing) { }
67 | public object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {
68 | return new FeatureState(actions.GroupBy(a => a.Name).Select(g => Tuple.Create(g.Key, g.Count() / (double)actions.Count))); // terminal frequencies
69 | }
70 |
71 | public override IDeepCloneable Clone(Cloner cloner) {
72 | return new FeaturesFunction(this, cloner);
73 | }
74 | }
75 |
76 | [Storable]
77 | private readonly ITabularStateValueFunction observedValues = new TabularMaxStateValueFunction(); // we only store the max observed value for each feature state
78 | [Storable]
79 | private readonly HashSet<FeatureState> observedStates = new HashSet<FeatureState>();
80 | [Storable]
81 | private readonly Dictionary<string, int> observedFeatures = new Dictionary<string, int>();
82 |
83 | public IStateFunction StateFunction {
84 | get {
85 | return ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value;
86 | }
87 | private set { ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value = value; }
88 | }
89 |
90 | public GbtApproximateStateValueFunction()
91 | : base() {
92 | Parameters.Add(new ValueParameter<IStateFunction>("Feature function", "The function that generates features for value approximation", new FeaturesFunction()));
93 | }
94 |
95 | private IRegressionModel model;
96 | private ModifiableDataset applicationDs;
97 | public double Value(object state) {
98 | if (model == null) return double.PositiveInfinity; // no function approximation yet
99 |
100 | // init the row
101 | var featureState = state as FeatureState;
102 | foreach (var variableName in applicationDs.VariableNames)
103 | applicationDs.SetVariableValue(featureState.GetValue(variableName), variableName, 0);
104 |
105 | // only one row
106 | return model.GetEstimatedValues(applicationDs, new int[] { 0 }).First();
107 | }
108 |
109 |
110 | private int newStatesUntilModelUpdate = 1000;
111 | public virtual void Update(object state, double observedQuality) {
112 | // update dataset of observed values
113 | var featureState = state as FeatureState;
114 | if (!observedStates.Contains(featureState)) {
115 | newStatesUntilModelUpdate--;
116 | observedStates.Add(featureState);
117 | }
118 | observedValues.Update(state, observedQuality);
119 | foreach (var f in featureState.GetFeatures()) {
120 | if (observedFeatures.ContainsKey(f)) observedFeatures[f]++;
121 | else observedFeatures.Add(f, 1);
122 | }
123 | if (newStatesUntilModelUpdate == 0) {
124 | newStatesUntilModelUpdate = 100;
125 | // update model after 100 new states have been observed
126 | var variableNames = new string[] { "target" }.Concat(observedFeatures.OrderByDescending(e => e.Value).Select(e => e.Key).Take(50)).ToArray();
127 | var rows = observedStates.Count;
128 | var cols = variableNames.Count();
129 | var variableValues = new double[rows, cols];
130 | int r = 0;
131 | foreach (var obsState in observedStates) {
132 | variableValues[r, 0] = observedValues.Value(obsState);
133 | for (int c = 1; c < cols; c++) {
134 | variableValues[r, c] = obsState.GetValue(variableNames[c]);
135 | }
136 | r++;
137 | }
138 | var trainingDs = new Dataset(variableNames, variableValues);
139 | applicationDs = new ModifiableDataset(variableNames, variableNames.Select(i => new List<double>(new[] { 0.0 }))); // init one row with zero values
140 | var problemData = new RegressionProblemData(trainingDs, variableNames.Skip(1), variableNames.First());
141 | var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), 30, 0.01, 0.5, 0.5, 200);
142 | //var rmsError = 0.0;
143 | //var cvRmsError = 0.0;
144 | //var solution = (SymbolicRegressionSolution)LinearRegression.CreateLinearRegressionSolution(problemData, out rmsError, out cvRmsError);
145 | model = solution.Model;
146 | }
147 | }
148 |
149 |
150 | #region item
151 | [StorableConstructor]
152 | protected GbtApproximateStateValueFunction(bool deserializing) : base(deserializing) { }
153 | protected GbtApproximateStateValueFunction(GbtApproximateStateValueFunction original, Cloner cloner)
154 | : base(original, cloner) {
155 | // TODO
156 | }
157 | public override IDeepCloneable Clone(Cloner cloner) {
158 | return new GbtApproximateStateValueFunction(this, cloner);
159 | }
160 | #endregion
161 |
162 | public void InitializeState() {
163 | ClearState();
164 | }
165 |
166 | public void ClearState() {
167 | }
168 | }
169 | }