using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration; using HeuristicLab.Collections; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration { [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")] [StorableClass] [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)] public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm { #region properties and result names private readonly string SearchStructureSizeName = "Search Structure Size"; private readonly string GeneratedPhrasesName = "Generated/Archived Phrases"; private readonly string GeneratedSentencesName = "Generated Sentences"; private readonly string DistinctSentencesName = "Distinct Sentences"; private readonly string PhraseExpansionsName = "Phrase Expansions"; private readonly string AverageSentenceComplexityName = "Avg. Sentence Complexity among Distinct"; private readonly string OverwrittenSentencesName = "Sentences overwritten"; private readonly string AnalyzersParameterName = "Analyzers"; private readonly string ExpansionsPerSecondName = "Expansions per second"; private readonly string OptimizeConstantsParameterName = "Optimize Constants"; private readonly string ErrorWeightParameterName = "Error Weight"; private readonly string SearchDataStructureParameterName = "Search Data Structure"; private readonly string MaxComplexityParameterName = "Max. Complexity"; private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval"; private readonly string GrammarSymbolsParameterName = "Grammar Symbols"; private readonly string SearchDataStructureSizeParameterName = "Search Data Structure Size"; // result names public static readonly string BestTrainingModelResultName = "Best model (Training)"; public static readonly string BestTrainingSolutionResultName = "Best solution (Training)"; public static readonly string BestComplexityResultName = "Best solution complexity"; public override bool SupportsPause { get { return true; } } protected IFixedValueParameter OptimizeConstantsParameter { get { return (IFixedValueParameter)Parameters[OptimizeConstantsParameterName]; } } public bool OptimizeConstants { get { return OptimizeConstantsParameter.Value.Value; } set { OptimizeConstantsParameter.Value.Value = value; } } protected IFixedValueParameter MaxComplexityParameter { get { return (IFixedValueParameter)Parameters[MaxComplexityParameterName]; } } public int MaxComplexity { get { return MaxComplexityParameter.Value.Value; } set { MaxComplexityParameter.Value.Value = value; } } protected IFixedValueParameter ErrorWeightParameter { get { return (IFixedValueParameter)Parameters[ErrorWeightParameterName]; } } public double ErrorWeight { get { return ErrorWeightParameter.Value.Value; } set { ErrorWeightParameter.Value.Value = value; } } protected IFixedValueParameter GuiUpdateIntervalParameter { get { return (IFixedValueParameter)Parameters[GuiUpdateIntervalParameterName]; } } public int GuiUpdateInterval { get { return GuiUpdateIntervalParameter.Value.Value; } set { GuiUpdateIntervalParameter.Value.Value = value; } } protected IFixedValueParameter> SearchDataStructureParameter { get { return (IFixedValueParameter>)Parameters[SearchDataStructureParameterName]; } } public IFixedValueParameter SearchDataStructureSizeParameter { get { return (IFixedValueParameter)Parameters[SearchDataStructureSizeParameterName]; } } public int SearchDataStructureSize { get { return SearchDataStructureSizeParameter.Value.Value; } } public StorageType SearchDataStructure { get { return SearchDataStructureParameter.Value.Value; } set { SearchDataStructureParameter.Value.Value = value; } } public IFixedValueParameter> AnalyzersParameter { get { return (IFixedValueParameter>)Parameters[AnalyzersParameterName]; } } public ICheckedItemCollection Analyzers { get { return AnalyzersParameter.Value; } } public IFixedValueParameter>> GrammarSymbolsParameter { get { return (IFixedValueParameter>>)Parameters[GrammarSymbolsParameterName]; } } public ReadOnlyCheckedItemCollection> GrammarSymbols { get { return GrammarSymbolsParameter.Value; } } [Storable] public SymbolString BestTrainingSentence { get; set; } // Currently set in RSquaredEvaluator: quite hacky, but makes testing much easier for now... #endregion [Storable] public Dictionary DistinctSentencesComplexity { get; private set; } // Semantically distinct sentences and their length in a run. [Storable] public HashSet ArchivedPhrases { get; private set; } [Storable] internal SearchDataStore OpenPhrases { get; private set; } // Stack/Queue/etc. for fetching the next node in the search tree. [Storable] public int MaxSentenceLength { get; private set; } #region execution stats [Storable] public int AllGeneratedSentencesCount { get; private set; } [Storable] public int OverwrittenSentencesCount { get; private set; } // It is not guaranteed that shorter solutions are found first. // When longer solutions are overwritten with shorter ones, // this counter is increased. [Storable] public int PhraseExpansionCount { get; private set; } // Number, how many times a nonterminal symbol is replaced with a production rule. #endregion [Storable] public Grammar Grammar { get; private set; } #region ctors public override IDeepCloneable Clone(Cloner cloner) { return new GrammarEnumerationAlgorithm(this, cloner); } [StorableConstructor] protected GrammarEnumerationAlgorithm(bool deserializing) : base(deserializing) { } private void RegisterEvents() { // re-wire analyzer events foreach (var analyzer in Analyzers.CheckedItems) analyzer.Register(this); Analyzers.CheckedItemsChanged += Analyzers_CheckedItemsChanged; SearchDataStructureParameter.Value.ValueChanged += (o, e) => Prepare(); SearchDataStructureSizeParameter.Value.ValueChanged += (o, e) => Prepare(); } private void DeregisterEvents() { foreach (var analyzer in Analyzers.CheckedItems) analyzer.Register(this); Analyzers.CheckedItemsChanged -= Analyzers_CheckedItemsChanged; SearchDataStructureParameter.Value.ValueChanged -= (o, e) => Prepare(); SearchDataStructureSizeParameter.Value.ValueChanged -= (o, e) => Prepare(); } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { RegisterEvents(); } public GrammarEnumerationAlgorithm() { Parameters.Add(new FixedValueParameter(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false))); Parameters.Add(new FixedValueParameter(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8))); Parameters.Add(new FixedValueParameter(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(12))); Parameters.Add(new FixedValueParameter(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(5000))); Parameters.Add(new FixedValueParameter(SearchDataStructureSizeParameterName, "The size of the search data structure.", new IntValue((int)1e5))); Parameters.Add(new FixedValueParameter>(SearchDataStructureParameterName, new EnumValue(StorageType.SortedSet))); SearchDataStructureParameter.Value.ValueChanged += (o, e) => Prepare(); SearchDataStructureSizeParameter.Value.ValueChanged += (o, e) => Prepare(); var availableAnalyzers = new IGrammarEnumerationAnalyzer[] { new SearchGraphVisualizer(), new SentenceLogger(), new RSquaredEvaluator() }; Parameters.Add(new FixedValueParameter>( AnalyzersParameterName, new CheckedItemCollection(availableAnalyzers).AsReadOnly())); Analyzers.CheckedItemsChanged += Analyzers_CheckedItemsChanged; foreach (var analyzer in Analyzers) { Analyzers.SetItemCheckedState(analyzer, false); } Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is RSquaredEvaluator), true); var grammarSymbols = Enum.GetValues(typeof(GrammarRule)) .Cast() .Select(v => new EnumValue(v)); Parameters.Add(new FixedValueParameter>>( GrammarSymbolsParameterName, new ReadOnlyCheckedItemCollection>(new CheckedItemCollection>(grammarSymbols)) )); foreach (EnumValue grammarSymbol in GrammarSymbols) { GrammarSymbols.SetItemCheckedState(grammarSymbol, true); } // set a default problem Problem = new RegressionProblem() { ProblemData = new Problems.Instances.DataAnalysis.PolyTen(seed: 1234).GenerateRegressionData() }; } public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { foreach (var analyzer in Analyzers.CheckedItems) analyzer.Register(this); Analyzers.CheckedItemsChanged += Analyzers_CheckedItemsChanged; DistinctSentencesComplexity = new Dictionary(original.DistinctSentencesComplexity); ArchivedPhrases = new HashSet(original.ArchivedPhrases); OpenPhrases = cloner.Clone(original.OpenPhrases); Grammar = cloner.Clone(original.Grammar); AllGeneratedSentencesCount = original.AllGeneratedSentencesCount; OverwrittenSentencesCount = original.OverwrittenSentencesCount; PhraseExpansionCount = original.PhraseExpansionCount; } #endregion [Storable] private Dictionary variableImportance; public override void Prepare() { DistinctSentencesComplexity = new Dictionary(); ArchivedPhrases = new HashSet(); AllGeneratedSentencesCount = 0; OverwrittenSentencesCount = 0; PhraseExpansionCount = 0; Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray(), GrammarSymbols.CheckedItems.Select(v => v.Value)); OpenPhrases = new SearchDataStore(SearchDataStructure, SearchDataStructureSize); // Select search strategy base.Prepare(); // this actually clears the results which will get reinitialized on Run() } protected override void Run(CancellationToken cancellationToken) { // do not reinitialize the algorithm if we're resuming from pause if (previousExecutionState != ExecutionState.Paused) { InitResults(); var phrase0 = new SymbolString(new[] { Grammar.StartSymbol }); var phrase0Hash = Grammar.Hasher.CalcHashCode(phrase0); OpenPhrases.Store(new SearchNode(phrase0Hash, 0.0, 0.0, phrase0)); } MaxSentenceLength = Grammar.GetMaxSentenceLength(MaxComplexity); var errorWeight = ErrorWeight; var optimizeConstants = OptimizeConstants; // cache value to avoid parameter lookup // main search loop while (OpenPhrases.Count > 0) { if (cancellationToken.IsCancellationRequested) break; SearchNode fetchedSearchNode = OpenPhrases.GetNext(); if (fetchedSearchNode == null) continue; SymbolString currPhrase = fetchedSearchNode.SymbolString; OnPhraseFetched(fetchedSearchNode.Hash, currPhrase); ArchivedPhrases.Add(fetchedSearchNode.Hash); // expand next nonterminal symbols int nonterminalSymbolIndex = currPhrase.NextNonterminalIndex(); NonterminalSymbol expandedSymbol = (NonterminalSymbol)currPhrase[nonterminalSymbolIndex]; var appliedProductions = Grammar.Productions[expandedSymbol]; for (int i = 0; i < appliedProductions.Count; i++) { PhraseExpansionCount++; SymbolString newPhrase = currPhrase.DerivePhrase(nonterminalSymbolIndex, appliedProductions[i]); int newPhraseComplexity = Grammar.GetComplexity(newPhrase); if (newPhraseComplexity > MaxComplexity) continue; var phraseHash = Grammar.Hasher.CalcHashCode(newPhrase); OnPhraseDerived(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); if (newPhrase.IsSentence()) { AllGeneratedSentencesCount++; OnSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); // Is the best solution found? (only if RSquaredEvaluator is activated) if (Results.ContainsKey(RSquaredEvaluator.BestTrainingQualityResultName)) { double r2 = ((DoubleValue)Results[RSquaredEvaluator.BestTrainingQualityResultName].Value).Value; if (r2.IsAlmost(1.0)) { UpdateView(force: true); return; } } if (!DistinctSentencesComplexity.ContainsKey(phraseHash) || DistinctSentencesComplexity[phraseHash] > newPhraseComplexity) { if (DistinctSentencesComplexity.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only DistinctSentencesComplexity[phraseHash] = newPhraseComplexity; OnDistinctSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); } UpdateView(); } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) { bool isCompleteSentence = IsCompleteSentence(newPhrase); double r2 = isCompleteSentence ? Grammar.EvaluatePhrase(newPhrase, Problem.ProblemData, optimizeConstants) : fetchedSearchNode.R2; double phrasePriority = GetPriority(newPhrase, r2, MaxSentenceLength); SearchNode newSearchNode = new SearchNode(phraseHash, phrasePriority, r2, newPhrase); OpenPhrases.Store(newSearchNode); } } } UpdateView(force: true); } protected static double GetPriority(SymbolString phrase, double r2, int maxSentenceLength) { return (1 - r2) * phrase.Count(); } private bool IsCompleteSentence(SymbolString phrase) { return !phrase.Any(x => x is NonterminalSymbol && x != Grammar.Expr); } #region pause support private ExecutionState previousExecutionState; protected override void OnPaused() { previousExecutionState = this.ExecutionState; base.OnPaused(); } protected override void OnPrepared() { previousExecutionState = this.ExecutionState; base.OnPrepared(); } protected override void OnStarted() { previousExecutionState = this.ExecutionState; base.OnStarted(); } protected override void OnStopped() { previousExecutionState = this.ExecutionState; // free memory at the end of the run (this saves a lot of memory) ArchivedPhrases.Clear(); OpenPhrases.Clear(); DistinctSentencesComplexity.Clear(); if (BestTrainingSentence == null) { base.OnStopped(); return; } var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter(); var tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence); var model = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, interpreter); SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants( interpreter, model.SymbolicExpressionTree, Problem.ProblemData, Problem.ProblemData.TrainingIndices, applyLinearScaling: true, maxIterations: 10, updateVariableWeights: false, updateConstantsInTree: true); model.Scale(Problem.ProblemData); var bestTrainingSolution = new SymbolicRegressionSolution(model, Problem.ProblemData); Results.AddOrUpdateResult(BestTrainingModelResultName, model); Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution); Results.AddOrUpdateResult(BestComplexityResultName, new IntValue(Grammar.GetComplexity(BestTrainingSentence))); base.OnStopped(); } #endregion #region Visualization in HL // Initialize entries in result set. private void InitResults() { Results.Clear(); Results.Add(new Result(GeneratedPhrasesName, new IntValue(0))); Results.Add(new Result(SearchStructureSizeName, new IntValue(0))); Results.Add(new Result(GeneratedSentencesName, new IntValue(0))); Results.Add(new Result(DistinctSentencesName, new IntValue(0))); Results.Add(new Result(PhraseExpansionsName, new IntValue(0))); Results.Add(new Result(OverwrittenSentencesName, new IntValue(0))); Results.Add(new Result(AverageSentenceComplexityName, new DoubleValue(1.0))); Results.Add(new Result(ExpansionsPerSecondName, "In Thousand expansions per second", new IntValue(0))); } // Update the view for intermediate results in an algorithm run. private int updates; private void UpdateView(bool force = false) { updates++; if (force || updates % GuiUpdateInterval == 1) { ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count; ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count; ((IntValue)Results[GeneratedSentencesName].Value).Value = AllGeneratedSentencesCount; ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentencesComplexity.Count; ((IntValue)Results[PhraseExpansionsName].Value).Value = PhraseExpansionCount; ((DoubleValue)Results[AverageSentenceComplexityName].Value).Value = DistinctSentencesComplexity.Count > 0 ? DistinctSentencesComplexity.Select(pair => pair.Value).Average() : 0; ((IntValue)Results[OverwrittenSentencesName].Value).Value = OverwrittenSentencesCount; ((IntValue)Results[ExpansionsPerSecondName].Value).Value = (int)((PhraseExpansionCount / ExecutionTime.TotalSeconds) / 1000.0); } } #endregion #region events // private event handlers for analyzers private void Analyzers_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs args) { // newly added items foreach (var item in args.Items.Except(args.OldItems).Union(args.OldItems.Except(args.Items))) { if (Analyzers.ItemChecked(item)) { item.Register(this); } else { item.Deregister(this); } } } public event EventHandler PhraseFetched; private void OnPhraseFetched(int hash, SymbolString symbolString) { if (PhraseFetched != null) { PhraseFetched(this, new PhraseEventArgs(hash, symbolString)); } } public event EventHandler PhraseDerived; private void OnPhraseDerived(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (PhraseDerived != null) { PhraseDerived(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } public event EventHandler SentenceGenerated; private void OnSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (SentenceGenerated != null) { SentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } public event EventHandler DistinctSentenceGenerated; private void OnDistinctSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (DistinctSentenceGenerated != null) { DistinctSentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } #endregion } #region events for analysis public class PhraseEventArgs : EventArgs { public int Hash { get; } public SymbolString Phrase { get; } public PhraseEventArgs(int hash, SymbolString phrase) { Hash = hash; Phrase = phrase; } } public class PhraseAddedEventArgs : EventArgs { public int ParentHash { get; } public int NewHash { get; } public SymbolString ParentPhrase { get; } public SymbolString NewPhrase { get; } public Symbol ExpandedSymbol { get; } public Production ExpandedProduction { get; } public PhraseAddedEventArgs(int parentHash, SymbolString parentPhrase, int newHash, SymbolString newPhrase, Symbol expandedSymbol, Production expandedProduction) { ParentHash = parentHash; ParentPhrase = parentPhrase; NewHash = newHash; NewPhrase = newPhrase; ExpandedSymbol = expandedSymbol; ExpandedProduction = expandedProduction; } } #endregion }