Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15957

Last change on this file since 15957 was 15957, checked in by bburlacu, 6 years ago

#2886: Minor refactor; fix multiple analyzer event registration

File size: 21.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
6using HeuristicLab.Collections;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Optimization;
11using HeuristicLab.Parameters;
12using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
13using HeuristicLab.Problems.DataAnalysis;
14
15namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
16  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
17  [StorableClass]
18  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
19  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
20    #region properties and result names
21    private readonly string VariableImportanceWeightName = "Variable Importance Weight";
22    private readonly string SearchStructureSizeName = "Search Structure Size";
23    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
24    private readonly string GeneratedSentencesName = "Generated Sentences";
25    private readonly string DistinctSentencesName = "Distinct Sentences";
26    private readonly string PhraseExpansionsName = "Phrase Expansions";
27    private readonly string AverageSentenceComplexityName = "Avg. Sentence Complexity among Distinct";
28    private readonly string OverwrittenSentencesName = "Sentences overwritten";
29    private readonly string AnalyzersParameterName = "Analyzers";
30    private readonly string ExpansionsPerSecondName = "Expansions per second";
31
32    private readonly string OptimizeConstantsParameterName = "Optimize Constants";
33    private readonly string ErrorWeightParameterName = "Error Weight";
34    private readonly string SearchDataStructureParameterName = "Search Data Structure";
35    private readonly string MaxComplexityParameterName = "Max. Complexity";
36    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
37    private readonly string GrammarSymbolsParameterName = "Grammar Symbols";
38
39    public override bool SupportsPause { get { return false; } }
40
41    protected IValueParameter<DoubleValue> VariableImportanceWeightParameter {
42      get { return (IValueParameter<DoubleValue>)Parameters[VariableImportanceWeightName]; }
43    }
44
45    protected double VariableImportanceWeight {
46      get { return VariableImportanceWeightParameter.Value.Value; }
47    }
48
49    protected IValueParameter<BoolValue> OptimizeConstantsParameter {
50      get { return (IValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
51    }
52
53    public bool OptimizeConstants {
54      get { return OptimizeConstantsParameter.Value.Value; }
55      set { OptimizeConstantsParameter.Value.Value = value; }
56    }
57
58    protected IValueParameter<IntValue> MaxComplexityParameter {
59      get { return (IValueParameter<IntValue>)Parameters[MaxComplexityParameterName]; }
60    }
61    public int MaxComplexity {
62      get { return MaxComplexityParameter.Value.Value; }
63      set { MaxComplexityParameter.Value.Value = value; }
64    }
65
66    protected IValueParameter<DoubleValue> ErrorWeightParameter {
67      get { return (IValueParameter<DoubleValue>)Parameters[ErrorWeightParameterName]; }
68    }
69    public double ErrorWeight {
70      get { return ErrorWeightParameter.Value.Value; }
71      set { ErrorWeightParameter.Value.Value = value; }
72    }
73
74    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
75      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
76    }
77    public int GuiUpdateInterval {
78      get { return GuiUpdateIntervalParameter.Value.Value; }
79      set { GuiUpdateIntervalParameter.Value.Value = value; }
80    }
81
82    protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
83      get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
84    }
85    public StorageType SearchDataStructure {
86      get { return SearchDataStructureParameter.Value.Value; }
87      set { SearchDataStructureParameter.Value.Value = value; }
88    }
89
90    public IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>> AnalyzersParameter {
91      get { return (IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>)Parameters[AnalyzersParameterName]; }
92    }
93
94    public ICheckedItemCollection<IGrammarEnumerationAnalyzer> Analyzers {
95      get { return AnalyzersParameter.Value; }
96    }
97
98    public IFixedValueParameter<ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>>> GrammarSymbolsParameter {
99      get { return (IFixedValueParameter<ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>>>)Parameters[GrammarSymbolsParameterName]; }
100    }
101
102    public ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>> GrammarSymbols {
103      get { return GrammarSymbolsParameter.Value; }
104    }
105
106    public SymbolString BestTrainingSentence { get; set; }     // Currently set in RSquaredEvaluator: quite hacky, but makes testing much easier for now...
107    #endregion
108
109    public Dictionary<int, int> DistinctSentencesComplexity { get; private set; }  // Semantically distinct sentences and their length in a run.
110    public HashSet<int> ArchivedPhrases { get; private set; }
111    internal SearchDataStore OpenPhrases { get; private set; }           // Stack/Queue/etc. for fetching the next node in the search tree. 
112
113    #region execution stats
114    public int AllGeneratedSentencesCount { get; private set; }
115
116    public int OverwrittenSentencesCount { get; private set; } // It is not guaranteed that shorter solutions are found first.
117                                                               // When longer solutions are overwritten with shorter ones,
118                                                               // this counter is increased.
119    public int PhraseExpansionCount { get; private set; }      // Number, how many times a nonterminal symbol is replaced with a production rule.
120    #endregion
121
122    public Grammar Grammar { get; private set; }
123
124
125    #region ctors
126    public override IDeepCloneable Clone(Cloner cloner) {
127      return new GrammarEnumerationAlgorithm(this, cloner);
128    }
129
130    public GrammarEnumerationAlgorithm() {
131      Parameters.Add(new ValueParameter<DoubleValue>(VariableImportanceWeightName, "Variable Weight.", new DoubleValue(1.0)));
132      Parameters.Add(new ValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false)));
133      Parameters.Add(new ValueParameter<DoubleValue>(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8)));
134      Parameters.Add(new ValueParameter<IntValue>(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(12)));
135      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(5000)));
136      Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.PriorityQueue)));
137
138      var availableAnalyzers = new IGrammarEnumerationAnalyzer[] {
139        new SearchGraphVisualizer(),
140        new SentenceLogger(),
141        new RSquaredEvaluator()
142      };
143      Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>(
144        AnalyzersParameterName,
145        new CheckedItemCollection<IGrammarEnumerationAnalyzer>(availableAnalyzers).AsReadOnly()));
146
147      Analyzers.CheckedItemsChanged += Analyzers_CheckedItemsChanged;
148
149      foreach (var analyzer in Analyzers) {
150        Analyzers.SetItemCheckedState(analyzer, false);
151      }
152      Analyzers.SetItemCheckedState(Analyzers.First(analyzer => analyzer is RSquaredEvaluator), true);
153
154      var grammarSymbols = Enum.GetValues(typeof(GrammarRule))
155        .Cast<GrammarRule>()
156        .Select(v => new EnumValue<GrammarRule>(v));
157
158      Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>>>(
159        GrammarSymbolsParameterName,
160        new ReadOnlyCheckedItemCollection<EnumValue<GrammarRule>>(new CheckedItemCollection<EnumValue<GrammarRule>>(grammarSymbols))
161      ));
162      foreach (EnumValue<GrammarRule> grammarSymbol in GrammarSymbols) {
163        GrammarSymbols.SetItemCheckedState(grammarSymbol, true);
164      }
165
166      // set a default problem
167      Problem = new RegressionProblem() {
168        ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.PolyTen(seed: 1234).GenerateRegressionData()
169      };
170    }
171
172    public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) {
173      foreach (var analyzer in Analyzers.CheckedItems)
174        analyzer.Register(this);
175      Analyzers.CheckedItemsChanged += Analyzers_CheckedItemsChanged;
176    }
177    #endregion
178
179    private Dictionary<VariableTerminalSymbol, double> variableImportance;
180
181    public override void Prepare() {
182      DistinctSentencesComplexity = new Dictionary<int, int>();
183      ArchivedPhrases = new HashSet<int>();
184      AllGeneratedSentencesCount = 0;
185      OverwrittenSentencesCount = 0;
186      PhraseExpansionCount = 0;
187
188      Analyzers.OfType<RSquaredEvaluator>().First().OptimizeConstants = OptimizeConstants;
189      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray(), GrammarSymbols.CheckedItems.Select(v => v.Value));
190      OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
191
192      base.Prepare(); // this actually clears the results which will get reinitialized on Run()
193    }
194
195    protected override void Run(CancellationToken cancellationToken) {
196      InitResults();
197      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
198      var phrase0Hash = Grammar.Hasher.CalcHashCode(phrase0);
199
200      #region Variable Importance
201      variableImportance = new Dictionary<VariableTerminalSymbol, double>();
202
203      RandomForestRegression rf = new RandomForestRegression();
204      rf.Problem = Problem;
205      rf.Start();
206      IRegressionSolution rfSolution = (RandomForestRegressionSolution)rf.Results["Random forest regression solution"].Value;
207      var rfImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(
208        rfSolution,
209        RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training,
210        RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle);
211
212      // save the normalized importances
213      var sum = rfImpacts.Sum(x => x.Item2);
214      foreach (Tuple<string, double> rfImpact in rfImpacts) {
215        VariableTerminalSymbol varSym = Grammar.VarTerminals.First(v => v.StringRepresentation == rfImpact.Item1);
216        variableImportance[varSym] = rfImpact.Item2 / sum;
217      }
218      #endregion
219
220      int maxSentenceLength = GetMaxSentenceLength();
221
222      OpenPhrases.Store(new SearchNode(phrase0Hash, 0.0, 0.0, phrase0));
223
224      var errorWeight = ErrorWeight;
225      var variableImportanceWeight = VariableImportanceWeight;
226
227      while (OpenPhrases.Count > 0) {
228        if (cancellationToken.IsCancellationRequested) break;
229
230        SearchNode fetchedSearchNode = OpenPhrases.GetNext();
231        SymbolString currPhrase = fetchedSearchNode.SymbolString;
232
233        OnPhraseFetched(fetchedSearchNode.Hash, currPhrase);
234
235        ArchivedPhrases.Add(fetchedSearchNode.Hash);
236
237        // expand next nonterminal symbols
238        int nonterminalSymbolIndex = currPhrase.NextNonterminalIndex();
239        NonterminalSymbol expandedSymbol = (NonterminalSymbol)currPhrase[nonterminalSymbolIndex];
240        var appliedProductions = Grammar.Productions[expandedSymbol];
241
242        for (int i = 0; i < appliedProductions.Count; i++) {
243          PhraseExpansionCount++;
244
245          SymbolString newPhrase = currPhrase.DerivePhrase(nonterminalSymbolIndex, appliedProductions[i]);
246          int newPhraseComplexity = Grammar.GetComplexity(newPhrase);
247
248          if (newPhraseComplexity > MaxComplexity)
249            continue;
250
251          var phraseHash = Grammar.Hasher.CalcHashCode(newPhrase);
252
253          OnPhraseDerived(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
254
255          if (newPhrase.IsSentence()) {
256            AllGeneratedSentencesCount++;
257
258            OnSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
259
260            // Is the best solution found? (only if RSquaredEvaluator is activated)
261            if (Results.ContainsKey(RSquaredEvaluator.BestTrainingQualityResultName)) {
262              double r2 = ((DoubleValue)Results[RSquaredEvaluator.BestTrainingQualityResultName].Value).Value;
263              if (r2.IsAlmost(1.0)) {
264                UpdateView(force: true);
265                return;
266              }
267            }
268
269            if (!DistinctSentencesComplexity.ContainsKey(phraseHash) || DistinctSentencesComplexity[phraseHash] > newPhraseComplexity) {
270              if (DistinctSentencesComplexity.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only
271
272              DistinctSentencesComplexity[phraseHash] = newPhraseComplexity;
273              OnDistinctSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolString, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
274            }
275            UpdateView();
276
277          } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
278
279            double r2 = GetR2(newPhrase, fetchedSearchNode.R2);
280            double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength, errorWeight, variableImportanceWeight);
281
282            SearchNode newSearchNode = new SearchNode(phraseHash, phrasePriority, r2, newPhrase);
283            OpenPhrases.Store(newSearchNode);
284          }
285        }
286      }
287      UpdateView(force: true);
288    }
289
290    protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength, double errorWeight, double variableImportanceWeight) {
291      var distinctVars = phrase.OfType<VariableTerminalSymbol>().Distinct();
292
293      var sum = 0d;
294      foreach (var v in distinctVars) {
295        sum += variableImportance[v];
296      }
297      var phraseVariableImportance = 1 - sum;
298
299      double relLength = (double)phrase.Count() / maxSentenceLength;
300      double error = 1.0 - r2;
301
302      return relLength + errorWeight * error + variableImportanceWeight * phraseVariableImportance;
303    }
304
305    private double GetR2(SymbolString phrase, double parentR2) {
306      int length = phrase.Count();
307
308      // If the only nonterminal symbol is Expr, we can need to evaluate the sentence. Otherwise
309      // the phrase has the same r2 as its parent, from which it was derived.
310      for (int i = 0; i < length; i++) {
311        if (phrase[i] is NonterminalSymbol && phrase[i] != Grammar.Expr) {
312          return parentR2;
313        }
314      }
315
316      return Grammar.EvaluatePhrase(phrase, Problem.ProblemData, OptimizeConstants);
317    }
318
319    private int GetMaxSentenceLength() {
320      SymbolString s = new SymbolString(Grammar.StartSymbol);
321
322      while (!s.IsSentence() && Grammar.GetComplexity(s) <= MaxComplexity) {
323        int expandedSymbolIndex = s.NextNonterminalIndex();
324        NonterminalSymbol expandedSymbol = (NonterminalSymbol)s[expandedSymbolIndex];
325
326        var productions = Grammar.Productions[expandedSymbol];
327        var longestProduction = productions // Find production with most terminal symbols to expand as much as possible...
328          .OrderBy(CountTerminals)          // but with lowest complexity/nonterminal count to keep complexity low.                                                                                     
329          .ThenByDescending(CountNonTerminals)
330          .First();
331
332        s = s.DerivePhrase(expandedSymbolIndex, longestProduction);
333      }
334
335      return s.Count();
336    }
337
338    private int CountTerminals(Production p) {
339      return p.Count(s => s is TerminalSymbol);
340    }
341
342    private int CountNonTerminals(Production p) {
343      return p.Count(s => s is NonterminalSymbol);
344    }
345
346    #region Visualization in HL
347    // Initialize entries in result set.
348    private void InitResults() {
349      Results.Clear();
350      Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
351      Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
352      Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
353      Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
354      Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
355      Results.Add(new Result(OverwrittenSentencesName, new IntValue(0)));
356      Results.Add(new Result(AverageSentenceComplexityName, new DoubleValue(1.0)));
357      Results.Add(new Result(ExpansionsPerSecondName, "In Thousand expansions per second", new IntValue(0)));
358    }
359
360    // Update the view for intermediate results in an algorithm run.
361    private int updates;
362    private void UpdateView(bool force = false) {
363      updates++;
364
365      if (force || updates % GuiUpdateInterval == 1) {
366        ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
367        ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
368        ((IntValue)Results[GeneratedSentencesName].Value).Value = AllGeneratedSentencesCount;
369        ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentencesComplexity.Count;
370        ((IntValue)Results[PhraseExpansionsName].Value).Value = PhraseExpansionCount;
371        ((DoubleValue)Results[AverageSentenceComplexityName].Value).Value = DistinctSentencesComplexity.Select(pair => pair.Value).Average();
372        ((IntValue)Results[OverwrittenSentencesName].Value).Value = OverwrittenSentencesCount;
373        ((IntValue)Results[ExpansionsPerSecondName].Value).Value = (int)((PhraseExpansionCount /
374                                                                          ExecutionTime.TotalSeconds) / 1000.0);
375      }
376    }
377    #endregion
378
379    #region events
380
381    // private event handlers for analyzers
382    private void Analyzers_CheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<IGrammarEnumerationAnalyzer> args) {
383      // newly added items
384      foreach (var item in args.Items.Except(args.OldItems).Union(args.OldItems.Except(args.Items))) {
385        if (Analyzers.ItemChecked(item)) {
386          item.Register(this);
387        } else {
388          item.Deregister(this);
389        }
390      }
391    }
392
393    public event EventHandler<PhraseEventArgs> PhraseFetched;
394    private void OnPhraseFetched(int hash, SymbolString symbolString) {
395      if (PhraseFetched != null) {
396        PhraseFetched(this, new PhraseEventArgs(hash, symbolString));
397      }
398    }
399
400    public event EventHandler<PhraseAddedEventArgs> PhraseDerived;
401    private void OnPhraseDerived(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
402      if (PhraseDerived != null) {
403        PhraseDerived(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
404      }
405    }
406
407    public event EventHandler<PhraseAddedEventArgs> SentenceGenerated;
408    private void OnSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
409      if (SentenceGenerated != null) {
410        SentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
411      }
412    }
413
414    public event EventHandler<PhraseAddedEventArgs> DistinctSentenceGenerated;
415    private void OnDistinctSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
416      if (DistinctSentenceGenerated != null) {
417        DistinctSentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
418      }
419    }
420
421    #endregion
422
423  }
424
425  #region events for analysis
426
427  public class PhraseEventArgs : EventArgs {
428    public int Hash { get; }
429
430    public SymbolString Phrase { get; }
431
432    public PhraseEventArgs(int hash, SymbolString phrase) {
433      Hash = hash;
434      Phrase = phrase;
435    }
436  }
437
438  public class PhraseAddedEventArgs : EventArgs {
439    public int ParentHash { get; }
440    public int NewHash { get; }
441
442    public SymbolString ParentPhrase { get; }
443    public SymbolString NewPhrase { get; }
444
445    public Symbol ExpandedSymbol { get; }
446
447    public Production ExpandedProduction { get; }
448
449    public PhraseAddedEventArgs(int parentHash, SymbolString parentPhrase, int newHash, SymbolString newPhrase, Symbol expandedSymbol, Production expandedProduction) {
450      ParentHash = parentHash;
451      ParentPhrase = parentPhrase;
452      NewHash = newHash;
453      NewPhrase = newPhrase;
454      ExpandedSymbol = expandedSymbol;
455      ExpandedProduction = expandedProduction;
456    }
457  }
458
459  #endregion
460}
Note: See TracBrowser for help on using the repository browser.