Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15812

Last change on this file since 15812 was 15812, checked in by lkammere, 6 years ago

#2886: Performance Improvements - Only store hash of archived phrases and reduce number of enumerators.

File size: 14.0 KB
RevLine 
[15765]1using System;
2using System.Collections.Generic;
[15800]3using System.Diagnostics;
4using System.IO;
[15712]5using System.Linq;
6using System.Threading;
7using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
[15722]11using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[15712]12using HeuristicLab.Optimization;
[15722]13using HeuristicLab.Parameters;
[15712]14using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
15using HeuristicLab.Problems.DataAnalysis;
[15722]16using HeuristicLab.Problems.DataAnalysis.Symbolic;
17using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
[15712]18
19namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
20  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
21  [StorableClass]
22  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
23  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
[15746]24    #region properties and result names
25    private readonly string BestTrainingQualityName = "Best R² (Training)";
26    private readonly string BestTrainingSolutionName = "Best solution (Training)";
[15803]27    private readonly string SearchStructureSizeName = "Search Structure Size";
28    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
[15746]29    private readonly string GeneratedSentencesName = "Generated Sentences";
30    private readonly string DistinctSentencesName = "Distinct Sentences";
31    private readonly string PhraseExpansionsName = "Phrase Expansions";
32    private readonly string AverageTreeLengthName = "Avg. Sentence Length among Distinct";
33    private readonly string GeneratedEqualSentencesName = "Generated equal sentences";
[15712]34
[15746]35
36    private readonly string SearchDataStructureParameterName = "Search Data Structure";
[15722]37    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
38    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
[15765]39
[15746]40    public override bool SupportsPause { get { return false; } }
[15712]41
[15723]42    protected IValueParameter<IntValue> MaxTreeSizeParameter {
[15722]43      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
[15712]44    }
[15722]45    public int MaxTreeSize {
46      get { return MaxTreeSizeParameter.Value.Value; }
[15723]47      set { MaxTreeSizeParameter.Value.Value = value; }
[15722]48    }
[15712]49
[15723]50    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
51      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
[15722]52    }
53    public int GuiUpdateInterval {
54      get { return GuiUpdateIntervalParameter.Value.Value; }
[15723]55      set { GuiUpdateIntervalParameter.Value.Value = value; }
[15722]56    }
[15712]57
[15746]58    protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
59      get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
[15723]60    }
[15746]61    public StorageType SearchDataStructure {
62      get { return SearchDataStructureParameter.Value.Value; }
63      set { SearchDataStructureParameter.Value.Value = value; }
[15723]64    }
65
[15724]66    public SymbolString BestTrainingSentence;
67
[15722]68    #endregion
[15712]69
[15800]70    public Dictionary<int, SymbolString> DistinctSentences { get; private set; }  // Semantically distinct sentences in a run.
71    public Dictionary<int, List<SymbolString>> AllSentences { get; private set; } // All sentences ever generated in a run.
[15812]72    public HashSet<int> ArchivedPhrases { get; private set; }
73
[15800]74    internal SearchDataStore OpenPhrases { get; private set; }                    // Stack/Queue/etc. for fetching the next node in the search tree. 
[15746]75
[15800]76    public int EqualGeneratedSentences { get; private set; } // It is not guaranteed that shorter solutions are found first.
77                                                             // When longer solutions are overwritten with shorter ones,
78                                                             // this counter is increased.
79    public int Expansions { get; private set; }              // Number, how many times a nonterminal symbol is replaced with a production rule.
80    public Grammar Grammar { get; private set; }
[15712]81
[15800]82    private readonly string dotFileName = Environment.GetFolderPath(System.Environment.SpecialFolder.DesktopDirectory) + @"\searchgraph.dot";
[15765]83
[15722]84    #region ctors
85    public override IDeepCloneable Clone(Cloner cloner) {
86      return new GrammarEnumerationAlgorithm(this, cloner);
87    }
[15712]88
[15722]89    public GrammarEnumerationAlgorithm() {
[15723]90      Problem = new RegressionProblem() {
[15800]91        ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenFunctionNine(seed: 1234).GenerateRegressionData()
[15723]92      };
93
94      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(6)));
[15800]95      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(1000)));
[15784]96      Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.Stack)));
[15722]97    }
[15712]98
[15746]99    public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
[15722]100    #endregion
[15712]101
[15722]102    protected override void Run(CancellationToken cancellationToken) {
[15746]103      #region init
104      InitResults();
[15723]105
[15746]106      AllSentences = new Dictionary<int, List<SymbolString>>();
[15812]107      ArchivedPhrases = new HashSet<int>();
108
[15746]109      DistinctSentences = new Dictionary<int, SymbolString>();
110      Expansions = 0;
111      EqualGeneratedSentences = 0;
112
[15724]113      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
[15712]114
[15746]115      OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
[15734]116      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
[15800]117      var phrase0Hash = Grammar.CalcHashCode(phrase0);
[15746]118      #endregion
[15712]119
[15800]120      using (TextWriterTraceListener dotFileTrace = new TextWriterTraceListener(new FileStream(dotFileName, FileMode.Create))) {
[15803]121        LogSearchGraph(dotFileTrace, "digraph searchgraph {");
[15746]122
[15800]123        OpenPhrases.Store(phrase0Hash, phrase0);
124        while (OpenPhrases.Count > 0) {
125          if (cancellationToken.IsCancellationRequested) break;
[15722]126
[15800]127          StoredSymbolString fetchedPhrase = OpenPhrases.GetNext();
128          SymbolString currPhrase = fetchedPhrase.SymbolString;
[15812]129#if DEBUG
[15803]130          LogSearchGraph(dotFileTrace, $"{fetchedPhrase.Hash} [label=\"{Grammar.PostfixToInfixParser(fetchedPhrase.SymbolString)}\"];");
[15812]131#endif
132          ArchivedPhrases.Add(fetchedPhrase.Hash);
[15765]133
[15800]134          // expand next nonterminal symbols
135          int nonterminalSymbolIndex = currPhrase.FindIndex(s => s is NonterminalSymbol);
136          NonterminalSymbol expandedSymbol = currPhrase[nonterminalSymbolIndex] as NonterminalSymbol;
[15726]137
[15800]138          foreach (Production productionAlternative in expandedSymbol.Alternatives) {
[15812]139            SymbolString newPhrase = new SymbolString(currPhrase.Count + productionAlternative.Count);
140            newPhrase.AddRange(currPhrase);
[15806]141            newPhrase.RemoveAt(nonterminalSymbolIndex);     // TODO: removeat and insertRange are both O(n)
[15800]142            newPhrase.InsertRange(nonterminalSymbolIndex, productionAlternative);
[15734]143
[15800]144            Expansions++;
145            if (newPhrase.Count <= MaxTreeSize) {
146              var phraseHash = Grammar.CalcHashCode(newPhrase);
[15812]147#if DEBUG
[15803]148              LogSearchGraph(dotFileTrace, $"{fetchedPhrase.Hash} -> {phraseHash} [label=\"{expandedSymbol.StringRepresentation} + &rarr; {productionAlternative}\"];");
[15812]149#endif
[15800]150              if (newPhrase.IsSentence()) {
151                // Sentence was generated.
152                SaveToAllSentences(phraseHash, newPhrase);
[15734]153
[15800]154                if (!DistinctSentences.ContainsKey(phraseHash) || DistinctSentences[phraseHash].Count > newPhrase.Count) {
155                  if (DistinctSentences.ContainsKey(phraseHash)) EqualGeneratedSentences++; // for analysis only
[15712]156
[15800]157                  DistinctSentences[phraseHash] = newPhrase;
158                  EvaluateSentence(newPhrase);
[15765]159
[15812]160#if DEBUG
[15803]161                  LogSearchGraph(dotFileTrace, $"{phraseHash} [label=\"{Grammar.PostfixToInfixParser(newPhrase)}\", style=\"filled\"];");
[15812]162#endif
[15800]163                }
[15812]164                UpdateView();
[15800]165
[15812]166              } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
[15800]167                OpenPhrases.Store(phraseHash, newPhrase);
[15746]168              }
[15712]169            }
170          }
171        }
[15812]172#if DEBUG
[15800]173        // Overwrite formatting of start search node and best found solution.
[15803]174        LogSearchGraph(dotFileTrace, $"{Grammar.CalcHashCode(BestTrainingSentence)} [label=\"{Grammar.PostfixToInfixParser(BestTrainingSentence)}\", shape=Mcircle, style=\"filled,bold\"];");
175        LogSearchGraph(dotFileTrace, $"{phrase0Hash} [label=\"{Grammar.PostfixToInfixParser(phrase0)}\", shape=doublecircle];}}");
[15800]176        dotFileTrace.Flush();
[15812]177#endif
[15712]178      }
179
[15812]180      UpdateView(force: true);
[15746]181      UpdateFinalResults();
182    }
[15723]183
[15746]184    // Store sentence to "MultiDictionary"
185    private void SaveToAllSentences(int hash, SymbolString sentence) {
186      if (AllSentences.ContainsKey(hash))
187        AllSentences[hash].Add(sentence);
188      else
[15806]189        AllSentences[hash] = new List<SymbolString> { sentence }; //TODO: here we store all sentences even if they have the same hash value, this is not strictly necessary
[15722]190    }
[15712]191
[15812]192#region Evaluation of generated models.
[15712]193
[15746]194    // Evaluate sentence within an algorithm run.
[15722]195    private void EvaluateSentence(SymbolString symbolString) {
[15724]196      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(symbolString);
[15712]197      SymbolicRegressionModel model = new SymbolicRegressionModel(
198        Problem.ProblemData.TargetVariable,
199        tree,
200        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
[15746]201
[15734]202      var probData = Problem.ProblemData;
203      var target = probData.TargetVariableTrainingValues;
204      var estVals = model.GetEstimatedValues(probData.Dataset, probData.TrainingIndices);
205      OnlineCalculatorError error;
206      var r2 = OnlinePearsonsRSquaredCalculator.Calculate(target, estVals, out error);
207      if (error != OnlineCalculatorError.None) r2 = 0.0;
[15712]208
[15746]209      var bestR2 = ((DoubleValue)Results[BestTrainingQualityName].Value).Value;
210      if (r2 > bestR2) {
211        ((DoubleValue)Results[BestTrainingQualityName].Value).Value = r2;
212        BestTrainingSentence = symbolString;
213      }
214    }
[15712]215
[15812]216#endregion
[15746]217
[15812]218#region Visualization in HL
[15746]219    // Initialize entries in result set.
220    private void InitResults() {
221      BestTrainingSentence = null;
222
223      Results.Add(new Result(BestTrainingQualityName, new DoubleValue(-1.0)));
224
[15803]225      Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
226      Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
[15746]227      Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
228      Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
229      Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
230      Results.Add(new Result(GeneratedEqualSentencesName, new IntValue(0)));
231      Results.Add(new Result(AverageTreeLengthName, new DoubleValue(1.0)));
[15712]232    }
[15746]233
234    // Update the view for intermediate results in an algorithm run.
235    private int updates;
[15812]236    private void UpdateView(bool force = false) {
[15746]237      updates++;
238
[15812]239      if (force || updates % GuiUpdateInterval == 1) {
240        var allGeneratedEnum = AllSentences.Values.SelectMany(x => x).ToArray();
[15803]241        ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
242        ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
[15746]243        ((IntValue)Results[GeneratedSentencesName].Value).Value = allGeneratedEnum.Length;
[15812]244        ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentences.Count;
[15800]245        ((IntValue)Results[PhraseExpansionsName].Value).Value = Expansions;
[15746]246        ((IntValue)Results[GeneratedEqualSentencesName].Value).Value = EqualGeneratedSentences;
247        ((DoubleValue)Results[AverageTreeLengthName].Value).Value = allGeneratedEnum.Select(sentence => sentence.Count).Average();
248      }
249    }
250
251    // Generate all Results after an algorithm run.
252    private void UpdateFinalResults() {
253      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence);
254      SymbolicRegressionModel model = new SymbolicRegressionModel(
255        Problem.ProblemData.TargetVariable,
256        tree,
257        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
258
259      IRegressionSolution bestTrainingSolution = new RegressionSolution(model, Problem.ProblemData);
260      Results.AddOrUpdateResult(BestTrainingSolutionName, bestTrainingSolution);
261
262      // Print generated sentences.
263      string[,] sentencesMatrix = new string[AllSentences.Values.SelectMany(x => x).Count(), 3];
264
265      int i = 0;
266      foreach (var sentenceSet in AllSentences) {
267        foreach (var sentence in sentenceSet.Value) {
268          sentencesMatrix[i, 0] = sentence.ToString();
269          sentencesMatrix[i, 1] = Grammar.PostfixToInfixParser(sentence).ToString();
270          sentencesMatrix[i, 2] = sentenceSet.Key.ToString();
271          i++;
272        }
273      }
274      Results.Add(new Result("All generated sentences", new StringMatrix(sentencesMatrix)));
275
276      string[,] distinctSentencesMatrix = new string[DistinctSentences.Count, 3];
277      i = 0;
278      foreach (KeyValuePair<int, SymbolString> distinctSentence in DistinctSentences) {
279        distinctSentencesMatrix[i, 0] = distinctSentence.Key.ToString();
280        distinctSentencesMatrix[i, 1] = Grammar.PostfixToInfixParser(distinctSentence.Value).ToString();
281        distinctSentencesMatrix[i, 2] = distinctSentence.Key.ToString();
282        i++;
283      }
284      Results.Add(new Result("Distinct generated sentences", new StringMatrix(distinctSentencesMatrix)));
285    }
[15812]286
287    private void LogSearchGraph(TraceListener listener, string msg) {
[15803]288      listener.Write(msg);
289    }
[15812]290#endregion
[15765]291
[15712]292  }
293}
Note: See TracBrowser for help on using the repository browser.