[16269] | 1 | #ifndef NATIVE_TREE_INTERPRETER_CLANG_H
|
---|
| 2 | #define NATIVE_TREE_INTERPRETER_CLANG_H
|
---|
| 3 |
|
---|
[16274] | 4 | #include "vector_operations.h"
|
---|
[16269] | 5 | #include "instruction.h"
|
---|
| 6 |
|
---|
| 7 | inline double evaluate(instruction *code, int len, int row) noexcept
|
---|
| 8 | {
|
---|
| 9 | for (int i = len - 1; i >= 0; --i)
|
---|
| 10 | {
|
---|
| 11 | instruction &in = code[i];
|
---|
| 12 | switch (in.opcode)
|
---|
| 13 | {
|
---|
[16356] | 14 | case OpCodes::Const: /* nothing to do */ break;
|
---|
[16269] | 15 | case OpCodes::Var:
|
---|
| 16 | {
|
---|
| 17 | in.value = in.weight * in.data[row];
|
---|
| 18 | break;
|
---|
| 19 | }
|
---|
| 20 | case OpCodes::Add:
|
---|
| 21 | {
|
---|
| 22 | in.value = code[in.childIndex].value;
|
---|
| 23 | for (int j = 1; j < in.narg; ++j)
|
---|
| 24 | {
|
---|
| 25 | in.value += code[in.childIndex + j].value;
|
---|
| 26 | }
|
---|
| 27 | break;
|
---|
| 28 | }
|
---|
| 29 | case OpCodes::Sub:
|
---|
| 30 | {
|
---|
| 31 | in.value = code[in.childIndex].value;
|
---|
| 32 | for (int j = 1; j < in.narg; ++j)
|
---|
| 33 | {
|
---|
| 34 | in.value -= code[in.childIndex + j].value;
|
---|
| 35 | }
|
---|
| 36 | if (in.narg == 1)
|
---|
| 37 | {
|
---|
| 38 | in.value = -in.value;
|
---|
| 39 | }
|
---|
| 40 | break;
|
---|
| 41 | }
|
---|
| 42 | case OpCodes::Mul:
|
---|
| 43 | {
|
---|
| 44 | in.value = code[in.childIndex].value;
|
---|
| 45 | for (int j = 1; j < in.narg; ++j)
|
---|
| 46 | {
|
---|
| 47 | in.value *= code[in.childIndex + j].value;
|
---|
| 48 | }
|
---|
| 49 | break;
|
---|
| 50 | }
|
---|
| 51 | case OpCodes::Div:
|
---|
| 52 | {
|
---|
| 53 | in.value = code[in.childIndex].value;
|
---|
| 54 | for (int j = 1; j < in.narg; ++j)
|
---|
| 55 | {
|
---|
| 56 | in.value /= code[in.childIndex + j].value;
|
---|
| 57 | }
|
---|
| 58 | if (in.narg == 1)
|
---|
| 59 | {
|
---|
| 60 | in.value = 1 / in.value;
|
---|
| 61 | }
|
---|
| 62 | break;
|
---|
| 63 | }
|
---|
| 64 | case OpCodes::Exp:
|
---|
| 65 | {
|
---|
[16701] | 66 | in.value = hl_exp(code[in.childIndex].value);
|
---|
[16269] | 67 | break;
|
---|
| 68 | }
|
---|
| 69 | case OpCodes::Log:
|
---|
| 70 | {
|
---|
[16701] | 71 | in.value = hl_log(code[in.childIndex].value);
|
---|
[16269] | 72 | break;
|
---|
| 73 | }
|
---|
| 74 | case OpCodes::Sin:
|
---|
| 75 | {
|
---|
[16701] | 76 | in.value = hl_sin(code[in.childIndex].value);
|
---|
[16269] | 77 | break;
|
---|
| 78 | }
|
---|
| 79 | case OpCodes::Cos:
|
---|
| 80 | {
|
---|
[16701] | 81 | in.value = hl_cos(code[in.childIndex].value);
|
---|
[16269] | 82 | break;
|
---|
| 83 | }
|
---|
| 84 | case OpCodes::Tan:
|
---|
| 85 | {
|
---|
[16701] | 86 | in.value = hl_tan(code[in.childIndex].value);
|
---|
[16269] | 87 | break;
|
---|
| 88 | }
|
---|
[16701] | 89 | case OpCodes::Tanh:
|
---|
| 90 | {
|
---|
| 91 | in.value = hl_tanh(code[in.childIndex].value);
|
---|
| 92 | break;
|
---|
| 93 | }
|
---|
[16269] | 94 | case OpCodes::Power:
|
---|
| 95 | {
|
---|
| 96 | double x = code[in.childIndex].value;
|
---|
[16701] | 97 | double y = hl_round(code[in.childIndex + 1].value);
|
---|
| 98 | in.value = hl_pow(x, y);
|
---|
[16269] | 99 | break;
|
---|
| 100 | }
|
---|
| 101 | case OpCodes::Root:
|
---|
| 102 | {
|
---|
| 103 | double x = code[in.childIndex].value;
|
---|
[16701] | 104 | double y = hl_round(code[in.childIndex + 1].value);
|
---|
| 105 | in.value = hl_pow(x, 1 / y);
|
---|
[16269] | 106 | break;
|
---|
| 107 | }
|
---|
[16356] | 108 | case OpCodes::Sqrt:
|
---|
| 109 | {
|
---|
[16701] | 110 | in.value = hl_pow(code[in.childIndex].value, 1./2.);
|
---|
[16356] | 111 | break;
|
---|
| 112 | }
|
---|
[16334] | 113 | case OpCodes::Square:
|
---|
| 114 | {
|
---|
[16701] | 115 | in.value = hl_pow(code[in.childIndex].value, 2.);
|
---|
[16334] | 116 | break;
|
---|
| 117 | }
|
---|
[16356] | 118 | case OpCodes::CubeRoot:
|
---|
[16269] | 119 | {
|
---|
[16905] | 120 | in.value = hl_cbrt(code[in.childIndex].value);
|
---|
[16269] | 121 | break;
|
---|
| 122 | }
|
---|
[16356] | 123 | case OpCodes::Cube:
|
---|
| 124 | {
|
---|
[16701] | 125 | in.value = hl_pow(code[in.childIndex].value, 3.);
|
---|
[16356] | 126 | break;
|
---|
| 127 | }
|
---|
| 128 | case OpCodes::Absolute:
|
---|
| 129 | {
|
---|
| 130 | in.value = std::fabs(code[in.childIndex].value);
|
---|
| 131 | break;
|
---|
| 132 | }
|
---|
| 133 | case OpCodes::AnalyticalQuotient:
|
---|
| 134 | {
|
---|
| 135 | double x = code[in.childIndex].value;
|
---|
| 136 | double y = code[in.childIndex + 1].value;
|
---|
[16701] | 137 | in.value = x / hl_sqrt(1 + y*y);
|
---|
[16356] | 138 | break;
|
---|
| 139 | }
|
---|
| 140 | default: in.value = NAN;
|
---|
[16269] | 141 | }
|
---|
| 142 | }
|
---|
| 143 | return code[0].value;
|
---|
| 144 | }
|
---|
| 145 |
|
---|
| 146 | inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
|
---|
| 147 | {
|
---|
| 148 | for (int i = 0; i < batchSize; ++i)
|
---|
| 149 | {
|
---|
| 150 | auto row = rows[rowIndex + i];
|
---|
| 151 | in.buf[i] = in.weight * in.data[row];
|
---|
| 152 | }
|
---|
| 153 | }
|
---|
| 154 |
|
---|
| 155 | inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
|
---|
| 156 | {
|
---|
| 157 | for (int i = len - 1; i >= 0; --i)
|
---|
| 158 | {
|
---|
| 159 | instruction &in = code[i];
|
---|
| 160 | switch (in.opcode)
|
---|
| 161 | {
|
---|
| 162 | case OpCodes::Var:
|
---|
| 163 | {
|
---|
| 164 | load_data(in, rows, rowIndex, batchSize); // buffer data
|
---|
| 165 | break;
|
---|
| 166 | }
|
---|
[16356] | 167 | case OpCodes::Const: /* nothing to do because buffers for constants are already set */ break;
|
---|
[16269] | 168 | case OpCodes::Add:
|
---|
| 169 | {
|
---|
| 170 | load(in.buf, code[in.childIndex].buf);
|
---|
| 171 | for (int j = 1; j < in.narg; ++j)
|
---|
| 172 | {
|
---|
| 173 | add(in.buf, code[in.childIndex + j].buf);
|
---|
| 174 | }
|
---|
| 175 | break;
|
---|
| 176 | }
|
---|
| 177 | case OpCodes::Sub:
|
---|
| 178 | {
|
---|
[16274] | 179 | if (in.narg == 1)
|
---|
[16269] | 180 | {
|
---|
[16274] | 181 | neg(in.buf, code[in.childIndex].buf);
|
---|
| 182 | break;
|
---|
[16269] | 183 | }
|
---|
[16274] | 184 | else
|
---|
[16269] | 185 | {
|
---|
[16274] | 186 | load(in.buf, code[in.childIndex].buf);
|
---|
| 187 | for (int j = 1; j < in.narg; ++j)
|
---|
| 188 | {
|
---|
| 189 | sub(in.buf, code[in.childIndex + j].buf);
|
---|
| 190 | }
|
---|
[16269] | 191 | }
|
---|
| 192 | break;
|
---|
| 193 | }
|
---|
| 194 | case OpCodes::Mul:
|
---|
| 195 | {
|
---|
| 196 | load(in.buf, code[in.childIndex].buf);
|
---|
| 197 | for (int j = 1; j < in.narg; ++j)
|
---|
| 198 | {
|
---|
| 199 | mul(in.buf, code[in.childIndex + j].buf);
|
---|
| 200 | }
|
---|
| 201 | break;
|
---|
| 202 | }
|
---|
| 203 | case OpCodes::Div:
|
---|
| 204 | {
|
---|
[16274] | 205 | if (in.narg == 1)
|
---|
[16269] | 206 | {
|
---|
[16274] | 207 | inv(in.buf, code[in.childIndex].buf);
|
---|
| 208 | break;
|
---|
[16269] | 209 | }
|
---|
[16274] | 210 | else
|
---|
[16269] | 211 | {
|
---|
[16274] | 212 | load(in.buf, code[in.childIndex].buf);
|
---|
| 213 | for (int j = 1; j < in.narg; ++j)
|
---|
| 214 | {
|
---|
| 215 | div(in.buf, code[in.childIndex + j].buf);
|
---|
| 216 | }
|
---|
[16269] | 217 | }
|
---|
| 218 | break;
|
---|
| 219 | }
|
---|
| 220 | case OpCodes::Sin:
|
---|
| 221 | {
|
---|
| 222 | sin(in.buf, code[in.childIndex].buf);
|
---|
| 223 | break;
|
---|
| 224 | }
|
---|
| 225 | case OpCodes::Cos:
|
---|
| 226 | {
|
---|
| 227 | cos(in.buf, code[in.childIndex].buf);
|
---|
| 228 | break;
|
---|
| 229 | }
|
---|
| 230 | case OpCodes::Tan:
|
---|
| 231 | {
|
---|
| 232 | tan(in.buf, code[in.childIndex].buf);
|
---|
| 233 | break;
|
---|
| 234 | }
|
---|
[16701] | 235 | case OpCodes::Tanh:
|
---|
| 236 | {
|
---|
| 237 | tanh(in.buf, code[in.childIndex].buf);
|
---|
| 238 | break;
|
---|
| 239 | }
|
---|
[16269] | 240 | case OpCodes::Log:
|
---|
| 241 | {
|
---|
| 242 | log(in.buf, code[in.childIndex].buf);
|
---|
| 243 | break;
|
---|
| 244 | }
|
---|
| 245 | case OpCodes::Exp:
|
---|
| 246 | {
|
---|
| 247 | exp(in.buf, code[in.childIndex].buf);
|
---|
| 248 | break;
|
---|
| 249 | }
|
---|
| 250 | case OpCodes::Power:
|
---|
| 251 | {
|
---|
| 252 | load(in.buf, code[in.childIndex].buf);
|
---|
| 253 | pow(in.buf, code[in.childIndex + 1].buf);
|
---|
| 254 | break;
|
---|
| 255 | }
|
---|
| 256 | case OpCodes::Root:
|
---|
| 257 | {
|
---|
| 258 | load(in.buf, code[in.childIndex].buf);
|
---|
| 259 | root(in.buf, code[in.childIndex + 1].buf);
|
---|
| 260 | break;
|
---|
| 261 | }
|
---|
| 262 | case OpCodes::Square:
|
---|
| 263 | {
|
---|
[16356] | 264 | pow(in.buf, code[in.childIndex].buf, 2.);
|
---|
[16269] | 265 | break;
|
---|
| 266 | }
|
---|
[16334] | 267 | case OpCodes::Sqrt:
|
---|
| 268 | {
|
---|
[16356] | 269 | pow(in.buf, code[in.childIndex].buf, 1./2.);
|
---|
[16334] | 270 | break;
|
---|
| 271 | }
|
---|
[16356] | 272 | case OpCodes::CubeRoot:
|
---|
| 273 | {
|
---|
[16905] | 274 | cbrt(in.buf, code[in.childIndex].buf);
|
---|
[16356] | 275 | break;
|
---|
| 276 | }
|
---|
| 277 | case OpCodes::Cube:
|
---|
| 278 | {
|
---|
| 279 | pow(in.buf, code[in.childIndex].buf, 3.);
|
---|
| 280 | break;
|
---|
| 281 | }
|
---|
| 282 | case OpCodes::Absolute:
|
---|
| 283 | {
|
---|
| 284 | abs(in.buf, code[in.childIndex].buf);
|
---|
| 285 | break;
|
---|
| 286 | }
|
---|
| 287 | case OpCodes::AnalyticalQuotient:
|
---|
| 288 | {
|
---|
| 289 | load(in.buf, code[in.childIndex].buf);
|
---|
| 290 | analytical_quotient(in.buf, code[in.childIndex + 1].buf);
|
---|
| 291 | break;
|
---|
| 292 | }
|
---|
| 293 | default: load(in.buf, NAN);
|
---|
| 294 | }
|
---|
[16269] | 295 | }
|
---|
| 296 | }
|
---|
| 297 |
|
---|
| 298 | #endif
|
---|