using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Algorithms.DataAnalysis; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction { [StorableClass] public class GbtApproximateStateValueFunction : ParameterizedNamedItem, IStateValueFunction { // encapuslates a feature list as a state (that is equatable) [StorableClass] private class FeatureState : IEquatable { private SortedList features; // readonly public FeatureState(IEnumerable> features) { this.features = new SortedList(features.ToDictionary(t => t.Item1, t => t.Item2)); } public IEnumerable GetFeatures() { return features.Keys; } public double GetValue(string featureName) { double d; if (!features.TryGetValue(featureName, out d)) return 0.0; else return d; } public bool Equals(FeatureState other) { if (other == null) return false; if (other.features.Count != this.features.Count) return false; var f0 = this.features.GetEnumerator(); var f1 = other.features.GetEnumerator(); while (f0.MoveNext() & f1.MoveNext()) { if (f0.Current.Key != f1.Current.Key || Math.Abs(f0.Current.Value - f1.Current.Value) > 1E-6) return false; } return true; } public override bool Equals(object obj) { return this.Equals(obj as FeatureState); } public override int GetHashCode() { // TODO perf return string.Join("", features.Keys).GetHashCode(); } } // calculates a sparse list of features [StorableClass] [Item("FeaturesFunction", "")] private class FeaturesFunction : Item, IStateFunction { public FeaturesFunction() { } private FeaturesFunction(FeaturesFunction original, Cloner cloner) : base(original, cloner) { } [StorableConstructor] protected FeaturesFunction(bool deserializing) : base(deserializing) { } public object CreateState(ISymbolicExpressionTreeNode root, List actions, ISymbolicExpressionTreeNode parentNode, int childIdx) { return new FeatureState(actions.GroupBy(a => a.Name).Select(g => Tuple.Create(g.Key, g.Count() / (double)actions.Count))); // terminal frequencies } public override IDeepCloneable Clone(Cloner cloner) { return new FeaturesFunction(this, cloner); } } [Storable] private readonly ITabularStateValueFunction observedValues = new TabularMaxStateValueFunction(); // we only store the max observed value for each feature state [Storable] private readonly HashSet observedStates = new HashSet(); [Storable] private readonly Dictionary observedFeatures = new Dictionary(); public IStateFunction StateFunction { get { return ((IValueParameter)Parameters["Feature function"]).Value; } private set { ((IValueParameter)Parameters["Feature function"]).Value = value; } } public GbtApproximateStateValueFunction() : base() { Parameters.Add(new ValueParameter("Feature function", "The function that generates features for value approximation", new FeaturesFunction())); } private IRegressionModel model; private ModifiableDataset applicationDs; public double Value(object state) { if (model == null) return double.PositiveInfinity; // no function approximation yet // init the row var featureState = state as FeatureState; foreach (var variableName in applicationDs.VariableNames) applicationDs.SetVariableValue(featureState.GetValue(variableName), variableName, 0); // only one row return model.GetEstimatedValues(applicationDs, new int[] { 0 }).First(); } private int newStatesUntilModelUpdate = 1000; public virtual void Update(object state, double observedQuality) { // update dataset of observed values var featureState = state as FeatureState; if (!observedStates.Contains(featureState)) { newStatesUntilModelUpdate--; observedStates.Add(featureState); } observedValues.Update(state, observedQuality); foreach (var f in featureState.GetFeatures()) { if (observedFeatures.ContainsKey(f)) observedFeatures[f]++; else observedFeatures.Add(f, 1); } if (newStatesUntilModelUpdate == 0) { newStatesUntilModelUpdate = 100; // update model after 100 new states have been observed var variableNames = new string[] { "target" }.Concat(observedFeatures.OrderByDescending(e => e.Value).Select(e => e.Key).Take(50)).ToArray(); var rows = observedStates.Count; var cols = variableNames.Count(); var variableValues = new double[rows, cols]; int r = 0; foreach (var obsState in observedStates) { variableValues[r, 0] = observedValues.Value(obsState); for (int c = 1; c < cols; c++) { variableValues[r, c] = obsState.GetValue(variableNames[c]); } r++; } var trainingDs = new Dataset(variableNames, variableValues); applicationDs = new ModifiableDataset(variableNames, variableNames.Select(i => new List(new[] { 0.0 }))); // init one row with zero values var problemData = new RegressionProblemData(trainingDs, variableNames.Skip(1), variableNames.First()); var solution = GradientBoostedTreesAlgorithmStatic.TrainGbm(problemData, new SquaredErrorLoss(), 30, 0.01, 0.5, 0.5, 200); //var rmsError = 0.0; //var cvRmsError = 0.0; //var solution = (SymbolicRegressionSolution)LinearRegression.CreateLinearRegressionSolution(problemData, out rmsError, out cvRmsError); model = solution.Model; } } #region item [StorableConstructor] protected GbtApproximateStateValueFunction(bool deserializing) : base(deserializing) { } protected GbtApproximateStateValueFunction(GbtApproximateStateValueFunction original, Cloner cloner) : base(original, cloner) { // TODO } public override IDeepCloneable Clone(Cloner cloner) { return new GbtApproximateStateValueFunction(this, cloner); } #endregion public void InitializeState() { ClearState(); } public void ClearState() { } } }