Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15812

Last change on this file since 15812 was 15812, checked in by lkammere, 6 years ago

#2886: Performance Improvements - Only store hash of archived phrases and reduce number of enumerators.

File size: 14.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.IO;
5using System.Linq;
6using System.Threading;
7using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
11using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
12using HeuristicLab.Optimization;
13using HeuristicLab.Parameters;
14using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
15using HeuristicLab.Problems.DataAnalysis;
16using HeuristicLab.Problems.DataAnalysis.Symbolic;
17using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
18
19namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
20  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
21  [StorableClass]
22  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
23  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
24    #region properties and result names
25    private readonly string BestTrainingQualityName = "Best R² (Training)";
26    private readonly string BestTrainingSolutionName = "Best solution (Training)";
27    private readonly string SearchStructureSizeName = "Search Structure Size";
28    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
29    private readonly string GeneratedSentencesName = "Generated Sentences";
30    private readonly string DistinctSentencesName = "Distinct Sentences";
31    private readonly string PhraseExpansionsName = "Phrase Expansions";
32    private readonly string AverageTreeLengthName = "Avg. Sentence Length among Distinct";
33    private readonly string GeneratedEqualSentencesName = "Generated equal sentences";
34
35
36    private readonly string SearchDataStructureParameterName = "Search Data Structure";
37    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
38    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
39
40    public override bool SupportsPause { get { return false; } }
41
42    protected IValueParameter<IntValue> MaxTreeSizeParameter {
43      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
44    }
45    public int MaxTreeSize {
46      get { return MaxTreeSizeParameter.Value.Value; }
47      set { MaxTreeSizeParameter.Value.Value = value; }
48    }
49
50    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
51      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
52    }
53    public int GuiUpdateInterval {
54      get { return GuiUpdateIntervalParameter.Value.Value; }
55      set { GuiUpdateIntervalParameter.Value.Value = value; }
56    }
57
58    protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
59      get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
60    }
61    public StorageType SearchDataStructure {
62      get { return SearchDataStructureParameter.Value.Value; }
63      set { SearchDataStructureParameter.Value.Value = value; }
64    }
65
66    public SymbolString BestTrainingSentence;
67
68    #endregion
69
70    public Dictionary<int, SymbolString> DistinctSentences { get; private set; }  // Semantically distinct sentences in a run.
71    public Dictionary<int, List<SymbolString>> AllSentences { get; private set; } // All sentences ever generated in a run.
72    public HashSet<int> ArchivedPhrases { get; private set; }
73
74    internal SearchDataStore OpenPhrases { get; private set; }                    // Stack/Queue/etc. for fetching the next node in the search tree. 
75
76    public int EqualGeneratedSentences { get; private set; } // It is not guaranteed that shorter solutions are found first.
77                                                             // When longer solutions are overwritten with shorter ones,
78                                                             // this counter is increased.
79    public int Expansions { get; private set; }              // Number, how many times a nonterminal symbol is replaced with a production rule.
80    public Grammar Grammar { get; private set; }
81
82    private readonly string dotFileName = Environment.GetFolderPath(System.Environment.SpecialFolder.DesktopDirectory) + @"\searchgraph.dot";
83
84    #region ctors
85    public override IDeepCloneable Clone(Cloner cloner) {
86      return new GrammarEnumerationAlgorithm(this, cloner);
87    }
88
89    public GrammarEnumerationAlgorithm() {
90      Problem = new RegressionProblem() {
91        ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenFunctionNine(seed: 1234).GenerateRegressionData()
92      };
93
94      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(6)));
95      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(1000)));
96      Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.Stack)));
97    }
98
99    public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
100    #endregion
101
102    protected override void Run(CancellationToken cancellationToken) {
103      #region init
104      InitResults();
105
106      AllSentences = new Dictionary<int, List<SymbolString>>();
107      ArchivedPhrases = new HashSet<int>();
108
109      DistinctSentences = new Dictionary<int, SymbolString>();
110      Expansions = 0;
111      EqualGeneratedSentences = 0;
112
113      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
114
115      OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
116      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
117      var phrase0Hash = Grammar.CalcHashCode(phrase0);
118      #endregion
119
120      using (TextWriterTraceListener dotFileTrace = new TextWriterTraceListener(new FileStream(dotFileName, FileMode.Create))) {
121        LogSearchGraph(dotFileTrace, "digraph searchgraph {");
122
123        OpenPhrases.Store(phrase0Hash, phrase0);
124        while (OpenPhrases.Count > 0) {
125          if (cancellationToken.IsCancellationRequested) break;
126
127          StoredSymbolString fetchedPhrase = OpenPhrases.GetNext();
128          SymbolString currPhrase = fetchedPhrase.SymbolString;
129#if DEBUG
130          LogSearchGraph(dotFileTrace, $"{fetchedPhrase.Hash} [label=\"{Grammar.PostfixToInfixParser(fetchedPhrase.SymbolString)}\"];");
131#endif
132          ArchivedPhrases.Add(fetchedPhrase.Hash);
133
134          // expand next nonterminal symbols
135          int nonterminalSymbolIndex = currPhrase.FindIndex(s => s is NonterminalSymbol);
136          NonterminalSymbol expandedSymbol = currPhrase[nonterminalSymbolIndex] as NonterminalSymbol;
137
138          foreach (Production productionAlternative in expandedSymbol.Alternatives) {
139            SymbolString newPhrase = new SymbolString(currPhrase.Count + productionAlternative.Count);
140            newPhrase.AddRange(currPhrase);
141            newPhrase.RemoveAt(nonterminalSymbolIndex);     // TODO: removeat and insertRange are both O(n)
142            newPhrase.InsertRange(nonterminalSymbolIndex, productionAlternative);
143
144            Expansions++;
145            if (newPhrase.Count <= MaxTreeSize) {
146              var phraseHash = Grammar.CalcHashCode(newPhrase);
147#if DEBUG
148              LogSearchGraph(dotFileTrace, $"{fetchedPhrase.Hash} -> {phraseHash} [label=\"{expandedSymbol.StringRepresentation} + &rarr; {productionAlternative}\"];");
149#endif
150              if (newPhrase.IsSentence()) {
151                // Sentence was generated.
152                SaveToAllSentences(phraseHash, newPhrase);
153
154                if (!DistinctSentences.ContainsKey(phraseHash) || DistinctSentences[phraseHash].Count > newPhrase.Count) {
155                  if (DistinctSentences.ContainsKey(phraseHash)) EqualGeneratedSentences++; // for analysis only
156
157                  DistinctSentences[phraseHash] = newPhrase;
158                  EvaluateSentence(newPhrase);
159
160#if DEBUG
161                  LogSearchGraph(dotFileTrace, $"{phraseHash} [label=\"{Grammar.PostfixToInfixParser(newPhrase)}\", style=\"filled\"];");
162#endif
163                }
164                UpdateView();
165
166              } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
167                OpenPhrases.Store(phraseHash, newPhrase);
168              }
169            }
170          }
171        }
172#if DEBUG
173        // Overwrite formatting of start search node and best found solution.
174        LogSearchGraph(dotFileTrace, $"{Grammar.CalcHashCode(BestTrainingSentence)} [label=\"{Grammar.PostfixToInfixParser(BestTrainingSentence)}\", shape=Mcircle, style=\"filled,bold\"];");
175        LogSearchGraph(dotFileTrace, $"{phrase0Hash} [label=\"{Grammar.PostfixToInfixParser(phrase0)}\", shape=doublecircle];}}");
176        dotFileTrace.Flush();
177#endif
178      }
179
180      UpdateView(force: true);
181      UpdateFinalResults();
182    }
183
184    // Store sentence to "MultiDictionary"
185    private void SaveToAllSentences(int hash, SymbolString sentence) {
186      if (AllSentences.ContainsKey(hash))
187        AllSentences[hash].Add(sentence);
188      else
189        AllSentences[hash] = new List<SymbolString> { sentence }; //TODO: here we store all sentences even if they have the same hash value, this is not strictly necessary
190    }
191
192#region Evaluation of generated models.
193
194    // Evaluate sentence within an algorithm run.
195    private void EvaluateSentence(SymbolString symbolString) {
196      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(symbolString);
197      SymbolicRegressionModel model = new SymbolicRegressionModel(
198        Problem.ProblemData.TargetVariable,
199        tree,
200        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
201
202      var probData = Problem.ProblemData;
203      var target = probData.TargetVariableTrainingValues;
204      var estVals = model.GetEstimatedValues(probData.Dataset, probData.TrainingIndices);
205      OnlineCalculatorError error;
206      var r2 = OnlinePearsonsRSquaredCalculator.Calculate(target, estVals, out error);
207      if (error != OnlineCalculatorError.None) r2 = 0.0;
208
209      var bestR2 = ((DoubleValue)Results[BestTrainingQualityName].Value).Value;
210      if (r2 > bestR2) {
211        ((DoubleValue)Results[BestTrainingQualityName].Value).Value = r2;
212        BestTrainingSentence = symbolString;
213      }
214    }
215
216#endregion
217
218#region Visualization in HL
219    // Initialize entries in result set.
220    private void InitResults() {
221      BestTrainingSentence = null;
222
223      Results.Add(new Result(BestTrainingQualityName, new DoubleValue(-1.0)));
224
225      Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
226      Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
227      Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
228      Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
229      Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
230      Results.Add(new Result(GeneratedEqualSentencesName, new IntValue(0)));
231      Results.Add(new Result(AverageTreeLengthName, new DoubleValue(1.0)));
232    }
233
234    // Update the view for intermediate results in an algorithm run.
235    private int updates;
236    private void UpdateView(bool force = false) {
237      updates++;
238
239      if (force || updates % GuiUpdateInterval == 1) {
240        var allGeneratedEnum = AllSentences.Values.SelectMany(x => x).ToArray();
241        ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
242        ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
243        ((IntValue)Results[GeneratedSentencesName].Value).Value = allGeneratedEnum.Length;
244        ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentences.Count;
245        ((IntValue)Results[PhraseExpansionsName].Value).Value = Expansions;
246        ((IntValue)Results[GeneratedEqualSentencesName].Value).Value = EqualGeneratedSentences;
247        ((DoubleValue)Results[AverageTreeLengthName].Value).Value = allGeneratedEnum.Select(sentence => sentence.Count).Average();
248      }
249    }
250
251    // Generate all Results after an algorithm run.
252    private void UpdateFinalResults() {
253      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence);
254      SymbolicRegressionModel model = new SymbolicRegressionModel(
255        Problem.ProblemData.TargetVariable,
256        tree,
257        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
258
259      IRegressionSolution bestTrainingSolution = new RegressionSolution(model, Problem.ProblemData);
260      Results.AddOrUpdateResult(BestTrainingSolutionName, bestTrainingSolution);
261
262      // Print generated sentences.
263      string[,] sentencesMatrix = new string[AllSentences.Values.SelectMany(x => x).Count(), 3];
264
265      int i = 0;
266      foreach (var sentenceSet in AllSentences) {
267        foreach (var sentence in sentenceSet.Value) {
268          sentencesMatrix[i, 0] = sentence.ToString();
269          sentencesMatrix[i, 1] = Grammar.PostfixToInfixParser(sentence).ToString();
270          sentencesMatrix[i, 2] = sentenceSet.Key.ToString();
271          i++;
272        }
273      }
274      Results.Add(new Result("All generated sentences", new StringMatrix(sentencesMatrix)));
275
276      string[,] distinctSentencesMatrix = new string[DistinctSentences.Count, 3];
277      i = 0;
278      foreach (KeyValuePair<int, SymbolString> distinctSentence in DistinctSentences) {
279        distinctSentencesMatrix[i, 0] = distinctSentence.Key.ToString();
280        distinctSentencesMatrix[i, 1] = Grammar.PostfixToInfixParser(distinctSentence.Value).ToString();
281        distinctSentencesMatrix[i, 2] = distinctSentence.Key.ToString();
282        i++;
283      }
284      Results.Add(new Result("Distinct generated sentences", new StringMatrix(distinctSentencesMatrix)));
285    }
286
287    private void LogSearchGraph(TraceListener listener, string msg) {
288      listener.Write(msg);
289    }
290#endregion
291
292  }
293}
Note: See TracBrowser for help on using the repository browser.