Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/QualityFunctions/GbtApproximateStateValueFunction.cs @ 12924

Last change on this file since 12924 was 12924, checked in by gkronber, 7 years ago

#2471

  • implemented a demo for value approximation with gbt and a trivial feature function
File size: 6.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Collections.Specialized;
4using System.Linq;
5using System.Runtime.InteropServices.WindowsRuntime;
6using HeuristicLab.Algorithms.DataAnalysis;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
10using HeuristicLab.Parameters;
11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
12using HeuristicLab.Problems.DataAnalysis;
13
14namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction {
15  [StorableClass]
16  public class GbtApproximateStateValueFunction : ParameterizedNamedItem, IStateValueFunction {
17    // encapuslates a feature list as a state (that is equatable)
18    [StorableClass]
19    private class FeatureState : IEquatable<FeatureState> {
20      private SortedList<string, double> features; // readonly
21
22      public FeatureState(IEnumerable<Tuple<string, double>> features) {
23        this.features = new SortedList<string, double>(features.ToDictionary(t => t.Item1, t => t.Item2));
24      }
25
26      public IEnumerable<string> GetFeatures() {
27        return features.Keys;
28      }
29      public double GetValue(string featureName) {
30        double d;
31        if (!features.TryGetValue(featureName, out d)) return 0.0;
32        else return d;
33      }
34
35      public bool Equals(FeatureState other) {
36        if (other == null) return false;
37        if (other.features.Count != this.features.Count) return false;
38        var f0 = this.features.GetEnumerator();
39        var f1 = other.features.GetEnumerator();
40        while (f0.MoveNext() & f1.MoveNext()) {
41          if (f0.Current.Key != f1.Current.Key ||
42             Math.Abs(f0.Current.Value - f1.Current.Value) > 1E-6) return false;
43        }
44        return true;
45      }
46
47      public override bool Equals(object obj) {
48        return this.Equals(obj as FeatureState);
49      }
50
51      public override int GetHashCode() {
52        // TODO perf
53        return string.Join("", features.Keys).GetHashCode();
54      }
55    }
56
57    // calculates a sparse list of features
58    [StorableClass]
59    [Item("FeaturesFunction", "")]
60    private class FeaturesFunction : Item, IStateFunction {
61      public FeaturesFunction() {
62      }
63
64      private FeaturesFunction(FeaturesFunction original, Cloner cloner)
65        : base(original, cloner) {
66      }
67      [StorableConstructor]
68      protected FeaturesFunction(bool deserializing) : base(deserializing) { }
69      public object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {
70        return new FeatureState(actions.GroupBy(a => a.Name).Select(g => Tuple.Create(g.Key, g.Count() / (double)actions.Count))); // terminal frequencies
71      }
72
73      public override IDeepCloneable Clone(Cloner cloner) {
74        return new FeaturesFunction(this, cloner);
75      }
76    }
77
78    [Storable]
79    private readonly ITabularStateValueFunction observedValues = new TabularMaxStateValueFunction(); // we only store the max observed value for each feature state
80    [Storable]
81    private readonly HashSet<FeatureState> observedStates = new HashSet<FeatureState>();
82    [Storable]
83    private readonly Dictionary<string, int> observedFeatures = new Dictionary<string, int>();
84
85    public IStateFunction StateFunction {
86      get {
87        return ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value;
88      }
89      private set { ((IValueParameter<IStateFunction>)Parameters["Feature function"]).Value = value; }
90    }
91
92    public GbtApproximateStateValueFunction()
93      : base() {
94      Parameters.Add(new ValueParameter<IStateFunction>("Feature function", "The function that generates features for value approximation", new FeaturesFunction()));
95    }
96
97    private IRegressionModel model;
98    private ModifiableDataset applicationDs;
99    public double Value(object state) {
100      if (model == null) return double.PositiveInfinity; // no function approximation yet
101
102      // init the row
103      var featureState = state as FeatureState;
104      foreach (var variableName in applicationDs.VariableNames)
105        applicationDs.SetVariableValue(featureState.GetValue(variableName), variableName, 0);
106
107      // only one row
108      return model.GetEstimatedValues(applicationDs, new int[] { 0 }).First();
109    }
110
111
112    private int newStatesUntilModelUpdate = 1000;
113    public virtual void Update(object state, double observedQuality) {
114      // update dataset of observed values
115      var featureState = state as FeatureState;
116      if (!observedStates.Contains(featureState)) {
117        newStatesUntilModelUpdate--;
118        observedStates.Add(featureState);
119      }
120      observedValues.Update(state, observedQuality);
121      foreach (var f in featureState.GetFeatures()) {
122        if (observedFeatures.ContainsKey(f)) observedFeatures[f]++;
123        else observedFeatures.Add(f, 1);
124      }
125      if (newStatesUntilModelUpdate == 0) {
126        newStatesUntilModelUpdate = 100;
127        // update model after 100 new states have been observed
128        var variableNames = new string[] { "target" }.Concat(observedFeatures.OrderByDescending(e => e.Value).Select(e => e.Key).Take(50)).ToArray();
129        var rows = observedStates.Count;
130        var cols = variableNames.Count();
131        var variableValues = new double[rows, cols];
132        int r = 0;
133        foreach (var obsState in observedStates) {
134          variableValues[r, 0] = observedValues.Value(obsState);
135          for (int c = 1; c < cols; c++) {
136            variableValues[r, c] = obsState.GetValue(variableNames[c]);
137          }
138          r++;
139        }
140        var trainingDs = new Dataset(variableNames, variableValues);
141        applicationDs = new ModifiableDataset(variableNames, variableNames.Select(i => new List<double>(new[] { 0.0 }))); // init one row with zero values
142        var problemData = new RegressionProblemData(trainingDs, variableNames.Skip(1), variableNames.First());
143        var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), 30, 0.01, 0.5, 0.5, 200);
144        //var rmsError = 0.0;
145        //var cvRmsError = 0.0;
146        //var solution = (SymbolicRegressionSolution)LinearRegression.CreateLinearRegressionSolution(problemData, out rmsError, out cvRmsError);
147        model = solution.Model;
148      }
149    }
150
151
152    #region item
153    [StorableConstructor]
154    protected GbtApproximateStateValueFunction(bool deserializing) : base(deserializing) { }
155    protected GbtApproximateStateValueFunction(GbtApproximateStateValueFunction original, Cloner cloner)
156      : base(original, cloner) {
157      // TODO
158    }
159    public override IDeepCloneable Clone(Cloner cloner) {
160      return new GbtApproximateStateValueFunction(this, cloner);
161    }
162    #endregion
163
164    public void InitializeState() {
165      ClearState();
166    }
167
168    public void ClearState() {
169
170    }
171
172  }
173
174
175}
Note: See TracBrowser for help on using the repository browser.