source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15734

Last change on this file since 15734 was 15734, checked in by gkronber, 2 years ago

#2886 worked on grammar enumeration

File size: 11.2 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
10using HeuristicLab.Optimization;
11using HeuristicLab.Parameters;
12using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
13using HeuristicLab.Problems.DataAnalysis;
14using HeuristicLab.Problems.DataAnalysis.Symbolic;
15using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
16
17namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
18  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
19  [StorableClass]
20  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
21  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
22    private readonly string BestTrainingSolution = "Best solution (training)";
23    private readonly string BestTrainingSolutionQuality = "Best solution quality (training)";
24    private readonly string BestTestSolution = "Best solution (test)";
25    private readonly string BestTestSolutionQuality = "Best solution quality (test)";
26
27    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
28    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
29    private readonly string UseMemoizationParameterName = "Use Memoization?";
30
31    #region properties
32    protected IValueParameter<IntValue> MaxTreeSizeParameter {
33      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
34    }
35    public int MaxTreeSize {
36      get { return MaxTreeSizeParameter.Value.Value; }
37      set { MaxTreeSizeParameter.Value.Value = value; }
38    }
39
40    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
41      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
42    }
43    public int GuiUpdateInterval {
44      get { return GuiUpdateIntervalParameter.Value.Value; }
45      set { GuiUpdateIntervalParameter.Value.Value = value; }
46    }
47
48    protected IValueParameter<BoolValue> UseMemoizationParameter {
49      get { return (IValueParameter<BoolValue>)Parameters[UseMemoizationParameterName]; }
50    }
51    public bool UseMemoization {
52      get { return UseMemoizationParameter.Value.Value; }
53      set { UseMemoizationParameter.Value.Value = value; }
54    }
55
56    public SymbolString BestTrainingSentence;
57    public SymbolString BestTestSentence;
58
59    public List<Tuple<SymbolString, int>> distinctSentences;
60    public List<Tuple<SymbolString, int>> sentences;
61    #endregion
62
63    public Grammar Grammar;
64
65
66    #region ctors
67    public override IDeepCloneable Clone(Cloner cloner) {
68      return new GrammarEnumerationAlgorithm(this, cloner);
69    }
70
71    public GrammarEnumerationAlgorithm() {
72
73      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
74      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
75
76      Problem = new RegressionProblem() {
77        ProblemData = regProblem
78      };
79
80      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(6)));
81      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(4000)));
82      Parameters.Add(new ValueParameter<BoolValue>(UseMemoizationParameterName, "Should already subtrees be reused within a run.", new BoolValue(true)));
83    }
84
85    private GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
86    #endregion
87
88
89    protected override void Run(CancellationToken cancellationToken) {
90      Results.Add(new Result("Best R²", new DoubleValue(0.0)));
91      var rand = new System.Random(1234);
92      BestTrainingSentence = null;
93      BestTrainingSentence = null;
94      this.sentences = new List<Tuple<SymbolString, int>>();
95      this.distinctSentences = new List<Tuple<SymbolString, int>>();
96      var archivedPhrases = new Dictionary<int, SymbolString>();
97      int expansions = 0;
98      Dictionary<int, SymbolString> evaluatedHashes = new Dictionary<int, SymbolString>();
99
100      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
101
102      var phrases = new Dictionary<int, SymbolString>();
103      var phrase0 = new SymbolString(new[] { Grammar.StartSymbol });
104      phrases.Add(Grammar.CalcHashCode(phrase0), phrase0);
105
106      while (phrases.Any()) {
107        if (cancellationToken.IsCancellationRequested) break;
108
109        // FIFO
110        // SymbolString currSymbolString = phrases.First();
111        // phrases.RemoveAt(0);
112
113       
114        // LIFO
115        // SymbolString currSymbolString = phrases.Last();
116        // phrases.RemoveAt(phrases.Count - 1);
117       
118
119        // RANDOM
120        int idx = rand.Next(phrases.Count);
121        var selectedEntry = phrases.ElementAt(idx);  // TODO: Perf von ElementAt ist schlecht.
122        phrases.Remove(selectedEntry.Key);
123        var currPhrase = selectedEntry.Value;
124
125        archivedPhrases.Add(selectedEntry.Key, selectedEntry.Value);
126
127        if (currPhrase.IsSentence()) {
128          int currSymbolStringHash = Grammar.CalcHashCode(currPhrase);
129          this.sentences.Add(new Tuple<SymbolString, int>(currPhrase, currSymbolStringHash));
130
131          if (!evaluatedHashes.ContainsKey(currSymbolStringHash)) {
132            evaluatedHashes[currSymbolStringHash] = currPhrase;
133
134            this.distinctSentences.Add(new Tuple<SymbolString, int>(currPhrase, currSymbolStringHash));
135            EvaluateSentence(currPhrase);
136          }
137          UpdateView(this.sentences, this.distinctSentences);
138
139        } else {
140          // expand next nonterminal symbols
141          int nonterminalSymbolIndex = currPhrase.FindIndex(s => s is NonterminalSymbol);
142          NonterminalSymbol expandedSymbol = currPhrase[nonterminalSymbolIndex] as NonterminalSymbol;
143
144          foreach (Production productionAlternative in expandedSymbol.Alternatives) {
145            SymbolString newSentence = new SymbolString(currPhrase);
146            newSentence.RemoveAt(nonterminalSymbolIndex);
147            newSentence.InsertRange(nonterminalSymbolIndex, productionAlternative);
148
149            expansions++;
150            if (newSentence.Count <= MaxTreeSize) {
151              var phraseHash = Grammar.CalcHashCode(newSentence);
152              if(!phrases.ContainsKey(phraseHash) &&
153                !archivedPhrases.ContainsKey(phraseHash))
154              phrases.Add(phraseHash, newSentence);
155            }
156          }
157        }
158      }
159
160      UpdateView(this.sentences, this.distinctSentences, force: true);
161
162      string[,] sentences = new string[this.sentences.Count, 3];
163      for (int i = 0; i < this.sentences.Count; i++) {
164        sentences[i, 0] = this.sentences[i].Item1.ToString();
165        sentences[i, 1] = Grammar.PostfixToInfixParser(this.sentences[i].Item1).ToString();
166        sentences[i, 2] = this.sentences[i].Item2.ToString();
167      }
168      Results.Add(new Result("All generated sentences", new StringMatrix(sentences)));
169
170      string[,] distinctSentences = new string[this.distinctSentences.Count, 3];
171      for (int i = 0; i < this.distinctSentences.Count; i++) {
172        distinctSentences[i, 0] = this.distinctSentences[i].Item1.ToString();
173        distinctSentences[i, 1] = Grammar.PostfixToInfixParser(this.distinctSentences[i].Item1).ToString();
174        distinctSentences[i, 2] = this.distinctSentences[i].Item2.ToString();
175      }
176      Results.Add(new Result("Distinct generated sentences", new StringMatrix(distinctSentences)));
177    }
178
179
180    private void UpdateView(List<Tuple<SymbolString, int>> allGenerated,
181        List<Tuple<SymbolString, int>> distinctGenerated, bool force = false) {
182      int generatedSolutions = allGenerated.Count;
183      int distinctSolutions = distinctGenerated.Count;
184
185      if (force || generatedSolutions % GuiUpdateInterval == 0) {
186        Results.AddOrUpdateResult("Generated Solutions", new IntValue(generatedSolutions));
187        Results.AddOrUpdateResult("Distinct Solutions", new IntValue(distinctSolutions));
188
189        DoubleValue averageTreeLength = new DoubleValue(allGenerated.Select(r => r.Item1.Count).Average());
190        Results.AddOrUpdateResult("Average Tree Length of Solutions", averageTreeLength);
191      }
192    }
193
194    private void EvaluateSentence(SymbolString symbolString) {
195      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(symbolString);
196      SymbolicRegressionModel model = new SymbolicRegressionModel(
197        Problem.ProblemData.TargetVariable,
198        tree,
199        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
200      var probData = Problem.ProblemData;
201      var target = probData.TargetVariableTrainingValues;
202      var estVals = model.GetEstimatedValues(probData.Dataset, probData.TrainingIndices);
203      OnlineCalculatorError error;
204      var r2 = OnlinePearsonsRSquaredCalculator.Calculate(target, estVals, out error);
205      if (error != OnlineCalculatorError.None) r2 = 0.0;
206
207      var bestR2 = ((DoubleValue)(Results["Best R²"]).Value).Value;
208      ((DoubleValue)(Results["Best R²"].Value)).Value = Math.Max(r2, bestR2);
209
210      // IRegressionSolution newSolution = model.CreateRegressionSolution(Problem.ProblemData);
211      //
212      // IResult currBestTrainingSolutionResult;
213      // IResult currBestTestSolutionResult;
214      // if (!Results.TryGetValue(BestTrainingSolution, out currBestTrainingSolutionResult)
215      //      || !Results.TryGetValue(BestTestSolution, out currBestTestSolutionResult)) {
216      //
217      //   BestTrainingSentence = symbolString;
218      //   Results.Add(new Result(BestTrainingSolution, newSolution));
219      //   Results.Add(new Result(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly()));
220      //
221      //   BestTestSentence = symbolString;
222      //   Results.Add(new Result(BestTestSolution, newSolution));
223      //   Results.Add(new Result(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly()));
224      //
225      // } else {
226      //   IRegressionSolution currBestTrainingSolution = (IRegressionSolution)currBestTrainingSolutionResult.Value;
227      //   if (currBestTrainingSolution.TrainingRSquared <= newSolution.TrainingRSquared) {
228      //     BestTrainingSentence = symbolString;
229      //     currBestTrainingSolutionResult.Value = newSolution;
230      //     Results.AddOrUpdateResult(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly());
231      //   }
232      //
233      //   IRegressionSolution currBestTestSolution = (IRegressionSolution)currBestTestSolutionResult.Value;
234      //   if (currBestTestSolution.TestRSquared <= newSolution.TestRSquared) {
235      //     BestTestSentence = symbolString;
236      //     currBestTestSolutionResult.Value = newSolution;
237      //     Results.AddOrUpdateResult(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly());
238      //   }
239      // }
240    }
241  }
242}
Note: See TracBrowser for help on using the repository browser.