Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.ExtLibs/HeuristicLab.NativeInterpreter/0.1/NativeInterpreter-0.1/src/interpreter.h @ 16683

Last change on this file since 16683 was 16356, checked in by gkronber, 6 years ago

#2915: merged all changes from branch to trunk

File size: 9.5 KB
RevLine 
[16269]1#ifndef NATIVE_TREE_INTERPRETER_CLANG_H
2#define NATIVE_TREE_INTERPRETER_CLANG_H
3
[16274]4#include "vector_operations.h"
[16269]5#include "instruction.h"
6
7inline double evaluate(instruction *code, int len, int row) noexcept
8{
9    for (int i = len - 1; i >= 0; --i)
10    {
11        instruction &in = code[i];
12        switch (in.opcode)
13        {
[16356]14            case OpCodes::Const: /* nothing to do */ break;
[16269]15            case OpCodes::Var:
16                {
17                    in.value = in.weight * in.data[row];
18                    break;
19                }
20            case OpCodes::Add:
21                {
22                    in.value = code[in.childIndex].value;
23                    for (int j = 1; j < in.narg; ++j)
24                    {
25                        in.value += code[in.childIndex + j].value;
26                    }
27                    break;
28                }
29            case OpCodes::Sub:
30                {
31                    in.value = code[in.childIndex].value;
32                    for (int j = 1; j < in.narg; ++j)
33                    {
34                        in.value -= code[in.childIndex + j].value;
35                    }
36                    if (in.narg == 1)
37                    {
38                        in.value = -in.value;
39                    }
40                    break;
41                }
42            case OpCodes::Mul:
43                {
44                    in.value = code[in.childIndex].value;
45                    for (int j = 1; j < in.narg; ++j)
46                    {
47                        in.value *= code[in.childIndex + j].value;
48                    }
49                    break;
50                }
51            case OpCodes::Div:
52                {
53                    in.value = code[in.childIndex].value;
54                    for (int j = 1; j < in.narg; ++j)
55                    {
56                        in.value /= code[in.childIndex + j].value;
57                    }
58                    if (in.narg == 1)
59                    {
60                        in.value = 1 / in.value;
61                    }
62                    break;
63                }
64            case OpCodes::Exp:
65                {
66                    in.value = std::exp(code[in.childIndex].value);
67                    break;
68                }
69            case OpCodes::Log:
70                {
71                    in.value = std::log(code[in.childIndex].value);
72                    break;
73                }
74            case OpCodes::Sin:
75                {
76                    in.value = std::sin(code[in.childIndex].value);
77                    break;
78                }
79            case OpCodes::Cos:
80                {
81                    in.value = std::cos(code[in.childIndex].value);
82                    break;
83                }
84            case OpCodes::Tan:
85                {
86                    in.value = std::tan(code[in.childIndex].value);
87                    break;
88                }
89            case OpCodes::Power:
90                {
91                    double x = code[in.childIndex].value;
92                    double y = std::round(code[in.childIndex + 1].value);
93                    in.value = std::pow(x, y);
94                    break;
95                }
96            case OpCodes::Root:
97                {
98                    double x = code[in.childIndex].value;
99                    double y = std::round(code[in.childIndex + 1].value);
100                    in.value = std::pow(x, 1 / y);
101                    break;
102                }
[16356]103            case OpCodes::Sqrt:
104                {
105                    in.value = std::pow(code[in.childIndex].value, 1./2.);
106                    break;
107                }
[16334]108            case OpCodes::Square:
109                {
110                    in.value = std::pow(code[in.childIndex].value, 2.);
111                    break;
112                }
[16356]113            case OpCodes::CubeRoot:
[16269]114                {
[16356]115                    in.value = std::pow(code[in.childIndex].value, 1./3.);
[16269]116                    break;
117                }
[16356]118            case OpCodes::Cube:
119                {
120                    in.value = std::pow(code[in.childIndex].value, 3.);
121                    break;
122                }
123            case OpCodes::Absolute:
124                {
125                    in.value = std::fabs(code[in.childIndex].value);
126                    break;
127                }
128            case OpCodes::AnalyticalQuotient:
129                {
130                    double x = code[in.childIndex].value;
131                    double y = code[in.childIndex + 1].value;
132                    in.value = x / std::sqrt(1 + y*y);
133                    break;
134                }
135            default: in.value = NAN;
[16269]136        }
137    }
138    return code[0].value;
139}
140
141inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
142{
143    for (int i = 0; i < batchSize; ++i)
144    {
145        auto row = rows[rowIndex + i];
146        in.buf[i] = in.weight * in.data[row];
147    }
148}
149
150inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
151{
152    for (int i = len - 1; i >= 0; --i)
153    {
154        instruction &in = code[i];
155        switch (in.opcode)
156        {
157            case OpCodes::Var:
158                {
159                    load_data(in, rows, rowIndex, batchSize); // buffer data
160                    break;
161                }
[16356]162            case OpCodes::Const: /* nothing to do because buffers for constants are already set */ break;
[16269]163            case OpCodes::Add:
164                {
165                    load(in.buf, code[in.childIndex].buf);
166                    for (int j = 1; j < in.narg; ++j)
167                    {
168                        add(in.buf, code[in.childIndex + j].buf);
169                    }
170                    break;
171                }
172            case OpCodes::Sub:
173                {
[16274]174                    if (in.narg == 1)
[16269]175                    {
[16274]176                        neg(in.buf, code[in.childIndex].buf);
177                        break;
[16269]178                    }
[16274]179                    else
[16269]180                    {
[16274]181                        load(in.buf, code[in.childIndex].buf);
182                        for (int j = 1; j < in.narg; ++j)
183                        {
184                            sub(in.buf, code[in.childIndex + j].buf);
185                        }
[16269]186                    }
187                    break;
188                }
189            case OpCodes::Mul:
190                {
191                    load(in.buf, code[in.childIndex].buf);
192                    for (int j = 1; j < in.narg; ++j)
193                    {
194                        mul(in.buf, code[in.childIndex + j].buf);
195                    }
196                    break;
197                }
198            case OpCodes::Div:
199                {
[16274]200                    if (in.narg == 1)
[16269]201                    {
[16274]202                        inv(in.buf, code[in.childIndex].buf);
203                        break;
[16269]204                    }
[16274]205                    else
[16269]206                    {
[16274]207                        load(in.buf, code[in.childIndex].buf);
208                        for (int j = 1; j < in.narg; ++j)
209                        {
210                            div(in.buf, code[in.childIndex + j].buf);
211                        }
[16269]212                    }
213                    break;
214                }
215            case OpCodes::Sin:
216                {
217                    sin(in.buf, code[in.childIndex].buf);
218                    break;
219                }
220            case OpCodes::Cos:
221                {
222                    cos(in.buf, code[in.childIndex].buf);
223                    break;
224                }
225            case OpCodes::Tan:
226                {
227                    tan(in.buf, code[in.childIndex].buf);
228                    break;
229                }
230            case OpCodes::Log:
231                {
232                    log(in.buf, code[in.childIndex].buf);
233                    break;
234                }
235            case OpCodes::Exp:
236                {
237                    exp(in.buf, code[in.childIndex].buf);
238                    break;
239                }
240            case OpCodes::Power:
241                {
242                    load(in.buf, code[in.childIndex].buf);
243                    pow(in.buf, code[in.childIndex + 1].buf);
244                    break;
245                }
246            case OpCodes::Root:
247                {
248                    load(in.buf, code[in.childIndex].buf);
249                    root(in.buf, code[in.childIndex + 1].buf);
250                    break;
251                }
252            case OpCodes::Square:
253                {
[16356]254                    pow(in.buf, code[in.childIndex].buf, 2.);
[16269]255                    break;
256                }
[16334]257            case OpCodes::Sqrt:
258                {
[16356]259                    pow(in.buf, code[in.childIndex].buf, 1./2.);
[16334]260                    break;
261                }
[16356]262            case OpCodes::CubeRoot:
263                {
264                    pow(in.buf, code[in.childIndex].buf, 1./3.);
265                    break;
266                }
267            case OpCodes::Cube:
268                {
269                    pow(in.buf, code[in.childIndex].buf, 3.);
270                    break;
271                }
272            case OpCodes::Absolute:
273                {
274                    abs(in.buf, code[in.childIndex].buf);
275                    break;
276                }
277            case OpCodes::AnalyticalQuotient:
278                {
279                    load(in.buf, code[in.childIndex].buf);
280                    analytical_quotient(in.buf, code[in.childIndex + 1].buf);
281                    break;
282                }
283            default: load(in.buf, NAN);
284            }
[16269]285    }
286}
287
288#endif
Note: See TracBrowser for help on using the repository browser.