Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.ExtLibs/HeuristicLab.NativeInterpreter/0.1/NativeInterpreter-0.1/src/interpreter.h @ 17800

Last change on this file since 17800 was 16905, checked in by gkronber, 6 years ago

#2915: Corrected calculation of cuberoot function in all interpreters and updated all formatters.
I tested a run with all interpreters and got the same results with all of them.

File size: 9.8 KB
RevLine 
[16269]1#ifndef NATIVE_TREE_INTERPRETER_CLANG_H
2#define NATIVE_TREE_INTERPRETER_CLANG_H
3
[16274]4#include "vector_operations.h"
[16269]5#include "instruction.h"
6
7inline double evaluate(instruction *code, int len, int row) noexcept
8{
9    for (int i = len - 1; i >= 0; --i)
10    {
11        instruction &in = code[i];
12        switch (in.opcode)
13        {
[16356]14            case OpCodes::Const: /* nothing to do */ break;
[16269]15            case OpCodes::Var:
16                {
17                    in.value = in.weight * in.data[row];
18                    break;
19                }
20            case OpCodes::Add:
21                {
22                    in.value = code[in.childIndex].value;
23                    for (int j = 1; j < in.narg; ++j)
24                    {
25                        in.value += code[in.childIndex + j].value;
26                    }
27                    break;
28                }
29            case OpCodes::Sub:
30                {
31                    in.value = code[in.childIndex].value;
32                    for (int j = 1; j < in.narg; ++j)
33                    {
34                        in.value -= code[in.childIndex + j].value;
35                    }
36                    if (in.narg == 1)
37                    {
38                        in.value = -in.value;
39                    }
40                    break;
41                }
42            case OpCodes::Mul:
43                {
44                    in.value = code[in.childIndex].value;
45                    for (int j = 1; j < in.narg; ++j)
46                    {
47                        in.value *= code[in.childIndex + j].value;
48                    }
49                    break;
50                }
51            case OpCodes::Div:
52                {
53                    in.value = code[in.childIndex].value;
54                    for (int j = 1; j < in.narg; ++j)
55                    {
56                        in.value /= code[in.childIndex + j].value;
57                    }
58                    if (in.narg == 1)
59                    {
60                        in.value = 1 / in.value;
61                    }
62                    break;
63                }
64            case OpCodes::Exp:
65                {
[16701]66                    in.value = hl_exp(code[in.childIndex].value);
[16269]67                    break;
68                }
69            case OpCodes::Log:
70                {
[16701]71                    in.value = hl_log(code[in.childIndex].value);
[16269]72                    break;
73                }
74            case OpCodes::Sin:
75                {
[16701]76                    in.value = hl_sin(code[in.childIndex].value);
[16269]77                    break;
78                }
79            case OpCodes::Cos:
80                {
[16701]81                    in.value = hl_cos(code[in.childIndex].value);
[16269]82                    break;
83                }
84            case OpCodes::Tan:
85                {
[16701]86                    in.value = hl_tan(code[in.childIndex].value);
[16269]87                    break;
88                }
[16701]89            case OpCodes::Tanh:
90                {
91                    in.value = hl_tanh(code[in.childIndex].value);
92                    break;
93                }
[16269]94            case OpCodes::Power:
95                {
96                    double x = code[in.childIndex].value;
[16701]97                    double y = hl_round(code[in.childIndex + 1].value);
98                    in.value = hl_pow(x, y);
[16269]99                    break;
100                }
101            case OpCodes::Root:
102                {
103                    double x = code[in.childIndex].value;
[16701]104                    double y = hl_round(code[in.childIndex + 1].value);
105                    in.value = hl_pow(x, 1 / y);
[16269]106                    break;
107                }
[16356]108            case OpCodes::Sqrt:
109                {
[16701]110                    in.value = hl_pow(code[in.childIndex].value, 1./2.);
[16356]111                    break;
112                }
[16334]113            case OpCodes::Square:
114                {
[16701]115                    in.value = hl_pow(code[in.childIndex].value, 2.);
[16334]116                    break;
117                }
[16356]118            case OpCodes::CubeRoot:
[16269]119                {
[16905]120                    in.value = hl_cbrt(code[in.childIndex].value);
[16269]121                    break;
122                }
[16356]123            case OpCodes::Cube:
124                {
[16701]125                    in.value = hl_pow(code[in.childIndex].value, 3.);
[16356]126                    break;
127                }
128            case OpCodes::Absolute:
129                {
130                    in.value = std::fabs(code[in.childIndex].value);
131                    break;
132                }
133            case OpCodes::AnalyticalQuotient:
134                {
135                    double x = code[in.childIndex].value;
136                    double y = code[in.childIndex + 1].value;
[16701]137                    in.value = x / hl_sqrt(1 + y*y);
[16356]138                    break;
139                }
140            default: in.value = NAN;
[16269]141        }
142    }
143    return code[0].value;
144}
145
146inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
147{
148    for (int i = 0; i < batchSize; ++i)
149    {
150        auto row = rows[rowIndex + i];
151        in.buf[i] = in.weight * in.data[row];
152    }
153}
154
155inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
156{
157    for (int i = len - 1; i >= 0; --i)
158    {
159        instruction &in = code[i];
160        switch (in.opcode)
161        {
162            case OpCodes::Var:
163                {
164                    load_data(in, rows, rowIndex, batchSize); // buffer data
165                    break;
166                }
[16356]167            case OpCodes::Const: /* nothing to do because buffers for constants are already set */ break;
[16269]168            case OpCodes::Add:
169                {
170                    load(in.buf, code[in.childIndex].buf);
171                    for (int j = 1; j < in.narg; ++j)
172                    {
173                        add(in.buf, code[in.childIndex + j].buf);
174                    }
175                    break;
176                }
177            case OpCodes::Sub:
178                {
[16274]179                    if (in.narg == 1)
[16269]180                    {
[16274]181                        neg(in.buf, code[in.childIndex].buf);
182                        break;
[16269]183                    }
[16274]184                    else
[16269]185                    {
[16274]186                        load(in.buf, code[in.childIndex].buf);
187                        for (int j = 1; j < in.narg; ++j)
188                        {
189                            sub(in.buf, code[in.childIndex + j].buf);
190                        }
[16269]191                    }
192                    break;
193                }
194            case OpCodes::Mul:
195                {
196                    load(in.buf, code[in.childIndex].buf);
197                    for (int j = 1; j < in.narg; ++j)
198                    {
199                        mul(in.buf, code[in.childIndex + j].buf);
200                    }
201                    break;
202                }
203            case OpCodes::Div:
204                {
[16274]205                    if (in.narg == 1)
[16269]206                    {
[16274]207                        inv(in.buf, code[in.childIndex].buf);
208                        break;
[16269]209                    }
[16274]210                    else
[16269]211                    {
[16274]212                        load(in.buf, code[in.childIndex].buf);
213                        for (int j = 1; j < in.narg; ++j)
214                        {
215                            div(in.buf, code[in.childIndex + j].buf);
216                        }
[16269]217                    }
218                    break;
219                }
220            case OpCodes::Sin:
221                {
222                    sin(in.buf, code[in.childIndex].buf);
223                    break;
224                }
225            case OpCodes::Cos:
226                {
227                    cos(in.buf, code[in.childIndex].buf);
228                    break;
229                }
230            case OpCodes::Tan:
231                {
232                    tan(in.buf, code[in.childIndex].buf);
233                    break;
234                }
[16701]235            case OpCodes::Tanh:
236                {
237                    tanh(in.buf, code[in.childIndex].buf);
238                    break;
239                }
[16269]240            case OpCodes::Log:
241                {
242                    log(in.buf, code[in.childIndex].buf);
243                    break;
244                }
245            case OpCodes::Exp:
246                {
247                    exp(in.buf, code[in.childIndex].buf);
248                    break;
249                }
250            case OpCodes::Power:
251                {
252                    load(in.buf, code[in.childIndex].buf);
253                    pow(in.buf, code[in.childIndex + 1].buf);
254                    break;
255                }
256            case OpCodes::Root:
257                {
258                    load(in.buf, code[in.childIndex].buf);
259                    root(in.buf, code[in.childIndex + 1].buf);
260                    break;
261                }
262            case OpCodes::Square:
263                {
[16356]264                    pow(in.buf, code[in.childIndex].buf, 2.);
[16269]265                    break;
266                }
[16334]267            case OpCodes::Sqrt:
268                {
[16356]269                    pow(in.buf, code[in.childIndex].buf, 1./2.);
[16334]270                    break;
271                }
[16356]272            case OpCodes::CubeRoot:
273                {
[16905]274                    cbrt(in.buf, code[in.childIndex].buf);
[16356]275                    break;
276                }
277            case OpCodes::Cube:
278                {
279                    pow(in.buf, code[in.childIndex].buf, 3.);
280                    break;
281                }
282            case OpCodes::Absolute:
283                {
284                    abs(in.buf, code[in.childIndex].buf);
285                    break;
286                }
287            case OpCodes::AnalyticalQuotient:
288                {
289                    load(in.buf, code[in.childIndex].buf);
290                    analytical_quotient(in.buf, code[in.childIndex + 1].buf);
291                    break;
292                }
293            default: load(in.buf, NAN);
294            }
[16269]295    }
296}
297
298#endif
Note: See TracBrowser for help on using the repository browser.