1 | // Package symexpr implements symbolic equations as an AST.
2 | // It aims to provide ease of dynamic manipulation of the tree.
3 | //
4 | // This work comes out of my Masters thesis at Binghamton University
5 | // and is geared towards Symbolic Regression.
6 | //
7 | package symexpr
8 |
9 | import (
10 | "fmt"
11 | "math"
12 | )
13 |
14 | type ExprType int
15 |
16 | const (
17 | NULL ExprType = iota
18 |
20 | CONSTANT // indexed constant, useful for non-linear regression tasks
21 | CONSTANTF // floating point constant
22 | TIME // useful when looking at time series and RK4 integration
23 | SYSTEM // i use this like a variable that changes between experiments, but not with time (mass,size,etc.)
24 | VAR // a canonical variable
26 |
28 | NEG
29 | ABS
30 | SQRT
31 | SIN
32 | COS
33 | TAN
34 | EXP
35 | LOG
37 |
38 | POWI // Expr^Integer
39 | POWF // Expr^Float
40 | POWE // Expr^Expr
41 | DIV // Expr/Expr
42 |
43 | ADD // these can have more than two child nodes
44 | MUL // this eases sorting and simplification
45 |
47 | STARTVAR // for serialization reduction of variables
48 | )
49 |
50 | // Expr is the interface to all node types for the AST of mathematical expression
51 | //
52 | type Expr interface {
53 |
54 | // types.go (this file)
55 | ExprType() ExprType
56 | Clone() Expr
57 |
58 | // stats.go
59 | Size() int
60 | Depth() int
61 | Height() int
62 | NumChildren() int
63 | CalcExprStats()
64 | calcExprStatsR(depth int, pos *int)
65 |
66 | // compare.go
67 | AmILess(rhs Expr) bool
68 | AmIEqual(rhs Expr) bool
69 | AmISame(rhs Expr) bool // equality without coefficient values/index
70 | AmIAlmostSame(rhs Expr) bool // adds flexibility to mul comparison to AmISame
71 | Sort()
72 |
73 | // has.go
74 | HasVar() bool
75 | HasVarI(i int) bool
76 | NumVar() int
77 |
78 | // DFS for a floating point valued ConstantF
79 | HasConst() bool
80 | // DFS for a indexed valued Constant
81 | HasConstI(i int) bool
82 | // Counts the number of indexed Constant nodes
83 | NumConstants() int
84 |
85 | // convert.go
86 |
87 | // Converts indexed Constant nodes to ConstantF nodes
88 | // using the input slice as the values for replacement
89 | ConvertToConstantFs(cs []float64) Expr
90 | // DFS converting float valued constants to indexed constants
91 | // the input should be an empty slice
92 | // the output is an appended slice the size = |ConstantF| in the tree
93 | ConvertToConstants(cs []float64) ([]float64, Expr)
94 | // IndexConstants( ci int ) int
95 |
96 | // getset.go
97 | // DFS retrieval of a node by index
98 | GetExpr(pos *int) Expr
99 | // DFS replacement of a node and it's subtree
100 | // replaced is used to discontinue the DFS after replacement
101 | // replace_me gets triggered when pos == 0 and informs the parent node to replace the respective child node
102 | SetExpr(pos *int, e Expr) (replace_me, replaced bool)
103 |
104 | // print.go
105 |
106 | // prints the AST
107 | String() string
108 |
109 | // creates an integer representation of the AST in ~prefix notation
110 | // The input is an empty slice, output is the representation.
111 | // The output is generally the ExprType integer value
112 | // Associative operators (+ & *) also include the number of children.
113 | // The terminal nodes include the index when appropriate.
114 | Serial([]int) []int
115 | StackSerial([]int) []int
116 |
117 | // Pretty print acts like String, but replaces the internal indexed
118 | // formatting with user specified strings and values
119 | PrettyPrint(dnames, snames []string, cvals []float64) string
120 | // WriteString( buf *bytes.Buffer )
121 |
122 | // Similar to PrettyPrint, but in latex format
123 | Latex(dnames, snames []string, cvals []float64) string
124 | Javascript(dnames, snames []string, cvals []float64) string
125 |
126 | // eval.go
127 | // Evaluates an expression at one point
128 | // t is a time value
129 | // x are the input Var values
130 | // c are the indexed Constant values
131 | // s are the indexed System values
132 | // the output is the result of DFS evaluation
133 | Eval(t float64, x, c, s []float64) float64
134 |
135 | // simp.go
136 | Simplify(rules SimpRules) Expr
137 |
138 | // deriv.go
139 | // Calculate the derivative w.r.t. Var_i
140 | DerivVar(i int) Expr
141 | // Calculate the derivative w.r.t. Constant_i
142 | DerivConst(i int) Expr
143 | }
144 |
145 | type ExprArray []Expr
146 |
147 | func (p ExprArray) Len() int { return len(p) }
148 | func (p ExprArray) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
149 | func (p ExprArray) Less(i, j int) bool {
150 | if p[i] == nil && p[j] == nil {
151 | return false
152 | }
153 | if p[i] == nil {
154 | return false
155 | }
156 | if p[j] == nil {
157 | return true
158 | }
159 | return p[i].AmILess(p[j])
160 | }
161 |
162 | // Null Leaf (shouldn't really appear)
163 | // This is a sample for the other node types
164 | type Null struct {
165 | ExprStats
166 | }
167 |
168 | func NewNull() *Null { return new(Null) }
169 | func (n *Null) ExprType() ExprType { return NULL }
170 | func (n *Null) Clone() Expr { return NewNull() }
171 |
172 | func (n *Null) CalcExprStats() {
173 | n.depth = 1
174 | n.height = 1
175 | n.size = 1
176 | n.pos = 0
177 | n.numchld = 0
178 | }
179 | func (n *Null) calcExprStatsR(depth int, pos *int) {
180 | n.depth = depth + 1
181 | n.height = 1
182 | n.size = 1
183 | n.pos = *pos
184 | (*pos)++
185 | n.numchld = 0
186 | }
187 |
188 | func (n *Null) AmILess(r Expr) bool { return NULL < r.ExprType() }
189 | func (n *Null) AmIEqual(r Expr) bool { return r.ExprType() == NULL }
190 | func (n *Null) AmISame(r Expr) bool { return r.ExprType() == NULL }
191 | func (n *Null) AmIAlmostSame(r Expr) bool { return r.ExprType() == NULL }
192 | func (n *Null) Sort() { return }
193 |
194 | func (n *Null) HasVar() bool { return false }
195 | func (n *Null) HasVarI(i int) bool { return false }
196 | func (n *Null) NumVar() int { return 0 }
197 | func (n *Null) HasConst() bool { return false }
198 | func (n *Null) HasConstI(i int) bool { return false }
199 | func (n *Null) NumConstants() int { return 0 }
200 |
201 | func (n *Null) ConvertToConstantFs(cs []float64) Expr { return n }
202 | func (n *Null) ConvertToConstants(cs []float64) ([]float64, Expr) { return cs, n }
203 |
204 | func (n *Null) GetExpr(pos *int) Expr {
205 | if (*pos) == 0 {
206 | return n
207 | }
208 | (*pos)--
209 | return nil
210 | }
211 | func (n *Null) SetExpr(pos *int, e Expr) (replace_me, replaced bool) {
212 | if (*pos) == 0 {
213 | return true, false
214 | }
215 | (*pos)--
216 | return false, false
217 | }
218 |
219 | func (n *Null) String() string { return "NULL" }
220 | func (n *Null) Serial(sofar []int) []int { return append(sofar, int(NULL)) }
221 | func (n *Null) StackSerial(sofar []int) []int { return append(sofar, int(NULL)) }
222 | func (n *Null) PrettyPrint(dnames, snames []string, cvals []float64) string { return "NULL" }
223 | func (n *Null) Latex(dnames, snames []string, cvals []float64) string { return "NULL" }
224 | func (n *Null) Javascript(dnames, snames []string, cvals []float64) string { return "null" }
225 |
226 | func (n *Null) Eval(t float64, x, c, s []float64) float64 { return math.NaN() }
227 |
228 | func (n *Null) Simplify(rules SimpRules) Expr { return n }
229 |
230 | func (n *Null) DerivConst(i int) Expr { return &ConstantF{F: 0} }
231 | func (n *Null) DerivVar(i int) Expr { return &ConstantF{F: 0} }
232 |
233 | func DumpExprTypes() {
234 | fmt.Printf("ExprTypes:\n")
235 | fmt.Printf("---------------\n")
236 |
237 | fmt.Printf("NULL: %d\n", int(NULL))
238 |
239 | fmt.Printf("STARTLEAF: %d\n", int(STARTLEAF))
240 | fmt.Printf("CONSTANT: %d\n", int(CONSTANT))
241 | fmt.Printf("TIME: %d\n", int(TIME))
242 | fmt.Printf("SYSTEM: %d\n", int(SYSTEM))
243 | fmt.Printf("VAR: %d\n", int(VAR))
244 | fmt.Printf("LASTLEAF: %d\n", int(LASTLEAF))
245 |
246 | fmt.Printf("STARTFUNC: %d\n", int(STARTFUNC))
247 | fmt.Printf("NEG: %d\n", int(NEG))
248 | fmt.Printf("ABS: %d\n", int(ABS))
249 | fmt.Printf("SQRT: %d\n", int(SQRT))
250 | fmt.Printf("SIN: %d\n", int(SIN))
251 | fmt.Printf("COS: %d\n", int(COS))
252 | fmt.Printf("TAN: %d\n", int(TAN))
253 | fmt.Printf("EXP: %d\n", int(EXP))
254 | fmt.Printf("LOG: %d\n", int(LOG))
255 | fmt.Printf("LASTFUNC: %d\n", int(LASTFUNC))
256 |
257 | fmt.Printf("POWI: %d\n", int(POWI))
258 | fmt.Printf("POWF: %d\n", int(POWF))
259 | fmt.Printf("POWE: %d\n", int(POWE))
260 | fmt.Printf("DIV: %d\n", int(DIV))
261 |
262 | fmt.Printf("ADD: %d\n", int(ADD))
263 | fmt.Printf("MUL: %d\n", int(MUL))
264 |
265 | fmt.Printf("EXPR_MAX: %d\n", int(EXPR_MAX))
266 | fmt.Printf("STARTVAR: %d\n", int(STARTVAR))
267 |
268 | }
269 |
270 | func (e ExprType) String() string {
271 | switch e {
272 | case NULL:
273 | return "NULL"
274 | case STARTLEAF:
275 | return "STARTLEAF"
276 | case CONSTANT:
277 | return "CONSTANT"
278 | case CONSTANTF:
279 | return "CONSTANTF"
280 | case TIME:
281 | return "TIME"
282 | case SYSTEM:
283 | return "SYSTEM"
284 | case VAR:
285 | return "VAR"
286 | case LASTLEAF:
287 | return "LASTLEAF"
288 | case STARTFUNC:
289 | return "STARTFUNC"
290 | case NEG:
291 | return "NEG"
292 | case ABS:
293 | return "ABS"
294 | case SQRT:
295 | return "SQRT"
296 | case SIN:
297 | return "SIN"
298 | case COS:
299 | return "COS"
300 | case TAN:
301 | return "TAN"
302 | case EXP:
303 | return "EXP"
304 | case LOG:
305 | return "LOG"
306 | case LASTFUNC:
307 | return "LASTFUNC"
308 | case POWI:
309 | return "POWI"
310 | case POWF:
311 | return "POWF"
312 | case POWE:
313 | return "POWE"
314 | case DIV:
315 | return "DIV"
316 | case ADD:
317 | return "ADD"
318 | case MUL:
319 | return "MUL"
320 | case EXPR_MAX:
321 | return "EXPR_MAX"
322 | case STARTVAR:
323 | return "STARTVAR"
324 | default:
325 | return "Unknown ExprType"
326 | }
327 | return "Unknown ExprType"
328 | }