Changeset 16157


Ignore:
Timestamp:
09/20/18 11:12:57 (13 months ago)
Author:
bburlacu
Message:

#2886: Update IGrammarEnumerationEvaluator interface (add Evaluate method accepting an ISymbolicExpressionTree for the case when the constants have already been optimized in the tree, add boolean OptimizeConstants flag), small refactor in GrammarEnumeration/GrammarEnumerationAlgorithm.cs, add unit tests

Location:
branches/2886_SymRegGrammarEnumeration
Files:
1 added
5 edited

Legend:

Unmodified
Added
Removed
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/IGrammarEnumerationEvaluator.cs

    r16053 r16157  
    2121
    2222using HeuristicLab.Core;
     23using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2324using HeuristicLab.Problems.DataAnalysis;
    2425
     
    2627  public interface IGrammarEnumerationEvaluator : IItem {
    2728    double Evaluate(IRegressionProblemData problemData, Grammar grammar, SymbolList sentence);
     29    double Evaluate(IRegressionProblemData problemData, ISymbolicExpressionTree tree);
     30    bool OptimizeConstants { get; set; }
    2831  }
    2932}
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs

    r16053 r16157  
    5252
    5353    private readonly string EvaluatorParameterName = "Evaluator";
    54 
    55     private readonly string ErrorWeightParameterName = "Error Weight";
    5654    private readonly string SearchDataStructureParameterName = "Search Data Structure";
    5755    private readonly string MaxComplexityParameterName = "Max. Complexity";
     56    private readonly string MaxLengthParameterName = "Max. Length";
    5857    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
    5958    private readonly string GrammarSymbolsParameterName = "Grammar Symbols";
     
    6766    public override bool SupportsPause { get { return true; } }
    6867
    69     public IFixedValueParameter<RSquaredEvaluator> EvaluatorParameter {
    70       get { return (IFixedValueParameter<RSquaredEvaluator>)Parameters[EvaluatorParameterName]; }
    71     }
    72 
    73     public RSquaredEvaluator Evaluator {
     68    public IValueParameter<IGrammarEnumerationEvaluator> EvaluatorParameter {
     69      get { return (IValueParameter<IGrammarEnumerationEvaluator>)Parameters[EvaluatorParameterName]; }
     70    }
     71
     72    public IGrammarEnumerationEvaluator Evaluator {
    7473      get { return EvaluatorParameter.Value; }
    7574    }
     
    7776    protected IFixedValueParameter<IntValue> MaxComplexityParameter {
    7877      get { return (IFixedValueParameter<IntValue>)Parameters[MaxComplexityParameterName]; }
     78    }
     79
     80    protected IFixedValueParameter<IntValue> MaxLengthParameter {
     81      get { return (IFixedValueParameter<IntValue>)Parameters[MaxLengthParameterName]; }
    7982    }
    8083
     
    8487    }
    8588
    86     protected IFixedValueParameter<DoubleValue> ErrorWeightParameter {
    87       get { return (IFixedValueParameter<DoubleValue>)Parameters[ErrorWeightParameterName]; }
    88     }
    89 
    90     public double ErrorWeight {
    91       get { return ErrorWeightParameter.Value.Value; }
    92       set { ErrorWeightParameter.Value.Value = value; }
     89    public int MaxLength {
     90      get { return MaxLengthParameter.Value.Value; }
     91      set { MaxLengthParameter.Value.Value = value; }
    9392    }
    9493
     
    200199
    201200    public GrammarEnumerationAlgorithm() {
    202       Parameters.Add(new FixedValueParameter<DoubleValue>(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8)));
    203201      Parameters.Add(new FixedValueParameter<IntValue>(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(12)));
     202      Parameters.Add(new FixedValueParameter<IntValue>(MaxLengthParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(20)));
    204203      Parameters.Add(new FixedValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(5000)));
    205204      Parameters.Add(new FixedValueParameter<IntValue>(SearchDataStructureSizeParameterName, "The size of the search data structure.", new IntValue((int)1e5)));
    206205      Parameters.Add(new FixedValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.SortedSet)));
    207       Parameters.Add(new FixedValueParameter<RSquaredEvaluator>(EvaluatorParameterName, new RSquaredEvaluator()));
     206      Parameters.Add(new ValueParameter<IGrammarEnumerationEvaluator>(EvaluatorParameterName, new RSquaredEvaluator()));
    208207
    209208      SearchDataStructureParameter.Value.ValueChanged += (o, e) => Prepare();
     
    283282      }
    284283
    285       MaxSentenceLength = Grammar.GetMaxSentenceLength(MaxComplexity);
    286       var errorWeight = ErrorWeight;
    287284      var evaluator = EvaluatorParameter.Value;
    288285      var problemData = Problem.ProblemData;
     286
     287      int maxLength = MaxLength;
     288      int maxComplexity = MaxComplexity;
    289289
    290290      // main search loop
     
    313313
    314314          SymbolList newPhrase = currPhrase.DerivePhrase(nonterminalSymbolIndex, appliedProductions[i]);
     315
     316          if (newPhrase.Count > maxLength)
     317            continue;
     318
    315319          int newPhraseComplexity = newPhrase.Complexity;
    316 
    317           if (newPhraseComplexity > MaxComplexity)
     320          if (newPhraseComplexity > maxComplexity)
    318321            continue;
    319322
     
    327330            OnSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolList, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
    328331
    329             // Is the best solution found? (only if RSquaredEvaluator is activated)
    330             //if (Results.ContainsKey(RSquaredEvaluator.BestTrainingQualityResultName)) {
    331             //  double r2 = ((DoubleValue)Results[RSquaredEvaluator.BestTrainingQualityResultName].Value).Value;
    332             //  if (r2.IsAlmost(1.0)) {
    333             //    UpdateView(force: true);
    334             //    return;
    335             //  }
    336             //}
    337 
    338332            if (!DistinctSentencesComplexity.ContainsKey(phraseHash) || DistinctSentencesComplexity[phraseHash] > newPhraseComplexity) {
    339333              if (DistinctSentencesComplexity.ContainsKey(phraseHash)) OverwrittenSentencesCount++; // for analysis only
    340 
    341334              DistinctSentencesComplexity[phraseHash] = newPhraseComplexity;
    342335              OnDistinctSentenceGenerated(fetchedSearchNode.Hash, fetchedSearchNode.SymbolList, phraseHash, newPhrase, expandedSymbol, appliedProductions[i]);
    343336            }
    344337            UpdateView();
    345 
    346338          } else if (!OpenPhrases.Contains(phraseHash) && !ArchivedPhrases.Contains(phraseHash)) {
    347339            double r2 = IsCompleteSentence(newPhrase) ? evaluator.Evaluate(problemData, Grammar, newPhrase) : fetchedSearchNode.R2;
     
    397389      var tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence);
    398390      var model = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, interpreter);
    399 
    400       var iterations = EvaluatorParameter.Value.ConstantOptimizationIterations;
    401       var applyLinearScaling = EvaluatorParameter.Value.ApplyLinearScaling;
    402 
    403       SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(
    404         interpreter,
    405         model.SymbolicExpressionTree,
    406         Problem.ProblemData,
    407         Problem.ProblemData.TrainingIndices,
    408         applyLinearScaling: applyLinearScaling,
    409         maxIterations: iterations,
    410         updateVariableWeights: false,
    411         updateConstantsInTree: true);
     391      Evaluator.Evaluate(Problem.ProblemData, tree); // this call will optimize the constants in the tree (if enabled)
    412392
    413393      model.Scale(Problem.ProblemData);
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/RSquaredEvaluator.cs

    r16073 r16157  
    6161    }
    6262
     63    private IFixedValueParameter<IntValue> SeedParameter {
     64      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     65    }
     66
    6367    private int Restarts {
    6468      get { return RestartsParameter.Value.Value; }
    6569      set { RestartsParameter.Value.Value = value; }
     70    }
     71
     72    private int Seed {
     73      get { return SeedParameter.Value.Value; }
     74      set { SeedParameter.Value.Value = value; }
    6675    }
    6776
     
    113122
    114123    public double Evaluate(IRegressionProblemData problemData, ISymbolicExpressionTree tree) {
     124      random.Seed((uint)Seed); // not the ideal solution for ensuring result consistency
    115125      return Evaluate(problemData, tree, random, OptimizeConstants, ConstantOptimizationIterations, ApplyLinearScaling, Restarts);
    116126    }
  • branches/2886_SymRegGrammarEnumeration/Test/Test.csproj

    r15974 r16157  
    119119    <Compile Include="Properties\AssemblyInfo.cs" />
    120120    <Compile Include="TreeHashingTest.cs" />
     121    <Compile Include="AlgorithmPerformanceTest.cs" />
    121122  </ItemGroup>
    122123  <ItemGroup>
  • branches/2886_SymRegGrammarEnumeration/Test/TreeHashingTest.cs

    r16056 r16157  
    1 using System.Linq;
     1using System;
     2using System.Collections.Generic;
     3using System.Linq;
    24using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration;
    35using Microsoft.VisualStudio.TestTools.UnitTesting;
     
    146148    }
    147149
    148     /* DEPRECATED; SINCE WE DO NOT ALLOW COMPOUND DIVISIONS
    149     [TestMethod]
    150     [TestCategory("TreeHashing")]
    151     public void CompoundInverseCancellationToSingleInverse() {
    152       SymbolList s1 = new SymbolList(new Symbol[] { varA, varB, grammar.Addition, grammar.Inv, grammar.Inv, grammar.Inv });
    153       SymbolList s2 = new SymbolList(new Symbol[] { varA, varB, grammar.Addition, grammar.Inv });
    154 
    155       int hash1 = grammar.CalcHashCode(s1);
    156       int hash2 = grammar.CalcHashCode(s2);
    157 
    158       Assert.AreEqual(hash1, hash2);
    159     }
    160 
    161     [TestMethod]
    162     [TestCategory("TreeHashing")]
    163     public void CompoundInverseCancellationToDivisor() {
    164       SymbolList s1 = new SymbolList(new Symbol[] { varA, varB, grammar.Addition, grammar.Inv, grammar.Inv });
    165       SymbolList s2 = new SymbolList(new Symbol[] { varA, varB, grammar.Addition });
    166 
    167       int hash1 = grammar.CalcHashCode(s1);
    168       int hash2 = grammar.CalcHashCode(s2);
    169 
    170       Assert.AreEqual(hash1, hash2);
    171     }
    172 
    173     [TestMethod]
    174     [TestCategory("TreeHashing")]
    175     public void UncancelableCompoundInverse() {
    176       // 1 / ( 1/b + sin(a*c) )
    177       SymbolList s1 = new SymbolList(new Symbol[] { varB, grammar.Inv, varA, varC, grammar.Multiplication, grammar.Sin, grammar.Addition, grammar.Inv });
    178       // b + sin(a*c)
    179       SymbolList s2 = new SymbolList(new Symbol[] { varB, varA, varC, grammar.Multiplication, grammar.Sin, grammar.Addition });
    180 
    181       int hash1 = grammar.CalcHashCode(s1);
    182       int hash2 = grammar.CalcHashCode(s2);
    183 
    184       Assert.AreNotEqual(hash1, hash2);
    185     }*/
     150    [TestMethod]
     151    [TestCategory("TreeHashing")]
     152    public void EnumerateGrammarTest() {
     153      //const int nvars = 1;
     154      //var variables = Enumerable.Range(1, nvars).Select(x => $"x{x}").ToArray();
     155      var variables = new[] { "b", "a" };
     156      var grammar = new Grammar(variables, Enum.GetValues(typeof(GrammarRule)).Cast<GrammarRule>());
     157
     158      int hash(SymbolList s) => grammar.Hasher.CalcHashCode(s);
     159
     160      List<SymbolList> sentences = EnumerateGrammarBreadth(grammar, length: 20, hashPhrases: false).ToList();
     161      Console.WriteLine($"Breadth: {sentences.Count};{sentences.Select(hash).Distinct().Count() }");
     162
     163      sentences = EnumerateGrammarBreadth(grammar, length: 20, hashPhrases: true).ToList();
     164      Console.WriteLine($"Breadth (hashed): {sentences.Count};{sentences.Select(hash).Distinct().Count() }");
     165
     166      sentences = EnumerateGrammarDepth(grammar, length: 20, hashPhrases: false).ToList();
     167      Console.WriteLine($"Depth: {sentences.Count};{sentences.Select(hash).Distinct().Count() }");
     168
     169      sentences = EnumerateGrammarDepth(grammar, length: 20, hashPhrases: true).ToList();
     170      Console.WriteLine($"Depth (hashed): {sentences.Count};{sentences.Select(hash).Distinct().Count() }");
     171    }
     172
     173    private static IEnumerable<SymbolList> EnumerateGrammarBreadth(Grammar grammar, int length, bool hashPhrases = true) {
     174      var phrases = new Queue<SymbolList>();
     175      phrases.Enqueue(new SymbolList(grammar.StartSymbol));
     176      var sentences = new List<SymbolList>();
     177      var archive = new HashSet<int>();
     178
     179      while (phrases.Any()) {
     180        var phrase = phrases.Dequeue();
     181
     182        if (phrase.Count > length)
     183          continue;
     184
     185        if (phrase.IsSentence()) {
     186          sentences.Add(phrase);
     187          continue;
     188        }
     189
     190        if (hashPhrases && !archive.Add(grammar.Hasher.CalcHashCode(phrase))) {
     191          continue;
     192        }
     193
     194        var idx = phrase.NextNonterminalIndex();
     195        var productions = grammar.Productions[phrase[idx]];
     196        var derived = productions.Select(p => phrase.DerivePhrase(idx, p)).Where(p => p.Count <= length);
     197        foreach (var d in derived)
     198          phrases.Enqueue(d);
     199      }
     200      return sentences;
     201    }
     202
     203    private static IEnumerable<SymbolList> EnumerateGrammarDepth(Grammar grammar, int length, bool hashPhrases = true) {
     204      return Expand(new SymbolList(grammar.StartSymbol), grammar, length, hashPhrases ? new HashSet<int>() : null);
     205    }
     206
     207    private static IEnumerable<SymbolList> Expand(SymbolList phrase, Grammar grammar, int maxLength, HashSet<int> visited) {
     208      if (phrase.Count > maxLength) {
     209        yield break;
     210      }
     211
     212      if (phrase.IsSentence()) {
     213        yield return phrase;
     214        yield break;
     215      }
     216
     217      if (visited != null && !visited.Add(grammar.Hasher.CalcHashCode(phrase))) {
     218        yield break;
     219      }
     220
     221      var i = phrase.NextNonterminalIndex();
     222      var productions = grammar.Productions[phrase[i]];
     223
     224      foreach (var s in productions.SelectMany(p => Expand(phrase.DerivePhrase(i, p), grammar, maxLength, visited)))
     225        yield return s;
     226    }
    186227  }
    187228}
Note: See TracChangeset for help on using the changeset viewer.