source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/BruteForceCodeGen.cs @ 10388

Last change on this file since 10388 was 10388, checked in by gkronber, 7 years ago

#2026 worked on code generator for brute force solver

File size: 8.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class BruteForceCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public sealed class ?IDENT?Solver {
14    public class SearchTreeNode {
15      public bool ready = false;
16      public SearchTreeNode[] children;
17      public int state;
18      public int nextAltIdx;
19
20      public SearchTreeNode(int state) {
21        // do not initialize children yet to save mem
22        nextAltIdx = 0;
23      }
24      public SearchTreeNode GetNextNode() {
25        if(children == null) {
26          int nChildren = Grammar.subtreeCount[state] == 1 ? Grammar.transitions[state].Length : Grammar.subtreeCount[state];
27          children = new SearchTreeNode[nChildren];
28        }
29        if(children[nextAltIdx] == null) {
30          children[nextAltIdx] =
31        }
32      }
33    }
34
35
36    private readonly ?IDENT?Problem problem;
37    private readonly Random random;
38    private SearchTreeNode searchTree = null;
39
40    private IEnumerable<Tree> GenerateTrees(int maxDepth) {
41      if(searchTree==null) searchTree = new SearchTreeNode(0);
42     
43    }
44
45    private Tree GenerateTree(SearchTreeNode node) {
46      curDepth += 1;
47      steps += 1;
48      depth = Math.Max(depth, curDepth);
49      Tree t = null;
50
51      // terminals
52      if(Grammar.subtreeCount[state] == 0) {
53        t = CreateTerminalNode(state, random, problem);
54      } else {
55        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
56        if(Grammar.subtreeCount[state] == 1) {
57          var targetStates = Grammar.transition[state];
58          var altIdx = SampleAlternative(random, state, curDepth);
59          var alternative = SampleTree(targetStates[altIdx], ref steps, ref curDepth, ref depth);
60          t = new Tree(altIdx, new Tree[] { alternative });
61        } else {
62          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
63          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
64          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
65            subtrees[i] = SampleTree(Grammar.transition[state][i], ref steps, ref curDepth, ref depth);
66          }
67          t = new Tree(-1, subtrees); // alternative index is ignored
68        }
69      }
70      curDepth -=1;
71      return t;
72    }
73
74    private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
75      switch(state) {
76        ?CREATETERMINALNODECODE?
77        default: { throw new ArgumentException(""Unknown state index"" + state); }
78      }
79    }
80
81    private int NextAlternative(SearchTreeNode node) {
82?RETURNNEXTALTERNATIVECODE?
83    }
84
85    public static void Main(string[] args) {
86      var problem = new ?IDENT?Problem();
87      var solver = new ?IDENT?Solver(problem);
88      solver.Start();
89    }
90
91    public ?IDENT?Solver(?IDENT?Problem problem) {
92      this.problem = problem;
93      this.random = new Random();
94    }
95
96    private void Start() {
97      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
98      int n = 0;
99      long sumDepth = 0;
100      long sumSize = 0;
101      var sumF = 0.0;
102      var sw = new System.Diagnostics.Stopwatch();
103      sw.Start();
104      for(int d = 1; d < 20; d++) {
105        foreach(var t in GenerateTrees(d)) {
106          var f = problem.Evaluate(t);
107         
108          n++;   
109          sumF += f;
110          if (problem.IsBetter(f, bestF)) {
111            bestF = f;
112            Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, 0, 0);
113          }
114          if (n % 1000 == 0) {
115            sw.Stop();
116            Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
117            sumSize = 0;
118            sumDepth = 0;
119            sumF = 0.0;
120            sw.Restart();
121          }
122        }
123      }
124    }
125  }
126}";
127
128    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
129      var solverSourceCode = new SourceBuilder();
130      solverSourceCode.Append(solverTemplate)
131        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
132        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
133        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
134      ;
135
136      problemSourceCode.Append(solverSourceCode.ToString());
137    }
138
139
140
141    private string GenerateSampleAlternativeSource(IGrammar grammar) {
142      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
143      var sb = new SourceBuilder();
144      int stateCount = 0;
145      foreach (var s in grammar.Symbols) {
146        sb.AppendFormat("case {0}: ", stateCount++);
147        if (grammar.IsTerminal(s)) {
148          // ignore
149        } else {
150          var terminalAltIndexes = grammar.GetAlternatives(s)
151            .Select((alt, idx) => new { alt, idx })
152            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
153            .Select(p => p.idx);
154          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
155            .Select((alt, idx) => new { alt, idx })
156            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
157            .Select(p => p.idx);
158          var hasTerminalAlts = terminalAltIndexes.Any();
159          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
160          if (hasTerminalAlts && hasNonTerminalAlts) {
161            sb.Append("if(random.NextDouble() < TerminalProbForDepth(depth)) {").BeginBlock();
162            GenerateReturnStatement(terminalAltIndexes, sb);
163            sb.Append("} else {");
164            GenerateReturnStatement(nonTerminalAltIndexes, sb);
165            sb.Append("}").EndBlock();
166          } else {
167            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
168          }
169        }
170      }
171      return sb.ToString();
172    }
173
174    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
175      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
176      var sb = new SourceBuilder();
177      var allSymbols = grammar.Symbols.ToList();
178      foreach (var s in grammar.Symbols) {
179        if (grammar.IsTerminal(s)) {
180          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
181          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
182          var terminal = terminals.Single(t => t.Ident == s.Name);
183          foreach (var constr in terminal.Constraints) {
184            if (constr.Type == ConstraintNodeType.Set) {
185              sb.Append("{").BeginBlock();
186              sb.AppendLine("throw new NotImplementedException(\"Enumeration of terminal values is not implemented.\");");
187              sb.AppendFormat("//var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
188              sb.AppendFormat("//t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
189              sb.AppendLine("}");
190            } else {
191              sb.Append("{").BeginBlock();
192              sb.AppendLine("throw new NotSupportedException(\"The brute force solver does not support RANGE constraints\");");
193              sb.AppendLine("}");
194            }
195          }
196          sb.AppendLine("return t;").EndBlock();
197          sb.Append("}");
198        }
199      }
200      return sb.ToString();
201    }
202
203    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
204      if (idxs.Count() == 1) {
205        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
206      } else {
207        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
208        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
209      }
210    }
211
212    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
213      if (nAlts > 1) {
214        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
215      } else if (nAlts == 1) {
216        sb.AppendLine("return 0; ");
217      } else {
218        sb.AppendLine("throw new InvalidProgramException();");
219      }
220    }
221  }
222}
Note: See TracBrowser for help on using the repository browser.