Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.ExtLibs/HeuristicLab.NativeInterpreter/0.1/NativeInterpreter-0.1/src/interpreter.h @ 16892

Last change on this file since 16892 was 16892, checked in by gkronber, 5 years ago

#2925 merged r16661:16890 from trunk to branch

File size: 9.8 KB
Line 
1#ifndef NATIVE_TREE_INTERPRETER_CLANG_H
2#define NATIVE_TREE_INTERPRETER_CLANG_H
3
4#include "vector_operations.h"
5#include "instruction.h"
6
7inline double evaluate(instruction *code, int len, int row) noexcept
8{
9    for (int i = len - 1; i >= 0; --i)
10    {
11        instruction &in = code[i];
12        switch (in.opcode)
13        {
14            case OpCodes::Const: /* nothing to do */ break;
15            case OpCodes::Var:
16                {
17                    in.value = in.weight * in.data[row];
18                    break;
19                }
20            case OpCodes::Add:
21                {
22                    in.value = code[in.childIndex].value;
23                    for (int j = 1; j < in.narg; ++j)
24                    {
25                        in.value += code[in.childIndex + j].value;
26                    }
27                    break;
28                }
29            case OpCodes::Sub:
30                {
31                    in.value = code[in.childIndex].value;
32                    for (int j = 1; j < in.narg; ++j)
33                    {
34                        in.value -= code[in.childIndex + j].value;
35                    }
36                    if (in.narg == 1)
37                    {
38                        in.value = -in.value;
39                    }
40                    break;
41                }
42            case OpCodes::Mul:
43                {
44                    in.value = code[in.childIndex].value;
45                    for (int j = 1; j < in.narg; ++j)
46                    {
47                        in.value *= code[in.childIndex + j].value;
48                    }
49                    break;
50                }
51            case OpCodes::Div:
52                {
53                    in.value = code[in.childIndex].value;
54                    for (int j = 1; j < in.narg; ++j)
55                    {
56                        in.value /= code[in.childIndex + j].value;
57                    }
58                    if (in.narg == 1)
59                    {
60                        in.value = 1 / in.value;
61                    }
62                    break;
63                }
64            case OpCodes::Exp:
65                {
66                    in.value = hl_exp(code[in.childIndex].value);
67                    break;
68                }
69            case OpCodes::Log:
70                {
71                    in.value = hl_log(code[in.childIndex].value);
72                    break;
73                }
74            case OpCodes::Sin:
75                {
76                    in.value = hl_sin(code[in.childIndex].value);
77                    break;
78                }
79            case OpCodes::Cos:
80                {
81                    in.value = hl_cos(code[in.childIndex].value);
82                    break;
83                }
84            case OpCodes::Tan:
85                {
86                    in.value = hl_tan(code[in.childIndex].value);
87                    break;
88                }
89            case OpCodes::Tanh:
90                {
91                    in.value = hl_tanh(code[in.childIndex].value);
92                    break;
93                }
94            case OpCodes::Power:
95                {
96                    double x = code[in.childIndex].value;
97                    double y = hl_round(code[in.childIndex + 1].value);
98                    in.value = hl_pow(x, y);
99                    break;
100                }
101            case OpCodes::Root:
102                {
103                    double x = code[in.childIndex].value;
104                    double y = hl_round(code[in.childIndex + 1].value);
105                    in.value = hl_pow(x, 1 / y);
106                    break;
107                }
108            case OpCodes::Sqrt:
109                {
110                    in.value = hl_pow(code[in.childIndex].value, 1./2.);
111                    break;
112                }
113            case OpCodes::Square:
114                {
115                    in.value = hl_pow(code[in.childIndex].value, 2.);
116                    break;
117                }
118            case OpCodes::CubeRoot:
119                {
120                    in.value = hl_pow(code[in.childIndex].value, 1./3.);
121                    break;
122                }
123            case OpCodes::Cube:
124                {
125                    in.value = hl_pow(code[in.childIndex].value, 3.);
126                    break;
127                }
128            case OpCodes::Absolute:
129                {
130                    in.value = std::fabs(code[in.childIndex].value);
131                    break;
132                }
133            case OpCodes::AnalyticalQuotient:
134                {
135                    double x = code[in.childIndex].value;
136                    double y = code[in.childIndex + 1].value;
137                    in.value = x / hl_sqrt(1 + y*y);
138                    break;
139                }
140            default: in.value = NAN;
141        }
142    }
143    return code[0].value;
144}
145
146inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
147{
148    for (int i = 0; i < batchSize; ++i)
149    {
150        auto row = rows[rowIndex + i];
151        in.buf[i] = in.weight * in.data[row];
152    }
153}
154
155inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
156{
157    for (int i = len - 1; i >= 0; --i)
158    {
159        instruction &in = code[i];
160        switch (in.opcode)
161        {
162            case OpCodes::Var:
163                {
164                    load_data(in, rows, rowIndex, batchSize); // buffer data
165                    break;
166                }
167            case OpCodes::Const: /* nothing to do because buffers for constants are already set */ break;
168            case OpCodes::Add:
169                {
170                    load(in.buf, code[in.childIndex].buf);
171                    for (int j = 1; j < in.narg; ++j)
172                    {
173                        add(in.buf, code[in.childIndex + j].buf);
174                    }
175                    break;
176                }
177            case OpCodes::Sub:
178                {
179                    if (in.narg == 1)
180                    {
181                        neg(in.buf, code[in.childIndex].buf);
182                        break;
183                    }
184                    else
185                    {
186                        load(in.buf, code[in.childIndex].buf);
187                        for (int j = 1; j < in.narg; ++j)
188                        {
189                            sub(in.buf, code[in.childIndex + j].buf);
190                        }
191                    }
192                    break;
193                }
194            case OpCodes::Mul:
195                {
196                    load(in.buf, code[in.childIndex].buf);
197                    for (int j = 1; j < in.narg; ++j)
198                    {
199                        mul(in.buf, code[in.childIndex + j].buf);
200                    }
201                    break;
202                }
203            case OpCodes::Div:
204                {
205                    if (in.narg == 1)
206                    {
207                        inv(in.buf, code[in.childIndex].buf);
208                        break;
209                    }
210                    else
211                    {
212                        load(in.buf, code[in.childIndex].buf);
213                        for (int j = 1; j < in.narg; ++j)
214                        {
215                            div(in.buf, code[in.childIndex + j].buf);
216                        }
217                    }
218                    break;
219                }
220            case OpCodes::Sin:
221                {
222                    sin(in.buf, code[in.childIndex].buf);
223                    break;
224                }
225            case OpCodes::Cos:
226                {
227                    cos(in.buf, code[in.childIndex].buf);
228                    break;
229                }
230            case OpCodes::Tan:
231                {
232                    tan(in.buf, code[in.childIndex].buf);
233                    break;
234                }
235            case OpCodes::Tanh:
236                {
237                    tanh(in.buf, code[in.childIndex].buf);
238                    break;
239                }
240            case OpCodes::Log:
241                {
242                    log(in.buf, code[in.childIndex].buf);
243                    break;
244                }
245            case OpCodes::Exp:
246                {
247                    exp(in.buf, code[in.childIndex].buf);
248                    break;
249                }
250            case OpCodes::Power:
251                {
252                    load(in.buf, code[in.childIndex].buf);
253                    pow(in.buf, code[in.childIndex + 1].buf);
254                    break;
255                }
256            case OpCodes::Root:
257                {
258                    load(in.buf, code[in.childIndex].buf);
259                    root(in.buf, code[in.childIndex + 1].buf);
260                    break;
261                }
262            case OpCodes::Square:
263                {
264                    pow(in.buf, code[in.childIndex].buf, 2.);
265                    break;
266                }
267            case OpCodes::Sqrt:
268                {
269                    pow(in.buf, code[in.childIndex].buf, 1./2.);
270                    break;
271                }
272            case OpCodes::CubeRoot:
273                {
274                    pow(in.buf, code[in.childIndex].buf, 1./3.);
275                    break;
276                }
277            case OpCodes::Cube:
278                {
279                    pow(in.buf, code[in.childIndex].buf, 3.);
280                    break;
281                }
282            case OpCodes::Absolute:
283                {
284                    abs(in.buf, code[in.childIndex].buf);
285                    break;
286                }
287            case OpCodes::AnalyticalQuotient:
288                {
289                    load(in.buf, code[in.childIndex].buf);
290                    analytical_quotient(in.buf, code[in.childIndex + 1].buf);
291                    break;
292                }
293            default: load(in.buf, NAN);
294            }
295    }
296}
297
298#endif
Note: See TracBrowser for help on using the repository browser.