source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/SymbolicExpressionGenerator.cs @ 13645

Last change on this file since 13645 was 13645, checked in by gkronber, 3 years ago

#2581: added an MCTS for symbolic regression models

File size: 7.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22using System;
23using System.Diagnostics.Contracts;
24using System.Linq;
25
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Problems.DataAnalysis.Symbolic;
28
29namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression {
30
31  // translates byte code to a symbolic expression tree
32  internal class SymbolicExpressionTreeGenerator {
33    const int MaxStackSize = 100;
34    private readonly ISymbolicExpressionTreeNode[] stack;
35    private readonly ConstantTreeNode const0;
36    private readonly ConstantTreeNode const1;
37    private readonly Addition addSy;
38    private readonly Multiplication mulSy;
39    private readonly Exponential expSy;
40    private readonly Logarithm logSy;
41    private readonly Division divSy;
42    private readonly VariableTreeNode varNode;
43    private readonly string[] variableNames;
44    private readonly StartSymbol startSy;
45    private readonly ProgramRootSymbol progRootSy;
46
47    public SymbolicExpressionTreeGenerator(string[] variableNames) {
48      stack = new ISymbolicExpressionTreeNode[MaxStackSize];
49      var grammar = new TypeCoherentExpressionGrammar();
50      this.variableNames = variableNames;
51
52      grammar.ConfigureAsDefaultRegressionGrammar();
53      const0 = (ConstantTreeNode)grammar.Symbols.OfType<Constant>().First().CreateTreeNode();
54      const0.Value = 0;
55      const1 = (ConstantTreeNode)grammar.Symbols.OfType<Constant>().First().CreateTreeNode();
56      const1.Value = 1;
57      varNode = (VariableTreeNode)grammar.Symbols.OfType<Variable>().First().CreateTreeNode();
58
59      addSy = grammar.AllowedSymbols.OfType<Addition>().First();
60      mulSy = grammar.AllowedSymbols.OfType<Multiplication>().First();
61      logSy = grammar.AllowedSymbols.OfType<Logarithm>().First();
62      expSy = grammar.AllowedSymbols.OfType<Exponential>().First();
63      divSy = grammar.AllowedSymbols.OfType<Division>().First();
64
65      progRootSy = grammar.AllowedSymbols.OfType<ProgramRootSymbol>().First();
66      startSy = grammar.AllowedSymbols.OfType<StartSymbol>().First();
67    }
68
69    public ISymbolicExpressionTreeNode Exec(byte[] code, double[] consts, int nParams, double[] scalingFactor, double[] scalingOffset) {
70      int topOfStack = -1;
71      int pc = 0;
72      int nextParamIdx = -1;
73      OpCodes op;
74      short arg;
75      while (true) {
76        ReadNext(code, ref pc, out op, out arg);
77        switch (op) {
78          case OpCodes.Nop: break;
79          case OpCodes.LoadConst0: {
80              ++topOfStack;
81              stack[topOfStack] = (ISymbolicExpressionTreeNode)const0.Clone();
82              break;
83            }
84          case OpCodes.LoadConst1: {
85              ++topOfStack;
86              stack[topOfStack] = (ISymbolicExpressionTreeNode)const1.Clone();
87              break;
88            }
89          case OpCodes.LoadParamN: {
90              ++topOfStack;
91              var p = (ConstantTreeNode)const1.Clone(); // value will be tuned later (evaluator and tree generator both use 1 as initial values)
92              p.Value = consts[++nextParamIdx];
93              stack[topOfStack] = p;
94              break;
95            }
96          case OpCodes.LoadVar:
97            ++topOfStack;
98            if (scalingOffset != null) {
99              var sumNode = addSy.CreateTreeNode();
100              var varNode = (VariableTreeNode)this.varNode.Clone();
101              var constNode = (ConstantTreeNode)const0.Clone();
102              varNode.Weight = scalingFactor[arg];
103              varNode.VariableName = variableNames[arg];
104              constNode.Value = scalingOffset[arg];
105              sumNode.AddSubtree(varNode);
106              sumNode.AddSubtree(constNode);
107              stack[topOfStack] = sumNode;
108            } else {
109              var varNode = (VariableTreeNode)this.varNode.Clone();
110              varNode.Weight = 1.0;
111              varNode.VariableName = variableNames[arg];
112              stack[topOfStack] = varNode;
113            }
114            break;
115          case OpCodes.Add: {
116              var t1 = stack[topOfStack];
117              var t2 = stack[topOfStack - 1];
118              topOfStack--;
119              if (t2.Symbol is Addition) {
120                t2.AddSubtree(t1);
121              } else {
122                var addNode = addSy.CreateTreeNode();
123                addNode.AddSubtree(t1);
124                addNode.AddSubtree(t2);
125                stack[topOfStack] = addNode;
126              }
127              break;
128            }
129          case OpCodes.Mul: {
130              var t1 = stack[topOfStack];
131              var t2 = stack[topOfStack - 1];
132              topOfStack--;
133              if (t2.Symbol is Multiplication) {
134                t2.AddSubtree(t1);
135              } else {
136                var mulNode = mulSy.CreateTreeNode();
137                mulNode.AddSubtree(t1);
138                mulNode.AddSubtree(t2);
139                stack[topOfStack] = mulNode;
140              }
141              break;
142            }
143          case OpCodes.Log: {
144              var v1 = stack[topOfStack];
145              var logNode = logSy.CreateTreeNode();
146              logNode.AddSubtree(v1);
147              stack[topOfStack] = logNode;
148              break;
149            }
150          case OpCodes.Exp: {
151              var v1 = stack[topOfStack];
152              var expNode = expSy.CreateTreeNode();
153              expNode.AddSubtree(v1);
154              stack[topOfStack] = expNode;
155              break;
156            }
157          case OpCodes.Inv: {
158              var v1 = stack[topOfStack];
159              var divNode = divSy.CreateTreeNode();
160              divNode.AddSubtree(v1);
161              stack[topOfStack] = divNode;
162              break;
163            }
164          case OpCodes.Exit:
165            Contract.Assert(topOfStack == 0);
166            var rootNode = progRootSy.CreateTreeNode();
167            var startNode = startSy.CreateTreeNode();
168            startNode.AddSubtree(stack[topOfStack]);
169            rootNode.AddSubtree(startNode);
170            return rootNode;
171        }
172      }
173    }
174
175    private void ReadNext(byte[] code, ref int pc, out OpCodes op, out short s) {
176      op = (OpCodes)Enum.ToObject(typeof(OpCodes), code[pc++]);
177      s = 0;
178      if (op == OpCodes.LoadVar) {
179        s = (short)(((short)code[pc] << 8) | (short)code[pc + 1]);
180        pc += 2;
181      }
182    }
183  }
184}
Note: See TracBrowser for help on using the repository browser.