[15765] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
[15800] | 3 | using System.Diagnostics;
|
---|
| 4 | using System.IO;
|
---|
[15712] | 5 | using System.Linq;
|
---|
| 6 | using System.Threading;
|
---|
| 7 | using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
|
---|
| 8 | using HeuristicLab.Common;
|
---|
| 9 | using HeuristicLab.Core;
|
---|
| 10 | using HeuristicLab.Data;
|
---|
[15722] | 11 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
[15712] | 12 | using HeuristicLab.Optimization;
|
---|
[15722] | 13 | using HeuristicLab.Parameters;
|
---|
[15712] | 14 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
| 15 | using HeuristicLab.Problems.DataAnalysis;
|
---|
[15722] | 16 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 17 | using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
|
---|
[15712] | 18 |
|
---|
| 19 | namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
|
---|
| 20 | [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
|
---|
| 21 | [StorableClass]
|
---|
| 22 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
|
---|
| 23 | public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
[15746] | 24 | #region properties and result names
|
---|
| 25 | private readonly string BestTrainingQualityName = "Best R² (Training)";
|
---|
| 26 | private readonly string BestTrainingSolutionName = "Best solution (Training)";
|
---|
[15803] | 27 | private readonly string SearchStructureSizeName = "Search Structure Size";
|
---|
| 28 | private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
|
---|
[15746] | 29 | private readonly string GeneratedSentencesName = "Generated Sentences";
|
---|
| 30 | private readonly string DistinctSentencesName = "Distinct Sentences";
|
---|
| 31 | private readonly string PhraseExpansionsName = "Phrase Expansions";
|
---|
| 32 | private readonly string AverageTreeLengthName = "Avg. Sentence Length among Distinct";
|
---|
| 33 | private readonly string GeneratedEqualSentencesName = "Generated equal sentences";
|
---|
[15712] | 34 |
|
---|
[15746] | 35 |
|
---|
| 36 | private readonly string SearchDataStructureParameterName = "Search Data Structure";
|
---|
[15722] | 37 | private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
|
---|
| 38 | private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
|
---|
[15765] | 39 |
|
---|
[15746] | 40 | public override bool SupportsPause { get { return false; } }
|
---|
[15712] | 41 |
|
---|
[15723] | 42 | protected IValueParameter<IntValue> MaxTreeSizeParameter {
|
---|
[15722] | 43 | get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
|
---|
[15712] | 44 | }
|
---|
[15722] | 45 | public int MaxTreeSize {
|
---|
| 46 | get { return MaxTreeSizeParameter.Value.Value; }
|
---|
[15723] | 47 | set { MaxTreeSizeParameter.Value.Value = value; }
|
---|
[15722] | 48 | }
|
---|
[15712] | 49 |
|
---|
[15723] | 50 | protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
|
---|
| 51 | get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
|
---|
[15722] | 52 | }
|
---|
| 53 | public int GuiUpdateInterval {
|
---|
| 54 | get { return GuiUpdateIntervalParameter.Value.Value; }
|
---|
[15723] | 55 | set { GuiUpdateIntervalParameter.Value.Value = value; }
|
---|
[15722] | 56 | }
|
---|
[15712] | 57 |
|
---|
[15746] | 58 | protected IValueParameter<EnumValue<StorageType>> SearchDataStructureParameter {
|
---|
| 59 | get { return (IValueParameter<EnumValue<StorageType>>)Parameters[SearchDataStructureParameterName]; }
|
---|
[15723] | 60 | }
|
---|
[15746] | 61 | public StorageType SearchDataStructure {
|
---|
| 62 | get { return SearchDataStructureParameter.Value.Value; }
|
---|
| 63 | set { SearchDataStructureParameter.Value.Value = value; }
|
---|
[15723] | 64 | }
|
---|
| 65 |
|
---|
[15724] | 66 | public SymbolString BestTrainingSentence;
|
---|
| 67 |
|
---|
[15722] | 68 | #endregion
|
---|
[15712] | 69 |
|
---|
[15800] | 70 | public Dictionary<int, SymbolString> DistinctSentences { get; private set; } // Semantically distinct sentences in a run.
|
---|
| 71 | public Dictionary<int, List<SymbolString>> AllSentences { get; private set; } // All sentences ever generated in a run.
|
---|
[15812] | 72 | public HashSet<int> ArchivedPhrases { get; private set; }
|
---|
| 73 |
|
---|
[15800] | 74 | internal SearchDataStore OpenPhrases { get; private set; } // Stack/Queue/etc. for fetching the next node in the search tree.
|
---|
[15746] | 75 |
|
---|
[15800] | 76 | public int EqualGeneratedSentences { get; private set; } // It is not guaranteed that shorter solutions are found first.
|
---|
| 77 | // When longer solutions are overwritten with shorter ones,
|
---|
| 78 | // this counter is increased.
|
---|
| 79 | public int Expansions { get; private set; } // Number, how many times a nonterminal symbol is replaced with a production rule.
|
---|
| 80 | public Grammar Grammar { get; private set; }
|
---|
[15712] | 81 |
|
---|
[15800] | 82 | private readonly string dotFileName = Environment.GetFolderPath(System.Environment.SpecialFolder.DesktopDirectory) + @"\searchgraph.dot";
|
---|
[15765] | 83 |
|
---|
[15722] | 84 | #region ctors
|
---|
| 85 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 86 | return new GrammarEnumerationAlgorithm(this, cloner);
|
---|
| 87 | }
|
---|
[15712] | 88 |
|
---|
[15722] | 89 | public GrammarEnumerationAlgorithm() {
|
---|
[15723] | 90 | Problem = new RegressionProblem() {
|
---|
[15800] | 91 | ProblemData = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenFunctionNine(seed: 1234).GenerateRegressionData()
|
---|
[15723] | 92 | };
|
---|
| 93 |
|
---|
| 94 | Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(6)));
|
---|
[15800] | 95 | Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(1000)));
|
---|
[15784] | 96 | Parameters.Add(new ValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.Stack)));
|
---|
[15722] | 97 | }
|
---|
[15712] | 98 |
|
---|
[15746] | 99 | public GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
|
---|
[15722] | 100 | #endregion
|
---|
[15712] | 101 |
|
---|
[15722] | 102 | protected override void Run(CancellationToken cancellationToken) {
|
---|
[15746] | 103 | #region init
|
---|
| 104 | InitResults();
|
---|
[15723] | 105 |
|
---|
[15746] | 106 | AllSentences = new Dictionary<int, List<SymbolString>>();
|
---|
[15812] | 107 | ArchivedPhrases = new HashSet<int>();
|
---|
| 108 |
|
---|
[15746] | 109 | DistinctSentences = new Dictionary<int, SymbolString>();
|
---|
| 110 | Expansions = 0;
|
---|
| 111 | EqualGeneratedSentences = 0;
|
---|
| 112 |
|
---|
[15724] | 113 | Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
|
---|
[15712] | 114 |
|
---|
[15746] | 115 | OpenPhrases = new SearchDataStore(SearchDataStructure); // Select search strategy
|
---|
[15734] | 116 | var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
|
---|
[15800] | 117 | var phrase0Hash = Grammar.CalcHashCode(phrase0);
|
---|
[15746] | 118 | #endregion
|
---|
[15712] | 119 |
|
---|
[15800] | 120 | using (TextWriterTraceListener dotFileTrace = new TextWriterTraceListener(new FileStream(dotFileName, FileMode.Create))) {
|
---|
[15803] | 121 | LogSearchGraph(dotFileTrace, "digraph searchgraph {");
|
---|
[15746] | 122 |
|
---|
[15800] | 123 | OpenPhrases.Store(phrase0Hash, phrase0);
|
---|
| 124 | while (OpenPhrases.Count > 0) {
|
---|
| 125 | if (cancellationToken.IsCancellationRequested) break;
|
---|
[15722] | 126 |
|
---|
[15800] | 127 | StoredSymbolString fetchedPhrase = OpenPhrases.GetNext();
|
---|
| 128 | SymbolString currPhrase = fetchedPhrase.SymbolString;
|
---|
[15812] | 129 | #if DEBUG
|
---|
[15803] | 130 | LogSearchGraph(dotFileTrace, $"{fetchedPhrase.Hash} [label=\"{Grammar.PostfixToInfixParser(fetchedPhrase.SymbolString)}\"];");
|
---|
[15812] | 131 | #endif
|
---|
| 132 | ArchivedPhrases.Add(fetchedPhrase.Hash);
|
---|
[15765] | 133 |
|
---|
[15800] | 134 | // expand next nonterminal symbols
|
---|
| 135 | int nonterminalSymbolIndex = currPhrase.FindIndex(s => s is NonterminalSymbol);
|
---|
| 136 | NonterminalSymbol expandedSymbol = currPhrase[nonterminalSymbolIndex] as NonterminalSymbol;
|
---|
[15726] | 137 |
|
---|
[15800] | 138 | foreach (Production productionAlternative in expandedSymbol.Alternatives) {
|
---|
[15812] | 139 | SymbolString newPhrase = new SymbolString(currPhrase.Count + productionAlternative.Count);
|
---|
| 140 | newPhrase.AddRange(currPhrase);
|
---|
[15806] | 141 | newPhrase.RemoveAt(nonterminalSymbolIndex); // TODO: removeat and insertRange are both O(n)
|
---|
[15800] | 142 | newPhrase.InsertRange(nonterminalSymbolIndex, productionAlternative);
|
---|
[15734] | 143 |
|
---|
[15800] | 144 | Expansions++;
|
---|
| 145 | if (newPhrase.Count <= MaxTreeSize) {
|
---|
| 146 | var phraseHash = Grammar.CalcHashCode(newPhrase);
|
---|
[15812] | 147 | #if DEBUG
|
---|
[15803] | 148 | LogSearchGraph(dotFileTrace, $"{fetchedPhrase.Hash} -> {phraseHash} [label=\"{expandedSymbol.StringRepresentation} + → {productionAlternative}\"];");
|
---|
[15812] | 149 | #endif
|
---|
[15800] | 150 | if (newPhrase.IsSentence()) {
|
---|
| 151 | // Sentence was generated.
|
---|
| 152 | SaveToAllSentences(phraseHash, newPhrase);
|
---|
[15734] | 153 |
|
---|
[15800] | 154 | if (!DistinctSentences.ContainsKey(phraseHash) || DistinctSentences[phraseHash].Count > newPhrase.Count) {
|
---|
| 155 | if (DistinctSentences.ContainsKey(phraseHash)) EqualGeneratedSentences++; // for analysis only
|
---|
[15712] | 156 |
|
---|
[15800] | 157 | DistinctSentences[phraseHash] = newPhrase;
|
---|
| 158 | EvaluateSentence(newPhrase);
|
---|
[15765] | 159 |
|
---|
[15812] | 160 | #if DEBUG
|
---|
[15803] | 161 | LogSearchGraph(dotFileTrace, $"{phraseHash} [label=\"{Grammar.PostfixToInfixParser(newPhrase)}\", style=\"filled\"];");
|
---|
[15812] | 162 | #endif
|
---|
[15800] | 163 | }
|
---|
[15812] | 164 | UpdateView();
|
---|
[15800] | 165 |
|
---|
[15812] | 166 | } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
|
---|
[15800] | 167 | OpenPhrases.Store(phraseHash, newPhrase);
|
---|
[15746] | 168 | }
|
---|
[15712] | 169 | }
|
---|
| 170 | }
|
---|
| 171 | }
|
---|
[15812] | 172 | #if DEBUG
|
---|
[15800] | 173 | // Overwrite formatting of start search node and best found solution.
|
---|
[15803] | 174 | LogSearchGraph(dotFileTrace, $"{Grammar.CalcHashCode(BestTrainingSentence)} [label=\"{Grammar.PostfixToInfixParser(BestTrainingSentence)}\", shape=Mcircle, style=\"filled,bold\"];");
|
---|
| 175 | LogSearchGraph(dotFileTrace, $"{phrase0Hash} [label=\"{Grammar.PostfixToInfixParser(phrase0)}\", shape=doublecircle];}}");
|
---|
[15800] | 176 | dotFileTrace.Flush();
|
---|
[15812] | 177 | #endif
|
---|
[15712] | 178 | }
|
---|
| 179 |
|
---|
[15812] | 180 | UpdateView(force: true);
|
---|
[15746] | 181 | UpdateFinalResults();
|
---|
| 182 | }
|
---|
[15723] | 183 |
|
---|
[15746] | 184 | // Store sentence to "MultiDictionary"
|
---|
| 185 | private void SaveToAllSentences(int hash, SymbolString sentence) {
|
---|
| 186 | if (AllSentences.ContainsKey(hash))
|
---|
| 187 | AllSentences[hash].Add(sentence);
|
---|
| 188 | else
|
---|
[15806] | 189 | AllSentences[hash] = new List<SymbolString> { sentence }; //TODO: here we store all sentences even if they have the same hash value, this is not strictly necessary
|
---|
[15722] | 190 | }
|
---|
[15712] | 191 |
|
---|
[15812] | 192 | #region Evaluation of generated models.
|
---|
[15712] | 193 |
|
---|
[15746] | 194 | // Evaluate sentence within an algorithm run.
|
---|
[15722] | 195 | private void EvaluateSentence(SymbolString symbolString) {
|
---|
[15724] | 196 | SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(symbolString);
|
---|
[15712] | 197 | SymbolicRegressionModel model = new SymbolicRegressionModel(
|
---|
| 198 | Problem.ProblemData.TargetVariable,
|
---|
| 199 | tree,
|
---|
| 200 | new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
|
---|
[15746] | 201 |
|
---|
[15734] | 202 | var probData = Problem.ProblemData;
|
---|
| 203 | var target = probData.TargetVariableTrainingValues;
|
---|
| 204 | var estVals = model.GetEstimatedValues(probData.Dataset, probData.TrainingIndices);
|
---|
| 205 | OnlineCalculatorError error;
|
---|
| 206 | var r2 = OnlinePearsonsRSquaredCalculator.Calculate(target, estVals, out error);
|
---|
| 207 | if (error != OnlineCalculatorError.None) r2 = 0.0;
|
---|
[15712] | 208 |
|
---|
[15746] | 209 | var bestR2 = ((DoubleValue)Results[BestTrainingQualityName].Value).Value;
|
---|
| 210 | if (r2 > bestR2) {
|
---|
| 211 | ((DoubleValue)Results[BestTrainingQualityName].Value).Value = r2;
|
---|
| 212 | BestTrainingSentence = symbolString;
|
---|
| 213 | }
|
---|
| 214 | }
|
---|
[15712] | 215 |
|
---|
[15812] | 216 | #endregion
|
---|
[15746] | 217 |
|
---|
[15812] | 218 | #region Visualization in HL
|
---|
[15746] | 219 | // Initialize entries in result set.
|
---|
| 220 | private void InitResults() {
|
---|
| 221 | BestTrainingSentence = null;
|
---|
| 222 |
|
---|
| 223 | Results.Add(new Result(BestTrainingQualityName, new DoubleValue(-1.0)));
|
---|
| 224 |
|
---|
[15803] | 225 | Results.Add(new Result(GeneratedPhrasesName, new IntValue(0)));
|
---|
| 226 | Results.Add(new Result(SearchStructureSizeName, new IntValue(0)));
|
---|
[15746] | 227 | Results.Add(new Result(GeneratedSentencesName, new IntValue(0)));
|
---|
| 228 | Results.Add(new Result(DistinctSentencesName, new IntValue(0)));
|
---|
| 229 | Results.Add(new Result(PhraseExpansionsName, new IntValue(0)));
|
---|
| 230 | Results.Add(new Result(GeneratedEqualSentencesName, new IntValue(0)));
|
---|
| 231 | Results.Add(new Result(AverageTreeLengthName, new DoubleValue(1.0)));
|
---|
[15712] | 232 | }
|
---|
[15746] | 233 |
|
---|
| 234 | // Update the view for intermediate results in an algorithm run.
|
---|
| 235 | private int updates;
|
---|
[15812] | 236 | private void UpdateView(bool force = false) {
|
---|
[15746] | 237 | updates++;
|
---|
| 238 |
|
---|
[15812] | 239 | if (force || updates % GuiUpdateInterval == 1) {
|
---|
| 240 | var allGeneratedEnum = AllSentences.Values.SelectMany(x => x).ToArray();
|
---|
[15803] | 241 | ((IntValue)Results[GeneratedPhrasesName].Value).Value = ArchivedPhrases.Count;
|
---|
| 242 | ((IntValue)Results[SearchStructureSizeName].Value).Value = OpenPhrases.Count;
|
---|
[15746] | 243 | ((IntValue)Results[GeneratedSentencesName].Value).Value = allGeneratedEnum.Length;
|
---|
[15812] | 244 | ((IntValue)Results[DistinctSentencesName].Value).Value = DistinctSentences.Count;
|
---|
[15800] | 245 | ((IntValue)Results[PhraseExpansionsName].Value).Value = Expansions;
|
---|
[15746] | 246 | ((IntValue)Results[GeneratedEqualSentencesName].Value).Value = EqualGeneratedSentences;
|
---|
| 247 | ((DoubleValue)Results[AverageTreeLengthName].Value).Value = allGeneratedEnum.Select(sentence => sentence.Count).Average();
|
---|
| 248 | }
|
---|
| 249 | }
|
---|
| 250 |
|
---|
| 251 | // Generate all Results after an algorithm run.
|
---|
| 252 | private void UpdateFinalResults() {
|
---|
| 253 | SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence);
|
---|
| 254 | SymbolicRegressionModel model = new SymbolicRegressionModel(
|
---|
| 255 | Problem.ProblemData.TargetVariable,
|
---|
| 256 | tree,
|
---|
| 257 | new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
|
---|
| 258 |
|
---|
| 259 | IRegressionSolution bestTrainingSolution = new RegressionSolution(model, Problem.ProblemData);
|
---|
| 260 | Results.AddOrUpdateResult(BestTrainingSolutionName, bestTrainingSolution);
|
---|
| 261 |
|
---|
| 262 | // Print generated sentences.
|
---|
| 263 | string[,] sentencesMatrix = new string[AllSentences.Values.SelectMany(x => x).Count(), 3];
|
---|
| 264 |
|
---|
| 265 | int i = 0;
|
---|
| 266 | foreach (var sentenceSet in AllSentences) {
|
---|
| 267 | foreach (var sentence in sentenceSet.Value) {
|
---|
| 268 | sentencesMatrix[i, 0] = sentence.ToString();
|
---|
| 269 | sentencesMatrix[i, 1] = Grammar.PostfixToInfixParser(sentence).ToString();
|
---|
| 270 | sentencesMatrix[i, 2] = sentenceSet.Key.ToString();
|
---|
| 271 | i++;
|
---|
| 272 | }
|
---|
| 273 | }
|
---|
| 274 | Results.Add(new Result("All generated sentences", new StringMatrix(sentencesMatrix)));
|
---|
| 275 |
|
---|
| 276 | string[,] distinctSentencesMatrix = new string[DistinctSentences.Count, 3];
|
---|
| 277 | i = 0;
|
---|
| 278 | foreach (KeyValuePair<int, SymbolString> distinctSentence in DistinctSentences) {
|
---|
| 279 | distinctSentencesMatrix[i, 0] = distinctSentence.Key.ToString();
|
---|
| 280 | distinctSentencesMatrix[i, 1] = Grammar.PostfixToInfixParser(distinctSentence.Value).ToString();
|
---|
| 281 | distinctSentencesMatrix[i, 2] = distinctSentence.Key.ToString();
|
---|
| 282 | i++;
|
---|
| 283 | }
|
---|
| 284 | Results.Add(new Result("Distinct generated sentences", new StringMatrix(distinctSentencesMatrix)));
|
---|
| 285 | }
|
---|
[15812] | 286 |
|
---|
| 287 | private void LogSearchGraph(TraceListener listener, string msg) {
|
---|
[15803] | 288 | listener.Write(msg);
|
---|
| 289 | }
|
---|
[15812] | 290 | #endregion
|
---|
[15765] | 291 |
|
---|
[15712] | 292 | }
|
---|
| 293 | } |
---|