[16269] | 1 | #ifndef NATIVE_TREE_INTERPRETER_CLANG_H
|
---|
| 2 | #define NATIVE_TREE_INTERPRETER_CLANG_H
|
---|
| 3 |
|
---|
[16274] | 4 | #include "vector_operations.h"
|
---|
[16269] | 5 | #include "instruction.h"
|
---|
| 6 |
|
---|
| 7 | inline double evaluate(instruction *code, int len, int row) noexcept
|
---|
| 8 | {
|
---|
| 9 | for (int i = len - 1; i >= 0; --i)
|
---|
| 10 | {
|
---|
| 11 | instruction &in = code[i];
|
---|
| 12 | switch (in.opcode)
|
---|
| 13 | {
|
---|
[18160] | 14 | case OpCodes::Number: /* nothing to do */ break;
|
---|
| 15 | case OpCodes::Constant: /* nothing to do */ break;
|
---|
[16269] | 16 | case OpCodes::Var:
|
---|
| 17 | {
|
---|
| 18 | in.value = in.weight * in.data[row];
|
---|
| 19 | break;
|
---|
| 20 | }
|
---|
| 21 | case OpCodes::Add:
|
---|
| 22 | {
|
---|
| 23 | in.value = code[in.childIndex].value;
|
---|
| 24 | for (int j = 1; j < in.narg; ++j)
|
---|
| 25 | {
|
---|
| 26 | in.value += code[in.childIndex + j].value;
|
---|
| 27 | }
|
---|
| 28 | break;
|
---|
| 29 | }
|
---|
| 30 | case OpCodes::Sub:
|
---|
| 31 | {
|
---|
| 32 | in.value = code[in.childIndex].value;
|
---|
| 33 | for (int j = 1; j < in.narg; ++j)
|
---|
| 34 | {
|
---|
| 35 | in.value -= code[in.childIndex + j].value;
|
---|
| 36 | }
|
---|
| 37 | if (in.narg == 1)
|
---|
| 38 | {
|
---|
| 39 | in.value = -in.value;
|
---|
| 40 | }
|
---|
| 41 | break;
|
---|
| 42 | }
|
---|
| 43 | case OpCodes::Mul:
|
---|
| 44 | {
|
---|
| 45 | in.value = code[in.childIndex].value;
|
---|
| 46 | for (int j = 1; j < in.narg; ++j)
|
---|
| 47 | {
|
---|
| 48 | in.value *= code[in.childIndex + j].value;
|
---|
| 49 | }
|
---|
| 50 | break;
|
---|
| 51 | }
|
---|
| 52 | case OpCodes::Div:
|
---|
| 53 | {
|
---|
| 54 | in.value = code[in.childIndex].value;
|
---|
| 55 | for (int j = 1; j < in.narg; ++j)
|
---|
| 56 | {
|
---|
| 57 | in.value /= code[in.childIndex + j].value;
|
---|
| 58 | }
|
---|
| 59 | if (in.narg == 1)
|
---|
| 60 | {
|
---|
| 61 | in.value = 1 / in.value;
|
---|
| 62 | }
|
---|
| 63 | break;
|
---|
| 64 | }
|
---|
| 65 | case OpCodes::Exp:
|
---|
| 66 | {
|
---|
[16701] | 67 | in.value = hl_exp(code[in.childIndex].value);
|
---|
[16269] | 68 | break;
|
---|
| 69 | }
|
---|
| 70 | case OpCodes::Log:
|
---|
| 71 | {
|
---|
[16701] | 72 | in.value = hl_log(code[in.childIndex].value);
|
---|
[16269] | 73 | break;
|
---|
| 74 | }
|
---|
| 75 | case OpCodes::Sin:
|
---|
| 76 | {
|
---|
[16701] | 77 | in.value = hl_sin(code[in.childIndex].value);
|
---|
[16269] | 78 | break;
|
---|
| 79 | }
|
---|
| 80 | case OpCodes::Cos:
|
---|
| 81 | {
|
---|
[16701] | 82 | in.value = hl_cos(code[in.childIndex].value);
|
---|
[16269] | 83 | break;
|
---|
| 84 | }
|
---|
| 85 | case OpCodes::Tan:
|
---|
| 86 | {
|
---|
[16701] | 87 | in.value = hl_tan(code[in.childIndex].value);
|
---|
[16269] | 88 | break;
|
---|
| 89 | }
|
---|
[16701] | 90 | case OpCodes::Tanh:
|
---|
| 91 | {
|
---|
| 92 | in.value = hl_tanh(code[in.childIndex].value);
|
---|
| 93 | break;
|
---|
| 94 | }
|
---|
[16269] | 95 | case OpCodes::Power:
|
---|
| 96 | {
|
---|
| 97 | double x = code[in.childIndex].value;
|
---|
[16701] | 98 | double y = hl_round(code[in.childIndex + 1].value);
|
---|
| 99 | in.value = hl_pow(x, y);
|
---|
[16269] | 100 | break;
|
---|
| 101 | }
|
---|
| 102 | case OpCodes::Root:
|
---|
| 103 | {
|
---|
| 104 | double x = code[in.childIndex].value;
|
---|
[16701] | 105 | double y = hl_round(code[in.childIndex + 1].value);
|
---|
| 106 | in.value = hl_pow(x, 1 / y);
|
---|
[16269] | 107 | break;
|
---|
| 108 | }
|
---|
[16356] | 109 | case OpCodes::Sqrt:
|
---|
| 110 | {
|
---|
[16701] | 111 | in.value = hl_pow(code[in.childIndex].value, 1./2.);
|
---|
[16356] | 112 | break;
|
---|
| 113 | }
|
---|
[16334] | 114 | case OpCodes::Square:
|
---|
| 115 | {
|
---|
[16701] | 116 | in.value = hl_pow(code[in.childIndex].value, 2.);
|
---|
[16334] | 117 | break;
|
---|
| 118 | }
|
---|
[16356] | 119 | case OpCodes::CubeRoot:
|
---|
[16269] | 120 | {
|
---|
[16905] | 121 | in.value = hl_cbrt(code[in.childIndex].value);
|
---|
[16269] | 122 | break;
|
---|
| 123 | }
|
---|
[16356] | 124 | case OpCodes::Cube:
|
---|
| 125 | {
|
---|
[16701] | 126 | in.value = hl_pow(code[in.childIndex].value, 3.);
|
---|
[16356] | 127 | break;
|
---|
| 128 | }
|
---|
| 129 | case OpCodes::Absolute:
|
---|
| 130 | {
|
---|
| 131 | in.value = std::fabs(code[in.childIndex].value);
|
---|
| 132 | break;
|
---|
| 133 | }
|
---|
| 134 | case OpCodes::AnalyticalQuotient:
|
---|
| 135 | {
|
---|
| 136 | double x = code[in.childIndex].value;
|
---|
| 137 | double y = code[in.childIndex + 1].value;
|
---|
[16701] | 138 | in.value = x / hl_sqrt(1 + y*y);
|
---|
[16356] | 139 | break;
|
---|
| 140 | }
|
---|
[18220] | 141 | case OpCodes::SubFunction:
|
---|
| 142 | {
|
---|
| 143 | in.value = code[in.childIndex].value;
|
---|
| 144 | break;
|
---|
| 145 | }
|
---|
[16356] | 146 | default: in.value = NAN;
|
---|
[16269] | 147 | }
|
---|
| 148 | }
|
---|
| 149 | return code[0].value;
|
---|
| 150 | }
|
---|
| 151 |
|
---|
| 152 | inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
|
---|
| 153 | {
|
---|
| 154 | for (int i = 0; i < batchSize; ++i)
|
---|
| 155 | {
|
---|
| 156 | auto row = rows[rowIndex + i];
|
---|
| 157 | in.buf[i] = in.weight * in.data[row];
|
---|
| 158 | }
|
---|
| 159 | }
|
---|
| 160 |
|
---|
| 161 | inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
|
---|
| 162 | {
|
---|
| 163 | for (int i = len - 1; i >= 0; --i)
|
---|
| 164 | {
|
---|
| 165 | instruction &in = code[i];
|
---|
| 166 | switch (in.opcode)
|
---|
| 167 | {
|
---|
| 168 | case OpCodes::Var:
|
---|
| 169 | {
|
---|
| 170 | load_data(in, rows, rowIndex, batchSize); // buffer data
|
---|
| 171 | break;
|
---|
| 172 | }
|
---|
[18160] | 173 | case OpCodes::Number: /* nothing to do because buffers for numbers are already set */ break;
|
---|
| 174 | case OpCodes::Constant: /* nothing to do because buffers for constants are already set */ break;
|
---|
[16269] | 175 | case OpCodes::Add:
|
---|
| 176 | {
|
---|
| 177 | load(in.buf, code[in.childIndex].buf);
|
---|
| 178 | for (int j = 1; j < in.narg; ++j)
|
---|
| 179 | {
|
---|
| 180 | add(in.buf, code[in.childIndex + j].buf);
|
---|
| 181 | }
|
---|
| 182 | break;
|
---|
| 183 | }
|
---|
| 184 | case OpCodes::Sub:
|
---|
| 185 | {
|
---|
[16274] | 186 | if (in.narg == 1)
|
---|
[16269] | 187 | {
|
---|
[16274] | 188 | neg(in.buf, code[in.childIndex].buf);
|
---|
| 189 | break;
|
---|
[16269] | 190 | }
|
---|
[16274] | 191 | else
|
---|
[16269] | 192 | {
|
---|
[16274] | 193 | load(in.buf, code[in.childIndex].buf);
|
---|
| 194 | for (int j = 1; j < in.narg; ++j)
|
---|
| 195 | {
|
---|
| 196 | sub(in.buf, code[in.childIndex + j].buf);
|
---|
| 197 | }
|
---|
[16269] | 198 | }
|
---|
| 199 | break;
|
---|
| 200 | }
|
---|
| 201 | case OpCodes::Mul:
|
---|
| 202 | {
|
---|
| 203 | load(in.buf, code[in.childIndex].buf);
|
---|
| 204 | for (int j = 1; j < in.narg; ++j)
|
---|
| 205 | {
|
---|
| 206 | mul(in.buf, code[in.childIndex + j].buf);
|
---|
| 207 | }
|
---|
| 208 | break;
|
---|
| 209 | }
|
---|
| 210 | case OpCodes::Div:
|
---|
| 211 | {
|
---|
[16274] | 212 | if (in.narg == 1)
|
---|
[16269] | 213 | {
|
---|
[16274] | 214 | inv(in.buf, code[in.childIndex].buf);
|
---|
| 215 | break;
|
---|
[16269] | 216 | }
|
---|
[16274] | 217 | else
|
---|
[16269] | 218 | {
|
---|
[16274] | 219 | load(in.buf, code[in.childIndex].buf);
|
---|
| 220 | for (int j = 1; j < in.narg; ++j)
|
---|
| 221 | {
|
---|
| 222 | div(in.buf, code[in.childIndex + j].buf);
|
---|
| 223 | }
|
---|
[16269] | 224 | }
|
---|
| 225 | break;
|
---|
| 226 | }
|
---|
| 227 | case OpCodes::Sin:
|
---|
| 228 | {
|
---|
| 229 | sin(in.buf, code[in.childIndex].buf);
|
---|
| 230 | break;
|
---|
| 231 | }
|
---|
| 232 | case OpCodes::Cos:
|
---|
| 233 | {
|
---|
| 234 | cos(in.buf, code[in.childIndex].buf);
|
---|
| 235 | break;
|
---|
| 236 | }
|
---|
| 237 | case OpCodes::Tan:
|
---|
| 238 | {
|
---|
| 239 | tan(in.buf, code[in.childIndex].buf);
|
---|
| 240 | break;
|
---|
| 241 | }
|
---|
[16701] | 242 | case OpCodes::Tanh:
|
---|
| 243 | {
|
---|
| 244 | tanh(in.buf, code[in.childIndex].buf);
|
---|
| 245 | break;
|
---|
| 246 | }
|
---|
[16269] | 247 | case OpCodes::Log:
|
---|
| 248 | {
|
---|
| 249 | log(in.buf, code[in.childIndex].buf);
|
---|
| 250 | break;
|
---|
| 251 | }
|
---|
| 252 | case OpCodes::Exp:
|
---|
| 253 | {
|
---|
| 254 | exp(in.buf, code[in.childIndex].buf);
|
---|
| 255 | break;
|
---|
| 256 | }
|
---|
| 257 | case OpCodes::Power:
|
---|
| 258 | {
|
---|
| 259 | load(in.buf, code[in.childIndex].buf);
|
---|
| 260 | pow(in.buf, code[in.childIndex + 1].buf);
|
---|
| 261 | break;
|
---|
| 262 | }
|
---|
| 263 | case OpCodes::Root:
|
---|
| 264 | {
|
---|
| 265 | load(in.buf, code[in.childIndex].buf);
|
---|
| 266 | root(in.buf, code[in.childIndex + 1].buf);
|
---|
| 267 | break;
|
---|
| 268 | }
|
---|
| 269 | case OpCodes::Square:
|
---|
| 270 | {
|
---|
[16356] | 271 | pow(in.buf, code[in.childIndex].buf, 2.);
|
---|
[16269] | 272 | break;
|
---|
| 273 | }
|
---|
[16334] | 274 | case OpCodes::Sqrt:
|
---|
| 275 | {
|
---|
[16356] | 276 | pow(in.buf, code[in.childIndex].buf, 1./2.);
|
---|
[16334] | 277 | break;
|
---|
| 278 | }
|
---|
[16356] | 279 | case OpCodes::CubeRoot:
|
---|
| 280 | {
|
---|
[16905] | 281 | cbrt(in.buf, code[in.childIndex].buf);
|
---|
[16356] | 282 | break;
|
---|
| 283 | }
|
---|
| 284 | case OpCodes::Cube:
|
---|
| 285 | {
|
---|
| 286 | pow(in.buf, code[in.childIndex].buf, 3.);
|
---|
| 287 | break;
|
---|
| 288 | }
|
---|
| 289 | case OpCodes::Absolute:
|
---|
| 290 | {
|
---|
| 291 | abs(in.buf, code[in.childIndex].buf);
|
---|
| 292 | break;
|
---|
| 293 | }
|
---|
| 294 | case OpCodes::AnalyticalQuotient:
|
---|
| 295 | {
|
---|
| 296 | load(in.buf, code[in.childIndex].buf);
|
---|
| 297 | analytical_quotient(in.buf, code[in.childIndex + 1].buf);
|
---|
| 298 | break;
|
---|
| 299 | }
|
---|
[18220] | 300 | case OpCodes::SubFunction:
|
---|
| 301 | {
|
---|
| 302 | load(in.buf, code[in.childIndex].buf);
|
---|
| 303 | break;
|
---|
| 304 | }
|
---|
| 305 |
|
---|
[16356] | 306 | default: load(in.buf, NAN);
|
---|
[18220] | 307 | }
|
---|
[16269] | 308 | }
|
---|
| 309 | }
|
---|
| 310 |
|
---|
| 311 | #endif
|
---|