Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs @ 10386

Last change on this file since 10386 was 10386, checked in by gkronber, 10 years ago

#2026 refactoring

File size: 8.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class RandomSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public sealed class ?IDENT?Solver {
14    private static double baseTerminalProbability = 0.05; // 5% of all samples are only a terminal node
15    private static double terminalProbabilityInc = 0.05; // for each level the probability to sample a terminal grows by 5%
16
17    private readonly ?IDENT?Problem problem;
18    private readonly Random random;
19
20    public Tree SampleTree(out int steps, out int depth) {
21      steps = 0;
22      depth = 0;
23      int curDepth = 0;
24      return SampleTree(0, ref steps, ref curDepth, ref depth);
25    }
26
27    private Tree SampleTree(int state, ref int steps, ref int curDepth, ref int depth) {
28      curDepth += 1;
29      steps += 1;
30      depth = Math.Max(depth, curDepth);
31      Tree t = null;
32
33      // terminals
34      if(Grammar.subtreeCount[state] == 0) {
35        t = Grammar.CreateTerminalNode(state, random, problem);
36      } else {
37        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
38        if(Grammar.subtreeCount[state] == 1) {
39          var targetStates = Grammar.transition[state];
40          var altIdx = SampleAlternative(random, state, curDepth);
41          var alternative = SampleTree(targetStates[altIdx], ref steps, ref curDepth, ref depth);
42          t = new Tree(altIdx, new Tree[] { alternative });
43        } else {
44          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
45          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
46          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
47            subtrees[i] = SampleTree(Grammar.transition[state][i], ref steps, ref curDepth, ref depth);
48          }
49          t = new Tree(-1, subtrees); // alternative index is ignored
50        }
51      }
52      curDepth -=1;
53      return t;
54    }
55
56    private int SampleAlternative(Random random, int state, int depth) {
57      switch(state) {
58
59?SAMPLEALTERNATIVECODE?
60
61        default: throw new InvalidOperationException();
62      }
63    }
64
65    private double TerminalProbForDepth(int depth) {
66      return baseTerminalProbability + depth * terminalProbabilityInc;
67    }
68
69    public static void Main(string[] args) {
70      if(args.Length >= 1) ParseArguments(args);
71
72      var problem = new ?IDENT?Problem();
73      var solver = new ?IDENT?Solver(problem);
74      solver.Start();
75    }
76    private static void ParseArguments(string[] args) {
77      var baseTerminalProbabilityRegex = new Regex(@""--terminalProbBase=(?<prob>.+)"");
78      var terminalProbabilityIncRegex = new Regex(@""--terminalProbInc=(?<prob>.+)"");
79      var helpRegex = new Regex(@""--help|/\?"");
80     
81      foreach(var arg in args) {
82        var baseTerminalProbabilityMatch = baseTerminalProbabilityRegex.Match(arg);
83        var terminalProbabilityIncMatch = terminalProbabilityIncRegex.Match(arg);
84        var helpMatch = helpRegex.Match(arg);
85        if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); }
86        else if(baseTerminalProbabilityMatch.Success) {
87          baseTerminalProbability = double.Parse(baseTerminalProbabilityMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
88          if(baseTerminalProbability < 0.0 || baseTerminalProbability > 1.0) throw new ArgumentException(""base terminal probability must lie in range [0.0 ... 1.0]"");
89        } else if(terminalProbabilityIncMatch.Success) {
90           terminalProbabilityInc = double.Parse(terminalProbabilityIncMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
91           if(terminalProbabilityInc < 0.0 || terminalProbabilityInc > 1.0) throw new ArgumentException(""terminal probability increment must lie in range [0.0 ... 1.0]"");
92        } else {
93           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
94        }
95      }
96    }
97    private static void PrintUsage() {
98      Console.WriteLine(""Find a solution using random tree search."");
99      Console.WriteLine();
100      Console.WriteLine(""Parameters:"");
101      Console.WriteLine(""\t--terminalProbBase=<prob>\tSets the probability of sampling a terminal alternative in a rule [Default: 0.05]"");
102      Console.WriteLine(""\t--terminalProbInc=<prob>\tSets the increment for the probability of sampling a terminal alternative for each level in the syntax tree [Default: 0.05]"");
103    }
104
105
106    public ?IDENT?Solver(?IDENT?Problem problem) {
107      this.problem = problem;
108      this.random = new Random();
109    }
110
111    private void Start() {
112      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
113      int n = 0;
114      long sumDepth = 0;
115      long sumSize = 0;
116      var sumF = 0.0;
117      var sw = new System.Diagnostics.Stopwatch();
118      sw.Start();
119      while (n <= 10000) {
120
121        int steps, depth;
122        var _t = SampleTree(out steps, out depth);
123        var f = problem.Evaluate(_t);
124 
125        n++;   
126        sumSize += steps;
127        sumDepth += depth;
128        sumF += f;
129        if (problem.IsBetter(f, bestF)) {
130          bestF = f;
131          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
132        }
133        if (n % 1000 == 0) {
134          sw.Stop();
135          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
136          sumSize = 0;
137          sumDepth = 0;
138          sumF = 0.0;
139          sw.Restart();
140        }
141      }
142    }
143  }
144}";
145
146    public void Generate(IGrammar grammar, bool maximization, SourceBuilder problemSourceCode) {
147      var solverSourceCode = new SourceBuilder();
148      solverSourceCode.Append(solverTemplate)
149        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
150        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
151      ;
152
153      problemSourceCode.Append(solverSourceCode.ToString());
154    }
155
156
157
158    private string GenerateSampleAlternativeSource(IGrammar grammar) {
159      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
160      var sb = new SourceBuilder();
161      int stateCount = 0;
162      foreach (var s in grammar.Symbols) {
163        sb.AppendFormat("case {0}: ", stateCount++);
164        if (grammar.IsTerminal(s)) {
165          // ignore
166        } else {
167          var terminalAltIndexes = grammar.GetAlternatives(s)
168            .Select((alt, idx) => new { alt, idx })
169            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
170            .Select(p => p.idx);
171          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
172            .Select((alt, idx) => new { alt, idx })
173            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
174            .Select(p => p.idx);
175          var hasTerminalAlts = terminalAltIndexes.Any();
176          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
177          if (hasTerminalAlts && hasNonTerminalAlts) {
178            sb.Append("if(random.NextDouble() < TerminalProbForDepth(depth)) {").BeginBlock();
179            GenerateReturnStatement(terminalAltIndexes, sb);
180            sb.Append("} else {");
181            GenerateReturnStatement(nonTerminalAltIndexes, sb);
182            sb.Append("}").EndBlock();
183          } else {
184            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
185          }
186        }
187      }
188      return sb.ToString();
189    }
190    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
191      if (idxs.Count() == 1) {
192        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
193      } else {
194        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
195        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
196      }
197    }
198
199    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
200      if (nAlts > 1) {
201        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
202      } else if (nAlts == 1) {
203        sb.AppendLine("return 0; ");
204      } else {
205        sb.AppendLine("throw new InvalidProgramException();");
206      }
207    }
208  }
209}
Note: See TracBrowser for help on using the repository browser.