source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15806

Last change on this file since 15806 was 15806, checked in by gkronber, 20 months ago

#2886 made a few comments

File size: 14.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.IO;
5using System.Linq;
6using System.Threading;
7using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
11using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
12using HeuristicLab.Optimization;
13using HeuristicLab.Parameters;
14using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
15using HeuristicLab.Problems.DataAnalysis;
16using HeuristicLab.Problems.DataAnalysis.Symbolic;
17using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
18
19namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
20  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
21  [StorableClass]
22  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
23  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
24    #region properties and result names
25    private readonly string BestTrainingQualityName = "Best R² (Training)";
26    private readonly string BestTrainingSolutionName = "Best solution (Training)";
27    private readonly string SearchStructureSizeName = "Search Structure Size";
28    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
29    private readonly string GeneratedSentencesName = "Generated Sentences";
30    private readonly string DistinctSentencesName = "Distinct Sentences";
31    private readonly string PhraseExpansionsName = "Phrase Expansions";
32    private readonly string AverageTreeLengthName = "Avg. Sentence Length among Distinct";
33    private readonly string GeneratedEqualSentencesName = "Generated equal sentences";
34
35
36    private readonly string SearchDataStructureParameterName = "Search Data Structure";
37    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
38    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
39
40    public override bool SupportsPause { get { return false; } }
41
42    protected IValueParameter<IntValue> MaxTreeSizeParameter {
43      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
44    }
45    public int MaxTreeSize {
46      get { return MaxTreeSizeParameter.Value.Value; }
47      set { MaxTreeSizeParameter.Value.Value = value; }
48    }
49
50    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
51      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
52    }
53    public int GuiUpdateInterval {
54      get { return GuiUpdateIntervalParameter.Value.Value; }
55      set { GuiUpdateIntervalParameter.Value.Value = value; }
56    }
57
58    protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
59      get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
60    }
61    public StorageType SearchDataStructure {
62      get { return SearchDataStructureParameter.Value.Value; }
63      set { SearchDataStructureParameter.Value.Value = value; }
64    }
65
66    public SymbolString BestTrainingSentence;
67
68    #endregion
69
70    public Dictionary<int, SymbolString> DistinctSentences { get; private set; }  // Semantically distinct sentences in a run.
71    public Dictionary<int, List<SymbolString>> AllSentences { get; private set; } // All sentences ever generated in a run.
72    public Dictionary<int, SymbolString> ArchivedPhrases { get; private set; }    // Nodes in the search tree
73    internal SearchDataStore OpenPhrases { get; private set; }                    // Stack/Queue/etc. for fetching the next node in the search tree. 
74
75    public int EqualGeneratedSentences { get; private set; } // It is not guaranteed that shorter solutions are found first.
76                                                             // When longer solutions are overwritten with shorter ones,
77                                                             // this counter is increased.
78    public int Expansions { get; private set; }              // Number, how many times a nonterminal symbol is replaced with a production rule.
79    public Grammar Grammar { get; private set; }
80
81    private readonly string dotFileName = Environment.GetFolderPath(System.Environment.SpecialFolder.DesktopDirectory) + @"\searchgraph.dot";
82
83    #region ctors
84    public override IDeepCloneable Clone(Cloner cloner) {
85      return new GrammarEnumerationAlgorithm(this, cloner);
86    }
87
88    public GrammarEnumerationAlgorithm() {
89      Problem = new RegressionProblem() {
90        ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenFunctionNine(seed: 1234).GenerateRegressionData()
91      };
92
93      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(6)));
94      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(1000)));
95      Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.Stack)));
96    }
97
98    public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
99    #endregion
100
101    protected override void Run(CancellationToken cancellationToken) {
102      #region init
103      InitResults();
104
105      AllSentences = new Dictionary<int, List<SymbolString>>();
106      ArchivedPhrases = new Dictionary<int, SymbolString>(); // Nodes in the search tree
107      DistinctSentences = new Dictionary<int, SymbolString>();
108      Expansions = 0;
109      EqualGeneratedSentences = 0;
110
111      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
112
113      OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
114      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
115      var phrase0Hash = Grammar.CalcHashCode(phrase0);
116      #endregion
117
118      using (TextWriterTraceListener dotFileTrace = new TextWriterTraceListener(new FileStream(dotFileName, FileMode.Create))) {
119        LogSearchGraph(dotFileTrace, "digraph searchgraph {");
120
121        OpenPhrases.Store(phrase0Hash, phrase0);
122        while (OpenPhrases.Count > 0) {
123          if (cancellationToken.IsCancellationRequested) break;
124
125          StoredSymbolString fetchedPhrase = OpenPhrases.GetNext();
126          SymbolString currPhrase = fetchedPhrase.SymbolString;
127
128          LogSearchGraph(dotFileTrace, $"{fetchedPhrase.Hash} [label=\"{Grammar.PostfixToInfixParser(fetchedPhrase.SymbolString)}\"];");
129
130          ArchivedPhrases.Add(fetchedPhrase.Hash, currPhrase);
131
132          // expand next nonterminal symbols
133          int nonterminalSymbolIndex = currPhrase.FindIndex(s => s is NonterminalSymbol);
134          NonterminalSymbol expandedSymbol = currPhrase[nonterminalSymbolIndex] as NonterminalSymbol;
135
136          foreach (Production productionAlternative in expandedSymbol.Alternatives) {
137            SymbolString newPhrase = new SymbolString(currPhrase);
138            newPhrase.RemoveAt(nonterminalSymbolIndex);     // TODO: removeat and insertRange are both O(n)
139            newPhrase.InsertRange(nonterminalSymbolIndex, productionAlternative);
140
141            Expansions++;
142            if (newPhrase.Count <= MaxTreeSize) {
143              var phraseHash = Grammar.CalcHashCode(newPhrase);
144
145              LogSearchGraph(dotFileTrace, $"{fetchedPhrase.Hash} -> {phraseHash} [label=\"{expandedSymbol.StringRepresentation} + &rarr; {productionAlternative}\"];");
146
147              if (newPhrase.IsSentence()) {
148                // Sentence was generated.
149                SaveToAllSentences(phraseHash, newPhrase);
150
151                if (!DistinctSentences.ContainsKey(phraseHash) || DistinctSentences[phraseHash].Count > newPhrase.Count) {
152                  if (DistinctSentences.ContainsKey(phraseHash)) EqualGeneratedSentences++; // for analysis only
153
154                  DistinctSentences[phraseHash] = newPhrase;
155                  EvaluateSentence(newPhrase);
156
157                  LogSearchGraph(dotFileTrace, $"{phraseHash} [label=\"{Grammar.PostfixToInfixParser(newPhrase)}\", style=\"filled\"];");
158                }
159                UpdateView(AllSentences, DistinctSentences);
160
161              } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.ContainsKey(phraseHash)) {
162                OpenPhrases.Store(phraseHash, newPhrase);
163              }
164            }
165          }
166        }
167
168        // Overwrite formatting of start search node and best found solution.
169        LogSearchGraph(dotFileTrace, $"{Grammar.CalcHashCode(BestTrainingSentence)} [label=\"{Grammar.PostfixToInfixParser(BestTrainingSentence)}\", shape=Mcircle, style=\"filled,bold\"];");
170        LogSearchGraph(dotFileTrace, $"{phrase0Hash} [label=\"{Grammar.PostfixToInfixParser(phrase0)}\", shape=doublecircle];}}");
171        dotFileTrace.Flush();
172      }
173
174      UpdateView(AllSentences, DistinctSentences, force: true);
175      UpdateFinalResults();
176    }
177
178    // Store sentence to "MultiDictionary"
179    private void SaveToAllSentences(int hash, SymbolString sentence) {
180      if (AllSentences.ContainsKey(hash))
181        AllSentences[hash].Add(sentence);
182      else
183        AllSentences[hash] = new List<SymbolString> { sentence }; //TODO: here we store all sentences even if they have the same hash value, this is not strictly necessary
184    }
185
186    #region Evaluation of generated models.
187
188    // Evaluate sentence within an algorithm run.
189    private void EvaluateSentence(SymbolString symbolString) {
190      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(symbolString);
191      SymbolicRegressionModel model = new SymbolicRegressionModel(
192        Problem.ProblemData.TargetVariable,
193        tree,
194        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
195
196      var probData = Problem.ProblemData;
197      var target = probData.TargetVariableTrainingValues;
198      var estVals = model.GetEstimatedValues(probData.Dataset, probData.TrainingIndices);
199      OnlineCalculatorError error;
200      var r2 = OnlinePearsonsRSquaredCalculator.Calculate(target, estVals, out error);
201      if (error != OnlineCalculatorError.None) r2 = 0.0;
202
203      var bestR2 = ((DoubleValue)Results[BestTrainingQualityName].Value).Value;
204      if (r2 > bestR2) {
205        ((DoubleValue)Results[BestTrainingQualityName].Value).Value = r2;
206        BestTrainingSentence = symbolString;
207      }
208    }
209
210    #endregion
211
212    #region Visualization in HL
213    // Initialize entries in result set.
214    private void InitResults() {
215      BestTrainingSentence = null;
216
217      Results.Add(new Result(BestTrainingQualityName, new DoubleValue(-1.0)));
218
219      Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
220      Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
221      Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
222      Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
223      Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
224      Results.Add(new Result(GeneratedEqualSentencesName, new IntValue(0)));
225      Results.Add(new Result(AverageTreeLengthName, new DoubleValue(1.0)));
226    }
227
228    // Update the view for intermediate results in an algorithm run.
229    private int updates;
230    private void UpdateView(Dictionary<int, List<SymbolString>> allGenerated,
231        Dictionary<int, SymbolString> distinctGenerated, bool force = false) {
232      updates++;
233
234      if (force || updates % GuiUpdateInterval == 0) {
235        var allGeneratedEnum = allGenerated.Values.SelectMany(x => x).ToArray();
236        ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
237        ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
238        ((IntValue)Results[GeneratedSentencesName].Value).Value = allGeneratedEnum.Length;
239        ((IntValue)Results[DistinctSentencesName].Value).Value = distinctGenerated.Count;
240        ((IntValue)Results[PhraseExpansionsName].Value).Value = Expansions;
241        ((IntValue)Results[GeneratedEqualSentencesName].Value).Value = EqualGeneratedSentences;
242        ((DoubleValue)Results[AverageTreeLengthName].Value).Value = allGeneratedEnum.Select(sentence => sentence.Count).Average();
243      }
244    }
245
246    // Generate all Results after an algorithm run.
247    private void UpdateFinalResults() {
248      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence);
249      SymbolicRegressionModel model = new SymbolicRegressionModel(
250        Problem.ProblemData.TargetVariable,
251        tree,
252        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
253
254      IRegressionSolution bestTrainingSolution = new RegressionSolution(model, Problem.ProblemData);
255      Results.AddOrUpdateResult(BestTrainingSolutionName, bestTrainingSolution);
256
257      // Print generated sentences.
258      string[,] sentencesMatrix = new string[AllSentences.Values.SelectMany(x => x).Count(), 3];
259
260      int i = 0;
261      foreach (var sentenceSet in AllSentences) {
262        foreach (var sentence in sentenceSet.Value) {
263          sentencesMatrix[i, 0] = sentence.ToString();
264          sentencesMatrix[i, 1] = Grammar.PostfixToInfixParser(sentence).ToString();
265          sentencesMatrix[i, 2] = sentenceSet.Key.ToString();
266          i++;
267        }
268      }
269      Results.Add(new Result("All generated sentences", new StringMatrix(sentencesMatrix)));
270
271      string[,] distinctSentencesMatrix = new string[DistinctSentences.Count, 3];
272      i = 0;
273      foreach (KeyValuePair<int, SymbolString> distinctSentence in DistinctSentences) {
274        distinctSentencesMatrix[i, 0] = distinctSentence.Key.ToString();
275        distinctSentencesMatrix[i, 1] = Grammar.PostfixToInfixParser(distinctSentence.Value).ToString();
276        distinctSentencesMatrix[i, 2] = distinctSentence.Key.ToString();
277        i++;
278      }
279      Results.Add(new Result("Distinct generated sentences", new StringMatrix(distinctSentencesMatrix)));
280    }
281   
282    private void LogSearchGraph(TraceListener listener, String msg) {
283#if false
284      listener.Write(msg);
285#endif
286    }
287    #endregion
288
289  }
290}
Note: See TracBrowser for help on using the repository browser.