Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15723

Last change on this file since 15723 was 15723, checked in by lkammere, 6 years ago

#2886: Add simple data analysis tests and further informations about algorithm run.

File size: 9.1 KB
Line 
1using System.Collections.Generic;
2using System.Linq;
3using System.Threading;
4using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
9using HeuristicLab.Optimization;
10using HeuristicLab.Parameters;
11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
12using HeuristicLab.Problems.DataAnalysis;
13using HeuristicLab.Problems.DataAnalysis.Symbolic;
14using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
15
16namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
17  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
18  [StorableClass]
19  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
20  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
21    private readonly string BestTrainingSolution = "Best solution (training)";
22    private readonly string BestTrainingSolutionString = "Best solution string (training)";
23    private readonly string BestTrainingSolutionQuality = "Best solution quality (training)";
24    private readonly string BestTestSolution = "Best solution (test)";
25    private readonly string BestTestSolutionString = "Best solution string (test)";
26    private readonly string BestTestSolutionQuality = "Best solution quality (test)";
27
28    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
29    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
30    private readonly string UseMemoizationParameterName = "Use Memoization?";
31
32
33    #region properties
34    protected IValueParameter<IntValue> MaxTreeSizeParameter {
35      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
36    }
37    public int MaxTreeSize {
38      get { return MaxTreeSizeParameter.Value.Value; }
39      set { MaxTreeSizeParameter.Value.Value = value; }
40    }
41
42    protected IValueParameter<IntValue> GuiUpdateIntervalParameter {
43      get { return (IValueParameter<IntValue>)Parameters[GuiUpdateIntervalParameterName]; }
44    }
45    public int GuiUpdateInterval {
46      get { return GuiUpdateIntervalParameter.Value.Value; }
47      set { GuiUpdateIntervalParameter.Value.Value = value; }
48    }
49
50    protected IValueParameter<BoolValue> UseMemoizationParameter {
51      get { return (IValueParameter<BoolValue>)Parameters[UseMemoizationParameterName]; }
52    }
53    public bool UseMemoization {
54      get { return UseMemoizationParameter.Value.Value; }
55      set { UseMemoizationParameter.Value.Value = value; }
56    }
57
58    #endregion
59
60    private Grammar grammar;
61
62
63    #region ctors
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new GrammarEnumerationAlgorithm(this, cloner);
66    }
67
68    public GrammarEnumerationAlgorithm() {
69
70      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
71      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
72
73      Problem = new RegressionProblem() {
74        ProblemData = regProblem
75      };
76
77      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(6)));
78      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(4000)));
79      Parameters.Add(new ValueParameter<BoolValue>(UseMemoizationParameterName, "Should already subtrees be reused within a run.", new BoolValue(true)));
80    }
81
82    private GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
83    #endregion
84
85
86    protected override void Run(CancellationToken cancellationToken) {
87      List<SymbolString> allGenerated = new List<SymbolString>();
88      List<SymbolString> distinctGenerated = new List<SymbolString>();
89
90      int expansions = 0;
91
92      HashSet<int> evaluatedHashes = new HashSet<int>();
93
94      grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
95
96      Stack<SymbolString> remainingTrees = new Stack<SymbolString>();
97      remainingTrees.Push(new SymbolString(new[] { grammar.StartSymbol }));
98
99      while (remainingTrees.Any()) {
100        if (cancellationToken.IsCancellationRequested) break;
101
102        SymbolString currSymbolString = remainingTrees.Pop();
103
104        if (currSymbolString.IsSentence()) {
105          allGenerated.Add(currSymbolString);
106
107          //if (evaluatedHashes.Add(grammar.CalcHashCode(currSymbolString))) {
108          EvaluateSentence(currSymbolString);
109          //distinctGenerated.Add(currSymbolString);
110          //}
111
112          UpdateView(allGenerated, distinctGenerated, expansions);
113
114        } else {
115          // expand next nonterminal symbols
116          int nonterminalSymbolIndex = currSymbolString.FindIndex(s => s is NonterminalSymbol);
117          NonterminalSymbol expandedSymbol = currSymbolString[nonterminalSymbolIndex] as NonterminalSymbol;
118
119          foreach (Production productionAlternative in expandedSymbol.Alternatives) {
120            SymbolString newSentence = new SymbolString(currSymbolString);
121            newSentence.RemoveAt(nonterminalSymbolIndex);
122            newSentence.InsertRange(nonterminalSymbolIndex, productionAlternative);
123
124            if (newSentence.Count <= MaxTreeSize) {
125              remainingTrees.Push(newSentence);
126            }
127          }
128        }
129      }
130
131      UpdateView(allGenerated, distinctGenerated, expansions, force: true);
132
133      StringArray sentences = new StringArray(allGenerated.Select(r => r.ToString()).ToArray());
134      Results.Add(new Result("All generated sentences", sentences));
135      StringArray distinctSentences = new StringArray(distinctGenerated.Select(r => r.ToString()).ToArray());
136      Results.Add(new Result("Distinct generated sentences", distinctSentences));
137    }
138
139
140    private void UpdateView(List<SymbolString> allGenerated, List<SymbolString> distinctGenerated, int expansions, bool force = false) {
141      int generatedSolutions = allGenerated.Count;
142      int distinctSolutions = distinctGenerated.Count;
143
144      if (force || generatedSolutions % GuiUpdateInterval == 0) {
145        Results.AddOrUpdateResult("Generated Solutions", new IntValue(generatedSolutions));
146        Results.AddOrUpdateResult("Distinct Solutions", new IntValue(distinctSolutions));
147
148        DoubleValue averageTreeLength = new DoubleValue(allGenerated.Select(r => r.Count).Average());
149        Results.AddOrUpdateResult("Average Tree Length of Solutions", averageTreeLength);
150
151        IntValue expansionsValue = new IntValue(expansions);
152        Results.AddOrUpdateResult("Expansions", expansionsValue);
153      }
154    }
155
156    private void EvaluateSentence(SymbolString symbolString) {
157      SymbolicExpressionTree tree = grammar.ParseSymbolicExpressionTree(symbolString);
158      SymbolicRegressionModel model = new SymbolicRegressionModel(
159        Problem.ProblemData.TargetVariable,
160        tree,
161        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
162
163      IRegressionSolution newSolution = model.CreateRegressionSolution(Problem.ProblemData);
164
165      IResult currBestTrainingSolutionResult;
166      IResult currBestTestSolutionResult;
167      if (!Results.TryGetValue(BestTrainingSolution, out currBestTrainingSolutionResult)
168           || !Results.TryGetValue(BestTestSolution, out currBestTestSolutionResult)) {
169
170        Results.Add(new Result(BestTrainingSolution, newSolution));
171        Results.Add(new Result(BestTrainingSolutionString, new StringValue(symbolString.ToString()).AsReadOnly()));
172        Results.Add(new Result(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly()));
173        Results.Add(new Result(BestTestSolution, newSolution));
174        Results.Add(new Result(BestTestSolutionString, new StringValue(symbolString.ToString()).AsReadOnly()));
175        Results.Add(new Result(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly()));
176
177      } else {
178        IRegressionSolution currBestTrainingSolution = (IRegressionSolution)currBestTrainingSolutionResult.Value;
179        if (currBestTrainingSolution.TrainingRSquared < newSolution.TrainingRSquared) {
180          currBestTrainingSolutionResult.Value = newSolution;
181          Results.AddOrUpdateResult(BestTrainingSolutionString, new StringValue(symbolString.ToString()));
182          Results.AddOrUpdateResult(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly());
183        }
184
185        IRegressionSolution currBestTestSolution = (IRegressionSolution)currBestTestSolutionResult.Value;
186        if (currBestTestSolution.TestRSquared < newSolution.TestRSquared) {
187          currBestTestSolutionResult.Value = newSolution;
188          Results.AddOrUpdateResult(BestTestSolutionString, new StringValue(symbolString.ToString()));
189          Results.AddOrUpdateResult(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly());
190        }
191      }
192    }
193  }
194}
Note: See TracBrowser for help on using the repository browser.