Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs @ 15722

Last change on this file since 15722 was 15722, checked in by lkammere, 6 years ago

#2886: Add evaluation of sentences.

File size: 7.4 KB
RevLine 
[15714]1using System;
2using System.Collections.Generic;
[15712]3using System.Collections.ObjectModel;
4using System.Diagnostics;
5using System.Linq;
6using System.Threading;
7using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
8using HeuristicLab.Common;
9using HeuristicLab.Core;
10using HeuristicLab.Data;
[15722]11using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[15712]12using HeuristicLab.Optimization;
[15722]13using HeuristicLab.Parameters;
[15712]14using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
15using HeuristicLab.Problems.DataAnalysis;
[15722]16using HeuristicLab.Problems.DataAnalysis.Symbolic;
17using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
[15712]18
19namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
20  [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")]
21  [StorableClass]
22  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
23  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
[15722]24    private readonly string BestTrainingSolution = "Best solution (training)";
25    private readonly string BestTrainingSolutionQuality = "Best solution quality (training)";
26    private readonly string BestTestSolution = "Best solution (test)";
27    private readonly string BestTestSolutionQuality = "Best solution quality (test)";
[15712]28
[15722]29    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
30    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
[15712]31
32
[15722]33    #region properties
34    public IValueParameter<IntValue> MaxTreeSizeParameter {
35      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
[15712]36    }
37
[15722]38    public int MaxTreeSize {
39      get { return MaxTreeSizeParameter.Value.Value; }
40    }
[15712]41
[15722]42    public IValueParameter<IntValue> GuiUpdateIntervalParameter {
43      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
44    }
[15712]45
[15722]46    public int GuiUpdateInterval {
47      get { return GuiUpdateIntervalParameter.Value.Value; }
48    }
[15712]49
[15722]50    #endregion
[15712]51
[15722]52    private Grammar grammar;
[15712]53
54
[15722]55    #region ctors
56    public override IDeepCloneable Clone(Cloner cloner) {
57      return new GrammarEnumerationAlgorithm(this, cloner);
58    }
[15712]59
[15722]60    public GrammarEnumerationAlgorithm() {
61      Problem = new RegressionProblem();
[15712]62
[15722]63      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(4)));
64      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(4000)));
65    }
[15712]66
[15722]67    private GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
68    #endregion
[15712]69
70
[15722]71    protected override void Run(CancellationToken cancellationToken) {
72      List<SymbolString> allGenerated = new List<SymbolString>();
73      List<SymbolString> distinctGenerated = new List<SymbolString>();
74      HashSet<int> evaluatedHashes = new HashSet<int>();
[15712]75
[15722]76      grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
[15712]77
78      Stack<SymbolString> remainingTrees = new Stack<SymbolString>();
79      remainingTrees.Push(new SymbolString(new[] { grammar.StartSymbol }));
80
81      while (remainingTrees.Any()) {
[15722]82        if (cancellationToken.IsCancellationRequested) break;
83
[15712]84        SymbolString currSymbolString = remainingTrees.Pop();
85
86        if (currSymbolString.IsSentence()) {
[15722]87          allGenerated.Add(currSymbolString);
[15712]88
[15722]89          if (evaluatedHashes.Add(grammar.CalcHashCode(currSymbolString))) {
90            EvaluateSentence(currSymbolString);
91            distinctGenerated.Add(currSymbolString);
92          }
[15712]93
[15722]94          UpdateView(allGenerated, distinctGenerated);
95
[15712]96        } else {
97          // expand next nonterminal symbols
98          int nonterminalSymbolIndex = currSymbolString.FindIndex(s => s is NonterminalSymbol);
99          NonterminalSymbol expandedSymbol = currSymbolString[nonterminalSymbolIndex] as NonterminalSymbol;
100
101          foreach (Production productionAlternative in expandedSymbol.Alternatives) {
102            SymbolString newSentence = new SymbolString(currSymbolString);
103            newSentence.RemoveAt(nonterminalSymbolIndex);
104            newSentence.InsertRange(nonterminalSymbolIndex, productionAlternative);
105
[15722]106            if (newSentence.Count <= MaxTreeSize) {
[15712]107              remainingTrees.Push(newSentence);
108            }
109          }
110        }
111      }
112
[15722]113      StringArray sentences = new StringArray(allGenerated.Select(r => r.ToString()).ToArray());
114      Results.Add(new Result("All generated sentences", sentences));
115      StringArray distinctSentences = new StringArray(distinctGenerated.Select(r => r.ToString()).ToArray());
116      Results.Add(new Result("Distinct generated sentences", distinctSentences));
117    }
[15712]118
119
[15722]120    private void UpdateView(List<SymbolString> allGenerated, List<SymbolString> distinctGenerated) {
121      int generatedSolutions = allGenerated.Count;
122      int distinctSolutions = distinctGenerated.Count;
[15712]123
[15722]124      if (generatedSolutions % GuiUpdateInterval == 0) {
125        Results.AddOrUpdateResult("Generated Solutions", new IntValue(generatedSolutions));
126        Results.Add(new Result("Distinct Solutions", new IntValue(distinctSolutions)));
[15712]127
[15722]128        DoubleValue averageTreeLength = new DoubleValue(allGenerated.Select(r => r.Count).Average());
129        Results.Add(new Result("Average Tree Length of Solutions", averageTreeLength));
130      }
131    }
[15712]132
[15722]133    private void EvaluateSentence(SymbolString symbolString) {
134      SymbolicExpressionTree tree = grammar.ParseSymbolicExpressionTree(symbolString);
[15712]135      SymbolicRegressionModel model = new SymbolicRegressionModel(
136        Problem.ProblemData.TargetVariable,
137        tree,
138        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
139
[15722]140      IRegressionSolution newSolution = model.CreateRegressionSolution(Problem.ProblemData);
[15712]141
[15722]142      IResult currBestTrainingSolutionResult;
143      IResult currBestTestSolutionResult;
144      if (!Results.TryGetValue(BestTrainingSolution, out currBestTrainingSolutionResult)
145           || !Results.TryGetValue(BestTestSolution, out currBestTestSolutionResult)) {
146
147        Results.Add(new Result(BestTrainingSolution, newSolution));
148        Results.Add(new Result(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly()));
149        Results.Add(new Result(BestTestSolution, newSolution));
150        Results.Add(new Result(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly()));
151
152      } else {
153        IRegressionSolution currBestTrainingSolution = (IRegressionSolution)currBestTrainingSolutionResult.Value;
154        if (currBestTrainingSolution.TrainingRSquared < newSolution.TrainingRSquared) {
155          currBestTrainingSolutionResult.Value = newSolution;
156          Results.AddOrUpdateResult(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly());
157        }
158
159        IRegressionSolution currBestTestSolution = (IRegressionSolution)currBestTestSolutionResult.Value;
160        if (currBestTestSolution.TestRSquared < newSolution.TestRSquared) {
161          currBestTestSolutionResult.Value = newSolution;
162          Results.AddOrUpdateResult(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly());
163        }
164      }
[15712]165    }
166  }
167}
Note: See TracBrowser for help on using the repository browser.