Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15821

Last change on this file since 15821 was 15821, checked in by lkammere, 7 years ago

#2886 Move code for visualization and logging of sentences to separate classes.

File size: 15.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
6using HeuristicLab.Collections;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9using HeuristicLab.Data;
10using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
11using HeuristicLab.Optimization;
12using HeuristicLab.Parameters;
13using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
14using HeuristicLab.Problems.DataAnalysis;
15using HeuristicLab.Problems.DataAnalysis.Symbolic;
16using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
17
18namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
19  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
20  [StorableClass]
21  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
22  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
23    #region properties and result names
24    private readonly string BestTrainingQualityName = "Best R² (Training)";
25    private readonly string BestTrainingSolutionName = "Best solution (Training)";
26    private readonly string SearchStructureSizeName = "Search Structure Size";
27    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
28    private readonly string GeneratedSentencesName = "Generated Sentences";
29    private readonly string DistinctSentencesName = "Distinct Sentences";
30    private readonly string PhraseExpansionsName = "Phrase Expansions";
31    private readonly string AverageSentenceLengthName = "Avg. Sentence Length among Distinct";
32    private readonly string OverwrittenSentencesName = "Sentences overwritten";
33    private readonly string AnalyzersParameterName = "Analyzers";
34
35
36    private readonly string SearchDataStructureParameterName = "Search Data Structure";
37    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
38    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
39
40    public override bool SupportsPause { get { return false; } }
41
42    protected IValueParameter<IntValue> MaxTreeSizeParameter {
43      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
44    }
45    public int MaxTreeSize {
46      get { return MaxTreeSizeParameter.Value.Value; }
47      set { MaxTreeSizeParameter.Value.Value = value; }
48    }
49
50    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
51      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
52    }
53    public int GuiUpdateInterval {
54      get { return GuiUpdateIntervalParameter.Value.Value; }
55      set { GuiUpdateIntervalParameter.Value.Value = value; }
56    }
57
58    protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
59      get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
60    }
61    public StorageType SearchDataStructure {
62      get { return SearchDataStructureParameter.Value.Value; }
63      set { SearchDataStructureParameter.Value.Value = value; }
64    }
65
66    public IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>> AnalyzersParameter {
67      get { return (IFixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>)Parameters[AnalyzersParameterName]; }
68    }
69
70    public ICheckedItemCollection<IGrammarEnumerationAnalyzer> Analyzers {
71      get { return AnalyzersParameter.Value; }
72    }
73
74    public SymbolString BestTrainingSentence;
75
76    #endregion
77
78    public Dictionary<int, int> DistinctSentencesLength { get; private set; }  // Semantically distinct sentences and their length in a run.
79    public HashSet<int> ArchivedPhrases { get; private set; }
80    internal SearchDataStore OpenPhrases { get; private set; }           // Stack/Queue/etc. for fetching the next node in the search tree. 
81
82    #region execution stats
83    public int AllGeneratedSentencesCount { get; private set; }
84
85    public int OverwrittenSentencesCount { get; private set; } // It is not guaranteed that shorter solutions are found first.
86                                                               // When longer solutions are overwritten with shorter ones,
87                                                               // this counter is increased.
88    public int PhraseExpansionCount { get; private set; }      // Number, how many times a nonterminal symbol is replaced with a production rule.
89    #endregion
90
91    public Grammar Grammar { get; private set; }
92
93    private readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
94
95    #region ctors
96    public override IDeepCloneable Clone(Cloner cloner) {
97      return new GrammarEnumerationAlgorithm(this, cloner);
98    }
99
100    public GrammarEnumerationAlgorithm() {
101      Problem = new RegressionProblem() {
102        ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenFunctionNine(seed: 1234).GenerateRegressionData()
103      };
104
105      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(6)));
106      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(1000)));
107      Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.Stack)));
108
109      var availableAnalyzers = new IGrammarEnumerationAnalyzer[] {
110        new SearchGraphVisualizer(),
111        new SentenceLogger()
112      };
113      Parameters.Add(new FixedValueParameter<ReadOnlyCheckedItemCollection<IGrammarEnumerationAnalyzer>>(
114        AnalyzersParameterName,
115        new CheckedItemCollection<IGrammarEnumerationAnalyzer>(availableAnalyzers).AsReadOnly()));
116
117      foreach (var analyzer in Analyzers) {
118        Analyzers.SetItemCheckedState(analyzer, false);
119      }
120      Analyzers.CheckedItemsChanged += AnalyzersOnCheckedItemsChanged;
121    }
122
123    public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
124    #endregion
125
126    protected override void Run(CancellationToken cancellationToken) {
127      #region init
128      InitResults();
129
130      ArchivedPhrases = new HashSet<int>();
131
132      DistinctSentencesLength = new Dictionary<int, int>();
133      AllGeneratedSentencesCount = 0;
134      OverwrittenSentencesCount = 0;
135      PhraseExpansionCount = 0;
136
137      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
138
139      OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
140      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
141      var phrase0Hash = Grammar.CalcHashCode(phrase0);
142      #endregion
143
144      OpenPhrases.Store(phrase0Hash, phrase0);
145      while (OpenPhrases.Count > 0) {
146        if (cancellationToken.IsCancellationRequested) break;
147
148        StoredSymbolString fetchedPhrase = OpenPhrases.GetNext();
149        SymbolString currPhrase = fetchedPhrase.SymbolString;
150
151        OnPhraseFetched(fetchedPhrase.Hash, currPhrase);
152
153        ArchivedPhrases.Add(fetchedPhrase.Hash);
154
155        // expand next nonterminal symbols
156        int nonterminalSymbolIndex = currPhrase.FindIndex(s => s is NonterminalSymbol);
157        NonterminalSymbol expandedSymbol = currPhrase[nonterminalSymbolIndex] as NonterminalSymbol;
158
159        foreach (Production productionAlternative in expandedSymbol.Alternatives) {
160          PhraseExpansionCount++;
161
162          SymbolString newPhrase = new SymbolString(currPhrase.Count + productionAlternative.Count);
163          newPhrase.AddRange(currPhrase);
164          newPhrase.RemoveAt(nonterminalSymbolIndex);     // TODO: removeat and insertRange are both O(n)
165          newPhrase.InsertRange(nonterminalSymbolIndex, productionAlternative);
166
167          if (newPhrase.Count <= MaxTreeSize) {
168            var phraseHash = Grammar.CalcHashCode(newPhrase);
169
170            OnPhraseDerived(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, productionAlternative);
171
172            if (newPhrase.IsSentence()) {
173              AllGeneratedSentencesCount++;
174
175              OnSentenceGenerated(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, productionAlternative);
176
177              if (!DistinctSentencesLength.ContainsKey(phraseHash) || DistinctSentencesLength[phraseHash] > newPhrase.Count) {
178                if (DistinctSentencesLength.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only
179
180                DistinctSentencesLength[phraseHash] = newPhrase.Count;
181                EvaluateSentence(newPhrase);
182
183                OnDistinctSentenceGenerated(fetchedPhrase.Hash, fetchedPhrase.SymbolString, phraseHash, newPhrase, expandedSymbol, productionAlternative);
184              }
185              UpdateView();
186
187            } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
188              OpenPhrases.Store(phraseHash, newPhrase);
189            }
190          }
191        }
192      }
193
194      UpdateView(force: true);
195      UpdateFinalResults();
196    }
197
198    #region Evaluation of generated models.
199
200    // Evaluate sentence within an algorithm run.
201    private void EvaluateSentence(SymbolString symbolString) {
202      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(symbolString);
203      SymbolicRegressionModel model = new SymbolicRegressionModel(
204        Problem.ProblemData.TargetVariable,
205        tree,
206        expressionTreeLinearInterpreter);
207
208      var probData = Problem.ProblemData;
209      var target = probData.TargetVariableTrainingValues;
210      var estVals = model.GetEstimatedValues(probData.Dataset, probData.TrainingIndices);
211      OnlineCalculatorError error;
212      var r2 = OnlinePearsonsRSquaredCalculator.Calculate(target, estVals, out error);
213      if (error != OnlineCalculatorError.None) r2 = 0.0;
214
215      var bestR2 = ((DoubleValue)Results[BestTrainingQualityName].Value).Value;
216      if (r2 > bestR2) {
217        ((DoubleValue)Results[BestTrainingQualityName].Value).Value = r2;
218        BestTrainingSentence = symbolString;
219      }
220    }
221
222    #endregion
223
224    #region Visualization in HL
225    // Initialize entries in result set.
226    private void InitResults() {
227      BestTrainingSentence = null;
228
229      Results.Add(new Result(BestTrainingQualityName, new DoubleValue(-1.0)));
230
231      Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
232      Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
233      Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
234      Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
235      Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
236      Results.Add(new Result(OverwrittenSentencesName, new IntValue(0)));
237      Results.Add(new Result(AverageSentenceLengthName, new DoubleValue(1.0)));
238    }
239
240    // Update the view for intermediate results in an algorithm run.
241    private int updates;
242    private void UpdateView(bool force = false) {
243      updates++;
244
245      if (force || updates % GuiUpdateInterval == 1) {
246        ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
247        ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
248        ((IntValue)Results[GeneratedSentencesName].Value).Value = AllGeneratedSentencesCount;
249        ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentencesLength.Count;
250        ((IntValue)Results[PhraseExpansionsName].Value).Value = PhraseExpansionCount;
251        ((DoubleValue)Results[AverageSentenceLengthName].Value).Value = DistinctSentencesLength.Select(pair => pair.Value).Average();
252        ((IntValue)Results[OverwrittenSentencesName].Value).Value = OverwrittenSentencesCount;
253      }
254    }
255
256    // Generate all Results after an algorithm run.
257    private void UpdateFinalResults() {
258      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence);
259      SymbolicRegressionModel model = new SymbolicRegressionModel(
260        Problem.ProblemData.TargetVariable,
261        tree,
262        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
263
264      IRegressionSolution bestTrainingSolution = new RegressionSolution(model, Problem.ProblemData);
265      Results.AddOrUpdateResult(BestTrainingSolutionName, bestTrainingSolution);
266    }
267    #endregion
268
269    #region events
270    private void AnalyzersOnCheckedItemsChanged(object sender, CollectionItemsChangedEventArgs<IGrammarEnumerationAnalyzer> collectionItemsChangedEventArgs) {
271      foreach (IGrammarEnumerationAnalyzer grammarEnumerationAnalyzer in collectionItemsChangedEventArgs.Items) {
272        if (Analyzers.ItemChecked(grammarEnumerationAnalyzer)) {
273          grammarEnumerationAnalyzer.Register(this);
274        } else {
275          grammarEnumerationAnalyzer.Deregister(this);
276        }
277      }
278    }
279
280    public event EventHandler<PhraseEventArgs> PhraseFetched;
281    private void OnPhraseFetched(int hash, SymbolString symbolString) {
282      if (PhraseFetched != null) {
283        PhraseFetched(this, new PhraseEventArgs(hash, symbolString));
284      }
285    }
286
287    public event EventHandler<PhraseAddedEventArgs> PhraseDerived;
288    private void OnPhraseDerived(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
289      if (PhraseDerived != null) {
290        PhraseDerived(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
291      }
292    }
293
294    public event EventHandler<PhraseAddedEventArgs> SentenceGenerated;
295    private void OnSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
296      if (SentenceGenerated != null) {
297        SentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
298      }
299    }
300
301    public event EventHandler<PhraseAddedEventArgs> DistinctSentenceGenerated;
302    private void OnDistinctSentenceGenerated(int parentHash, SymbolString parentSymbolString, int addedHash, SymbolString addedSymbolString, Symbol expandedSymbol, Production expandedProduction) {
303      if (DistinctSentenceGenerated != null) {
304        DistinctSentenceGenerated(this, new PhraseAddedEventArgs(parentHash, parentSymbolString, addedHash, addedSymbolString, expandedSymbol, expandedProduction));
305      }
306    }
307
308    #endregion
309
310  }
311
312  #region events for analysis
313
314  public class PhraseEventArgs : EventArgs {
315    public int Hash { get; }
316
317    public SymbolString Phrase { get; }
318
319    public PhraseEventArgs(int hash, SymbolString phrase) {
320      Hash = hash;
321      Phrase = phrase;
322    }
323  }
324
325  public class PhraseAddedEventArgs : EventArgs {
326    public int ParentHash { get; }
327    public int NewHash { get; }
328
329    public SymbolString ParentPhrase { get; }
330    public SymbolString NewPhrase { get; }
331
332    public Symbol ExpandedSymbol { get; }
333
334    public Production ExpandedProduction { get; }
335
336    public PhraseAddedEventArgs(int parentHash, SymbolString parentPhrase, int newHash, SymbolString newPhrase, Symbol expandedSymbol, Production expandedProduction) {
337      ParentHash = parentHash;
338      ParentPhrase = parentPhrase;
339      NewHash = newHash;
340      NewPhrase = newPhrase;
341      ExpandedSymbol = expandedSymbol;
342      ExpandedProduction = expandedProduction;
343    }
344  }
345
346  #endregion
347}
Note: See TracBrowser for help on using the repository browser.