Changeset 10100


Ignore:
Timestamp:
10/31/13 13:15:17 (6 years ago)
Author:
gkronber
Message:

#2026: major refactoring of example GPDL solver (random search complete except for RANGE constrained terminals)

Location:
branches/HeuristicLab.Problems.GPDL
Files:
2 added
3 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GPDL/CodeGenerator/CodeGenerator.csproj

    r10080 r10100  
    3737  <ItemGroup>
    3838    <Compile Include="BruteForceCodeGen.cs" />
     39    <Compile Include="ProblemCodeGen.cs" />
     40    <Compile Include="SourceBuilder.cs" />
    3941    <Compile Include="Properties\AssemblyInfo.cs" />
    4042    <Compile Include="RandomSearchCodeGen.cs" />
  • branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs

    r10086 r10100  
    22using System.Collections.Generic;
    33using System.Diagnostics;
    4 using System.IO;
    54using System.Linq;
    6 using System.Text;
    75using HeuristicLab.Grammars;
    8 using Attribute = HeuristicLab.Grammars.Attribute;
    96
    107namespace CodeGenerator {
    118  public class RandomSearchCodeGen {
    12 
    13     private string usings = @"
    14 using System.Collections.Generic;
    15 using System.Linq;
    16 using System;
    17 ";
    189
    1910    private string solverTemplate = @"
    2011namespace ?PROBLEMNAME? {
    2112  public sealed class ?IDENT?Solver {
    22     public sealed class SolverState {
    23       private int currentSeed;
    24       public Random random;
     13    private static double baseTerminalProbability = 0.05; // 5% of all samples are only a terminal node
     14    private static double terminalProbabilityInc = 0.05; // for each level the probability to sample a terminal grows by 5%
     15   
     16    public sealed class SolverState : ISolverState {
     17      private class Tree {
     18        public int altIdx;
     19        // public string symbol; // for debugging
     20        public List<Tree> subtrees;
     21        public Tree(int state, int altIdx) {
     22          subtrees = new List<Tree>(subtreeCount[state]);
     23          this.altIdx = altIdx;
     24        }
     25      }
     26      public int curDepth;
     27      public int steps;
    2528      public int depth;
    26       public int steps;
    27       public int maxDepth;
    28       public SolverState(int seed) {
    29         this.currentSeed = seed;
    30       }
    31 
    32       public SolverState IncDepth() {
    33         depth++;
    34         if(depth>maxDepth) maxDepth = depth;
    35         return this;
    36       }
    37       public SolverState Prepare() {
    38         this.random = new Random(currentSeed);
    39         depth = 0;
    40         maxDepth = 0;
    41         steps = 0;
    42         return this;
     29      private readonly Stack<Tree> nodes;
     30      private readonly IGpdlProblem problem;
     31
     32      private static Dictionary<int, int[]> transition = new Dictionary<int, int[]>() {
     33?TRANSITIONTABLE?
     34      };
     35      private static Dictionary<int, int> subtreeCount = new Dictionary<int, int>() {
     36         { -1, 0 }, // terminals
     37?SUBTREECOUNTTABLE?
     38      };
     39      private static string[] symb = new string[] { ?SYMBOLNAMES? };
     40
     41      public SolverState(IGpdlProblem problem, int seed) {
     42        this.problem = problem;
     43        this.nodes = new Stack<Tree>();
     44       
     45        // create a random tree
     46        var tree = SampleTree(new Random(seed), 0, -1);  // state 0 is the state for the start symbol
     47        nodes.Push(tree);
     48      }
     49
     50      public void Reset() {
     51        // stack must contain only the root of the tree
     52        System.Diagnostics.Debug.Assert(nodes.Count == 1);
     53      }
     54
     55      private Tree SampleTree(Random random, int state, int altIdx) {
     56        // Console.Write(state + "" "");       
     57        curDepth += 1;
     58        steps += 1;
     59        depth = Math.Max(depth, curDepth);
     60        var t = new Tree(state, altIdx);
     61        // t.symbol = symb.Length > state ? symb[state] : ""TERM"";
     62        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
     63        if(subtreeCount[state] == 1) {
     64          var targetStates = transition[state];
     65          var i = SampleAlternative(random, state);
     66          if(targetStates.Length == 0) {
     67            //terminal
     68            t.subtrees.Add(SampleTree(random, -1, i));
     69          } else {
     70            t.subtrees.Add(SampleTree(random, targetStates[i], i));
     71          }
     72        } else {
     73          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
     74          for(int i = 0; i < subtreeCount[state]; i++) {
     75            t.subtrees.Add(SampleTree(random, transition[state][i], i));
     76          }
     77        }
     78        curDepth -=1;
     79        return t;
     80      }
     81
     82      public int PeekNextAlternative() {
     83        // this must only be called nodes that contain alternatives and therefore must only have single-symbols alternatives
     84        System.Diagnostics.Debug.Assert(nodes.Peek().subtrees.Count == 1);
     85        return nodes.Peek().subtrees[0].altIdx;
     86      }
     87
     88      public void Follow(int idx) {
     89        nodes.Push(nodes.Peek().subtrees[idx]);
     90      }
     91
     92      public void Unwind() {
     93        nodes.Pop();
     94      }
     95
     96      private int SampleAlternative(Random random, int state) {
     97        switch(state) {
     98
     99?SAMPLEALTERNATIVECODE?
     100
     101          default: throw new InvalidOperationException();
     102        }
     103      }
     104
     105      private double TerminalProbForDepth(int depth) {
     106        return baseTerminalProbability + depth * terminalProbabilityInc;
    43107      }
    44108    }
    45109
    46110    public static void Main(string[] args) {
    47       var solver = new ?IDENT?Solver();
     111      if(args.Length >= 1) ParseArguments(args);
     112
     113      var problem = new ?IDENT?Problem();
     114      var solver = new ?IDENT?Solver(problem);
    48115      solver.Start();
    49116    }
    50 
    51     public ?IDENT?Solver() {
    52       Initialize();
    53     }   
    54 
    55     private void Initialize() {
    56       ?INITCODE?
    57     }
    58 
    59     private bool IsBetter(double a, double b) {
    60       return ?MAXIMIZATION? ? a > b : a < b;
    61     }
    62 
    63     private double TerminalProbForDepth(int depth) {
    64       const double baseProb = 0.05;  // 5% of all samples are only a terminal node
    65       const double probIncPerLevel = 0.05; // for each level the probability to sample a terminal grows by 5%
    66       return baseProb + depth * probIncPerLevel;
    67     }
    68 
    69     private SolverState _state;
     117    private static void ParseArguments(string[] args) {
     118      var baseTerminalProbabilityRegex = new Regex(@""--terminalProbBase=(?<prob>.+)"");
     119      var terminalProbabilityIncRegex = new Regex(@""--terminalProbInc=(?<prob>.+)"");
     120      var helpRegex = new Regex(@""--help|/\?"");
     121     
     122      foreach(var arg in args) {
     123        var baseTerminalProbabilityMatch = baseTerminalProbabilityRegex.Match(arg);
     124        var terminalProbabilityIncMatch = terminalProbabilityIncRegex.Match(arg);
     125        var helpMatch = helpRegex.Match(arg);
     126        if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); }
     127        else if(baseTerminalProbabilityMatch.Success) {
     128          baseTerminalProbability = double.Parse(baseTerminalProbabilityMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
     129          if(baseTerminalProbability < 0.0 || baseTerminalProbability > 1.0) throw new ArgumentException(""base terminal probability must lie in range [0.0 ... 1.0]"");
     130        } else if(terminalProbabilityIncMatch.Success) {
     131           terminalProbabilityInc = double.Parse(terminalProbabilityIncMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
     132           if(terminalProbabilityInc < 0.0 || terminalProbabilityInc > 1.0) throw new ArgumentException(""terminal probability increment must lie in range [0.0 ... 1.0]"");
     133        } else {
     134           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
     135        }
     136      }
     137    }
     138    private static void PrintUsage() {
     139      Console.WriteLine(""Find a solution using random tree search."");
     140      Console.WriteLine();
     141      Console.WriteLine(""Parameters:"");
     142      Console.WriteLine(""\t--terminalProbBase=<prob>\tSets the probability of sampling a terminal alternative in a rule [Default: 0.05]"");
     143      Console.WriteLine(""\t--terminalProbInc=<prob>\tSets the increment for the probability of sampling a terminal alternative for each level in the syntax tree [Default: 0.05]"");
     144    }
     145
     146
     147    private readonly ?IDENT?Problem problem;
     148    public ?IDENT?Solver(?IDENT?Problem problem) {
     149      this.problem = problem;
     150    }
     151
    70152    private void Start() {
    71153      var seedRandom = new Random();
    72154      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
    73155      int n = 0;
     156      long sumDepth = 0;
     157      long sumSize = 0;
     158      var sumF = 0.0;
    74159      var sw = new System.Diagnostics.Stopwatch();
    75160      sw.Start();
     
    79164        // so we use a PRNG for generating seeds for a separate PRNG that is reset each time the start symbol is called
    80165       
    81         _state = new SolverState(seedRandom.Next());
    82 
    83         var f = Calculate();
    84 
     166        var _state = new SolverState(problem, seedRandom.Next());
     167
     168        var f = problem.Evaluate(_state);
     169 
    85170        n++;
     171        sumSize += _state.steps;
     172        sumDepth += _state.depth;
     173        sumF += f;
    86174        if (IsBetter(f, bestF)) {
     175          // evaluate again with tracing to console
     176          // problem.Evaluate(new SolverState(_state.seed, true));
    87177          bestF = f;
    88           Console.WriteLine(""{0}\t{1}\t(depth={2})"", n, bestF, _state.maxDepth);
     178          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, _state.steps, _state.depth);
    89179        }
    90180        if (n % 1000 == 0) {
    91181          sw.Stop();
    92           Console.WriteLine(""{0}\t{1}\t{2}\t({3:0.00} sols/ms)"", n, bestF, f, 1000.0 / sw.ElapsedMilliseconds);
     182          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
    93183          sw.Reset();
     184          sumSize = 0;
     185          sumDepth = 0;
     186          sumF = 0.0;
    94187          sw.Start();
    95188        }
     
    97190    }
    98191
    99     public double Calculate() {
    100       ?FITNESSFUNCTION?
    101     }
    102 
    103     ?ADDITIONALCODE?
    104 
    105     ?INTERPRETERSOURCE?
    106 
    107     ?CONSTRAINTSSOURCE?
     192    private bool IsBetter(double a, double b) {
     193      return ?MAXIMIZATION? ? a > b : a < b;
     194    }
    108195  }
    109196}";
    110197
    111 
    112     /// <summary>
    113     /// Generates the source code for a brute force searcher that can be compiled with a C# compiler
    114     /// </summary>
    115     /// <param name="ast">An abstract syntax tree for a GPDL file</param>
    116     public void Generate(GPDefNode ast) {
    117       var problemSourceCode = new StringBuilder();
    118       problemSourceCode.AppendLine(usings);
    119 
    120       GenerateProblem(ast, problemSourceCode);
    121 
    122       problemSourceCode.Replace("?PROBLEMNAME?", ast.Name);
    123 
    124       // write the source file to disk
    125       using (var stream = new StreamWriter(ast.Name + ".cs")) {
    126         stream.WriteLine(problemSourceCode.ToString());
    127       }
    128     }
    129 
    130     private void GenerateProblem(GPDefNode ast, StringBuilder problemSourceCode) {
    131       var grammar = CreateGrammarFromAst(ast);
    132       var problemClassCode =
    133         solverTemplate
    134           .Replace("?MAXIMIZATION?", ast.FitnessFunctionNode.Maximization.ToString().ToLowerInvariant())
    135           .Replace("?IDENT?", ast.Name)
    136           .Replace("?FITNESSFUNCTION?", ast.FitnessFunctionNode.SrcCode)
    137           .Replace("?INTERPRETERSOURCE?", GenerateInterpreterSource(grammar))
    138           .Replace("?INITCODE?", ast.InitCodeNode.SrcCode)
    139           .Replace("?ADDITIONALCODE?", ast.ClassCodeNode.SrcCode)
    140           .Replace("?CONSTRAINTSSOURCE?", GenerateConstraintMethods(ast.Terminals))
    141           ;
    142 
    143       problemSourceCode.AppendLine(problemClassCode).AppendLine();
    144     }
    145 
    146     #region create grammar instance from AST
    147     private AttributedGrammar CreateGrammarFromAst(GPDefNode ast) {
    148 
    149       var nonTerminals = ast.NonTerminals.Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters))).ToArray();
    150       var terminals = ast.Terminals.Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters))).ToArray();
    151       string startSymbolName = ast.Rules.First().NtSymbol;
    152 
    153       // create startSymbol
    154       var startSymbol = nonTerminals.Single(s => s.Name == startSymbolName);
    155       var g = new AttributedGrammar(startSymbol, nonTerminals, terminals);
    156 
    157       // add all production rules
    158       foreach (var rule in ast.Rules) {
    159         var ntSymbol = nonTerminals.Single(s => s.Name == rule.NtSymbol);
    160         foreach (var alt in GetAlternatives(rule.Alternatives, nonTerminals.Concat(terminals))) {
    161           g.AddProductionRule(ntSymbol, alt);
    162         }
    163         // local initialization code
    164         if (!string.IsNullOrEmpty(rule.LocalCode)) g.AddLocalDefinitions(ntSymbol, rule.LocalCode);
    165       }
    166       return g;
    167     }
    168 
    169     private IEnumerable<IAttribute> GetSymbolAttributes(string formalParameters) {
    170       return (from fieldDef in Util.ExtractParameters(formalParameters)
    171               select new Attribute(fieldDef.Identifier, fieldDef.Type, AttributeType.Parse(fieldDef.RefOrOut))).
    172         ToList();
    173     }
    174 
    175     private IEnumerable<Sequence> GetAlternatives(AlternativesNode altNode, IEnumerable<ISymbol> allSymbols) {
    176       foreach (var alt in altNode.Alternatives) {
    177         yield return GetSequence(alt.Sequence, allSymbols);
    178       }
    179     }
    180 
    181     private Sequence GetSequence(IEnumerable<RuleExprNode> sequence, IEnumerable<ISymbol> allSymbols) {
    182       Debug.Assert(sequence.All(s => s is CallSymbolNode || s is RuleActionNode));
    183       var l = new List<ISymbol>();
    184       foreach (var node in sequence) {
    185         var callSymbolNode = node as CallSymbolNode;
    186         var actionNode = node as RuleActionNode;
    187         if (callSymbolNode != null) {
    188           Debug.Assert(allSymbols.Any(s => s.Name == callSymbolNode.Ident));
    189           // create a new symbol with actual parameters
    190           l.Add(new Symbol(callSymbolNode.Ident, GetSymbolAttributes(callSymbolNode.ActualParameter)));
    191         } else if (actionNode != null) {
    192           l.Add(new SemanticSymbol("SEM", actionNode.SrcCode));
    193         }
    194       }
    195       return new Sequence(l);
    196     }
    197     #endregion
    198 
    199     #region helper methods for terminal symbols
    200     // produces helper methods for the attributes of all terminal nodes
    201     private string GenerateConstraintMethods(List<SymbolNode> symbols) {
    202       var sb = new StringBuilder();
    203       var terminals = symbols.OfType<TerminalNode>();
    204       foreach (var t in terminals) {
    205         sb.AppendLine(GenerateConstraintMethods(t));
     198    public void Generate(IGrammar grammar, bool maximization, IEnumerable<SymbolNode> terminalSymbols, SourceBuilder problemSourceCode) {
     199      var solverSourceCode = new SourceBuilder();
     200      solverSourceCode.Append(solverTemplate)
     201        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
     202        .Replace("?SYMBOLNAMES?", grammar.Symbols.Select(s => s.Name).Aggregate(string.Empty, (str, symb) => str + "\"" + symb + "\", "))
     203        .Replace("?TRANSITIONTABLE?", GenerateTransitionTable(grammar))
     204        .Replace("?SUBTREECOUNTTABLE?", GenerateSubtreeCountTable(grammar))
     205        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
     206      ;
     207
     208      problemSourceCode.Append(solverSourceCode.ToString());
     209    }
     210
     211
     212
     213    private string GenerateTransitionTable(IGrammar grammar) {
     214      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
     215      var sb = new SourceBuilder();
     216
     217      // state idx = idx of the corresponding symbol in the grammar
     218      var allSymbols = grammar.Symbols.ToList();
     219      var attributes = new List<string>();
     220      foreach (var s in grammar.Symbols) {
     221        var targetStates = new List<int>();
     222        if (grammar.IsTerminal(s)) {
     223          foreach (var att in s.Attributes) {
     224            targetStates.Add(allSymbols.Count + attributes.Count);
     225            attributes.Add(s.Name + "_" + att);
     226          }
     227        } else {
     228          if (grammar.NumberOfAlternatives(s) > 1) {
     229            foreach (var alt in grammar.GetAlternatives(s)) {
     230              // only single-symbol alternatives are supported
     231              Debug.Assert(alt.Count() == 1);
     232              targetStates.Add(allSymbols.IndexOf(alt.Single()));
     233            }
     234          } else {
     235            // rule is a sequence of symbols
     236            var seq = grammar.GetAlternatives(s).Single();
     237            targetStates.AddRange(seq.Select(symb => allSymbols.IndexOf(symb)));
     238          }
     239        }
     240
     241        var targetStateString = targetStates.Aggregate(string.Empty, (str, state) => str + state + ", ");
     242
     243        var idxOfSourceState = allSymbols.IndexOf(s);
     244        sb.AppendFormat("// {0}", s).AppendLine();
     245        sb.AppendFormat("{{ {0} , new int[] {{ {1} }} }},", idxOfSourceState, targetStateString).AppendLine();
     246      }
     247      for (int attIdx = 0; attIdx < attributes.Count; attIdx++) {
     248        sb.AppendFormat("// {0}", attributes[attIdx]).AppendLine();
     249        sb.AppendFormat("{{ {0} , new int[] {{ }} }},", attIdx + allSymbols.Count).AppendLine();
    206250      }
    207251      return sb.ToString();
    208252    }
    209 
    210     // generates helper methods for the attributes of a given terminal node
    211     private string GenerateConstraintMethods(TerminalNode t) {
    212       var sb = new StringBuilder();
    213       foreach (var c in t.Constraints) {
    214         var fieldType = t.FieldDefinitions.First(d => d.Identifier == c.Ident).Type;
    215         if (c.Type == ConstraintNodeType.Range) {
    216           sb.AppendFormat("public {0} GetMax{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMaxExpression).AppendLine();
    217           sb.AppendFormat("public {0} GetMin{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMinExpression).AppendLine();
    218           sb.AppendFormat("public {0} GetRandom{1}_{2}(SolverState _state) {{ return _state.random.NextDouble() * (GetMax{1}_{2}() - GetMin{1}_{2}()) + GetMin{1}_{2}(); }}", fieldType, t.Ident, c.Ident).AppendLine();
    219         } else if (c.Type == ConstraintNodeType.Set) {
    220           sb.AppendFormat("public IEnumerable<{0}> GetAllowed{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.SetExpression).AppendLine();
    221           sb.AppendFormat("public {0} GetRandom{1}_{2}(SolverState _state) {{ var tmp = GetAllowed{1}_{2}().ToArray(); return tmp[_state.random.Next(tmp.Length)]; }}", fieldType, t.Ident, c.Ident).AppendLine();
    222         }
     253    private string GenerateSubtreeCountTable(IGrammar grammar) {
     254      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
     255      var sb = new SourceBuilder();
     256
     257      // state idx = idx of the corresponding symbol in the grammar
     258      var allSymbols = grammar.Symbols.ToList();
     259      var attributes = new List<string>();
     260      foreach (var s in grammar.Symbols) {
     261        int subtreeCount;
     262        if (grammar.IsTerminal(s)) {
     263          subtreeCount = s.Attributes.Count();
     264          attributes.AddRange(s.Attributes.Select(att => s.Name + "_" + att.Name));
     265        } else {
     266          if (grammar.NumberOfAlternatives(s) > 1) {
     267            Debug.Assert(grammar.GetAlternatives(s).All(alt => alt.Count() == 1));
     268            subtreeCount = 1;
     269          } else {
     270            subtreeCount = grammar.GetAlternative(s, 0).Count();
     271          }
     272        }
     273
     274        sb.AppendFormat("// {0}", s).AppendLine();
     275        sb.AppendFormat("{{ {0} , {1} }},", allSymbols.IndexOf(s), subtreeCount).AppendLine();
     276      }
     277
     278      for (int attIdx = 0; attIdx < attributes.Count; attIdx++) {
     279        sb.AppendFormat("// {0}", attributes[attIdx]).AppendLine();
     280        sb.AppendFormat("{{ {0} , 1 }},", attIdx + allSymbols.Count).AppendLine();
    223281      }
    224282      return sb.ToString();
    225283    }
    226     #endregion
    227 
    228     private string GenerateInterpreterSource(AttributedGrammar grammar) {
    229       var sb = new StringBuilder();
    230 
    231       // generate methods for all nonterminals and terminals using the grammar instance
    232       foreach (var s in grammar.NonTerminalSymbols) {
    233         sb.AppendLine(GenerateInterpreterMethod(grammar, s));
    234       }
    235       foreach (var s in grammar.TerminalSymbols) {
    236         sb.AppendLine(GenerateTerminalInterpreterMethod(s));
     284
     285    private string GenerateSampleAlternativeSource(IGrammar grammar) {
     286      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
     287      var sb = new SourceBuilder();
     288      int stateCount = 0;
     289      var attributes = new List<Tuple<string, string>>();
     290      foreach (var s in grammar.Symbols) {
     291        sb.AppendFormat("case {0}: ", stateCount++);
     292        if (grammar.IsTerminal(s)) {
     293          GenerateReturnStatement(s.Attributes.Count(), sb);
     294          attributes.AddRange(s.Attributes.Select(att => Tuple.Create(s.Name, att.Name)));
     295        } else {
     296          var terminalAltIndexes = grammar.GetAlternatives(s)
     297            .Select((alt, idx) => new { alt, idx })
     298            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
     299            .Select(p => p.idx);
     300          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
     301            .Select((alt, idx) => new { alt, idx })
     302            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
     303            .Select(p => p.idx);
     304          var hasTerminalAlts = terminalAltIndexes.Any();
     305          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
     306          if (hasTerminalAlts && hasNonTerminalAlts) {
     307            sb.Append("if(random.NextDouble() < TerminalProbForDepth(depth)) {").BeginBlock();
     308            GenerateReturnStatement(terminalAltIndexes, sb);
     309            sb.Append("} else {");
     310            GenerateReturnStatement(nonTerminalAltIndexes, sb);
     311            sb.Append("}").EndBlock();
     312          } else {
     313            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
     314          }
     315        }
     316      }
     317      for (int attIdx = 0; attIdx < attributes.Count; attIdx++) {
     318        var terminalName = attributes[attIdx].Item1;
     319        var attributeName = attributes[attIdx].Item2;
     320        sb.AppendFormat("case {0}: return random.Next(problem.GetCardinality(\"{1}\", \"{2}\"));", attIdx + stateCount, terminalName, attributeName).AppendLine();
    237321      }
    238322      return sb.ToString();
    239323    }
    240 
    241     private string GenerateInterpreterMethod(AttributedGrammar g, ISymbol s) {
    242       var sb = new StringBuilder();
    243 
    244       // if this is the start symbol we additionally have to create the method which can be called from the fitness function
    245       if (g.StartSymbol.Equals(s)) {
    246         if (!s.Attributes.Any())
    247           sb.AppendFormat("private void {0}() {{", s.Name);
    248         else
    249           sb.AppendFormat("private void {0}({1}) {{", s.Name, s.GetAttributeString());
    250 
    251         // get formal parameters of start symbol
    252         var attr = g.StartSymbol.Attributes;
    253 
    254         // actual parameter are the same as formalparameter only without type identifier
    255         string actualParameter;
    256         if (attr.Any())
    257           actualParameter = attr.Skip(1).Aggregate(attr.First().AttributeType + " " + attr.First().Name, (str, a) => str + ", " + a.AttributeType + " " + a.Name);
    258         else
    259           actualParameter = string.Empty;
    260 
    261         sb.AppendFormat("{0}(_state.Prepare(), {1});", g.StartSymbol.Name, actualParameter).AppendLine();
    262         sb.AppendLine("}");
    263       }
    264 
    265       if (!s.Attributes.Any())
    266         sb.AppendFormat("private void {0}(SolverState _state) {{", s.Name);
    267       else
    268         sb.AppendFormat("private void {0}(SolverState _state, {1}) {{", s.Name, s.GetAttributeString());
    269 
    270       // generate local definitions
    271       sb.AppendLine(g.GetLocalDefinitions(s));
    272 
    273 
    274       var altsWithSemActions = g.GetAlternativesWithSemanticActions(s).ToArray();
    275       var terminalAlts = altsWithSemActions.Where(alt => alt.Count(g.IsNonTerminal) == 0);
    276       var nonTerminalAlts = altsWithSemActions.Where(alt => alt.Count(g.IsNonTerminal) > 0);
    277       bool hasTerminalAlts = terminalAlts.Any();
    278       bool hasNonTerminalAlts = nonTerminalAlts.Any();
    279 
    280       if (altsWithSemActions.Length > 1) {
    281         // here we need to bias the selection of alternatives (non-terminal vs terminal alternatives) to make sure that
    282         // terminals are selected with a certain probability to make sure that:
    283         // 1) we don't create the same small trees all the time (terminals have high probability to be selected)
    284         // 2) we don't create very big trees by recursing to deep (leads to stack-overflow) (terminals have a low probability to be selected)
    285         // so we first decide if we want to generate a terminal or non-terminal (50%, 50%) and then choose a symbol in the class randomly
    286         // the probability of choosing terminals should depend on the depth of the tree (small likelihood to choose terminals for small depths, large likelihood for large depths)
    287         if (hasTerminalAlts && hasNonTerminalAlts) {
    288           sb.AppendLine("if(_state.random.NextDouble() < TerminalProbForDepth(_state.depth)) {");
    289           // terminals
    290           sb.AppendLine("// terminals ");
    291           GenerateSwitchStatement(sb, terminalAlts);
    292           sb.AppendLine("} else {");
    293           // non-terminals
    294           sb.AppendLine("// non-terminals ");
    295           GenerateSwitchStatement(sb, nonTerminalAlts);
    296           sb.AppendLine("}");
    297         } else if (hasTerminalAlts) {
    298           sb.AppendLine("// terminals ");
    299           GenerateSwitchStatement(sb, terminalAlts);
    300         } else if (hasNonTerminalAlts) {
    301           sb.AppendLine("// non-terminals ");
    302           GenerateSwitchStatement(sb, nonTerminalAlts);
    303         }
     324    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
     325      if (idxs.Count() == 1) {
     326        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
    304327      } else {
    305         foreach (var altSymb in altsWithSemActions.Single()) {
    306           sb.AppendLine(GenerateSourceForAction(altSymb));
    307         }
    308       }
    309       sb.AppendLine("}");
    310       return sb.ToString();
    311     }
    312 
    313     private void GenerateSwitchStatement(StringBuilder sb, IEnumerable<Sequence> alts) {
    314       sb.AppendFormat("switch(_state.random.Next({0})) {{", alts.Count());
    315       // generate a case for each alternative
    316       int i = 0;
    317       foreach (var alt in alts) {
    318         sb.AppendFormat("case {0}: {{ ", i).AppendLine();
    319 
    320         // this only works for alternatives with a single non-terminal symbol (ignoring semantic symbols) so far!
    321         // a way to handle this is through grammar transformation (the examplary grammars all have the correct from)
    322         Debug.Assert(alt.Count(symb => !(symb is SemanticSymbol)) == 1);
    323         foreach (var altSymb in alt) {
    324           sb.AppendLine(GenerateSourceForAction(altSymb));
    325         }
    326         i++;
    327         sb.AppendLine("break;").AppendLine("}");
    328       }
    329       sb.AppendLine("default: throw new System.InvalidOperationException();").AppendLine("}");
    330 
    331     }
    332 
    333     // helper for generating calls to other symbol methods
    334     private string GenerateSourceForAction(ISymbol s) {
    335       var action = s as SemanticSymbol;
    336       if (action != null) {
    337         return action.Code + ";";
     328        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
     329        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
     330      }
     331    }
     332
     333    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
     334      if (nAlts > 1) {
     335        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
     336      } else if (nAlts == 1) {
     337        sb.AppendLine("return 0; ");
    338338      } else {
    339         if (!s.Attributes.Any())
    340           return string.Format("{0}(_state.IncDepth()); _state.depth--;", s.Name);
    341         else
    342           return string.Format("{0}(_state.IncDepth(), {1}); _state.depth--;", s.Name, s.GetAttributeString());
    343       }
    344     }
    345 
    346     private string GenerateTerminalInterpreterMethod(ISymbol s) {
    347       var sb = new StringBuilder();
    348       // if the terminal symbol has attributes then we must samples values for these attributes
    349       if (!s.Attributes.Any())
    350         sb.AppendFormat("private void {0}(SolverState _state) {{", s.Name);
    351       else
    352         sb.AppendFormat("private void {0}(SolverState _state, {1}) {{", s.Name, s.GetAttributeString());
    353 
    354       // each field must match a formal parameter, assign a value for each parameter
    355       foreach (var element in s.Attributes) {
    356         sb.AppendFormat("{{ {0} = GetRandom{1}_{0}(_state); }} ", element.Name, s.Name);
    357       }
    358       sb.AppendLine("}");
    359       return sb.ToString();
     339        sb.AppendLine("throw new InvalidProgramException();");
     340      }
    360341    }
    361342  }
  • branches/HeuristicLab.Problems.GPDL/GpdlCompiler/Program.cs

    r10080 r10100  
    3232      Parser parser = new Parser(scanner);
    3333      parser.Parse();
    34       var codeGen = new RandomSearchCodeGen();
     34      var codeGen = new ProblemCodeGen();
    3535
    3636      codeGen.Generate(parser.AbstractSyntaxTree);
Note: See TracChangeset for help on using the changeset viewer.