source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs @ 13645

Last change on this file since 13645 was 13645, checked in by gkronber, 3 years ago

#2581: added an MCTS for symbolic regression models

File size: 15.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections.Generic;
25using System.IO;
26
27namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
28  // this is the core class for generating expressions.
29  // the automaton determines which expressions are allowed
30  internal class Automaton {
31    public const int StateExpr = 1;
32    public const int StateExprEnd = 2;
33    public const int StateTermStart = 3;
34    public const int StateTermEnd = 4;
35    public const int StateFactorStart = 5;
36    public const int StateFactorEnd = 6;
37    public const int StateVariableFactorStart = 7;
38    public const int StateVariableFactorEnd = 8;
39    public const int StateExpFactorStart = 9;
40    public const int StateExpFactorEnd = 10;
41    public const int StateLogFactorStart = 11;
42    public const int StateLogFactorEnd = 12;
43    public const int StateInvFactorStart = 13;
44    public const int StateInvFactorEnd = 14;
45    public const int StateExpFStart = 15;
46    public const int StateExpFEnd = 16;
47    public const int StateLogTStart = 17;
48    public const int StateLogTEnd = 18;
49    public const int StateLogTFStart = 19;
50    public const int StateLogTFEnd = 20;
51    public const int StateInvTStart = 21;
52    public const int StateInvTEnd = 22;
53    public const int StateInvTFStart = 23;
54    public const int StateInvTFEnd = 24;
55    private const int FirstDynamicState = 25;
56
57    private const int StartState = StateExpr;
58    public int CurrentState { get; private set; }
59
60    public readonly List<string> stateNames;
61    private List<int>[] followStates;
62    private List<Action>[,] actions; // not every follow state is possible but this representation should be efficient
63    private List<string>[,] actionStrings; // just for printing
64    private readonly CodeGenerator codeGenerator;
65    private readonly ConstraintHandler constraintHandler;
66
67    public Automaton(double[][] vars, int maxVarsInExpression = 100,
68       bool allowProdOfVars = true,
69       bool allowExp = true,
70       bool allowLog = true,
71       bool allowInv = true,
72       bool allowMultipleTerms = false) {
73      int nVars = vars.Length;
74      stateNames = new List<string>() { string.Empty, "Expr", "ExprEnd", "TermStart", "TermEnd", "FactorStart", "FactorEnd", "VarFactorStart", "VarFactorEnd", "ExpFactorStart", "ExpFactorEnd", "LogFactorStart", "LogFactorEnd", "InvFactorStart", "InvFactorEnd", "ExpFStart", "ExpFEnd", "LogTStart", "LogTEnd", "LogTFStart", "LogTFEnd", "InvTStart", "InvTEnd", "InvTFStart", "InvTFEnd" };
75      codeGenerator = new CodeGenerator();
76      constraintHandler = new ConstraintHandler(maxVarsInExpression);
77      BuildAutomaton(nVars, allowProdOfVars, allowExp, allowLog, allowInv, allowMultipleTerms);
78
79      Reset();
80#if DEBUG
81      PrintAutomaton();
82#endif
83    }
84
85    // reverse notation ops
86    // Expr -> c 0 Term { '+' Term } '+' '*' c '+' 'exit'
87    // Term -> c Fact { '*' Fact } '*'
88    // Fact -> VarFact | ExpFact | LogFact | InvFact
89    // VarFact -> var_1 ... var_n
90    // ExpFact -> 1 ExpF { '*' ExpF } '*' c '*' 'exp' // c must be at end to allow scaling in evaluator
91    // ExpF    -> var_1 ... var_n
92    // LogFact -> 0 LogT { '+' LogT } '+' c '+' 'log' // c must be at end to allow scaling in evaluator
93    // LogT    -> c LogTF { '*' LogTF } '*'
94    // LogTF   -> var_1 ... var_n
95    // InvFact -> 1 InvT { '+' InvT } '+' 'inv'
96    // InvT    -> (var_1 ... var_n) c '*'
97    private void BuildAutomaton(int nVars,
98      bool allowProdOfVars = true,
99       bool allowExp = true,
100       bool allowLog = true,
101       bool allowInv = true,
102       bool allowMultipleTerms = false) {
103
104      int nStates = FirstDynamicState + 4 * nVars;
105      followStates = new List<int>[nStates];
106      actions = new List<Action>[nStates, nStates];
107      actionStrings = new List<string>[nStates, nStates];
108
109      // Expr -> c 0 Term { '+' Term } '+' '*' c '+' 'exit'
110      AddTransition(StateExpr, StateTermStart, () => {
111        codeGenerator.Reset();
112        codeGenerator.Emit1(OpCodes.LoadParamN);
113        codeGenerator.Emit1(OpCodes.LoadConst0);
114        constraintHandler.Reset();
115      }, "c 0, Reset");
116      AddTransition(StateTermEnd, StateExprEnd, () => {
117        codeGenerator.Emit1(OpCodes.Add);
118        codeGenerator.Emit1(OpCodes.Mul);
119        codeGenerator.Emit1(OpCodes.LoadParamN);
120        codeGenerator.Emit1(OpCodes.Add);
121        codeGenerator.Emit1(OpCodes.Exit);
122      }, "+*c+ exit");
123      if (allowMultipleTerms)
124        AddTransition(StateTermEnd, StateTermStart, () => {
125          codeGenerator.Emit1(OpCodes.Add);
126        }, "+");
127
128      // Term -> c Fact { '*' Fact } '*'
129      AddTransition(StateTermStart, StateFactorStart,
130        () => {
131          codeGenerator.Emit1(OpCodes.LoadParamN);
132          constraintHandler.StartTerm();
133        },
134        "c, StartTerm");
135      AddTransition(StateFactorEnd, StateTermEnd,
136        () => {
137          codeGenerator.Emit1(OpCodes.Mul);
138          constraintHandler.EndTerm();
139        },
140        "*, EndTerm");
141
142      AddTransition(StateFactorEnd, StateFactorStart,
143        () => { codeGenerator.Emit1(OpCodes.Mul); },
144        "*");
145
146
147      // Fact -> VarFact | ExpFact | LogFact | InvFact
148      if (allowProdOfVars)
149        AddTransition(StateFactorStart, StateVariableFactorStart, () => {
150          constraintHandler.StartFactor(StateVariableFactorStart);
151        }, "StartFactor");
152      if (allowExp)
153        AddTransition(StateFactorStart, StateExpFactorStart, () => {
154          constraintHandler.StartFactor(StateExpFactorStart);
155        }, "StartFactor");
156      if (allowLog)
157        AddTransition(StateFactorStart, StateLogFactorStart, () => {
158          constraintHandler.StartFactor(StateLogFactorStart);
159        }, "StartFactor");
160      if (allowInv)
161        AddTransition(StateFactorStart, StateInvFactorStart, () => {
162          constraintHandler.StartFactor(StateInvFactorStart);
163        }, "StartFactor");
164      AddTransition(StateVariableFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "EndFactor");
165      AddTransition(StateExpFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "EndFactor");
166      AddTransition(StateLogFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "EndFactor");
167      AddTransition(StateInvFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "EndFactor");
168
169      // VarFact -> var_1 ... var_n
170      // add dynamic states for each variable
171      int curDynVarState = FirstDynamicState;
172      for (int i = 0; i < nVars; i++) {
173        short varIdx = (short)i;
174        var varState = curDynVarState;
175        stateNames.Add("var_1");
176        AddTransition(StateVariableFactorStart, curDynVarState,
177          () => {
178            codeGenerator.Emit2(OpCodes.LoadVar, varIdx);
179            constraintHandler.AddVarToCurrentFactor(varState);
180          },
181          "var_" + varIdx + ", AddVar");
182        AddTransition(curDynVarState, StateVariableFactorEnd);
183        curDynVarState++;
184      }
185
186      // ExpFact -> 1 ExpF { '*' ExpF } '*' c '*' 'exp'
187      AddTransition(StateExpFactorStart, StateExpFStart,
188        () => {
189          codeGenerator.Emit1(OpCodes.LoadConst1);
190        },
191        "1");
192      AddTransition(StateExpFEnd, StateExpFactorEnd,
193        () => {
194          codeGenerator.Emit1(OpCodes.Mul);
195          codeGenerator.Emit1(OpCodes.LoadParamN);
196          codeGenerator.Emit1(OpCodes.Mul);
197          codeGenerator.Emit1(OpCodes.Exp);
198        },
199        "*c*exp");
200      AddTransition(StateExpFEnd, StateExpFStart,
201        () => { codeGenerator.Emit1(OpCodes.Mul); },
202        "*");
203
204      // ExpF    -> var_1 ... var_n
205      for (int i = 0; i < nVars; i++) {
206        short varIdx = (short)i;
207        int varState = curDynVarState;
208        stateNames.Add("var_2");
209        AddTransition(StateExpFStart, curDynVarState,
210          () => {
211            codeGenerator.Emit2(OpCodes.LoadVar, varIdx);
212            constraintHandler.AddVarToCurrentFactor(varState);
213          },
214          "var_" + varIdx + ", AddVar");
215        AddTransition(curDynVarState, StateExpFEnd);
216        curDynVarState++;
217      }
218
219      // must have c at end because of adjustment of c in evaluator
220      // LogFact -> 0 LogT { '+' LogT } '+' c '+' 'log'
221      AddTransition(StateLogFactorStart, StateLogTStart,
222        () => {
223          codeGenerator.Emit1(OpCodes.LoadConst0);
224        },
225        "0");
226      AddTransition(StateLogTEnd, StateLogFactorEnd,
227        () => {
228          codeGenerator.Emit1(OpCodes.Add);
229          codeGenerator.Emit1(OpCodes.LoadParamN);
230          codeGenerator.Emit1(OpCodes.Add);
231          codeGenerator.Emit1(OpCodes.Log);
232        },
233        "+c+log");
234      AddTransition(StateLogTEnd, StateLogTStart,
235        () => { codeGenerator.Emit1(OpCodes.Add); },
236        "+");
237
238      // LogT    -> c LogTF { '*' LogTF } '*'
239      AddTransition(StateLogTStart, StateLogTFStart,
240        () => {
241          codeGenerator.Emit1(OpCodes.LoadParamN);
242        },
243        "c");
244      AddTransition(StateLogTFEnd, StateLogTEnd,
245        () => {
246          codeGenerator.Emit1(OpCodes.Mul);
247        },
248        "*");
249      AddTransition(StateLogTFEnd, StateLogTFStart,
250        () => {
251          codeGenerator.Emit1(OpCodes.Mul);
252        },
253        "*");
254
255      // LogTF   -> var_1 ... var_n
256      for (int i = 0; i < nVars; i++) {
257        short varIdx = (short)i;
258        int varState = curDynVarState;
259        stateNames.Add("var_3");
260        AddTransition(StateLogTFStart, curDynVarState,
261          () => {
262            codeGenerator.Emit2(OpCodes.LoadVar, varIdx);
263            constraintHandler.AddVarToCurrentFactor(varState);
264          },
265          "var_" + varIdx + ", AddVar");
266        AddTransition(curDynVarState, StateLogTFEnd);
267        curDynVarState++;
268      }
269
270      // InvFact -> 1 InvT { '+' InvT } '+' 'inv'
271      AddTransition(StateInvFactorStart, StateInvTStart,
272        () => {
273          codeGenerator.Emit1(OpCodes.LoadConst1);
274        },
275        "c");
276      AddTransition(StateInvTEnd, StateInvFactorEnd,
277        () => {
278          codeGenerator.Emit1(OpCodes.Add);
279          codeGenerator.Emit1(OpCodes.Inv);
280        },
281        "+inv");
282      AddTransition(StateInvTEnd, StateInvTStart,
283        () => { codeGenerator.Emit1(OpCodes.Add); },
284        "+");
285
286      // InvT    -> c InvTF { '*' InvTF } '*'
287      AddTransition(StateInvTStart, StateInvTFStart,
288        () => {
289          codeGenerator.Emit1(OpCodes.LoadParamN);
290        },
291        "c");
292      AddTransition(StateInvTFEnd, StateInvTEnd,
293        () => {
294          codeGenerator.Emit1(OpCodes.Mul);
295        },
296        "*");
297      AddTransition(StateInvTFEnd, StateInvTFStart,
298        () => {
299          codeGenerator.Emit1(OpCodes.Mul);
300        },
301        "*");
302
303      // InvTF    -> (var_1 ... var_n) c '*'
304      for (int i = 0; i < nVars; i++) {
305        short varIdx = (short)i;
306        int varState = curDynVarState;
307        stateNames.Add("var_4");
308        AddTransition(StateInvTFStart, curDynVarState,
309          () => {
310            codeGenerator.Emit2(OpCodes.LoadVar, varIdx);
311            constraintHandler.AddVarToCurrentFactor(varState);
312          },
313          "var_" + varIdx + ", AddVar");
314        AddTransition(curDynVarState, StateInvTFEnd);
315        curDynVarState++;
316      }
317
318      followStates[StateExprEnd] = new List<int>(); // no follow states
319    }
320
321    private void AddTransition(int fromState, int toState) {
322      if (followStates[fromState] == null) followStates[fromState] = new List<int>();
323      followStates[fromState].Add(toState);
324    }
325    private void AddTransition(int fromState, int toState, Action action, string str) {
326      if (followStates[fromState] == null) followStates[fromState] = new List<int>();
327      followStates[fromState].Add(toState);
328
329      if (actions[fromState, toState] == null) {
330        actions[fromState, toState] = new List<Action>();
331        actionStrings[fromState, toState] = new List<string>();
332      }
333
334      actions[fromState, toState].Add(action);
335      actionStrings[fromState, toState].Add(str);
336    }
337
338    private readonly int[] followStatesBuf = new int[1000];
339    public void FollowStates(int state, out int[] buf, out int nElements) {
340      // return followStates[state]
341      //   .Where(s => s < FirstDynamicState || s >= minVarIdx) // for variables we only allow non-decreasing state sequences
342      //   // the following states imply an additional variable being added to the expression
343      //   // F, Sum, Prod
344      //   .Where(s => (s != StateF && s != StateSum && s != StateProd) || variablesRemaining > 0);
345
346      // for loop instead of where iterator
347      var fs = followStates[state];
348      int j = 0;
349      //Console.Write(stateNames[CurrentState] + " allowed: ");
350      for (int i = 0; i < fs.Count; i++) {
351        var s = fs[i];
352        if (constraintHandler.IsAllowedFollowState(state, s)) {
353          //Console.Write(s + " ");
354          followStatesBuf[j++] = s;
355        }
356      }
357      //Console.WriteLine();
358      buf = followStatesBuf;
359      nElements = j;
360    }
361
362
363    public void Goto(int targetState) {
364      //Console.WriteLine("->{0}", stateNames[targetState]);
365      // Contract.Assert(FollowStates(CurrentState).Contains(targetState));
366
367      if (actions[CurrentState, targetState] != null)
368        actions[CurrentState, targetState].ForEach(a => a()); // execute all actions
369      CurrentState = targetState;
370    }
371
372    public bool IsFinalState(int s) {
373      return s == StateExprEnd;
374    }
375
376    public void GetCode(out byte[] code, out int nParams) {
377      codeGenerator.GetCode(out code, out nParams);
378    }
379
380    public void Reset() {
381      CurrentState = StartState;
382      codeGenerator.Reset();
383      constraintHandler.Reset();
384    }
385
386#if DEBUG
387    public void PrintAutomaton() {
388      using (var writer = new StreamWriter("automaton.gv")) {
389        writer.WriteLine("digraph {");
390        // writer.WriteLine("rankdir=LR");
391        int[] fs;
392        int nFs;
393        for (int s = StartState; s < stateNames.Count; s++) {
394          for (int i = 0; i < followStates[s].Count; i++) {
395            if (followStates[s][i] <= 0) continue;
396            var followS = followStates[s][i];
397            var label = actionStrings[s, followS] != null ? string.Join(" , ", actionStrings[s, followS]) : "";
398            writer.WriteLine("{0} -> {1} [ label = \"{2}\" ];", stateNames[s], stateNames[followS], label);
399          }
400        }
401        writer.WriteLine("}");
402      }
403    }
404#endif
405  }
406}
Note: See TracBrowser for help on using the repository browser.