Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.ExtLibs/HeuristicLab.NativeInterpreter/0.1/NativeInterpreter-0.1/src/interpreter.h @ 17071

Last change on this file since 17071 was 17071, checked in by mkommend, 5 years ago

#2958: Merged 16266, 16269, 16274, 16276, 16277, 16285, 16286, 16287, 16289, 16293, 16296, 16297, 16298, 16333, 16334 into stable.

File size: 7.8 KB
Line 
1#ifndef NATIVE_TREE_INTERPRETER_CLANG_H
2#define NATIVE_TREE_INTERPRETER_CLANG_H
3
4#include "vector_operations.h"
5#include "instruction.h"
6
7inline double evaluate(instruction *code, int len, int row) noexcept
8{
9    for (int i = len - 1; i >= 0; --i)
10    {
11        instruction &in = code[i];
12        switch (in.opcode)
13        {
14            case OpCodes::Var:
15                {
16                    in.value = in.weight * in.data[row];
17                    break;
18                }
19            case OpCodes::Add:
20                {
21                    in.value = code[in.childIndex].value;
22                    for (int j = 1; j < in.narg; ++j)
23                    {
24                        in.value += code[in.childIndex + j].value;
25                    }
26                    break;
27                }
28            case OpCodes::Sub:
29                {
30                    in.value = code[in.childIndex].value;
31                    for (int j = 1; j < in.narg; ++j)
32                    {
33                        in.value -= code[in.childIndex + j].value;
34                    }
35                    if (in.narg == 1)
36                    {
37                        in.value = -in.value;
38                    }
39                    break;
40                }
41            case OpCodes::Mul:
42                {
43                    in.value = code[in.childIndex].value;
44                    for (int j = 1; j < in.narg; ++j)
45                    {
46                        in.value *= code[in.childIndex + j].value;
47                    }
48                    break;
49                }
50            case OpCodes::Div:
51                {
52                    in.value = code[in.childIndex].value;
53                    for (int j = 1; j < in.narg; ++j)
54                    {
55                        in.value /= code[in.childIndex + j].value;
56                    }
57                    if (in.narg == 1)
58                    {
59                        in.value = 1 / in.value;
60                    }
61                    break;
62                }
63            case OpCodes::Exp:
64                {
65                    in.value = std::exp(code[in.childIndex].value);
66                    break;
67                }
68            case OpCodes::Log:
69                {
70                    in.value = std::log(code[in.childIndex].value);
71                    break;
72                }
73            case OpCodes::Sin:
74                {
75                    in.value = std::sin(code[in.childIndex].value);
76                    break;
77                }
78            case OpCodes::Cos:
79                {
80                    in.value = std::cos(code[in.childIndex].value);
81                    break;
82                }
83            case OpCodes::Tan:
84                {
85                    in.value = std::tan(code[in.childIndex].value);
86                    break;
87                }
88            case OpCodes::Power:
89                {
90                    double x = code[in.childIndex].value;
91                    double y = std::round(code[in.childIndex + 1].value);
92                    in.value = std::pow(x, y);
93                    break;
94                }
95            case OpCodes::Root:
96                {
97                    double x = code[in.childIndex].value;
98                    double y = std::round(code[in.childIndex + 1].value);
99                    in.value = std::pow(x, 1 / y);
100                    break;
101                }
102            case OpCodes::Square:
103                {
104                    in.value = std::pow(code[in.childIndex].value, 2.);
105                    break;
106                }
107            case OpCodes::Sqrt:
108                {
109                    in.value = std::sqrt(code[in.childIndex].value);
110                    break;
111                }
112        }
113    }
114    return code[0].value;
115}
116
117inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
118{
119    for (int i = 0; i < batchSize; ++i)
120    {
121        auto row = rows[rowIndex + i];
122        in.buf[i] = in.weight * in.data[row];
123    }
124}
125
126inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
127{
128    for (int i = len - 1; i >= 0; --i)
129    {
130        instruction &in = code[i];
131        switch (in.opcode)
132        {
133            case OpCodes::Var:
134                {
135                    load_data(in, rows, rowIndex, batchSize); // buffer data
136                    break;
137                }
138            case OpCodes::Add:
139                {
140                    load(in.buf, code[in.childIndex].buf);
141                    for (int j = 1; j < in.narg; ++j)
142                    {
143                        add(in.buf, code[in.childIndex + j].buf);
144                    }
145                    break;
146                }
147            case OpCodes::Sub:
148                {
149                    if (in.narg == 1)
150                    {
151                        neg(in.buf, code[in.childIndex].buf);
152                        break;
153                    }
154                    else
155                    {
156                        load(in.buf, code[in.childIndex].buf);
157                        for (int j = 1; j < in.narg; ++j)
158                        {
159                            sub(in.buf, code[in.childIndex + j].buf);
160                        }
161                    }
162                    break;
163                }
164            case OpCodes::Mul:
165                {
166                    load(in.buf, code[in.childIndex].buf);
167                    for (int j = 1; j < in.narg; ++j)
168                    {
169                        mul(in.buf, code[in.childIndex + j].buf);
170                    }
171                    break;
172                }
173            case OpCodes::Div:
174                {
175                    if (in.narg == 1)
176                    {
177                        inv(in.buf, code[in.childIndex].buf);
178                        break;
179                    }
180                    else
181                    {
182                        load(in.buf, code[in.childIndex].buf);
183                        for (int j = 1; j < in.narg; ++j)
184                        {
185                            div(in.buf, code[in.childIndex + j].buf);
186                        }
187                    }
188                    break;
189                }
190            case OpCodes::Sin:
191                {
192                    sin(in.buf, code[in.childIndex].buf);
193                    break;
194                }
195            case OpCodes::Cos:
196                {
197                    cos(in.buf, code[in.childIndex].buf);
198                    break;
199                }
200            case OpCodes::Tan:
201                {
202                    tan(in.buf, code[in.childIndex].buf);
203                    break;
204                }
205            case OpCodes::Log:
206                {
207                    log(in.buf, code[in.childIndex].buf);
208                    break;
209                }
210            case OpCodes::Exp:
211                {
212                    exp(in.buf, code[in.childIndex].buf);
213                    break;
214                }
215            case OpCodes::Power:
216                {
217                    load(in.buf, code[in.childIndex].buf);
218                    pow(in.buf, code[in.childIndex + 1].buf);
219                    break;
220                }
221            case OpCodes::Root:
222                {
223                    load(in.buf, code[in.childIndex].buf);
224                    root(in.buf, code[in.childIndex + 1].buf);
225                    break;
226                }
227            case OpCodes::Square:
228                {
229                    square(in.buf, code[in.childIndex].buf);
230                    break;
231                }
232            case OpCodes::Sqrt:
233                {
234                    sqrt(in.buf, code[in.childIndex].buf);
235                    break;
236                }
237        }
238    }
239}
240
241#endif
Note: See TracBrowser for help on using the repository browser.