Changeset 15722


Ignore:
Timestamp:
02/05/18 18:16:25 (19 months ago)
Author:
lkammere
Message:

#2886: Add evaluation of sentences.

Location:
branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/Grammar.cs

    r15714 r15722  
    66using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
    77using HeuristicLab.Common;
     8using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     9using HeuristicLab.Problems.DataAnalysis.Symbolic;
    810
    911namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
     
    1315
    1416    #region Symbols
     17
    1518    public VariableSymbol Var;
    1619
     
    2124    public TerminalSymbol Addition;
    2225    public TerminalSymbol Multiplication;
    23     #endregion
    24 
     26
     27    #endregion
     28
     29
     30    #region HL Symbols for Parsing ExpressionTrees
     31
     32    private TypeCoherentExpressionGrammar symbolicExpressionGrammar;
     33
     34    private ISymbol constSy;
     35    private ISymbol varSy;
     36
     37    private ISymbol addSy;
     38    private ISymbol mulSy;
     39    private ISymbol logSy;
     40    private ISymbol expSy;
     41    private ISymbol divSy;
     42
     43    private ISymbol rootSy;
     44    private ISymbol startSy;
     45
     46    #endregion
    2547
    2648    public Grammar(string[] variables) {
     
    4870      Factor.AddProduction(Var);
    4971      #endregion
     72
     73      #region Parsing to SymbolicExpressionTree
     74      symbolicExpressionGrammar = new TypeCoherentExpressionGrammar();
     75      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();
     76
     77      constSy = symbolicExpressionGrammar.Symbols.OfType<Constant>().First();
     78      varSy = symbolicExpressionGrammar.Symbols.OfType<Variable>().First();
     79      addSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Addition>().First();
     80      mulSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Multiplication>().First();
     81      logSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Logarithm>().First();
     82      expSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Exponential>().First();
     83      divSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Division>().First();
     84
     85      rootSy = symbolicExpressionGrammar.AllowedSymbols.OfType<ProgramRootSymbol>().First();
     86      startSy = symbolicExpressionGrammar.AllowedSymbols.OfType<StartSymbol>().First();
     87
     88      #endregion
    5089    }
    5190
     
    6099    }
    61100
     101    #region Hashing
    62102    private int[] GetSubtreeHashes(TerminalSymbol currentSymbol, Stack<TerminalSymbol> parseStack) {
    63103      List<int> childHashes = null;
     
    66106        childHashes = currentSymbol.StringRepresentation.GetHashCode().ToEnumerable().ToList();
    67107
    68       } else if (ReferenceEquals(currentSymbol, Multiplication)) { // MULTIPLICATION
     108      } else if (ReferenceEquals(currentSymbol, Multiplication)) {
     109        // MULTIPLICATION
    69110        childHashes = new List<int>();
    70111
     
    94135
    95136
    96       } else if (ReferenceEquals(currentSymbol, Addition)) { // ADDITION
     137      } else if (ReferenceEquals(currentSymbol, Addition)) {
     138        // ADDITION
    97139        HashSet<int> uniqueChildHashes = new HashSet<int>();
    98140
     
    129171      return hashes.Aggregate(0, (result, ti) => ((result << 5) + result) ^ ti.GetHashCode());
    130172    }
     173    #endregion
     174
     175    #region Parse to SymbolicExpressionTree
     176    public SymbolicExpressionTree ParseSymbolicExpressionTree(SymbolString sentence) {
     177      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
     178      Debug.Assert(sentence.All(s => s is TerminalSymbol), "Trying to evaluate symbol sequence with nonterminalsymbols!");
     179
     180      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();
     181
     182      var rootNode = rootSy.CreateTreeNode();
     183      var startNode = startSy.CreateTreeNode();
     184      rootNode.AddSubtree(startNode);
     185
     186      Stack<TerminalSymbol> parseStack = new Stack<TerminalSymbol>(sentence.OfType<TerminalSymbol>());
     187      startNode.AddSubtree(ParseSymbolicExpressionTree(parseStack));
     188
     189      return new SymbolicExpressionTree(rootNode);
     190    }
     191
     192    public ISymbolicExpressionTreeNode ParseSymbolicExpressionTree(Stack<TerminalSymbol> parseStack) {
     193      TerminalSymbol currentSymbol = parseStack.Pop();
     194
     195      ISymbolicExpressionTreeNode parsedSubTree = null;
     196
     197      if (ReferenceEquals(currentSymbol, Addition)) {
     198        parsedSubTree = addSy.CreateTreeNode();
     199        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
     200        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
     201
     202      } else if (ReferenceEquals(currentSymbol, Multiplication)) {
     203        parsedSubTree = mulSy.CreateTreeNode();
     204        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
     205        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
     206
     207      } else if (Var.VariableTerminalSymbols.Contains(currentSymbol)) {
     208        VariableTreeNode varNode = (VariableTreeNode)varSy.CreateTreeNode();
     209        varNode.Weight = 1.0;
     210        varNode.VariableName = currentSymbol.StringRepresentation;
     211        parsedSubTree = varNode;
     212      }
     213
     214      Debug.Assert(parsedSubTree != null);
     215      return parsedSubTree;
     216    }
     217    #endregion
    131218  }
    132219}
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs

    r15714 r15722  
    99using HeuristicLab.Core;
    1010using HeuristicLab.Data;
     11using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    1112using HeuristicLab.Optimization;
     13using HeuristicLab.Parameters;
    1214using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    1315using HeuristicLab.Problems.DataAnalysis;
     16using HeuristicLab.Problems.DataAnalysis.Symbolic;
     17using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
    1418
    1519namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
     
    1822  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)]
    1923  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
     24    private readonly string BestTrainingSolution = "Best solution (training)";
     25    private readonly string BestTrainingSolutionQuality = "Best solution quality (training)";
     26    private readonly string BestTestSolution = "Best solution (test)";
     27    private readonly string BestTestSolutionQuality = "Best solution quality (test)";
     28
     29    private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes";
     30    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
     31
     32
     33    #region properties
     34    public IValueParameter<IntValue> MaxTreeSizeParameter {
     35      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
     36    }
     37
     38    public int MaxTreeSize {
     39      get { return MaxTreeSizeParameter.Value.Value; }
     40    }
     41
     42    public IValueParameter<IntValue> GuiUpdateIntervalParameter {
     43      get { return (IValueParameter<IntValue>)Parameters[MaxTreeSizeParameterName]; }
     44    }
     45
     46    public int GuiUpdateInterval {
     47      get { return GuiUpdateIntervalParameter.Value.Value; }
     48    }
     49
     50    #endregion
     51
     52    private Grammar grammar;
     53
     54
     55    #region ctors
    2056    public override IDeepCloneable Clone(Cloner cloner) {
    2157      return new GrammarEnumerationAlgorithm(this, cloner);
    2258    }
    2359
    24 
    2560    public GrammarEnumerationAlgorithm() {
    2661      Problem = new RegressionProblem();
    2762
     63      Parameters.Add(new ValueParameter<IntValue>(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(4)));
     64      Parameters.Add(new ValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(4000)));
    2865    }
    2966
    30 
    3167    private GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { }
    32 
    33 
    34 
     68    #endregion
    3569
    3670
    3771    protected override void Run(CancellationToken cancellationToken) {
     72      List<SymbolString> allGenerated = new List<SymbolString>();
     73      List<SymbolString> distinctGenerated = new List<SymbolString>();
     74      HashSet<int> evaluatedHashes = new HashSet<int>();
    3875
    39       IntValue generatedSolutions = new IntValue();
    40       Results.Add(new Result("Generated Solutions", generatedSolutions));
    41 
    42       DoubleValue averageTreeLength = new DoubleValue();
    43       Results.Add(new Result("Average Tree Length of Solutions", averageTreeLength));
    44 
    45       int maxStringLength = 4;
    46 
    47 
    48       List<SymbolString> results = new List<SymbolString>();
    49 
    50       Grammar grammar = new Grammar(new[] { "a", "b" });
     76      grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray());
    5177
    5278      Stack<SymbolString> remainingTrees = new Stack<SymbolString>();
     
    5480
    5581      while (remainingTrees.Any()) {
     82        if (cancellationToken.IsCancellationRequested) break;
     83
    5684        SymbolString currSymbolString = remainingTrees.Pop();
    5785
    5886        if (currSymbolString.IsSentence()) {
    59           results.Add(currSymbolString);
     87          allGenerated.Add(currSymbolString);
    6088
    61           generatedSolutions.Value++;
    62           averageTreeLength.Value = results.Select(r => r.Count).Average();
     89          if (evaluatedHashes.Add(grammar.CalcHashCode(currSymbolString))) {
     90            EvaluateSentence(currSymbolString);
     91            distinctGenerated.Add(currSymbolString);
     92          }
     93
     94          UpdateView(allGenerated, distinctGenerated);
    6395
    6496        } else {
    6597          // expand next nonterminal symbols
    6698          int nonterminalSymbolIndex = currSymbolString.FindIndex(s => s is NonterminalSymbol);
    67 
    6899          NonterminalSymbol expandedSymbol = currSymbolString[nonterminalSymbolIndex] as NonterminalSymbol;
    69100
     
    73104            newSentence.InsertRange(nonterminalSymbolIndex, productionAlternative);
    74105
    75             if (newSentence.Count <= maxStringLength) {
     106            if (newSentence.Count <= MaxTreeSize) {
    76107              remainingTrees.Push(newSentence);
    77108            }
     
    80111      }
    81112
    82 
    83       StringArray sentences = new StringArray(results.Select(r => r.ToString()).ToArray());
    84       Results.Add(new Result("All sentences", sentences));
     113      StringArray sentences = new StringArray(allGenerated.Select(r => r.ToString()).ToArray());
     114      Results.Add(new Result("All generated sentences", sentences));
     115      StringArray distinctSentences = new StringArray(distinctGenerated.Select(r => r.ToString()).ToArray());
     116      Results.Add(new Result("Distinct generated sentences", distinctSentences));
     117    }
    85118
    86119
     120    private void UpdateView(List<SymbolString> allGenerated, List<SymbolString> distinctGenerated) {
     121      int generatedSolutions = allGenerated.Count;
     122      int distinctSolutions = distinctGenerated.Count;
    87123
    88       /*
    89       addSy = grammar.AllowedSymbols.OfType<Addition>().First();
    90       mulSy = grammar.AllowedSymbols.OfType<Multiplication>().First();
    91       logSy = grammar.AllowedSymbols.OfType<Logarithm>().First();
    92       expSy = grammar.AllowedSymbols.OfType<Exponential>().First();
    93       divSy = grammar.AllowedSymbols.OfType<Division>().First();
     124      if (generatedSolutions % GuiUpdateInterval == 0) {
     125        Results.AddOrUpdateResult("Generated Solutions", new IntValue(generatedSolutions));
     126        Results.Add(new Result("Distinct Solutions", new IntValue(distinctSolutions)));
    94127
    95       progRootSy = grammar.AllowedSymbols.OfType<ProgramRootSymbol>().First();
    96       startSy = grammar.AllowedSymbols.OfType<StartSymbol>().First();
     128        DoubleValue averageTreeLength = new DoubleValue(allGenerated.Select(r => r.Count).Average());
     129        Results.Add(new Result("Average Tree Length of Solutions", averageTreeLength));
     130      }
     131    }
    97132
    98 
    99 
    100       var rootNode = progRootSy.CreateTreeNode();
    101 
    102       var startNode = startSy.CreateTreeNode();
    103       rootNode.AddSubtree(startNode);
    104 
    105       startNode.AddSubtree(const0);
    106 
    107 
    108 
    109       SymbolicExpressionTree tree = new SymbolicExpressionTree(rootNode);
    110 
     133    private void EvaluateSentence(SymbolString symbolString) {
     134      SymbolicExpressionTree tree = grammar.ParseSymbolicExpressionTree(symbolString);
    111135      SymbolicRegressionModel model = new SymbolicRegressionModel(
    112136        Problem.ProblemData.TargetVariable,
     
    114138        new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
    115139
    116       IRegressionSolution solution = model.CreateRegressionSolution(Problem.ProblemData);
    117       Results.Add(new Result("Best solution (training)", solution));
    118       Results.Add(new Result("Best solution quality (training)", new DoubleValue(solution.TrainingRSquared).AsReadOnly()));
     140      IRegressionSolution newSolution = model.CreateRegressionSolution(Problem.ProblemData);
    119141
    120       Results.Add(new Result("Best solution (test)", solution));
    121       Results.Add(new Result("Best solution quality (test)", new DoubleValue(solution.TestRSquared).AsReadOnly()));
    122       */
     142      IResult currBestTrainingSolutionResult;
     143      IResult currBestTestSolutionResult;
     144      if (!Results.TryGetValue(BestTrainingSolution, out currBestTrainingSolutionResult)
     145           || !Results.TryGetValue(BestTestSolution, out currBestTestSolutionResult)) {
     146
     147        Results.Add(new Result(BestTrainingSolution, newSolution));
     148        Results.Add(new Result(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly()));
     149        Results.Add(new Result(BestTestSolution, newSolution));
     150        Results.Add(new Result(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly()));
     151
     152      } else {
     153        IRegressionSolution currBestTrainingSolution = (IRegressionSolution)currBestTrainingSolutionResult.Value;
     154        if (currBestTrainingSolution.TrainingRSquared < newSolution.TrainingRSquared) {
     155          currBestTrainingSolutionResult.Value = newSolution;
     156          Results.AddOrUpdateResult(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly());
     157        }
     158
     159        IRegressionSolution currBestTestSolution = (IRegressionSolution)currBestTestSolutionResult.Value;
     160        if (currBestTestSolution.TestRSquared < newSolution.TestRSquared) {
     161          currBestTestSolutionResult.Value = newSolution;
     162          Results.AddOrUpdateResult(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly());
     163        }
     164      }
    123165    }
    124166  }
Note: See TracChangeset for help on using the changeset viewer.