Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/QualityFunctions/GbtApproximateStateValueFunction.cs @ 12955

Last change on this file since 12955 was 12955, checked in by gkronber, 7 years ago

#2471: implemented deterministic BFS and DFS for iterated symbolic expression construction

File size: 6.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Algorithms.DataAnalysis;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Parameters;
9using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
10using HeuristicLab.Problems.DataAnalysis;
11
12namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction {
13  [StorableClass]
14  public class GbtApproximateStateValueFunction : ParameterizedNamedItem, IStateValueFunction {
15    // encapuslates a feature list as a state (that is equatable)
16    [StorableClass]
17    private class FeatureState : IEquatable<FeatureState> {
18      private SortedList<string, double> features; // readonly
19
20      public FeatureState(IEnumerable<Tuple<string, double>> features) {
21        this.features = new SortedList<string, double>(features.ToDictionary(t => t.Item1, t => t.Item2));
22      }
23
24      public IEnumerable<string> GetFeatures() {
25        return features.Keys;
26      }
27      public double GetValue(string featureName) {
28        double d;
29        if (!features.TryGetValue(featureName, out d)) return 0.0;
30        else return d;
31      }
32
33      public bool Equals(FeatureState other) {
34        if (other == null) return false;
35        if (other.features.Count != this.features.Count) return false;
36        var f0 = this.features.GetEnumerator();
37        var f1 = other.features.GetEnumerator();
38        while (f0.MoveNext() & f1.MoveNext()) {
39          if (f0.Current.Key != f1.Current.Key ||
40             Math.Abs(f0.Current.Value - f1.Current.Value) > 1E-6) return false;
41        }
42        return true;
43      }
44
45      public override bool Equals(object obj) {
46        return this.Equals(obj as FeatureState);
47      }
48
49      public override int GetHashCode() {
50        // TODO perf
51        return string.Join("", features.Keys).GetHashCode();
52      }
53    }
54
55    // calculates a sparse list of features
56    [StorableClass]
57    [Item("FeaturesFunction", "")]
58    private class FeaturesFunction : Item, IStateFunction {
59      public FeaturesFunction() {
60      }
61
62      private FeaturesFunction(FeaturesFunction original, Cloner cloner)
63        : base(original, cloner) {
64      }
65      [StorableConstructor]
66      protected FeaturesFunction(bool deserializing) : base(deserializing) { }
67      public object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {
68        return new FeatureState(actions.GroupBy(a => a.Name).Select(g => Tuple.Create(g.Key, g.Count() / (double)actions.Count))); // terminal frequencies
69      }
70
71      public override IDeepCloneable Clone(Cloner cloner) {
72        return new FeaturesFunction(this, cloner);
73      }
74    }
75
76    [Storable]
77    private readonly ITabularStateValueFunction observedValues = new TabularMaxStateValueFunction(); // we only store the max observed value for each feature state
78    [Storable]
79    private readonly HashSet<FeatureState> observedStates = new HashSet<FeatureState>();
80    [Storable]
81    private readonly Dictionary<string, int> observedFeatures = new Dictionary<string, int>();
82
83    public IStateFunction StateFunction {
84      get {
85        return ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value;
86      }
87      private set { ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value = value; }
88    }
89
90    public GbtApproximateStateValueFunction()
91      : base() {
92      Parameters.Add(new ValueParameter<IStateFunction>("Feature function", "The function that generates features for value approximation", new FeaturesFunction()));
93    }
94
95    private IRegressionModel model;
96    private ModifiableDataset applicationDs;
97    public double Value(object state) {
98      if (model == null) return double.PositiveInfinity; // no function approximation yet
99
100      // init the row
101      var featureState = state as FeatureState;
102      foreach (var variableName in applicationDs.VariableNames)
103        applicationDs.SetVariableValue(featureState.GetValue(variableName), variableName, 0);
104
105      // only one row
106      return model.GetEstimatedValues(applicationDs, new int[] { 0 }).First();
107    }
108
109
110    private int newStatesUntilModelUpdate = 1000;
111    public virtual void Update(object state, double observedQuality) {
112      // update dataset of observed values
113      var featureState = state as FeatureState;
114      if (!observedStates.Contains(featureState)) {
115        newStatesUntilModelUpdate--;
116        observedStates.Add(featureState);
117      }
118      observedValues.Update(state, observedQuality);
119      foreach (var f in featureState.GetFeatures()) {
120        if (observedFeatures.ContainsKey(f)) observedFeatures[f]++;
121        else observedFeatures.Add(f, 1);
122      }
123      if (newStatesUntilModelUpdate == 0) {
124        newStatesUntilModelUpdate = 100;
125        // update model after 100 new states have been observed
126        var variableNames = new string[] { "target" }.Concat(observedFeatures.OrderByDescending(e => e.Value).Select(e => e.Key).Take(50)).ToArray();
127        var rows = observedStates.Count;
128        var cols = variableNames.Count();
129        var variableValues = new double[rows, cols];
130        int r = 0;
131        foreach (var obsState in observedStates) {
132          variableValues[r, 0] = observedValues.Value(obsState);
133          for (int c = 1; c < cols; c++) {
134            variableValues[r, c] = obsState.GetValue(variableNames[c]);
135          }
136          r++;
137        }
138        var trainingDs = new Dataset(variableNames, variableValues);
139        applicationDs = new ModifiableDataset(variableNames, variableNames.Select(i => new List<double>(new[] { 0.0 }))); // init one row with zero values
140        var problemData = new RegressionProblemData(trainingDs, variableNames.Skip(1), variableNames.First());
141        var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), 30, 0.01, 0.5, 0.5, 200);
142        //var rmsError = 0.0;
143        //var cvRmsError = 0.0;
144        //var solution = (SymbolicRegressionSolution)LinearRegression.CreateLinearRegressionSolution(problemData, out rmsError, out cvRmsError);
145        model = solution.Model;
146      }
147    }
148
149
150    #region item
151    [StorableConstructor]
152    protected GbtApproximateStateValueFunction(bool deserializing) : base(deserializing) { }
153    protected GbtApproximateStateValueFunction(GbtApproximateStateValueFunction original, Cloner cloner)
154      : base(original, cloner) {
155      // TODO
156    }
157    public override IDeepCloneable Clone(Cloner cloner) {
158      return new GbtApproximateStateValueFunction(this, cloner);
159    }
160    #endregion
161
162    public void InitializeState() {
163      ClearState();
164    }
165
166    public void ClearState() {
167    }
168  }
169}
Note: See TracBrowser for help on using the repository browser.