Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15915

Last change on this file since 15915 was 15915, checked in by lkammere, 6 years ago

#2886: Add separate data structure for storing phrases in the queue.

File size: 17.8 KB
RevLine 
[15765]1using System;
2using System.Collections.Generic;
[15712]3using System.Linq;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
[15821]6using HeuristicLab.Collections;
[15712]7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Optimization;
[15722]11using HeuristicLab.Parameters;
[15712]12using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
13using HeuristicLab.Problems.DataAnalysis;
14
15namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
16  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
17  [StorableClass]
18  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
19  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
[15746]20    #region properties and result names
[15803]21    private readonly string SearchStructureSizeName = "Search Structure Size";
22    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
[15746]23    private readonly string GeneratedSentencesName = "Generated Sentences";
24    private readonly string DistinctSentencesName = "Distinct Sentences";
25    private readonly string PhraseExpansionsName = "Phrase Expansions";
[15860]26    private readonly string AverageSentenceComplexityName = "Avg. Sentence Complexity among Distinct";
[15821]27    private readonly string OverwrittenSentencesName = "Sentences overwritten";
28    private readonly string AnalyzersParameterName = "Analyzers";
[15824]29    private readonly string ExpansionsPerSecondName = "Expansions per second";
[15712]30
[15746]31
[15861]32    private readonly string OptimizeConstantsParameterName = "Optimize Constants";
[15910]33    private readonly string ErrorWeightParameterName = "Error Weight";
[15746]34    private readonly string SearchDataStructureParameterName = "Search Data Structure";
[15860]35    private readonly string MaxComplexityParameterName = "Max. Complexity";
[15722]36    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
[15765]37
[15746]38    public override bool SupportsPause { get { return false; } }
[15712]39
[15861]40    protected IValueParameter<BoolValue> OptimizeConstantsParameter {
41      get { return (IValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
42    }
43
44    public bool OptimizeConstants {
45      get { return OptimizeConstantsParameter.Value.Value; }
46      set { OptimizeConstantsParameter.Value.Value = value; }
47    }
48
[15860]49    protected IValueParameter<IntValue> MaxComplexityParameter {
50      get { return (IValueParameter<IntValue>)Parameters[MaxComplexityParameterName]; }
[15712]51    }
[15860]52    public int MaxComplexity {
53      get { return MaxComplexityParameter.Value.Value; }
54      set { MaxComplexityParameter.Value.Value = value; }
[15722]55    }
[15712]56
[15910]57    protected IValueParameter<DoubleValue> ErrorWeightParameter {
58      get { return (IValueParameter<DoubleValue>)Parameters[ErrorWeightParameterName]; }
59    }
60    public double ErrorWeight {
61      get { return ErrorWeightParameter.Value.Value; }
62      set { ErrorWeightParameter.Value.Value = value; }
63    }
64
[15723]65    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
66      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
[15722]67    }
68    public int GuiUpdateInterval {
69      get { return GuiUpdateIntervalParameter.Value.Value; }
[15723]70      set { GuiUpdateIntervalParameter.Value.Value = value; }
[15722]71    }
[15712]72
[15746]73    protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
74      get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
[15723]75    }
[15746]76    public StorageType SearchDataStructure {
77      get { return SearchDataStructureParameter.Value.Value; }
78      set { SearchDataStructureParameter.Value.Value = value; }
[15723]79    }
80
[15821]81    public IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>> AnalyzersParameter {
82      get { return (IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>)Parameters[AnalyzersParameterName]; }
83    }
84
85    public ICheckedItemCollection<IGrammarEnumerationAnalyzer> Analyzers {
86      get { return AnalyzersParameter.Value; }
87    }
88
[15824]89    public SymbolString BestTrainingSentence { get; set; }     // Currently set in RSquaredEvaluator: quite hacky, but makes testing much easier for now...
[15722]90    #endregion
[15712]91
[15860]92    public Dictionary<int, int> DistinctSentencesComplexity { get; private set; }  // Semantically distinct sentences and their length in a run.
[15812]93    public HashSet<int> ArchivedPhrases { get; private set; }
[15821]94    internal SearchDataStore OpenPhrases { get; private set; }           // Stack/Queue/etc. for fetching the next node in the search tree. 
[15812]95
[15821]96    #region execution stats
97    public int AllGeneratedSentencesCount { get; private set; }
[15746]98
[15821]99    public int OverwrittenSentencesCount { get; private set; } // It is not guaranteed that shorter solutions are found first.
100                                                               // When longer solutions are overwritten with shorter ones,
101                                                               // this counter is increased.
102    public int PhraseExpansionCount { get; private set; }      // Number, how many times a nonterminal symbol is replaced with a production rule.
103    #endregion
104
[15800]105    public Grammar Grammar { get; private set; }
[15712]106
[15765]107
[15722]108    #region ctors
109    public override IDeepCloneable Clone(Cloner cloner) {
110      return new GrammarEnumerationAlgorithm(this, cloner);
111    }
[15712]112
[15722]113    public GrammarEnumerationAlgorithm() {
[15723]114      Problem = new RegressionProblem() {
[15910]115        ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.PolyTen(seed: 1234).GenerateRegressionData()
[15723]116      };
117
[15910]118      Parameters.Add(new ValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false)));
119      Parameters.Add(new ValueParameter<DoubleValue>(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8)));
120      Parameters.Add(new ValueParameter<IntValue>(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(12)));
121      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(5000)));
122      Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.PriorityQueue)));
[15821]123
124      var availableAnalyzers = new IGrammarEnumerationAnalyzer[] {
125        new SearchGraphVisualizer(),
[15824]126        new SentenceLogger(),
127        new RSquaredEvaluator()
[15821]128      };
129      Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>(
130        AnalyzersParameterName,
131        new CheckedItemCollection<IGrammarEnumerationAnalyzer>(availableAnalyzers).AsReadOnly()));
132
133      foreach (var analyzer in Analyzers) {
134        Analyzers.SetItemCheckedState(analyzer, false);
135      }
[15824]136      Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is RSquaredEvaluator), true);
[15910]137      //Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is SentenceLogger), true);
[15722]138    }
[15712]139
[15910]140    public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) {
141
142
143    }
[15722]144    #endregion
[15712]145
[15722]146    protected override void Run(CancellationToken cancellationToken) {
[15746]147      #region init
148      InitResults();
[15723]149
[15910]150      foreach (IGrammarEnumerationAnalyzer grammarEnumerationAnalyzer in Analyzers) {
151        if (Analyzers.ItemChecked(grammarEnumerationAnalyzer)) {
152          grammarEnumerationAnalyzer.Register(this);
153        } else {
154          grammarEnumerationAnalyzer.Deregister(this);
155        }
156      }
157
[15861]158      Analyzers.OfType<RSquaredEvaluator>().First().OptimizeConstants = OptimizeConstants;
159
[15812]160      ArchivedPhrases = new HashSet<int>();
161
[15860]162      DistinctSentencesComplexity = new Dictionary<int, int>();
[15821]163      AllGeneratedSentencesCount = 0;
164      OverwrittenSentencesCount = 0;
165      PhraseExpansionCount = 0;
[15746]166
[15724]167      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
[15712]168
[15746]169      OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
[15734]170      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
[15832]171      var phrase0Hash = Grammar.Hasher.CalcHashCode(phrase0);
[15746]172      #endregion
[15712]173
[15910]174      int maxSentenceLength = GetMaxSentenceLength();
175
[15915]176      OpenPhrases.Store(new SearchNode(phrase0Hash, 0.0, 0.0, phrase0));
[15821]177      while (OpenPhrases.Count > 0) {
178        if (cancellationToken.IsCancellationRequested) break;
[15746]179
[15915]180        SearchNode fetchedSearchNode = OpenPhrases.GetNext();
181        SymbolString currPhrase = fetchedSearchNode.SymbolString;
[15722]182
[15915]183        OnPhraseFetched(fetchedSearchNode.Hash, currPhrase);
[15765]184
[15915]185        ArchivedPhrases.Add(fetchedSearchNode.Hash);
[15726]186
[15821]187        // expand next nonterminal symbols
[15827]188        int nonterminalSymbolIndex = currPhrase.NextNonterminalIndex();
189        NonterminalSymbol expandedSymbol = (NonterminalSymbol)currPhrase[nonterminalSymbolIndex];
[15834]190        var appliedProductions = Grammar.Productions[expandedSymbol];
[15734]191
[15827]192        for (int i = 0; i < appliedProductions.Count; i++) {
[15821]193          PhraseExpansionCount++;
[15734]194
[15827]195          SymbolString newPhrase = currPhrase.DerivePhrase(nonterminalSymbolIndex, appliedProductions[i]);
[15860]196          int newPhraseComplexity = Grammar.GetComplexity(newPhrase);
[15712]197
[15860]198          if (newPhraseComplexity <= MaxComplexity) {
[15832]199            var phraseHash = Grammar.Hasher.CalcHashCode(newPhrase);
[15765]200
[15915]201            OnPhraseDerived(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
[15800]202
[15821]203            if (newPhrase.IsSentence()) {
204              AllGeneratedSentencesCount++;
205
[15915]206              OnSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
[15821]207
[15883]208              // Is the best solution found? (only if RSquaredEvaluator is activated)
[15907]209              if (Results.ContainsKey(RSquaredEvaluator.BestTrainingQualityResultName)) {
210                double r2 = ((DoubleValue)Results[RSquaredEvaluator.BestTrainingQualityResultName].Value).Value;
211                if (r2.IsAlmost(1.0)) {
212                  UpdateView(force: true);
213                  return;
214                }
[15883]215              }
216
[15860]217              if (!DistinctSentencesComplexity.ContainsKey(phraseHash) || DistinctSentencesComplexity[phraseHash] > newPhraseComplexity) {
218                if (DistinctSentencesComplexity.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only
[15821]219
[15860]220                DistinctSentencesComplexity[phraseHash] = newPhraseComplexity;
[15915]221                OnDistinctSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
[15746]222              }
[15821]223              UpdateView();
224
225            } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
[15915]226
227              double r2 = GetR2(newPhrase, fetchedSearchNode.R2);
228              double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength);
229
230              SearchNode newSearchNode = new SearchNode(phraseHash, phrasePriority, r2, newPhrase);
231              OpenPhrases.Store(newSearchNode);
[15712]232            }
233          }
234        }
235      }
[15812]236      UpdateView(force: true);
[15746]237    }
[15723]238
[15915]239    protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength) {
[15910]240      double relLength = (double)phrase.Count() / maxSentenceLength;
[15907]241      double error = 1.0 - r2;
242
[15910]243      return relLength + ErrorWeight * error;
244    }
245
[15915]246    private double GetR2(SymbolString phrase, double parentR2) {
247      int length = phrase.Count();
248
249      // If the only nonterminal symbol is Expr, we can need to evaluate the sentence. Otherwise
250      // the phrase has the same r2 as its parent, from which it was derived.
251      for (int i = 0; i < length; i++) {
252        if (phrase[i] is NonterminalSymbol && phrase[i] != Grammar.Expr) {
253          return parentR2;
254        }
255      }
256
257      return Grammar.EvaluatePhrase(phrase, Problem.ProblemData, OptimizeConstants);
258    }
259
[15910]260    private int GetMaxSentenceLength() {
261      SymbolString s = new SymbolString(Grammar.StartSymbol);
262
263      while (Grammar.GetComplexity(s) <= MaxComplexity) {
264        int expandedSymbolIndex = s.NextNonterminalIndex();
265        NonterminalSymbol expandedSymbol = (NonterminalSymbol)s[expandedSymbolIndex];
266
267        var productions = Grammar.Productions[expandedSymbol];
268        var longestProduction = productions // Find production with most terminal symbols to expand as much as possible...
269          .OrderBy(CountTerminals)          // but with lowest complexity/nonterminal count to keep complexity low.                                                                                     
270          .ThenByDescending(CountNonTerminals)
271          .First();
272
273        s = s.DerivePhrase(expandedSymbolIndex, longestProduction);
[15907]274      }
275
[15910]276      return s.Count();
277    }
[15907]278
[15910]279    private int CountTerminals(Production p) {
280      return p.Count(s => s is TerminalSymbol);
[15883]281    }
282
[15910]283    private int CountNonTerminals(Production p) {
284      return p.Count(s => s is NonterminalSymbol);
285    }
286
[15821]287    #region Visualization in HL
[15746]288    // Initialize entries in result set.
289    private void InitResults() {
[15803]290      Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
291      Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
[15746]292      Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
293      Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
294      Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
[15821]295      Results.Add(new Result(OverwrittenSentencesName, new IntValue(0)));
[15860]296      Results.Add(new Result(AverageSentenceComplexityName, new DoubleValue(1.0)));
[15824]297      Results.Add(new Result(ExpansionsPerSecondName, "In Thousand expansions per second", new IntValue(0)));
[15712]298    }
[15746]299
300    // Update the view for intermediate results in an algorithm run.
301    private int updates;
[15812]302    private void UpdateView(bool force = false) {
[15746]303      updates++;
304
[15812]305      if (force || updates % GuiUpdateInterval == 1) {
[15803]306        ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
307        ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
[15821]308        ((IntValue)Results[GeneratedSentencesName].Value).Value = AllGeneratedSentencesCount;
[15860]309        ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentencesComplexity.Count;
[15821]310        ((IntValue)Results[PhraseExpansionsName].Value).Value = PhraseExpansionCount;
[15860]311        ((DoubleValue)Results[AverageSentenceComplexityName].Value).Value = DistinctSentencesComplexity.Select(pair => pair.Value).Average();
[15821]312        ((IntValue)Results[OverwrittenSentencesName].Value).Value = OverwrittenSentencesCount;
[15824]313        ((IntValue)Results[ExpansionsPerSecondName].Value).Value = (int)((PhraseExpansionCount /
314                                                                          ExecutionTime.TotalSeconds) / 1000.0);
[15746]315      }
316    }
[15821]317    #endregion
[15746]318
[15821]319    #region events
320    public event EventHandler<PhraseEventArgs> PhraseFetched;
321    private void OnPhraseFetched(int hash, SymbolString symbolString) {
322      if (PhraseFetched != null) {
323        PhraseFetched(this, new PhraseEventArgs(hash, symbolString));
[15746]324      }
325    }
[15812]326
[15821]327    public event EventHandler<PhraseAddedEventArgs> PhraseDerived;
328    private void OnPhraseDerived(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
329      if (PhraseDerived != null) {
330        PhraseDerived(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
331      }
[15803]332    }
[15765]333
[15821]334    public event EventHandler<PhraseAddedEventArgs> SentenceGenerated;
335    private void OnSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
336      if (SentenceGenerated != null) {
337        SentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
338      }
339    }
340
341    public event EventHandler<PhraseAddedEventArgs> DistinctSentenceGenerated;
342    private void OnDistinctSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
343      if (DistinctSentenceGenerated != null) {
344        DistinctSentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
345      }
346    }
347
348    #endregion
349
[15712]350  }
[15821]351
352  #region events for analysis
353
354  public class PhraseEventArgs : EventArgs {
355    public int Hash { get; }
356
357    public SymbolString Phrase { get; }
358
359    public PhraseEventArgs(int hash, SymbolString phrase) {
360      Hash = hash;
361      Phrase = phrase;
362    }
363  }
364
365  public class PhraseAddedEventArgs : EventArgs {
366    public int ParentHash { get; }
367    public int NewHash { get; }
368
369    public SymbolString ParentPhrase { get; }
370    public SymbolString NewPhrase { get; }
371
372    public Symbol ExpandedSymbol { get; }
373
374    public Production ExpandedProduction { get; }
375
376    public PhraseAddedEventArgs(int parentHash, SymbolString parentPhrase, int newHash, SymbolString newPhrase, Symbol expandedSymbol, Production expandedProduction) {
377      ParentHash = parentHash;
378      ParentPhrase = parentPhrase;
379      NewHash = newHash;
380      NewPhrase = newPhrase;
381      ExpandedSymbol = expandedSymbol;
382      ExpandedProduction = expandedProduction;
383    }
384  }
385
386  #endregion
[15712]387}
Note: See TracBrowser for help on using the repository browser.