using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration; using HeuristicLab.Collections; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration { [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")] [StorableClass] [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)] public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm { #region properties and result names private readonly string VariableImportanceWeightName = "Variable Importance Weight"; private readonly string SearchStructureSizeName = "Search Structure Size"; private readonly string GeneratedPhrasesName = "Generated/Archived Phrases"; private readonly string GeneratedSentencesName = "Generated Sentences"; private readonly string DistinctSentencesName = "Distinct Sentences"; private readonly string PhraseExpansionsName = "Phrase Expansions"; private readonly string AverageSentenceComplexityName = "Avg. Sentence Complexity among Distinct"; private readonly string OverwrittenSentencesName = "Sentences overwritten"; private readonly string AnalyzersParameterName = "Analyzers"; private readonly string ExpansionsPerSecondName = "Expansions per second"; private readonly string OptimizeConstantsParameterName = "Optimize Constants"; private readonly string ErrorWeightParameterName = "Error Weight"; private readonly string SearchDataStructureParameterName = "Search Data Structure"; private readonly string MaxComplexityParameterName = "Max. Complexity"; private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval"; private readonly string GrammarSymbolsParameterName = "Grammar Symbols"; private readonly string SearchCacheSizeParameterName = "Search Cache Size"; private readonly string SearchDataStructureSizeParameterName = "Search Data Structure Size"; // result names public static readonly string BestTrainingModelResultName = "Best model (Training)"; public static readonly string BestTrainingSolutionResultName = "Best solution (Training)"; public static readonly string BestComplexityResultName = "Best solution complexity"; public override bool SupportsPause { get { return true; } } protected IFixedValueParameter VariableImportanceWeightParameter { get { return (IFixedValueParameter)Parameters[VariableImportanceWeightName]; } } protected double VariableImportanceWeight { get { return VariableImportanceWeightParameter.Value.Value; } } protected IFixedValueParameter OptimizeConstantsParameter { get { return (IFixedValueParameter)Parameters[OptimizeConstantsParameterName]; } } public bool OptimizeConstants { get { return OptimizeConstantsParameter.Value.Value; } set { OptimizeConstantsParameter.Value.Value = value; } } protected IFixedValueParameter MaxComplexityParameter { get { return (IFixedValueParameter)Parameters[MaxComplexityParameterName]; } } public int MaxComplexity { get { return MaxComplexityParameter.Value.Value; } set { MaxComplexityParameter.Value.Value = value; } } protected IFixedValueParameter ErrorWeightParameter { get { return (IFixedValueParameter)Parameters[ErrorWeightParameterName]; } } public double ErrorWeight { get { return ErrorWeightParameter.Value.Value; } set { ErrorWeightParameter.Value.Value = value; } } protected IFixedValueParameter GuiUpdateIntervalParameter { get { return (IFixedValueParameter)Parameters[GuiUpdateIntervalParameterName]; } } public int GuiUpdateInterval { get { return GuiUpdateIntervalParameter.Value.Value; } set { GuiUpdateIntervalParameter.Value.Value = value; } } protected IFixedValueParameter> SearchDataStructureParameter { get { return (IFixedValueParameter>)Parameters[SearchDataStructureParameterName]; } } public IFixedValueParameter SearchDataStructureSizeParameter { get { return (IFixedValueParameter)Parameters[SearchDataStructureSizeParameterName]; } } public int SearchDataStructureSize { get { return SearchDataStructureSizeParameter.Value.Value; } } public IFixedValueParameter SearchCacheSizeParameter { get { return (IFixedValueParameter)Parameters[SearchCacheSizeParameterName]; } } public int SearchCacheSize { get { return SearchCacheSizeParameter.Value.Value; } } public StorageType SearchDataStructure { get { return SearchDataStructureParameter.Value.Value; } set { SearchDataStructureParameter.Value.Value = value; } } public IFixedValueParameter> AnalyzersParameter { get { return (IFixedValueParameter>)Parameters[AnalyzersParameterName]; } } public ICheckedItemCollection Analyzers { get { return AnalyzersParameter.Value; } } public IFixedValueParameter>> GrammarSymbolsParameter { get { return (IFixedValueParameter>>)Parameters[GrammarSymbolsParameterName]; } } public ReadOnlyCheckedItemCollection> GrammarSymbols { get { return GrammarSymbolsParameter.Value; } } public SymbolString BestTrainingSentence { get; set; } // Currently set in RSquaredEvaluator: quite hacky, but makes testing much easier for now... #endregion [Storable] public Dictionary DistinctSentencesComplexity { get; private set; } // Semantically distinct sentences and their length in a run. [Storable] public HashSet ArchivedPhrases { get; private set; } [Storable] internal SearchDataStore OpenPhrases { get; private set; } // Stack/Queue/etc. for fetching the next node in the search tree. [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { variableImportance = CalculateVariableImportances(); } #region execution stats [Storable] public int AllGeneratedSentencesCount { get; private set; } [Storable] public int OverwrittenSentencesCount { get; private set; } // It is not guaranteed that shorter solutions are found first. // When longer solutions are overwritten with shorter ones, // this counter is increased. [Storable] public int PhraseExpansionCount { get; private set; } // Number, how many times a nonterminal symbol is replaced with a production rule. #endregion public Grammar Grammar { get; private set; } #region ctors public override IDeepCloneable Clone(Cloner cloner) { return new GrammarEnumerationAlgorithm(this, cloner); } public GrammarEnumerationAlgorithm() { Parameters.Add(new FixedValueParameter(VariableImportanceWeightName, "Variable Weight.", new DoubleValue(1.0))); Parameters.Add(new FixedValueParameter(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false))); Parameters.Add(new FixedValueParameter(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8))); Parameters.Add(new FixedValueParameter(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(12))); Parameters.Add(new FixedValueParameter(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(5000))); Parameters.Add(new FixedValueParameter(SearchCacheSizeParameterName, "The size of the search node cache.", new IntValue((int)1e5))); Parameters.Add(new FixedValueParameter(SearchDataStructureSizeParameterName, "The size of the search data structure.", new IntValue((int)1e5))); Parameters.Add(new FixedValueParameter>(SearchDataStructureParameterName, new EnumValue(StorageType.PriorityQueue))); SearchDataStructureParameter.Value.ValueChanged += (o, e) => Prepare(); SearchDataStructureSizeParameter.Value.ValueChanged += (o, e) => Prepare(); SearchCacheSizeParameter.Value.ValueChanged += (o, e) => Prepare(); var availableAnalyzers = new IGrammarEnumerationAnalyzer[] { new SearchGraphVisualizer(), new SentenceLogger(), new RSquaredEvaluator() }; Parameters.Add(new FixedValueParameter>( AnalyzersParameterName, new CheckedItemCollection(availableAnalyzers).AsReadOnly())); Analyzers.CheckedItemsChanged += Analyzers_CheckedItemsChanged; foreach (var analyzer in Analyzers) { Analyzers.SetItemCheckedState(analyzer, false); } Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is RSquaredEvaluator), true); var grammarSymbols = Enum.GetValues(typeof(GrammarRule)) .Cast() .Select(v => new EnumValue(v)); Parameters.Add(new FixedValueParameter>>( GrammarSymbolsParameterName, new ReadOnlyCheckedItemCollection>(new CheckedItemCollection>(grammarSymbols)) )); foreach (EnumValue grammarSymbol in GrammarSymbols) { GrammarSymbols.SetItemCheckedState(grammarSymbol, true); } // set a default problem Problem = new RegressionProblem() { ProblemData = new Problems.Instances.DataAnalysis.PolyTen(seed: 1234).GenerateRegressionData() }; } public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { foreach (var analyzer in Analyzers.CheckedItems) analyzer.Register(this); Analyzers.CheckedItemsChanged += Analyzers_CheckedItemsChanged; DistinctSentencesComplexity = new Dictionary(original.DistinctSentencesComplexity); ArchivedPhrases = new HashSet(original.ArchivedPhrases); OpenPhrases = cloner.Clone(original.OpenPhrases); Grammar = cloner.Clone(original.Grammar); AllGeneratedSentencesCount = original.AllGeneratedSentencesCount; OverwrittenSentencesCount = original.OverwrittenSentencesCount; PhraseExpansionCount = original.PhraseExpansionCount; if (original.variableImportance != null) variableImportance = new Dictionary(original.variableImportance); } #endregion private Dictionary variableImportance; public override void Prepare() { DistinctSentencesComplexity = new Dictionary(); ArchivedPhrases = new HashSet(); AllGeneratedSentencesCount = 0; OverwrittenSentencesCount = 0; PhraseExpansionCount = 0; Analyzers.OfType().First().OptimizeConstants = OptimizeConstants; Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray(), GrammarSymbols.CheckedItems.Select(v => v.Value)); OpenPhrases = new SearchDataStore(SearchDataStructure, SearchDataStructureSize, SearchCacheSize); // Select search strategy base.Prepare(); // this actually clears the results which will get reinitialized on Run() } private Dictionary CalculateVariableImportances() { variableImportance = new Dictionary(); RandomForestRegression rf = new RandomForestRegression(); rf.Problem = Problem; rf.Start(); IRegressionSolution rfSolution = (RandomForestRegressionSolution)rf.Results["Random forest regression solution"].Value; var rfImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts( rfSolution, RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training, RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle); // save the normalized importances var sum = rfImpacts.Sum(x => x.Item2); foreach (Tuple rfImpact in rfImpacts) { VariableTerminalSymbol varSym = Grammar.VarTerminals.First(v => v.StringRepresentation == rfImpact.Item1); variableImportance[varSym] = rfImpact.Item2 / sum; } return variableImportance; } protected override void Run(CancellationToken cancellationToken) { // do not reinitialize the algorithm if we're resuming from pause if (previousExecutionState != ExecutionState.Paused) { CalculateVariableImportances(); InitResults(); var phrase0 = new SymbolString(new[] { Grammar.StartSymbol }); var phrase0Hash = Grammar.Hasher.CalcHashCode(phrase0); OpenPhrases.Store(new SearchNode(phrase0Hash, 0.0, 0.0, phrase0)); } int maxSentenceLength = GetMaxSentenceLength(); var errorWeight = ErrorWeight; var variableImportanceWeight = VariableImportanceWeight; // main search loop while (OpenPhrases.Count > 0) { if (cancellationToken.IsCancellationRequested) break; SearchNode fetchedSearchNode = OpenPhrases.GetNext(); if (fetchedSearchNode == null) continue; SymbolString currPhrase = fetchedSearchNode.SymbolString; OnPhraseFetched(fetchedSearchNode.Hash, currPhrase); ArchivedPhrases.Add(fetchedSearchNode.Hash); // expand next nonterminal symbols int nonterminalSymbolIndex = currPhrase.NextNonterminalIndex(); NonterminalSymbol expandedSymbol = (NonterminalSymbol)currPhrase[nonterminalSymbolIndex]; var appliedProductions = Grammar.Productions[expandedSymbol]; for (int i = 0; i < appliedProductions.Count; i++) { PhraseExpansionCount++; SymbolString newPhrase = currPhrase.DerivePhrase(nonterminalSymbolIndex, appliedProductions[i]); int newPhraseComplexity = Grammar.GetComplexity(newPhrase); if (newPhraseComplexity > MaxComplexity) continue; var phraseHash = Grammar.Hasher.CalcHashCode(newPhrase); OnPhraseDerived(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); if (newPhrase.IsSentence()) { AllGeneratedSentencesCount++; OnSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); // Is the best solution found? (only if RSquaredEvaluator is activated) if (Results.ContainsKey(RSquaredEvaluator.BestTrainingQualityResultName)) { double r2 = ((DoubleValue)Results[RSquaredEvaluator.BestTrainingQualityResultName].Value).Value; if (r2.IsAlmost(1.0)) { UpdateView(force: true); return; } } if (!DistinctSentencesComplexity.ContainsKey(phraseHash) || DistinctSentencesComplexity[phraseHash] > newPhraseComplexity) { if (DistinctSentencesComplexity.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only DistinctSentencesComplexity[phraseHash] = newPhraseComplexity; OnDistinctSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]); } UpdateView(); } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) { double r2 = GetR2(newPhrase, fetchedSearchNode.R2); double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength, errorWeight, variableImportanceWeight); SearchNode newSearchNode = new SearchNode(phraseHash, phrasePriority, r2, newPhrase); OpenPhrases.Store(newSearchNode); } } } UpdateView(force: true); } protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength, double errorWeight, double variableImportanceWeight) { var distinctVars = phrase.OfType().Distinct(); var sum = 0d; foreach (var v in distinctVars) { sum += variableImportance[v]; } var phraseVariableImportance = 1 - sum; double relLength = (double)phrase.Count() / maxSentenceLength; double error = 1.0 - r2; return relLength + errorWeight * error + variableImportanceWeight * phraseVariableImportance; } private double GetR2(SymbolString phrase, double parentR2) { int length = phrase.Count(); // If the only nonterminal symbol is Expr, we can need to evaluate the sentence. Otherwise // the phrase has the same r2 as its parent, from which it was derived. for (int i = 0; i < length; i++) { if (phrase[i] is NonterminalSymbol && phrase[i] != Grammar.Expr) { return parentR2; } } return Grammar.EvaluatePhrase(phrase, Problem.ProblemData, OptimizeConstants); } private int GetMaxSentenceLength() { SymbolString s = new SymbolString(Grammar.StartSymbol); while (!s.IsSentence() && Grammar.GetComplexity(s) <= MaxComplexity) { int expandedSymbolIndex = s.NextNonterminalIndex(); NonterminalSymbol expandedSymbol = (NonterminalSymbol)s[expandedSymbolIndex]; var productions = Grammar.Productions[expandedSymbol]; var longestProduction = productions // Find production with most terminal symbols to expand as much as possible... .OrderBy(CountTerminals) // but with lowest complexity/nonterminal count to keep complexity low. .ThenByDescending(CountNonTerminals) .First(); s = s.DerivePhrase(expandedSymbolIndex, longestProduction); } return s.Count(); } private int CountTerminals(Production p) { return p.Count(s => s is TerminalSymbol); } private int CountNonTerminals(Production p) { return p.Count(s => s is NonterminalSymbol); } #region pause support private ExecutionState previousExecutionState; protected override void OnPaused() { previousExecutionState = this.ExecutionState; base.OnPaused(); } protected override void OnPrepared() { previousExecutionState = this.ExecutionState; base.OnPrepared(); } protected override void OnStarted() { previousExecutionState = this.ExecutionState; base.OnStarted(); } protected override void OnStopped() { previousExecutionState = this.ExecutionState; if (BestTrainingSentence == null) { base.OnStopped(); return; } var tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence); var model = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter()); model.Scale(Problem.ProblemData); var bestTrainingSolution = new SymbolicRegressionSolution(model, Problem.ProblemData); Results.AddOrUpdateResult(BestTrainingModelResultName, model); Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution); Results.AddOrUpdateResult(BestComplexityResultName, new IntValue(Grammar.GetComplexity(BestTrainingSentence))); base.OnStopped(); } #endregion #region Visualization in HL // Initialize entries in result set. private void InitResults() { Results.Clear(); Results.Add(new Result(GeneratedPhrasesName, new IntValue(0))); Results.Add(new Result(SearchStructureSizeName, new IntValue(0))); Results.Add(new Result(GeneratedSentencesName, new IntValue(0))); Results.Add(new Result(DistinctSentencesName, new IntValue(0))); Results.Add(new Result(PhraseExpansionsName, new IntValue(0))); Results.Add(new Result(OverwrittenSentencesName, new IntValue(0))); Results.Add(new Result(AverageSentenceComplexityName, new DoubleValue(1.0))); Results.Add(new Result(ExpansionsPerSecondName, "In Thousand expansions per second", new IntValue(0))); } // Update the view for intermediate results in an algorithm run. private int updates; private void UpdateView(bool force = false) { updates++; if (force || updates % GuiUpdateInterval == 1) { ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count; ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count; ((IntValue)Results[GeneratedSentencesName].Value).Value = AllGeneratedSentencesCount; ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentencesComplexity.Count; ((IntValue)Results[PhraseExpansionsName].Value).Value = PhraseExpansionCount; ((DoubleValue)Results[AverageSentenceComplexityName].Value).Value = DistinctSentencesComplexity.Count > 0 ? DistinctSentencesComplexity.Select(pair => pair.Value).Average() : 0; ((IntValue)Results[OverwrittenSentencesName].Value).Value = OverwrittenSentencesCount; ((IntValue)Results[ExpansionsPerSecondName].Value).Value = (int)((PhraseExpansionCount / ExecutionTime.TotalSeconds) / 1000.0); } } #endregion #region events // private event handlers for analyzers private void Analyzers_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs args) { // newly added items foreach (var item in args.Items.Except(args.OldItems).Union(args.OldItems.Except(args.Items))) { if (Analyzers.ItemChecked(item)) { item.Register(this); } else { item.Deregister(this); } } } public event EventHandler PhraseFetched; private void OnPhraseFetched(int hash, SymbolString symbolString) { if (PhraseFetched != null) { PhraseFetched(this, new PhraseEventArgs(hash, symbolString)); } } public event EventHandler PhraseDerived; private void OnPhraseDerived(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (PhraseDerived != null) { PhraseDerived(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } public event EventHandler SentenceGenerated; private void OnSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (SentenceGenerated != null) { SentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } public event EventHandler DistinctSentenceGenerated; private void OnDistinctSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) { if (DistinctSentenceGenerated != null) { DistinctSentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction)); } } #endregion } #region events for analysis public class PhraseEventArgs : EventArgs { public int Hash { get; } public SymbolString Phrase { get; } public PhraseEventArgs(int hash, SymbolString phrase) { Hash = hash; Phrase = phrase; } } public class PhraseAddedEventArgs : EventArgs { public int ParentHash { get; } public int NewHash { get; } public SymbolString ParentPhrase { get; } public SymbolString NewPhrase { get; } public Symbol ExpandedSymbol { get; } public Production ExpandedProduction { get; } public PhraseAddedEventArgs(int parentHash, SymbolString parentPhrase, int newHash, SymbolString newPhrase, Symbol expandedSymbol, Production expandedProduction) { ParentHash = parentHash; ParentPhrase = parentPhrase; NewHash = newHash; NewPhrase = newPhrase; ExpandedSymbol = expandedSymbol; ExpandedProduction = expandedProduction; } } #endregion }