1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using HeuristicLab.Algorithms.DataAnalysis;
|
---|
5 | using HeuristicLab.Common;
|
---|
6 | using HeuristicLab.Core;
|
---|
7 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
8 | using HeuristicLab.Parameters;
|
---|
9 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
10 | using HeuristicLab.Problems.DataAnalysis;
|
---|
11 |
|
---|
12 | namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction {
|
---|
13 | [StorableClass]
|
---|
14 | public class GbtApproximateStateValueFunction : ParameterizedNamedItem, IStateValueFunction {
|
---|
15 | // encapuslates a feature list as a state (that is equatable)
|
---|
16 | [StorableClass]
|
---|
17 | private class FeatureState : IEquatable<FeatureState> {
|
---|
18 | private SortedList<string, double> features; // readonly
|
---|
19 |
|
---|
20 | public FeatureState(IEnumerable<Tuple<string, double>> features) {
|
---|
21 | this.features = new SortedList<string, double>(features.ToDictionary(t => t.Item1, t => t.Item2));
|
---|
22 | }
|
---|
23 |
|
---|
24 | public IEnumerable<string> GetFeatures() {
|
---|
25 | return features.Keys;
|
---|
26 | }
|
---|
27 | public double GetValue(string featureName) {
|
---|
28 | double d;
|
---|
29 | if (!features.TryGetValue(featureName, out d)) return 0.0;
|
---|
30 | else return d;
|
---|
31 | }
|
---|
32 |
|
---|
33 | public bool Equals(FeatureState other) {
|
---|
34 | if (other == null) return false;
|
---|
35 | if (other.features.Count != this.features.Count) return false;
|
---|
36 | var f0 = this.features.GetEnumerator();
|
---|
37 | var f1 = other.features.GetEnumerator();
|
---|
38 | while (f0.MoveNext() & f1.MoveNext()) {
|
---|
39 | if (f0.Current.Key != f1.Current.Key ||
|
---|
40 | Math.Abs(f0.Current.Value - f1.Current.Value) > 1E-6) return false;
|
---|
41 | }
|
---|
42 | return true;
|
---|
43 | }
|
---|
44 |
|
---|
45 | public override bool Equals(object obj) {
|
---|
46 | return this.Equals(obj as FeatureState);
|
---|
47 | }
|
---|
48 |
|
---|
49 | public override int GetHashCode() {
|
---|
50 | // TODO perf
|
---|
51 | return string.Join("", features.Keys).GetHashCode();
|
---|
52 | }
|
---|
53 | }
|
---|
54 |
|
---|
55 | // calculates a sparse list of features
|
---|
56 | [StorableClass]
|
---|
57 | [Item("FeaturesFunction", "")]
|
---|
58 | private class FeaturesFunction : Item, IStateFunction {
|
---|
59 | public FeaturesFunction() {
|
---|
60 | }
|
---|
61 |
|
---|
62 | private FeaturesFunction(FeaturesFunction original, Cloner cloner)
|
---|
63 | : base(original, cloner) {
|
---|
64 | }
|
---|
65 | [StorableConstructor]
|
---|
66 | protected FeaturesFunction(bool deserializing) : base(deserializing) { }
|
---|
67 | public object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {
|
---|
68 | return new FeatureState(actions.GroupBy(a => a.Name).Select(g => Tuple.Create(g.Key, g.Count() / (double)actions.Count))); // terminal frequencies
|
---|
69 | }
|
---|
70 |
|
---|
71 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
72 | return new FeaturesFunction(this, cloner);
|
---|
73 | }
|
---|
74 | }
|
---|
75 |
|
---|
76 | [Storable]
|
---|
77 | private readonly ITabularStateValueFunction observedValues = new TabularMaxStateValueFunction(); // we only store the max observed value for each feature state
|
---|
78 | [Storable]
|
---|
79 | private readonly HashSet<FeatureState> observedStates = new HashSet<FeatureState>();
|
---|
80 | [Storable]
|
---|
81 | private readonly Dictionary<string, int> observedFeatures = new Dictionary<string, int>();
|
---|
82 |
|
---|
83 | public IStateFunction StateFunction {
|
---|
84 | get {
|
---|
85 | return ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value;
|
---|
86 | }
|
---|
87 | private set { ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value = value; }
|
---|
88 | }
|
---|
89 |
|
---|
90 | public GbtApproximateStateValueFunction()
|
---|
91 | : base() {
|
---|
92 | Parameters.Add(new ValueParameter<IStateFunction>("Feature function", "The function that generates features for value approximation", new FeaturesFunction()));
|
---|
93 | }
|
---|
94 |
|
---|
95 | private IRegressionModel model;
|
---|
96 | private ModifiableDataset applicationDs;
|
---|
97 | public double Value(object state) {
|
---|
98 | if (model == null) return double.PositiveInfinity; // no function approximation yet
|
---|
99 |
|
---|
100 | // init the row
|
---|
101 | var featureState = state as FeatureState;
|
---|
102 | foreach (var variableName in applicationDs.VariableNames)
|
---|
103 | applicationDs.SetVariableValue(featureState.GetValue(variableName), variableName, 0);
|
---|
104 |
|
---|
105 | // only one row
|
---|
106 | return model.GetEstimatedValues(applicationDs, new int[] { 0 }).First();
|
---|
107 | }
|
---|
108 |
|
---|
109 |
|
---|
110 | private int newStatesUntilModelUpdate = 1000;
|
---|
111 | public virtual void Update(object state, double observedQuality) {
|
---|
112 | // update dataset of observed values
|
---|
113 | var featureState = state as FeatureState;
|
---|
114 | if (!observedStates.Contains(featureState)) {
|
---|
115 | newStatesUntilModelUpdate--;
|
---|
116 | observedStates.Add(featureState);
|
---|
117 | }
|
---|
118 | observedValues.Update(state, observedQuality);
|
---|
119 | foreach (var f in featureState.GetFeatures()) {
|
---|
120 | if (observedFeatures.ContainsKey(f)) observedFeatures[f]++;
|
---|
121 | else observedFeatures.Add(f, 1);
|
---|
122 | }
|
---|
123 | if (newStatesUntilModelUpdate == 0) {
|
---|
124 | newStatesUntilModelUpdate = 100;
|
---|
125 | // update model after 100 new states have been observed
|
---|
126 | var variableNames = new string[] { "target" }.Concat(observedFeatures.OrderByDescending(e => e.Value).Select(e => e.Key).Take(50)).ToArray();
|
---|
127 | var rows = observedStates.Count;
|
---|
128 | var cols = variableNames.Count();
|
---|
129 | var variableValues = new double[rows, cols];
|
---|
130 | int r = 0;
|
---|
131 | foreach (var obsState in observedStates) {
|
---|
132 | variableValues[r, 0] = observedValues.Value(obsState);
|
---|
133 | for (int c = 1; c < cols; c++) {
|
---|
134 | variableValues[r, c] = obsState.GetValue(variableNames[c]);
|
---|
135 | }
|
---|
136 | r++;
|
---|
137 | }
|
---|
138 | var trainingDs = new Dataset(variableNames, variableValues);
|
---|
139 | applicationDs = new ModifiableDataset(variableNames, variableNames.Select(i => new List<double>(new[] { 0.0 }))); // init one row with zero values
|
---|
140 | var problemData = new RegressionProblemData(trainingDs, variableNames.Skip(1), variableNames.First());
|
---|
141 | var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), 30, 0.01, 0.5, 0.5, 200);
|
---|
142 | //var rmsError = 0.0;
|
---|
143 | //var cvRmsError = 0.0;
|
---|
144 | //var solution = (SymbolicRegressionSolution)LinearRegression.CreateLinearRegressionSolution(problemData, out rmsError, out cvRmsError);
|
---|
145 | model = solution.Model;
|
---|
146 | }
|
---|
147 | }
|
---|
148 |
|
---|
149 |
|
---|
150 | #region item
|
---|
151 | [StorableConstructor]
|
---|
152 | protected GbtApproximateStateValueFunction(bool deserializing) : base(deserializing) { }
|
---|
153 | protected GbtApproximateStateValueFunction(GbtApproximateStateValueFunction original, Cloner cloner)
|
---|
154 | : base(original, cloner) {
|
---|
155 | // TODO
|
---|
156 | }
|
---|
157 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
158 | return new GbtApproximateStateValueFunction(this, cloner);
|
---|
159 | }
|
---|
160 | #endregion
|
---|
161 |
|
---|
162 | public void InitializeState() {
|
---|
163 | ClearState();
|
---|
164 | }
|
---|
165 |
|
---|
166 | public void ClearState() {
|
---|
167 | }
|
---|
168 | }
|
---|
169 | }
|
---|