Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2915-AbsoluteSymbol/HeuristicLab.ExtLibs/HeuristicLab.NativeInterpreter/0.1/NativeInterpreter-0.1/src/interpreter.h @ 16350

Last change on this file since 16350 was 16350, checked in by gkronber, 5 years ago

#2915: merged r16333:16343 from trunk to branch (resolving a conflict in interpreter.h - the branch for sqrt)

File size: 9.7 KB
RevLine 
[16269]1#ifndef NATIVE_TREE_INTERPRETER_CLANG_H
2#define NATIVE_TREE_INTERPRETER_CLANG_H
3
[16274]4#include "vector_operations.h"
[16269]5#include "instruction.h"
6
7inline double evaluate(instruction *code, int len, int row) noexcept
8{
9    for (int i = len - 1; i >= 0; --i)
10    {
11        instruction &in = code[i];
12        switch (in.opcode)
13        {
[16349]14            case OpCodes::Const: /* nothing to do */ break;
[16269]15            case OpCodes::Var:
16                {
17                    in.value = in.weight * in.data[row];
18                    break;
19                }
20            case OpCodes::Add:
21                {
22                    in.value = code[in.childIndex].value;
23                    for (int j = 1; j < in.narg; ++j)
24                    {
25                        in.value += code[in.childIndex + j].value;
26                    }
27                    break;
28                }
29            case OpCodes::Sub:
30                {
31                    in.value = code[in.childIndex].value;
32                    for (int j = 1; j < in.narg; ++j)
33                    {
34                        in.value -= code[in.childIndex + j].value;
35                    }
36                    if (in.narg == 1)
37                    {
38                        in.value = -in.value;
39                    }
40                    break;
41                }
42            case OpCodes::Mul:
43                {
44                    in.value = code[in.childIndex].value;
45                    for (int j = 1; j < in.narg; ++j)
46                    {
47                        in.value *= code[in.childIndex + j].value;
48                    }
49                    break;
50                }
51            case OpCodes::Div:
52                {
53                    in.value = code[in.childIndex].value;
54                    for (int j = 1; j < in.narg; ++j)
55                    {
56                        in.value /= code[in.childIndex + j].value;
57                    }
58                    if (in.narg == 1)
59                    {
60                        in.value = 1 / in.value;
61                    }
62                    break;
63                }
64            case OpCodes::Exp:
65                {
66                    in.value = std::exp(code[in.childIndex].value);
67                    break;
68                }
69            case OpCodes::Log:
70                {
71                    in.value = std::log(code[in.childIndex].value);
72                    break;
73                }
74            case OpCodes::Sin:
75                {
76                    in.value = std::sin(code[in.childIndex].value);
77                    break;
78                }
79            case OpCodes::Cos:
80                {
81                    in.value = std::cos(code[in.childIndex].value);
82                    break;
83                }
84            case OpCodes::Tan:
85                {
86                    in.value = std::tan(code[in.childIndex].value);
87                    break;
88                }
89            case OpCodes::Power:
90                {
91                    double x = code[in.childIndex].value;
92                    double y = std::round(code[in.childIndex + 1].value);
93                    in.value = std::pow(x, y);
94                    break;
95                }
96            case OpCodes::Root:
97                {
98                    double x = code[in.childIndex].value;
99                    double y = std::round(code[in.childIndex + 1].value);
100                    in.value = std::pow(x, 1 / y);
101                    break;
102                }
[16350]103            case OpCodes::Square:
104                {
105                    in.value = std::pow(code[in.childIndex].value, 2.);
106                    break;
107                }
[16269]108            case OpCodes::Sqrt:
109                {
110                    in.value = std::sqrt(code[in.childIndex].value);
111                    break;
112                }
[16349]113            case OpCodes::Square:
114                {
115                    in.value = std::pow(code[in.childIndex].value, 2.);
116                    break;
117                }
118            case OpCodes::CubeRoot:
119                {
120                    in.value = std::pow(code[in.childIndex].value, 1./3.);
121                    break;
122                }
123            case OpCodes::Cube:
124                {
125                    in.value = std::pow(code[in.childIndex].value, 3.);
126                    break;
127                }
128            case OpCodes::Absolute:
129                {
130                    in.value = std::abs(code[in.childIndex].value);
131                    break;
132                }
133            case OpCodes::AnalyticalQuotient:
134                {
135                    double x = code[in.childIndex].value;
136                    double y = code[in.childIndex + 1].value;
137                    in.value = x / std::sqrt(1 + y*y);
138                    break;
139                }
140            default: in.value = NAN;
[16269]141        }
142    }
143    return code[0].value;
144}
145
146inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
147{
148    for (int i = 0; i < batchSize; ++i)
149    {
150        auto row = rows[rowIndex + i];
151        in.buf[i] = in.weight * in.data[row];
152    }
153}
154
155inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
156{
157    for (int i = len - 1; i >= 0; --i)
158    {
159        instruction &in = code[i];
160        switch (in.opcode)
161        {
162            case OpCodes::Var:
163                {
164                    load_data(in, rows, rowIndex, batchSize); // buffer data
165                    break;
166                }
[16349]167            case OpCodes::Const: /* nothing to do because buffers for constants are already set */ break;
[16269]168            case OpCodes::Add:
169                {
170                    load(in.buf, code[in.childIndex].buf);
171                    for (int j = 1; j < in.narg; ++j)
172                    {
173                        add(in.buf, code[in.childIndex + j].buf);
174                    }
175                    break;
176                }
177            case OpCodes::Sub:
178                {
[16274]179                    if (in.narg == 1)
[16269]180                    {
[16274]181                        neg(in.buf, code[in.childIndex].buf);
182                        break;
[16269]183                    }
[16274]184                    else
[16269]185                    {
[16274]186                        load(in.buf, code[in.childIndex].buf);
187                        for (int j = 1; j < in.narg; ++j)
188                        {
189                            sub(in.buf, code[in.childIndex + j].buf);
190                        }
[16269]191                    }
192                    break;
193                }
194            case OpCodes::Mul:
195                {
196                    load(in.buf, code[in.childIndex].buf);
197                    for (int j = 1; j < in.narg; ++j)
198                    {
199                        mul(in.buf, code[in.childIndex + j].buf);
200                    }
201                    break;
202                }
203            case OpCodes::Div:
204                {
[16274]205                    if (in.narg == 1)
[16269]206                    {
[16274]207                        inv(in.buf, code[in.childIndex].buf);
208                        break;
[16269]209                    }
[16274]210                    else
[16269]211                    {
[16274]212                        load(in.buf, code[in.childIndex].buf);
213                        for (int j = 1; j < in.narg; ++j)
214                        {
215                            div(in.buf, code[in.childIndex + j].buf);
216                        }
[16269]217                    }
218                    break;
219                }
220            case OpCodes::Sin:
221                {
222                    sin(in.buf, code[in.childIndex].buf);
223                    break;
224                }
225            case OpCodes::Cos:
226                {
227                    cos(in.buf, code[in.childIndex].buf);
228                    break;
229                }
230            case OpCodes::Tan:
231                {
232                    tan(in.buf, code[in.childIndex].buf);
233                    break;
234                }
235            case OpCodes::Log:
236                {
237                    log(in.buf, code[in.childIndex].buf);
238                    break;
239                }
240            case OpCodes::Exp:
241                {
242                    exp(in.buf, code[in.childIndex].buf);
243                    break;
244                }
245            case OpCodes::Power:
246                {
247                    load(in.buf, code[in.childIndex].buf);
248                    pow(in.buf, code[in.childIndex + 1].buf);
249                    break;
250                }
251            case OpCodes::Root:
252                {
253                    load(in.buf, code[in.childIndex].buf);
254                    root(in.buf, code[in.childIndex + 1].buf);
255                    break;
256                }
257            case OpCodes::Square:
258                {
[16349]259                    pow(in.buf, code[in.childIndex].buf, 2.);
[16269]260                    break;
261                }
[16349]262            case OpCodes::Sqrt:
263                {
[16350]264                    sqrt(in.buf, code[in.childIndex].buf);
[16349]265                    break;
266                }
267            case OpCodes::CubeRoot:
268                {
269                    pow(in.buf, code[in.childIndex].buf, 1./3.);
270                    break;
271                }
272            case OpCodes::Cube:
273                {
274                    pow(in.buf, code[in.childIndex].buf, 3.);
275                    break;
276                }
277            case OpCodes::Absolute:
278                {
279                    abs(in.buf, code[in.childIndex].buf);
280                    break;
281                }
282            case OpCodes::AnalyticalQuotient:
283                {
284                    load(in.buf, code[in.childIndex].buf);
285                    analytical_quotient(in.buf, code[in.childIndex + 1].buf);
286                    break;
287                }
288            default: load(in.buf, NAN);
289            }
[16269]290    }
291}
292
293#endif
Note: See TracBrowser for help on using the repository browser.