source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15724

Last change on this file since 15724 was 15724, checked in by lkammere, 21 months ago

#2886: Add parsing to infix form for debugging purpose.

File size: 8.9 KB
RevLine 
[15723]1using System.Collections.Generic;
[15712]2using System.Linq;
3using System.Threading;
4using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
[15722]8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[15712]9using HeuristicLab.Optimization;
[15722]10using HeuristicLab.Parameters;
[15712]11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
12using HeuristicLab.Problems.DataAnalysis;
[15722]13using HeuristicLab.Problems.DataAnalysis.Symbolic;
14using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
[15712]15
16namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
17  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
18  [StorableClass]
19  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
20  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
[15722]21    private readonly string BestTrainingSolution = "Best solution (training)";
22    private readonly string BestTrainingSolutionQuality = "Best solution quality (training)";
23    private readonly string BestTestSolution = "Best solution (test)";
24    private readonly string BestTestSolutionQuality = "Best solution quality (test)";
[15712]25
[15722]26    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
27    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
[15723]28    private readonly string UseMemoizationParameterName = "Use Memoization?";
[15712]29
30
[15722]31    #region properties
[15723]32    protected IValueParameter<IntValue> MaxTreeSizeParameter {
[15722]33      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
[15712]34    }
[15722]35    public int MaxTreeSize {
36      get { return MaxTreeSizeParameter.Value.Value; }
[15723]37      set { MaxTreeSizeParameter.Value.Value = value; }
[15722]38    }
[15712]39
[15723]40    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
41      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
[15722]42    }
43    public int GuiUpdateInterval {
44      get { return GuiUpdateIntervalParameter.Value.Value; }
[15723]45      set { GuiUpdateIntervalParameter.Value.Value = value; }
[15722]46    }
[15712]47
[15723]48    protected IValueParameter<BoolValue> UseMemoizationParameter {
49      get { return (IValueParameter<BoolValue>)Parameters[UseMemoizationParameterName]; }
50    }
51    public bool UseMemoization {
52      get { return UseMemoizationParameter.Value.Value; }
53      set { UseMemoizationParameter.Value.Value = value; }
54    }
55
[15724]56    public SymbolString BestTrainingSentence;
57    public SymbolString BestTestSentence;
58
[15722]59    #endregion
[15712]60
[15724]61    public Grammar Grammar;
[15712]62
63
[15722]64    #region ctors
65    public override IDeepCloneable Clone(Cloner cloner) {
66      return new GrammarEnumerationAlgorithm(this, cloner);
67    }
[15712]68
[15722]69    public GrammarEnumerationAlgorithm() {
[15712]70
[15723]71      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
72      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
73
74      Problem = new RegressionProblem() {
75        ProblemData = regProblem
76      };
77
78      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(6)));
[15722]79      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(4000)));
[15723]80      Parameters.Add(new ValueParameter<BoolValue>(UseMemoizationParameterName, "Should already subtrees be reused within a run.", new BoolValue(true)));
[15722]81    }
[15712]82
[15722]83    private GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
84    #endregion
[15712]85
86
[15722]87    protected override void Run(CancellationToken cancellationToken) {
[15724]88      BestTrainingSentence = null;
89      BestTrainingSentence = null;
90
[15722]91      List<SymbolString> allGenerated = new List<SymbolString>();
92      List<SymbolString> distinctGenerated = new List<SymbolString>();
[15723]93
94      int expansions = 0;
95
[15722]96      HashSet<int> evaluatedHashes = new HashSet<int>();
[15712]97
[15724]98      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
[15712]99
100      Stack<SymbolString> remainingTrees = new Stack<SymbolString>();
[15724]101      remainingTrees.Push(new SymbolString(new[] { Grammar.StartSymbol }));
[15712]102
103      while (remainingTrees.Any()) {
[15722]104        if (cancellationToken.IsCancellationRequested) break;
105
[15712]106        SymbolString currSymbolString = remainingTrees.Pop();
107
108        if (currSymbolString.IsSentence()) {
[15724]109          allGenerated.Add(Grammar.PostfixToInfixParser(currSymbolString));
[15712]110
[15723]111          //if (evaluatedHashes.Add(grammar.CalcHashCode(currSymbolString))) {
112          EvaluateSentence(currSymbolString);
[15724]113          distinctGenerated.Add(Grammar.PostfixToInfixParser(currSymbolString));
[15723]114          //}
[15712]115
[15723]116          UpdateView(allGenerated, distinctGenerated, expansions);
[15722]117
[15712]118        } else {
119          // expand next nonterminal symbols
120          int nonterminalSymbolIndex = currSymbolString.FindIndex(s => s is NonterminalSymbol);
121          NonterminalSymbol expandedSymbol = currSymbolString[nonterminalSymbolIndex] as NonterminalSymbol;
122
123          foreach (Production productionAlternative in expandedSymbol.Alternatives) {
124            SymbolString newSentence = new SymbolString(currSymbolString);
125            newSentence.RemoveAt(nonterminalSymbolIndex);
126            newSentence.InsertRange(nonterminalSymbolIndex, productionAlternative);
127
[15722]128            if (newSentence.Count <= MaxTreeSize) {
[15712]129              remainingTrees.Push(newSentence);
130            }
131          }
132        }
133      }
134
[15723]135      UpdateView(allGenerated, distinctGenerated, expansions, force: true);
136
[15722]137      StringArray sentences = new StringArray(allGenerated.Select(r => r.ToString()).ToArray());
138      Results.Add(new Result("All generated sentences", sentences));
139      StringArray distinctSentences = new StringArray(distinctGenerated.Select(r => r.ToString()).ToArray());
140      Results.Add(new Result("Distinct generated sentences", distinctSentences));
141    }
[15712]142
143
[15723]144    private void UpdateView(List<SymbolString> allGenerated, List<SymbolString> distinctGenerated, int expansions, bool force = false) {
[15722]145      int generatedSolutions = allGenerated.Count;
146      int distinctSolutions = distinctGenerated.Count;
[15712]147
[15723]148      if (force || generatedSolutions % GuiUpdateInterval == 0) {
[15722]149        Results.AddOrUpdateResult("Generated Solutions", new IntValue(generatedSolutions));
[15723]150        Results.AddOrUpdateResult("Distinct Solutions", new IntValue(distinctSolutions));
[15712]151
[15722]152        DoubleValue averageTreeLength = new DoubleValue(allGenerated.Select(r => r.Count).Average());
[15723]153        Results.AddOrUpdateResult("Average Tree Length of Solutions", averageTreeLength);
154
155        IntValue expansionsValue = new IntValue(expansions);
156        Results.AddOrUpdateResult("Expansions", expansionsValue);
[15722]157      }
158    }
[15712]159
[15722]160    private void EvaluateSentence(SymbolString symbolString) {
[15724]161      SymbolicExpressionTree tree = Grammar.ParseSymbolicExpressionTree(symbolString);
[15712]162      SymbolicRegressionModel model = new SymbolicRegressionModel(
163        Problem.ProblemData.TargetVariable,
164        tree,
165        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
166
[15722]167      IRegressionSolution newSolution = model.CreateRegressionSolution(Problem.ProblemData);
[15712]168
[15722]169      IResult currBestTrainingSolutionResult;
170      IResult currBestTestSolutionResult;
171      if (!Results.TryGetValue(BestTrainingSolution, out currBestTrainingSolutionResult)
172           || !Results.TryGetValue(BestTestSolution, out currBestTestSolutionResult)) {
173
[15724]174        BestTrainingSentence = symbolString;
[15722]175        Results.Add(new Result(BestTrainingSolution, newSolution));
176        Results.Add(new Result(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly()));
[15724]177
178        BestTestSentence = symbolString;
[15722]179        Results.Add(new Result(BestTestSolution, newSolution));
180        Results.Add(new Result(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly()));
181
182      } else {
183        IRegressionSolution currBestTrainingSolution = (IRegressionSolution)currBestTrainingSolutionResult.Value;
184        if (currBestTrainingSolution.TrainingRSquared < newSolution.TrainingRSquared) {
[15724]185          BestTrainingSentence = symbolString;
[15722]186          currBestTrainingSolutionResult.Value = newSolution;
187          Results.AddOrUpdateResult(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly());
188        }
189
190        IRegressionSolution currBestTestSolution = (IRegressionSolution)currBestTestSolutionResult.Value;
191        if (currBestTestSolution.TestRSquared < newSolution.TestRSquared) {
[15724]192          BestTestSentence = symbolString;
[15722]193          currBestTestSolutionResult.Value = newSolution;
194          Results.AddOrUpdateResult(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly());
195        }
196      }
[15712]197    }
198  }
199}
Note: See TracBrowser for help on using the repository browser.