Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3138_Shape_Constraints_Transformations/HeuristicLab.ExtLibs/HeuristicLab.NativeInterpreter/0.1/NativeInterpreter-0.1/src/interpreter.h @ 18180

Last change on this file since 18180 was 18180, checked in by dpiringe, 2 years ago

#3138

  • merged trunk into branch
File size: 10.0 KB
Line 
1#ifndef NATIVE_TREE_INTERPRETER_CLANG_H
2#define NATIVE_TREE_INTERPRETER_CLANG_H
3
4#include "vector_operations.h"
5#include "instruction.h"
6
7inline double evaluate(instruction *code, int len, int row) noexcept
8{
9    for (int i = len - 1; i >= 0; --i)
10    {
11        instruction &in = code[i];
12        switch (in.opcode)
13        {
14            case OpCodes::Number: /* nothing to do */ break;
15            case OpCodes::Constant: /* nothing to do */ break;
16            case OpCodes::Var:
17                {
18                    in.value = in.weight * in.data[row];
19                    break;
20                }
21            case OpCodes::Add:
22                {
23                    in.value = code[in.childIndex].value;
24                    for (int j = 1; j < in.narg; ++j)
25                    {
26                        in.value += code[in.childIndex + j].value;
27                    }
28                    break;
29                }
30            case OpCodes::Sub:
31                {
32                    in.value = code[in.childIndex].value;
33                    for (int j = 1; j < in.narg; ++j)
34                    {
35                        in.value -= code[in.childIndex + j].value;
36                    }
37                    if (in.narg == 1)
38                    {
39                        in.value = -in.value;
40                    }
41                    break;
42                }
43            case OpCodes::Mul:
44                {
45                    in.value = code[in.childIndex].value;
46                    for (int j = 1; j < in.narg; ++j)
47                    {
48                        in.value *= code[in.childIndex + j].value;
49                    }
50                    break;
51                }
52            case OpCodes::Div:
53                {
54                    in.value = code[in.childIndex].value;
55                    for (int j = 1; j < in.narg; ++j)
56                    {
57                        in.value /= code[in.childIndex + j].value;
58                    }
59                    if (in.narg == 1)
60                    {
61                        in.value = 1 / in.value;
62                    }
63                    break;
64                }
65            case OpCodes::Exp:
66                {
67                    in.value = hl_exp(code[in.childIndex].value);
68                    break;
69                }
70            case OpCodes::Log:
71                {
72                    in.value = hl_log(code[in.childIndex].value);
73                    break;
74                }
75            case OpCodes::Sin:
76                {
77                    in.value = hl_sin(code[in.childIndex].value);
78                    break;
79                }
80            case OpCodes::Cos:
81                {
82                    in.value = hl_cos(code[in.childIndex].value);
83                    break;
84                }
85            case OpCodes::Tan:
86                {
87                    in.value = hl_tan(code[in.childIndex].value);
88                    break;
89                }
90            case OpCodes::Tanh:
91                {
92                    in.value = hl_tanh(code[in.childIndex].value);
93                    break;
94                }
95            case OpCodes::Power:
96                {
97                    double x = code[in.childIndex].value;
98                    double y = hl_round(code[in.childIndex + 1].value);
99                    in.value = hl_pow(x, y);
100                    break;
101                }
102            case OpCodes::Root:
103                {
104                    double x = code[in.childIndex].value;
105                    double y = hl_round(code[in.childIndex + 1].value);
106                    in.value = hl_pow(x, 1 / y);
107                    break;
108                }
109            case OpCodes::Sqrt:
110                {
111                    in.value = hl_pow(code[in.childIndex].value, 1./2.);
112                    break;
113                }
114            case OpCodes::Square:
115                {
116                    in.value = hl_pow(code[in.childIndex].value, 2.);
117                    break;
118                }
119            case OpCodes::CubeRoot:
120                {
121                    in.value = hl_cbrt(code[in.childIndex].value);
122                    break;
123                }
124            case OpCodes::Cube:
125                {
126                    in.value = hl_pow(code[in.childIndex].value, 3.);
127                    break;
128                }
129            case OpCodes::Absolute:
130                {
131                    in.value = std::fabs(code[in.childIndex].value);
132                    break;
133                }
134            case OpCodes::AnalyticalQuotient:
135                {
136                    double x = code[in.childIndex].value;
137                    double y = code[in.childIndex + 1].value;
138                    in.value = x / hl_sqrt(1 + y*y);
139                    break;
140                }
141            default: in.value = NAN;
142        }
143    }
144    return code[0].value;
145}
146
147inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
148{
149    for (int i = 0; i < batchSize; ++i)
150    {
151        auto row = rows[rowIndex + i];
152        in.buf[i] = in.weight * in.data[row];
153    }
154}
155
156inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
157{
158    for (int i = len - 1; i >= 0; --i)
159    {
160        instruction &in = code[i];
161        switch (in.opcode)
162        {
163            case OpCodes::Var:
164                {
165                    load_data(in, rows, rowIndex, batchSize); // buffer data
166                    break;
167                }
168            case OpCodes::Number: /* nothing to do because buffers for numbers are already set */ break;
169            case OpCodes::Constant: /* nothing to do because buffers for constants are already set */ break;
170            case OpCodes::Add:
171                {
172                    load(in.buf, code[in.childIndex].buf);
173                    for (int j = 1; j < in.narg; ++j)
174                    {
175                        add(in.buf, code[in.childIndex + j].buf);
176                    }
177                    break;
178                }
179            case OpCodes::Sub:
180                {
181                    if (in.narg == 1)
182                    {
183                        neg(in.buf, code[in.childIndex].buf);
184                        break;
185                    }
186                    else
187                    {
188                        load(in.buf, code[in.childIndex].buf);
189                        for (int j = 1; j < in.narg; ++j)
190                        {
191                            sub(in.buf, code[in.childIndex + j].buf);
192                        }
193                    }
194                    break;
195                }
196            case OpCodes::Mul:
197                {
198                    load(in.buf, code[in.childIndex].buf);
199                    for (int j = 1; j < in.narg; ++j)
200                    {
201                        mul(in.buf, code[in.childIndex + j].buf);
202                    }
203                    break;
204                }
205            case OpCodes::Div:
206                {
207                    if (in.narg == 1)
208                    {
209                        inv(in.buf, code[in.childIndex].buf);
210                        break;
211                    }
212                    else
213                    {
214                        load(in.buf, code[in.childIndex].buf);
215                        for (int j = 1; j < in.narg; ++j)
216                        {
217                            div(in.buf, code[in.childIndex + j].buf);
218                        }
219                    }
220                    break;
221                }
222            case OpCodes::Sin:
223                {
224                    sin(in.buf, code[in.childIndex].buf);
225                    break;
226                }
227            case OpCodes::Cos:
228                {
229                    cos(in.buf, code[in.childIndex].buf);
230                    break;
231                }
232            case OpCodes::Tan:
233                {
234                    tan(in.buf, code[in.childIndex].buf);
235                    break;
236                }
237            case OpCodes::Tanh:
238                {
239                    tanh(in.buf, code[in.childIndex].buf);
240                    break;
241                }
242            case OpCodes::Log:
243                {
244                    log(in.buf, code[in.childIndex].buf);
245                    break;
246                }
247            case OpCodes::Exp:
248                {
249                    exp(in.buf, code[in.childIndex].buf);
250                    break;
251                }
252            case OpCodes::Power:
253                {
254                    load(in.buf, code[in.childIndex].buf);
255                    pow(in.buf, code[in.childIndex + 1].buf);
256                    break;
257                }
258            case OpCodes::Root:
259                {
260                    load(in.buf, code[in.childIndex].buf);
261                    root(in.buf, code[in.childIndex + 1].buf);
262                    break;
263                }
264            case OpCodes::Square:
265                {
266                    pow(in.buf, code[in.childIndex].buf, 2.);
267                    break;
268                }
269            case OpCodes::Sqrt:
270                {
271                    pow(in.buf, code[in.childIndex].buf, 1./2.);
272                    break;
273                }
274            case OpCodes::CubeRoot:
275                {
276                    cbrt(in.buf, code[in.childIndex].buf);
277                    break;
278                }
279            case OpCodes::Cube:
280                {
281                    pow(in.buf, code[in.childIndex].buf, 3.);
282                    break;
283                }
284            case OpCodes::Absolute:
285                {
286                    abs(in.buf, code[in.childIndex].buf);
287                    break;
288                }
289            case OpCodes::AnalyticalQuotient:
290                {
291                    load(in.buf, code[in.childIndex].buf);
292                    analytical_quotient(in.buf, code[in.childIndex + 1].buf);
293                    break;
294                }
295            default: load(in.buf, NAN);
296            }
297    }
298}
299
300#endif
Note: See TracBrowser for help on using the repository browser.