using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration; using HeuristicLab.Collections; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration { [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")] [StorableClass] [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)] public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm { #region properties and result names private readonly string SearchStructureSizeName = "Search Structure Size"; private readonly string GeneratedPhrasesName = "Generated/Archived Phrases"; private readonly string GeneratedSentencesName = "Generated Sentences"; private readonly string DistinctSentencesName = "Distinct Sentences"; private readonly string PhraseExpansionsName = "Phrase Expansions"; private readonly string AverageSentenceComplexityName = "Avg. Sentence Complexity among Distinct"; private readonly string OverwrittenSentencesName = "Sentences overwritten"; private readonly string AnalyzersParameterName = "Analyzers"; private readonly string ExpansionsPerSecondName = "Expansions per second"; private readonly string OptimizeConstantsParameterName = "Optimize Constants"; private readonly string ErrorWeightParameterName = "Error Weight"; private readonly string SearchDataStructureParameterName = "Search Data Structure"; private readonly string MaxComplexityParameterName = "Max. Complexity"; private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval"; public override bool SupportsPause { get { return false; } } protected IValueParameter OptimizeConstantsParameter { get { return (IValueParameter)Parameters[OptimizeConstantsParameterName]; } } public bool OptimizeConstants { get { return OptimizeConstantsParameter.Value.Value; } set { OptimizeConstantsParameter.Value.Value = value; } } protected IValueParameter MaxComplexityParameter { get { return (IValueParameter)Parameters[MaxComplexityParameterName]; } } public int MaxComplexity { get { return MaxComplexityParameter.Value.Value; } set { MaxComplexityParameter.Value.Value = value; } } protected IValueParameter ErrorWeightParameter { get { return (IValueParameter)Parameters[ErrorWeightParameterName]; } } public double ErrorWeight { get { return ErrorWeightParameter.Value.Value; } set { ErrorWeightParameter.Value.Value = value; } } protected IValueParameter GuiUpdateIntervalParameter { get { return (IValueParameter)Parameters[GuiUpdateIntervalParameterName]; } } public int GuiUpdateInterval { get { return GuiUpdateIntervalParameter.Value.Value; } set { GuiUpdateIntervalParameter.Value.Value = value; } } protected IValueParameter> SearchDataStructureParameter { get { return (IValueParameter>)Parameters[SearchDataStructureParameterName]; } } public StorageType SearchDataStructure { get { return SearchDataStructureParameter.Value.Value; } set { SearchDataStructureParameter.Value.Value = value; } } public IFixedValueParameter> AnalyzersParameter { get { return (IFixedValueParameter>)Parameters[AnalyzersParameterName]; } } public ICheckedItemCollection Analyzers { get { return AnalyzersParameter.Value; } } public SymbolString BestTrainingSentence { get; set; } // Currently set in RSquaredEvaluator: quite hacky, but makes testing much easier for now... #endregion public Dictionary DistinctSentencesComplexity { get; private set; } // Semantically distinct sentences and their length in a run. public HashSet ArchivedPhrases { get; private set; } internal SearchDataStore OpenPhrases { get; private set; } // Stack/Queue/etc. for fetching the next node in the search tree. #region execution stats public int AllGeneratedSentencesCount { get; private set; } public int OverwrittenSentencesCount { get; private set; } // It is not guaranteed that shorter solutions are found first. // When longer solutions are overwritten with shorter ones, // this counter is increased. public int PhraseExpansionCount { get; private set; } // Number, how many times a nonterminal symbol is replaced with a production rule. #endregion public Grammar Grammar { get; private set; } #region ctors public override IDeepCloneable Clone(Cloner cloner) { return new GrammarEnumerationAlgorithm(this, cloner); } public GrammarEnumerationAlgorithm() { Problem = new RegressionProblem() { ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.PolyTen(seed: 1234).GenerateRegressionData() }; Parameters.Add(new ValueParameter(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false))); Parameters.Add(new ValueParameter(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8))); Parameters.Add(new ValueParameter(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(12))); Parameters.Add(new ValueParameter(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(5000))); Parameters.Add(new ValueParameter>(SearchDataStructureParameterName, new EnumValue(StorageType.PriorityQueue))); var availableAnalyzers = new IGrammarEnumerationAnalyzer[] { new SearchGraphVisualizer(), new SentenceLogger(), new RSquaredEvaluator() }; Parameters.Add(new FixedValueParameter>( AnalyzersParameterName, new CheckedItemCollection(availableAnalyzers).AsReadOnly())); foreach (var analyzer in Analyzers) { Analyzers.SetItemCheckedState(analyzer, false); } Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is RSquaredEvaluator), true); //Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is SentenceLogger), true); } public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { } #endregion protected override void Run(CancellationToken cancellationToken) { #region init InitResults(); foreach (IGrammarEnumerationAnalyzer grammarEnumerationAnalyzer in Analyzers) { if (Analyzers.ItemChecked(grammarEnumerationAnalyzer)) { grammarEnumerationAnalyzer.Register(this); } else { grammarEnumerationAnalyzer.Deregister(this); } } Analyzers.OfType().First().OptimizeConstants = OptimizeConstants; ArchivedPhrases = new HashSet(); DistinctSentencesComplexity = new Dictionary(); AllGeneratedSentencesCount = 0; OverwrittenSentencesCount = 0; PhraseExpansionCount = 0; Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray()); OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy var phrase0 = new SymbolString(new[] { Grammar.StartSymbol }); var phrase0Hash = Grammar.Hasher.CalcHashCode(phrase0); #endregion int maxSentenceLength = GetMaxSentenceLength(); OpenPhrases.Store(phrase0Hash, 0.0, phrase0); while (OpenPhrases.Count > 0) { if (cancellationToken.IsCancellationRequested) break; StoredSymbolString fetchedPhrase = OpenPhrases.GetNext(); SymbolString currPhrase = fetchedPhrase.SymbolString; OnPhraseFetched(fetchedPhrase.Hash, currPhrase); ArchivedPhrases.Add(fetchedPhrase.Hash); // expand next nonterminal symbols int nonterminalSymbolIndex = currPhrase.NextNonterminalIndex(); NonterminalSymbol expandedSymbol = (NonterminalSymbol)currPhrase[nonterminalSymbolIndex]; var appliedProductions = Grammar.Productions[expandedSymbol]; for (int i = 0; i < appliedProductions.Count; i++) { PhraseExpansionCount++; SymbolString newPhrase = currPhrase.DerivePhrase(nonterminalSymbolIndex, appliedProductions[i]); int newPhraseComplexity = Grammar.GetComplexity(newPhrase); if (newPhraseComplexity <= MaxComplexity) { var phraseHash = Grammar.Hasher.CalcHashCode(newPhrase); OnPhraseDerived(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); if (newPhrase.IsSentence()) { AllGeneratedSentencesCount++; OnSentenceGenerated(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); // Is the best solution found? (only if RSquaredEvaluator is activated) if (Results.ContainsKey(RSquaredEvaluator.BestTrainingQualityResultName)) { double r2 = ((DoubleValue)Results[RSquaredEvaluator.BestTrainingQualityResultName].Value).Value; if (r2.IsAlmost(1.0)) { UpdateView(force: true); return; } } if (!DistinctSentencesComplexity.ContainsKey(phraseHash) || DistinctSentencesComplexity[phraseHash] > newPhraseComplexity) { if (DistinctSentencesComplexity.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only DistinctSentencesComplexity[phraseHash] = newPhraseComplexity; OnDistinctSentenceGenerated(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); } UpdateView(); } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) { double phrasePriority = GetPriority(newPhrase, maxSentenceLength); OpenPhrases.Store(phraseHash, phrasePriority, newPhrase); } } } } UpdateView(force: true); } protected double GetPriority(SymbolString phrase, int maxSentenceLength) { double relLength = (double)phrase.Count() / maxSentenceLength; double r2 = Grammar.EvaluatePhrase(phrase, Problem.ProblemData, OptimizeConstants); double error = 1.0 - r2; return relLength + ErrorWeight * error; } private int GetMaxSentenceLength() { SymbolString s = new SymbolString(Grammar.StartSymbol); while (Grammar.GetComplexity(s) <= MaxComplexity) { int expandedSymbolIndex = s.NextNonterminalIndex(); NonterminalSymbol expandedSymbol = (NonterminalSymbol)s[expandedSymbolIndex]; var productions = Grammar.Productions[expandedSymbol]; var longestProduction = productions // Find production with most terminal symbols to expand as much as possible... .OrderBy(CountTerminals) // but with lowest complexity/nonterminal count to keep complexity low. .ThenByDescending(CountNonTerminals) .First(); s = s.DerivePhrase(expandedSymbolIndex, longestProduction); } return s.Count(); } private int CountTerminals(Production p) { return p.Count(s => s is TerminalSymbol); } private int CountNonTerminals(Production p) { return p.Count(s => s is NonterminalSymbol); } #region Visualization in HL // Initialize entries in result set. private void InitResults() { Results.Add(new Result(GeneratedPhrasesName, new IntValue(0))); Results.Add(new Result(SearchStructureSizeName, new IntValue(0))); Results.Add(new Result(GeneratedSentencesName, new IntValue(0))); Results.Add(new Result(DistinctSentencesName, new IntValue(0))); Results.Add(new Result(PhraseExpansionsName, new IntValue(0))); Results.Add(new Result(OverwrittenSentencesName, new IntValue(0))); Results.Add(new Result(AverageSentenceComplexityName, new DoubleValue(1.0))); Results.Add(new Result(ExpansionsPerSecondName, "In Thousand expansions per second", new IntValue(0))); } // Update the view for intermediate results in an algorithm run. private int updates; private void UpdateView(bool force = false) { updates++; if (force || updates % GuiUpdateInterval == 1) { ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count; ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count; ((IntValue)Results[GeneratedSentencesName].Value).Value = AllGeneratedSentencesCount; ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentencesComplexity.Count; ((IntValue)Results[PhraseExpansionsName].Value).Value = PhraseExpansionCount; ((DoubleValue)Results[AverageSentenceComplexityName].Value).Value = DistinctSentencesComplexity.Select(pair => pair.Value).Average(); ((IntValue)Results[OverwrittenSentencesName].Value).Value = OverwrittenSentencesCount; ((IntValue)Results[ExpansionsPerSecondName].Value).Value = (int)((PhraseExpansionCount / ExecutionTime.TotalSeconds) / 1000.0); } } #endregion #region events public event EventHandler PhraseFetched; private void OnPhraseFetched(int hash, SymbolString symbolString) { if (PhraseFetched != null) { PhraseFetched(this, new PhraseEventArgs(hash, symbolString)); } } public event EventHandler PhraseDerived; private void OnPhraseDerived(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (PhraseDerived != null) { PhraseDerived(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } public event EventHandler SentenceGenerated; private void OnSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (SentenceGenerated != null) { SentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } public event EventHandler DistinctSentenceGenerated; private void OnDistinctSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (DistinctSentenceGenerated != null) { DistinctSentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } #endregion } #region events for analysis public class PhraseEventArgs : EventArgs { public int Hash { get; } public SymbolString Phrase { get; } public PhraseEventArgs(int hash, SymbolString phrase) { Hash = hash; Phrase = phrase; } } public class PhraseAddedEventArgs : EventArgs { public int ParentHash { get; } public int NewHash { get; } public SymbolString ParentPhrase { get; } public SymbolString NewPhrase { get; } public Symbol ExpandedSymbol { get; } public Production ExpandedProduction { get; } public PhraseAddedEventArgs(int parentHash, SymbolString parentPhrase, int newHash, SymbolString newPhrase, Symbol expandedSymbol, Production expandedProduction) { ParentHash = parentHash; ParentPhrase = parentPhrase; NewHash = newHash; NewPhrase = newPhrase; ExpandedSymbol = expandedSymbol; ExpandedProduction = expandedProduction; } } #endregion }