Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15930

Last change on this file since 15930 was 15930, checked in by lkammere, 6 years ago

#2886: Make grammar more configurable in grammar enumeration.

File size: 18.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
6using HeuristicLab.Collections;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Optimization;
11using HeuristicLab.Parameters;
12using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
13using HeuristicLab.Problems.DataAnalysis;
14
15namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
16  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
17  [StorableClass]
18  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
19  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
20    #region properties and result names
21    private readonly string SearchStructureSizeName = "Search Structure Size";
22    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
23    private readonly string GeneratedSentencesName = "Generated Sentences";
24    private readonly string DistinctSentencesName = "Distinct Sentences";
25    private readonly string PhraseExpansionsName = "Phrase Expansions";
26    private readonly string AverageSentenceComplexityName = "Avg. Sentence Complexity among Distinct";
27    private readonly string OverwrittenSentencesName = "Sentences overwritten";
28    private readonly string AnalyzersParameterName = "Analyzers";
29    private readonly string ExpansionsPerSecondName = "Expansions per second";
30
31
32    private readonly string OptimizeConstantsParameterName = "Optimize Constants";
33    private readonly string ErrorWeightParameterName = "Error Weight";
34    private readonly string SearchDataStructureParameterName = "Search Data Structure";
35    private readonly string MaxComplexityParameterName = "Max. Complexity";
36    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
37    private readonly string GrammarSymbolsParameterName = "Grammar Symbols";
38
39    public override bool SupportsPause { get { return false; } }
40
41    protected IValueParameter<BoolValue> OptimizeConstantsParameter {
42      get { return (IValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
43    }
44
45    public bool OptimizeConstants {
46      get { return OptimizeConstantsParameter.Value.Value; }
47      set { OptimizeConstantsParameter.Value.Value = value; }
48    }
49
50    protected IValueParameter<IntValue> MaxComplexityParameter {
51      get { return (IValueParameter<IntValue>)Parameters[MaxComplexityParameterName]; }
52    }
53    public int MaxComplexity {
54      get { return MaxComplexityParameter.Value.Value; }
55      set { MaxComplexityParameter.Value.Value = value; }
56    }
57
58    protected IValueParameter<DoubleValue> ErrorWeightParameter {
59      get { return (IValueParameter<DoubleValue>)Parameters[ErrorWeightParameterName]; }
60    }
61    public double ErrorWeight {
62      get { return ErrorWeightParameter.Value.Value; }
63      set { ErrorWeightParameter.Value.Value = value; }
64    }
65
66    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
67      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
68    }
69    public int GuiUpdateInterval {
70      get { return GuiUpdateIntervalParameter.Value.Value; }
71      set { GuiUpdateIntervalParameter.Value.Value = value; }
72    }
73
74    protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
75      get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
76    }
77    public StorageType SearchDataStructure {
78      get { return SearchDataStructureParameter.Value.Value; }
79      set { SearchDataStructureParameter.Value.Value = value; }
80    }
81
82    public IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>> AnalyzersParameter {
83      get { return (IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>)Parameters[AnalyzersParameterName]; }
84    }
85
86    public ICheckedItemCollection<IGrammarEnumerationAnalyzer> Analyzers {
87      get { return AnalyzersParameter.Value; }
88    }
89
90    public IFixedValueParameter<ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>>> GrammarSymbolsParameter {
91      get { return (IFixedValueParameter<ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>>>)Parameters[GrammarSymbolsParameterName]; }
92    }
93
94    public ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>> GrammarSymbols {
95      get { return GrammarSymbolsParameter.Value; }
96    }
97
98    public SymbolString BestTrainingSentence { get; set; }     // Currently set in RSquaredEvaluator: quite hacky, but makes testing much easier for now...
99    #endregion
100
101    public Dictionary<int, int> DistinctSentencesComplexity { get; private set; }  // Semantically distinct sentences and their length in a run.
102    public HashSet<int> ArchivedPhrases { get; private set; }
103    internal SearchDataStore OpenPhrases { get; private set; }           // Stack/Queue/etc. for fetching the next node in the search tree. 
104
105    #region execution stats
106    public int AllGeneratedSentencesCount { get; private set; }
107
108    public int OverwrittenSentencesCount { get; private set; } // It is not guaranteed that shorter solutions are found first.
109                                                               // When longer solutions are overwritten with shorter ones,
110                                                               // this counter is increased.
111    public int PhraseExpansionCount { get; private set; }      // Number, how many times a nonterminal symbol is replaced with a production rule.
112    #endregion
113
114    public Grammar Grammar { get; private set; }
115
116
117    #region ctors
118    public override IDeepCloneable Clone(Cloner cloner) {
119      return new GrammarEnumerationAlgorithm(this, cloner);
120    }
121
122    public GrammarEnumerationAlgorithm() {
123      Problem = new RegressionProblem() {
124        ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.PolyTen(seed: 1234).GenerateRegressionData()
125      };
126
127      Parameters.Add(new ValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false)));
128      Parameters.Add(new ValueParameter<DoubleValue>(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8)));
129      Parameters.Add(new ValueParameter<IntValue>(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(12)));
130      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(5000)));
131      Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.PriorityQueue)));
132
133      var availableAnalyzers = new IGrammarEnumerationAnalyzer[] {
134        new SearchGraphVisualizer(),
135        new SentenceLogger(),
136        new RSquaredEvaluator()
137      };
138      Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>(
139        AnalyzersParameterName,
140        new CheckedItemCollection<IGrammarEnumerationAnalyzer>(availableAnalyzers).AsReadOnly()));
141
142      foreach (var analyzer in Analyzers) {
143        Analyzers.SetItemCheckedState(analyzer, false);
144      }
145      Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is RSquaredEvaluator), true);
146      //Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is SentenceLogger), true);
147
148      var grammarSymbols = Enum.GetValues(typeof(GrammarRule))
149        .Cast<GrammarRule>()
150        .Select(v => new EnumValue<GrammarRule>(v));
151
152      Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>>>(
153        GrammarSymbolsParameterName,
154        new ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>>(new CheckedItemCollection<EnumValue<GrammarRule>>(grammarSymbols))
155      ));
156      foreach (EnumValue<GrammarRule> grammarSymbol in GrammarSymbols) {
157        GrammarSymbols.SetItemCheckedState(grammarSymbol, true);
158      }
159    }
160
161    public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) {
162    }
163    #endregion
164
165    protected override void Run(CancellationToken cancellationToken) {
166      #region init
167      InitResults();
168
169      foreach (IGrammarEnumerationAnalyzer grammarEnumerationAnalyzer in Analyzers) {
170        if (Analyzers.ItemChecked(grammarEnumerationAnalyzer)) {
171          grammarEnumerationAnalyzer.Register(this);
172        } else {
173          grammarEnumerationAnalyzer.Deregister(this);
174        }
175      }
176
177      Analyzers.OfType<RSquaredEvaluator>().First().OptimizeConstants = OptimizeConstants;
178
179      ArchivedPhrases = new HashSet<int>();
180
181      DistinctSentencesComplexity = new Dictionary<int, int>();
182      AllGeneratedSentencesCount = 0;
183      OverwrittenSentencesCount = 0;
184      PhraseExpansionCount = 0;
185
186      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray(), GrammarSymbols.CheckedItems.Select(v => v.Value));
187
188      OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
189      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
190      var phrase0Hash = Grammar.Hasher.CalcHashCode(phrase0);
191      #endregion
192
193      int maxSentenceLength = GetMaxSentenceLength();
194
195      OpenPhrases.Store(new SearchNode(phrase0Hash, 0.0, 0.0, phrase0));
196      while (OpenPhrases.Count > 0) {
197        if (cancellationToken.IsCancellationRequested) break;
198
199        SearchNode fetchedSearchNode = OpenPhrases.GetNext();
200        SymbolString currPhrase = fetchedSearchNode.SymbolString;
201
202        OnPhraseFetched(fetchedSearchNode.Hash, currPhrase);
203
204        ArchivedPhrases.Add(fetchedSearchNode.Hash);
205
206        // expand next nonterminal symbols
207        int nonterminalSymbolIndex = currPhrase.NextNonterminalIndex();
208        NonterminalSymbol expandedSymbol = (NonterminalSymbol)currPhrase[nonterminalSymbolIndex];
209        var appliedProductions = Grammar.Productions[expandedSymbol];
210
211        for (int i = 0; i < appliedProductions.Count; i++) {
212          PhraseExpansionCount++;
213
214          SymbolString newPhrase = currPhrase.DerivePhrase(nonterminalSymbolIndex, appliedProductions[i]);
215          int newPhraseComplexity = Grammar.GetComplexity(newPhrase);
216
217          if (newPhraseComplexity <= MaxComplexity) {
218            var phraseHash = Grammar.Hasher.CalcHashCode(newPhrase);
219
220            OnPhraseDerived(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
221
222            if (newPhrase.IsSentence()) {
223              AllGeneratedSentencesCount++;
224
225              OnSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
226
227              // Is the best solution found? (only if RSquaredEvaluator is activated)
228              if (Results.ContainsKey(RSquaredEvaluator.BestTrainingQualityResultName)) {
229                double r2 = ((DoubleValue)Results[RSquaredEvaluator.BestTrainingQualityResultName].Value).Value;
230                if (r2.IsAlmost(1.0)) {
231                  UpdateView(force: true);
232                  return;
233                }
234              }
235
236              if (!DistinctSentencesComplexity.ContainsKey(phraseHash) || DistinctSentencesComplexity[phraseHash] > newPhraseComplexity) {
237                if (DistinctSentencesComplexity.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only
238
239                DistinctSentencesComplexity[phraseHash] = newPhraseComplexity;
240                OnDistinctSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
241              }
242              UpdateView();
243
244            } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
245
246              double r2 = GetR2(newPhrase, fetchedSearchNode.R2);
247              double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength);
248
249              SearchNode newSearchNode = new SearchNode(phraseHash, phrasePriority, r2, newPhrase);
250              OpenPhrases.Store(newSearchNode);
251            }
252          }
253        }
254      }
255      UpdateView(force: true);
256    }
257
258    protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength) {
259      double relLength = (double)phrase.Count() / maxSentenceLength;
260      double error = 1.0 - r2;
261
262      return relLength + ErrorWeight * error;
263    }
264
265    private double GetR2(SymbolString phrase, double parentR2) {
266      int length = phrase.Count();
267
268      // If the only nonterminal symbol is Expr, we can need to evaluate the sentence. Otherwise
269      // the phrase has the same r2 as its parent, from which it was derived.
270      for (int i = 0; i < length; i++) {
271        if (phrase[i] is NonterminalSymbol && phrase[i] != Grammar.Expr) {
272          return parentR2;
273        }
274      }
275
276      return Grammar.EvaluatePhrase(phrase, Problem.ProblemData, OptimizeConstants);
277    }
278
279    private int GetMaxSentenceLength() {
280      SymbolString s = new SymbolString(Grammar.StartSymbol);
281
282      while (!s.IsSentence() && Grammar.GetComplexity(s) <= MaxComplexity) {
283        int expandedSymbolIndex = s.NextNonterminalIndex();
284        NonterminalSymbol expandedSymbol = (NonterminalSymbol)s[expandedSymbolIndex];
285
286        var productions = Grammar.Productions[expandedSymbol];
287        var longestProduction = productions // Find production with most terminal symbols to expand as much as possible...
288          .OrderBy(CountTerminals)          // but with lowest complexity/nonterminal count to keep complexity low.                                                                                     
289          .ThenByDescending(CountNonTerminals)
290          .First();
291
292        s = s.DerivePhrase(expandedSymbolIndex, longestProduction);
293      }
294
295      return s.Count();
296    }
297
298    private int CountTerminals(Production p) {
299      return p.Count(s => s is TerminalSymbol);
300    }
301
302    private int CountNonTerminals(Production p) {
303      return p.Count(s => s is NonterminalSymbol);
304    }
305
306    #region Visualization in HL
307    // Initialize entries in result set.
308    private void InitResults() {
309      Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
310      Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
311      Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
312      Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
313      Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
314      Results.Add(new Result(OverwrittenSentencesName, new IntValue(0)));
315      Results.Add(new Result(AverageSentenceComplexityName, new DoubleValue(1.0)));
316      Results.Add(new Result(ExpansionsPerSecondName, "In Thousand expansions per second", new IntValue(0)));
317    }
318
319    // Update the view for intermediate results in an algorithm run.
320    private int updates;
321    private void UpdateView(bool force = false) {
322      updates++;
323
324      if (force || updates % GuiUpdateInterval == 1) {
325        ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
326        ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
327        ((IntValue)Results[GeneratedSentencesName].Value).Value = AllGeneratedSentencesCount;
328        ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentencesComplexity.Count;
329        ((IntValue)Results[PhraseExpansionsName].Value).Value = PhraseExpansionCount;
330        ((DoubleValue)Results[AverageSentenceComplexityName].Value).Value = DistinctSentencesComplexity.Select(pair => pair.Value).Average();
331        ((IntValue)Results[OverwrittenSentencesName].Value).Value = OverwrittenSentencesCount;
332        ((IntValue)Results[ExpansionsPerSecondName].Value).Value = (int)((PhraseExpansionCount /
333                                                                          ExecutionTime.TotalSeconds) / 1000.0);
334      }
335    }
336    #endregion
337
338    #region events
339    public event EventHandler<PhraseEventArgs> PhraseFetched;
340    private void OnPhraseFetched(int hash, SymbolString symbolString) {
341      if (PhraseFetched != null) {
342        PhraseFetched(this, new PhraseEventArgs(hash, symbolString));
343      }
344    }
345
346    public event EventHandler<PhraseAddedEventArgs> PhraseDerived;
347    private void OnPhraseDerived(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
348      if (PhraseDerived != null) {
349        PhraseDerived(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
350      }
351    }
352
353    public event EventHandler<PhraseAddedEventArgs> SentenceGenerated;
354    private void OnSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
355      if (SentenceGenerated != null) {
356        SentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
357      }
358    }
359
360    public event EventHandler<PhraseAddedEventArgs> DistinctSentenceGenerated;
361    private void OnDistinctSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
362      if (DistinctSentenceGenerated != null) {
363        DistinctSentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
364      }
365    }
366
367    #endregion
368
369  }
370
371  #region events for analysis
372
373  public class PhraseEventArgs : EventArgs {
374    public int Hash { get; }
375
376    public SymbolString Phrase { get; }
377
378    public PhraseEventArgs(int hash, SymbolString phrase) {
379      Hash = hash;
380      Phrase = phrase;
381    }
382  }
383
384  public class PhraseAddedEventArgs : EventArgs {
385    public int ParentHash { get; }
386    public int NewHash { get; }
387
388    public SymbolString ParentPhrase { get; }
389    public SymbolString NewPhrase { get; }
390
391    public Symbol ExpandedSymbol { get; }
392
393    public Production ExpandedProduction { get; }
394
395    public PhraseAddedEventArgs(int parentHash, SymbolString parentPhrase, int newHash, SymbolString newPhrase, Symbol expandedSymbol, Production expandedProduction) {
396      ParentHash = parentHash;
397      ParentPhrase = parentPhrase;
398      NewHash = newHash;
399      NewPhrase = newPhrase;
400      ExpandedSymbol = expandedSymbol;
401      ExpandedProduction = expandedProduction;
402    }
403  }
404
405  #endregion
406}
Note: See TracBrowser for help on using the repository browser.