source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15907

Last change on this file since 15907 was 15907, checked in by lkammere, 17 months ago

#2886: Changes in search heuristic for solving Poly-10 problem. Adapt tree evaluation to cover non-terminal symbols.

File size: 16.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Security.Cryptography;
5using System.Threading;
6using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
7using HeuristicLab.Collections;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
11using HeuristicLab.Optimization;
12using HeuristicLab.Parameters;
13using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
14using HeuristicLab.Problems.DataAnalysis;
15
16namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
17  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
18  [StorableClass]
19  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
20  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
21    #region properties and result names
22    private readonly string SearchStructureSizeName = "Search Structure Size";
23    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
24    private readonly string GeneratedSentencesName = "Generated Sentences";
25    private readonly string DistinctSentencesName = "Distinct Sentences";
26    private readonly string PhraseExpansionsName = "Phrase Expansions";
27    private readonly string AverageSentenceComplexityName = "Avg. Sentence Complexity among Distinct";
28    private readonly string OverwrittenSentencesName = "Sentences overwritten";
29    private readonly string AnalyzersParameterName = "Analyzers";
30    private readonly string ExpansionsPerSecondName = "Expansions per second";
31
32
33    private readonly string OptimizeConstantsParameterName = "Optimize Constants";
34    private readonly string SearchDataStructureParameterName = "Search Data Structure";
35    private readonly string MaxComplexityParameterName = "Max. Complexity";
36    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
37
38    public override bool SupportsPause { get { return false; } }
39
40    protected IValueParameter<BoolValue> OptimizeConstantsParameter {
41      get { return (IValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
42    }
43
44    public bool OptimizeConstants {
45      get { return OptimizeConstantsParameter.Value.Value; }
46      set { OptimizeConstantsParameter.Value.Value = value; }
47    }
48
49    protected IValueParameter<IntValue> MaxComplexityParameter {
50      get { return (IValueParameter<IntValue>)Parameters[MaxComplexityParameterName]; }
51    }
52    public int MaxComplexity {
53      get { return MaxComplexityParameter.Value.Value; }
54      set { MaxComplexityParameter.Value.Value = value; }
55    }
56
57    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
58      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
59    }
60    public int GuiUpdateInterval {
61      get { return GuiUpdateIntervalParameter.Value.Value; }
62      set { GuiUpdateIntervalParameter.Value.Value = value; }
63    }
64
65    protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
66      get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
67    }
68    public StorageType SearchDataStructure {
69      get { return SearchDataStructureParameter.Value.Value; }
70      set { SearchDataStructureParameter.Value.Value = value; }
71    }
72
73    public IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>> AnalyzersParameter {
74      get { return (IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>)Parameters[AnalyzersParameterName]; }
75    }
76
77    public ICheckedItemCollection<IGrammarEnumerationAnalyzer> Analyzers {
78      get { return AnalyzersParameter.Value; }
79    }
80
81    public SymbolString BestTrainingSentence { get; set; }     // Currently set in RSquaredEvaluator: quite hacky, but makes testing much easier for now...
82    #endregion
83
84    public Dictionary<int, int> DistinctSentencesComplexity { get; private set; }  // Semantically distinct sentences and their length in a run.
85    public HashSet<int> ArchivedPhrases { get; private set; }
86    internal SearchDataStore OpenPhrases { get; private set; }           // Stack/Queue/etc. for fetching the next node in the search tree. 
87
88    #region execution stats
89    public int AllGeneratedSentencesCount { get; private set; }
90
91    public int OverwrittenSentencesCount { get; private set; } // It is not guaranteed that shorter solutions are found first.
92                                                               // When longer solutions are overwritten with shorter ones,
93                                                               // this counter is increased.
94    public int PhraseExpansionCount { get; private set; }      // Number, how many times a nonterminal symbol is replaced with a production rule.
95    #endregion
96
97    public Grammar Grammar { get; private set; }
98
99
100    #region ctors
101    public override IDeepCloneable Clone(Cloner cloner) {
102      return new GrammarEnumerationAlgorithm(this, cloner);
103    }
104
105    public GrammarEnumerationAlgorithm() {
106      Problem = new RegressionProblem() {
107        ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenFunctionNine(seed: 1234).GenerateRegressionData()
108      };
109
110      Parameters.Add(new ValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(true)));
111      Parameters.Add(new ValueParameter<IntValue>(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(5)));
112      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(100000)));
113      Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.Stack)));
114
115      var availableAnalyzers = new IGrammarEnumerationAnalyzer[] {
116        new SearchGraphVisualizer(),
117        new SentenceLogger(),
118        new RSquaredEvaluator()
119      };
120      Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>(
121        AnalyzersParameterName,
122        new CheckedItemCollection<IGrammarEnumerationAnalyzer>(availableAnalyzers).AsReadOnly()));
123
124      foreach (var analyzer in Analyzers) {
125        Analyzers.SetItemCheckedState(analyzer, false);
126      }
127      Analyzers.CheckedItemsChanged += AnalyzersOnCheckedItemsChanged;
128      Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is RSquaredEvaluator), true);
129      Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is SentenceLogger), true);
130    }
131
132    public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
133    #endregion
134
135    protected override void Run(CancellationToken cancellationToken) {
136      #region init
137      InitResults();
138
139      Analyzers.OfType<RSquaredEvaluator>().First().OptimizeConstants = OptimizeConstants;
140
141      ArchivedPhrases = new HashSet<int>();
142
143      DistinctSentencesComplexity = new Dictionary<int, int>();
144      AllGeneratedSentencesCount = 0;
145      OverwrittenSentencesCount = 0;
146      PhraseExpansionCount = 0;
147
148      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
149
150      OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
151      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
152      var phrase0Hash = Grammar.Hasher.CalcHashCode(phrase0);
153      #endregion
154
155      OpenPhrases.Store(phrase0Hash, 0.0, phrase0);
156      while (OpenPhrases.Count > 0) {
157        if (cancellationToken.IsCancellationRequested) break;
158
159        StoredSymbolString fetchedPhrase = OpenPhrases.GetNext();
160        SymbolString currPhrase = fetchedPhrase.SymbolString;
161
162        OnPhraseFetched(fetchedPhrase.Hash, currPhrase);
163
164        ArchivedPhrases.Add(fetchedPhrase.Hash);
165
166        // expand next nonterminal symbols
167        int nonterminalSymbolIndex = currPhrase.NextNonterminalIndex();
168        NonterminalSymbol expandedSymbol = (NonterminalSymbol)currPhrase[nonterminalSymbolIndex];
169        var appliedProductions = Grammar.Productions[expandedSymbol];
170
171        for (int i = 0; i < appliedProductions.Count; i++) {
172          PhraseExpansionCount++;
173
174          SymbolString newPhrase = currPhrase.DerivePhrase(nonterminalSymbolIndex, appliedProductions[i]);
175          int newPhraseComplexity = Grammar.GetComplexity(newPhrase);
176
177          if (newPhraseComplexity <= MaxComplexity) {
178            var phraseHash = Grammar.Hasher.CalcHashCode(newPhrase);
179
180            OnPhraseDerived(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
181
182            if (newPhrase.IsSentence()) {
183              AllGeneratedSentencesCount++;
184
185              OnSentenceGenerated(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
186
187              // Is the best solution found? (only if RSquaredEvaluator is activated)
188              if (Results.ContainsKey(RSquaredEvaluator.BestTrainingQualityResultName)) {
189                double r2 = ((DoubleValue)Results[RSquaredEvaluator.BestTrainingQualityResultName].Value).Value;
190                if (r2.IsAlmost(1.0)) {
191                  UpdateView(force: true);
192                  return;
193                }
194              }
195
196              if (!DistinctSentencesComplexity.ContainsKey(phraseHash) || DistinctSentencesComplexity[phraseHash] > newPhraseComplexity) {
197                if (DistinctSentencesComplexity.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only
198
199                DistinctSentencesComplexity[phraseHash] = newPhraseComplexity;
200                OnDistinctSentenceGenerated(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
201              }
202              UpdateView();
203
204            } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
205              double phrasePriority = GetPriority(newPhrase);
206              OpenPhrases.Store(phraseHash, phrasePriority, newPhrase);
207            }
208          }
209        }
210      }
211      UpdateView(force: true);
212    }
213
214    protected double GetPriority(SymbolString phrase) {
215      double complexity = (double)Grammar.GetComplexity(phrase);
216
217      double length = phrase.Count();
218      double relLength = (length - 2) / (MaxComplexity * 7);
219      double r2 = Grammar.EvaluatePhrase(phrase, Problem.ProblemData, OptimizeConstants);
220      double error = 1.0 - r2;
221
222      double variables = 0;
223      for (int i = 0; i < phrase.Count(); i++) {
224        if (phrase[i] is VariableTerminalSymbol) variables++;
225      }
226
227      double variableRatio = 1.0 - variables / complexity;
228
229      return 1.5*relLength + error;
230    }
231
232    #region Visualization in HL
233    // Initialize entries in result set.
234    private void InitResults() {
235      Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
236      Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
237      Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
238      Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
239      Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
240      Results.Add(new Result(OverwrittenSentencesName, new IntValue(0)));
241      Results.Add(new Result(AverageSentenceComplexityName, new DoubleValue(1.0)));
242      Results.Add(new Result(ExpansionsPerSecondName, "In Thousand expansions per second", new IntValue(0)));
243    }
244
245    // Update the view for intermediate results in an algorithm run.
246    private int updates;
247    private void UpdateView(bool force = false) {
248      updates++;
249
250      if (force || updates % GuiUpdateInterval == 1) {
251        ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
252        ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
253        ((IntValue)Results[GeneratedSentencesName].Value).Value = AllGeneratedSentencesCount;
254        ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentencesComplexity.Count;
255        ((IntValue)Results[PhraseExpansionsName].Value).Value = PhraseExpansionCount;
256        ((DoubleValue)Results[AverageSentenceComplexityName].Value).Value = DistinctSentencesComplexity.Select(pair => pair.Value).Average();
257        ((IntValue)Results[OverwrittenSentencesName].Value).Value = OverwrittenSentencesCount;
258        ((IntValue)Results[ExpansionsPerSecondName].Value).Value = (int)((PhraseExpansionCount /
259                                                                          ExecutionTime.TotalSeconds) / 1000.0);
260      }
261    }
262    #endregion
263
264    #region events
265    private void AnalyzersOnCheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<IGrammarEnumerationAnalyzer> collectionItemsChangedEventArgs) {
266      foreach (IGrammarEnumerationAnalyzer grammarEnumerationAnalyzer in collectionItemsChangedEventArgs.Items) {
267        if (Analyzers.ItemChecked(grammarEnumerationAnalyzer)) {
268          grammarEnumerationAnalyzer.Register(this);
269        } else {
270          grammarEnumerationAnalyzer.Deregister(this);
271        }
272      }
273    }
274
275    public event EventHandler<PhraseEventArgs> PhraseFetched;
276    private void OnPhraseFetched(int hash, SymbolString symbolString) {
277      if (PhraseFetched != null) {
278        PhraseFetched(this, new PhraseEventArgs(hash, symbolString));
279      }
280    }
281
282    public event EventHandler<PhraseAddedEventArgs> PhraseDerived;
283    private void OnPhraseDerived(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
284      if (PhraseDerived != null) {
285        PhraseDerived(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
286      }
287    }
288
289    public event EventHandler<PhraseAddedEventArgs> SentenceGenerated;
290    private void OnSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
291      if (SentenceGenerated != null) {
292        SentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
293      }
294    }
295
296    public event EventHandler<PhraseAddedEventArgs> DistinctSentenceGenerated;
297    private void OnDistinctSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
298      if (DistinctSentenceGenerated != null) {
299        DistinctSentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
300      }
301    }
302
303    #endregion
304
305  }
306
307  #region events for analysis
308
309  public class PhraseEventArgs : EventArgs {
310    public int Hash { get; }
311
312    public SymbolString Phrase { get; }
313
314    public PhraseEventArgs(int hash, SymbolString phrase) {
315      Hash = hash;
316      Phrase = phrase;
317    }
318  }
319
320  public class PhraseAddedEventArgs : EventArgs {
321    public int ParentHash { get; }
322    public int NewHash { get; }
323
324    public SymbolString ParentPhrase { get; }
325    public SymbolString NewPhrase { get; }
326
327    public Symbol ExpandedSymbol { get; }
328
329    public Production ExpandedProduction { get; }
330
331    public PhraseAddedEventArgs(int parentHash, SymbolString parentPhrase, int newHash, SymbolString newPhrase, Symbol expandedSymbol, Production expandedProduction) {
332      ParentHash = parentHash;
333      ParentPhrase = parentPhrase;
334      NewHash = newHash;
335      NewPhrase = newPhrase;
336      ExpandedSymbol = expandedSymbol;
337      ExpandedProduction = expandedProduction;
338    }
339  }
340
341  #endregion
342}
Note: See TracBrowser for help on using the repository browser.