[16269] | 1 | #ifndef NATIVE_TREE_INTERPRETER_CLANG_H
|
---|
| 2 | #define NATIVE_TREE_INTERPRETER_CLANG_H
|
---|
| 3 |
|
---|
[16274] | 4 | #include "vector_operations.h"
|
---|
[16269] | 5 | #include "instruction.h"
|
---|
| 6 |
|
---|
| 7 | inline double evaluate(instruction *code, int len, int row) noexcept
|
---|
| 8 | {
|
---|
| 9 | for (int i = len - 1; i >= 0; --i)
|
---|
| 10 | {
|
---|
| 11 | instruction &in = code[i];
|
---|
| 12 | switch (in.opcode)
|
---|
| 13 | {
|
---|
[16356] | 14 | case OpCodes::Const: /* nothing to do */ break;
|
---|
[16269] | 15 | case OpCodes::Var:
|
---|
| 16 | {
|
---|
| 17 | in.value = in.weight * in.data[row];
|
---|
| 18 | break;
|
---|
| 19 | }
|
---|
| 20 | case OpCodes::Add:
|
---|
| 21 | {
|
---|
| 22 | in.value = code[in.childIndex].value;
|
---|
| 23 | for (int j = 1; j < in.narg; ++j)
|
---|
| 24 | {
|
---|
| 25 | in.value += code[in.childIndex + j].value;
|
---|
| 26 | }
|
---|
| 27 | break;
|
---|
| 28 | }
|
---|
| 29 | case OpCodes::Sub:
|
---|
| 30 | {
|
---|
| 31 | in.value = code[in.childIndex].value;
|
---|
| 32 | for (int j = 1; j < in.narg; ++j)
|
---|
| 33 | {
|
---|
| 34 | in.value -= code[in.childIndex + j].value;
|
---|
| 35 | }
|
---|
| 36 | if (in.narg == 1)
|
---|
| 37 | {
|
---|
| 38 | in.value = -in.value;
|
---|
| 39 | }
|
---|
| 40 | break;
|
---|
| 41 | }
|
---|
| 42 | case OpCodes::Mul:
|
---|
| 43 | {
|
---|
| 44 | in.value = code[in.childIndex].value;
|
---|
| 45 | for (int j = 1; j < in.narg; ++j)
|
---|
| 46 | {
|
---|
| 47 | in.value *= code[in.childIndex + j].value;
|
---|
| 48 | }
|
---|
| 49 | break;
|
---|
| 50 | }
|
---|
| 51 | case OpCodes::Div:
|
---|
| 52 | {
|
---|
| 53 | in.value = code[in.childIndex].value;
|
---|
| 54 | for (int j = 1; j < in.narg; ++j)
|
---|
| 55 | {
|
---|
| 56 | in.value /= code[in.childIndex + j].value;
|
---|
| 57 | }
|
---|
| 58 | if (in.narg == 1)
|
---|
| 59 | {
|
---|
| 60 | in.value = 1 / in.value;
|
---|
| 61 | }
|
---|
| 62 | break;
|
---|
| 63 | }
|
---|
| 64 | case OpCodes::Exp:
|
---|
| 65 | {
|
---|
| 66 | in.value = std::exp(code[in.childIndex].value);
|
---|
| 67 | break;
|
---|
| 68 | }
|
---|
| 69 | case OpCodes::Log:
|
---|
| 70 | {
|
---|
| 71 | in.value = std::log(code[in.childIndex].value);
|
---|
| 72 | break;
|
---|
| 73 | }
|
---|
| 74 | case OpCodes::Sin:
|
---|
| 75 | {
|
---|
| 76 | in.value = std::sin(code[in.childIndex].value);
|
---|
| 77 | break;
|
---|
| 78 | }
|
---|
| 79 | case OpCodes::Cos:
|
---|
| 80 | {
|
---|
| 81 | in.value = std::cos(code[in.childIndex].value);
|
---|
| 82 | break;
|
---|
| 83 | }
|
---|
| 84 | case OpCodes::Tan:
|
---|
| 85 | {
|
---|
| 86 | in.value = std::tan(code[in.childIndex].value);
|
---|
| 87 | break;
|
---|
| 88 | }
|
---|
| 89 | case OpCodes::Power:
|
---|
| 90 | {
|
---|
| 91 | double x = code[in.childIndex].value;
|
---|
| 92 | double y = std::round(code[in.childIndex + 1].value);
|
---|
| 93 | in.value = std::pow(x, y);
|
---|
| 94 | break;
|
---|
| 95 | }
|
---|
| 96 | case OpCodes::Root:
|
---|
| 97 | {
|
---|
| 98 | double x = code[in.childIndex].value;
|
---|
| 99 | double y = std::round(code[in.childIndex + 1].value);
|
---|
| 100 | in.value = std::pow(x, 1 / y);
|
---|
| 101 | break;
|
---|
| 102 | }
|
---|
[16356] | 103 | case OpCodes::Sqrt:
|
---|
| 104 | {
|
---|
| 105 | in.value = std::pow(code[in.childIndex].value, 1./2.);
|
---|
| 106 | break;
|
---|
| 107 | }
|
---|
[16334] | 108 | case OpCodes::Square:
|
---|
| 109 | {
|
---|
| 110 | in.value = std::pow(code[in.childIndex].value, 2.);
|
---|
| 111 | break;
|
---|
| 112 | }
|
---|
[16356] | 113 | case OpCodes::CubeRoot:
|
---|
[16269] | 114 | {
|
---|
[16356] | 115 | in.value = std::pow(code[in.childIndex].value, 1./3.);
|
---|
[16269] | 116 | break;
|
---|
| 117 | }
|
---|
[16356] | 118 | case OpCodes::Cube:
|
---|
| 119 | {
|
---|
| 120 | in.value = std::pow(code[in.childIndex].value, 3.);
|
---|
| 121 | break;
|
---|
| 122 | }
|
---|
| 123 | case OpCodes::Absolute:
|
---|
| 124 | {
|
---|
| 125 | in.value = std::fabs(code[in.childIndex].value);
|
---|
| 126 | break;
|
---|
| 127 | }
|
---|
| 128 | case OpCodes::AnalyticalQuotient:
|
---|
| 129 | {
|
---|
| 130 | double x = code[in.childIndex].value;
|
---|
| 131 | double y = code[in.childIndex + 1].value;
|
---|
| 132 | in.value = x / std::sqrt(1 + y*y);
|
---|
| 133 | break;
|
---|
| 134 | }
|
---|
| 135 | default: in.value = NAN;
|
---|
[16269] | 136 | }
|
---|
| 137 | }
|
---|
| 138 | return code[0].value;
|
---|
| 139 | }
|
---|
| 140 |
|
---|
| 141 | inline void load_data(instruction &in, int* __restrict rows, int rowIndex, int batchSize) noexcept
|
---|
| 142 | {
|
---|
| 143 | for (int i = 0; i < batchSize; ++i)
|
---|
| 144 | {
|
---|
| 145 | auto row = rows[rowIndex + i];
|
---|
| 146 | in.buf[i] = in.weight * in.data[row];
|
---|
| 147 | }
|
---|
| 148 | }
|
---|
| 149 |
|
---|
| 150 | inline void evaluate(instruction* code, int len, int* __restrict rows, int rowIndex, int batchSize) noexcept
|
---|
| 151 | {
|
---|
| 152 | for (int i = len - 1; i >= 0; --i)
|
---|
| 153 | {
|
---|
| 154 | instruction &in = code[i];
|
---|
| 155 | switch (in.opcode)
|
---|
| 156 | {
|
---|
| 157 | case OpCodes::Var:
|
---|
| 158 | {
|
---|
| 159 | load_data(in, rows, rowIndex, batchSize); // buffer data
|
---|
| 160 | break;
|
---|
| 161 | }
|
---|
[16356] | 162 | case OpCodes::Const: /* nothing to do because buffers for constants are already set */ break;
|
---|
[16269] | 163 | case OpCodes::Add:
|
---|
| 164 | {
|
---|
| 165 | load(in.buf, code[in.childIndex].buf);
|
---|
| 166 | for (int j = 1; j < in.narg; ++j)
|
---|
| 167 | {
|
---|
| 168 | add(in.buf, code[in.childIndex + j].buf);
|
---|
| 169 | }
|
---|
| 170 | break;
|
---|
| 171 | }
|
---|
| 172 | case OpCodes::Sub:
|
---|
| 173 | {
|
---|
[16274] | 174 | if (in.narg == 1)
|
---|
[16269] | 175 | {
|
---|
[16274] | 176 | neg(in.buf, code[in.childIndex].buf);
|
---|
| 177 | break;
|
---|
[16269] | 178 | }
|
---|
[16274] | 179 | else
|
---|
[16269] | 180 | {
|
---|
[16274] | 181 | load(in.buf, code[in.childIndex].buf);
|
---|
| 182 | for (int j = 1; j < in.narg; ++j)
|
---|
| 183 | {
|
---|
| 184 | sub(in.buf, code[in.childIndex + j].buf);
|
---|
| 185 | }
|
---|
[16269] | 186 | }
|
---|
| 187 | break;
|
---|
| 188 | }
|
---|
| 189 | case OpCodes::Mul:
|
---|
| 190 | {
|
---|
| 191 | load(in.buf, code[in.childIndex].buf);
|
---|
| 192 | for (int j = 1; j < in.narg; ++j)
|
---|
| 193 | {
|
---|
| 194 | mul(in.buf, code[in.childIndex + j].buf);
|
---|
| 195 | }
|
---|
| 196 | break;
|
---|
| 197 | }
|
---|
| 198 | case OpCodes::Div:
|
---|
| 199 | {
|
---|
[16274] | 200 | if (in.narg == 1)
|
---|
[16269] | 201 | {
|
---|
[16274] | 202 | inv(in.buf, code[in.childIndex].buf);
|
---|
| 203 | break;
|
---|
[16269] | 204 | }
|
---|
[16274] | 205 | else
|
---|
[16269] | 206 | {
|
---|
[16274] | 207 | load(in.buf, code[in.childIndex].buf);
|
---|
| 208 | for (int j = 1; j < in.narg; ++j)
|
---|
| 209 | {
|
---|
| 210 | div(in.buf, code[in.childIndex + j].buf);
|
---|
| 211 | }
|
---|
[16269] | 212 | }
|
---|
| 213 | break;
|
---|
| 214 | }
|
---|
| 215 | case OpCodes::Sin:
|
---|
| 216 | {
|
---|
| 217 | sin(in.buf, code[in.childIndex].buf);
|
---|
| 218 | break;
|
---|
| 219 | }
|
---|
| 220 | case OpCodes::Cos:
|
---|
| 221 | {
|
---|
| 222 | cos(in.buf, code[in.childIndex].buf);
|
---|
| 223 | break;
|
---|
| 224 | }
|
---|
| 225 | case OpCodes::Tan:
|
---|
| 226 | {
|
---|
| 227 | tan(in.buf, code[in.childIndex].buf);
|
---|
| 228 | break;
|
---|
| 229 | }
|
---|
| 230 | case OpCodes::Log:
|
---|
| 231 | {
|
---|
| 232 | log(in.buf, code[in.childIndex].buf);
|
---|
| 233 | break;
|
---|
| 234 | }
|
---|
| 235 | case OpCodes::Exp:
|
---|
| 236 | {
|
---|
| 237 | exp(in.buf, code[in.childIndex].buf);
|
---|
| 238 | break;
|
---|
| 239 | }
|
---|
| 240 | case OpCodes::Power:
|
---|
| 241 | {
|
---|
| 242 | load(in.buf, code[in.childIndex].buf);
|
---|
| 243 | pow(in.buf, code[in.childIndex + 1].buf);
|
---|
| 244 | break;
|
---|
| 245 | }
|
---|
| 246 | case OpCodes::Root:
|
---|
| 247 | {
|
---|
| 248 | load(in.buf, code[in.childIndex].buf);
|
---|
| 249 | root(in.buf, code[in.childIndex + 1].buf);
|
---|
| 250 | break;
|
---|
| 251 | }
|
---|
| 252 | case OpCodes::Square:
|
---|
| 253 | {
|
---|
[16356] | 254 | pow(in.buf, code[in.childIndex].buf, 2.);
|
---|
[16269] | 255 | break;
|
---|
| 256 | }
|
---|
[16334] | 257 | case OpCodes::Sqrt:
|
---|
| 258 | {
|
---|
[16356] | 259 | pow(in.buf, code[in.childIndex].buf, 1./2.);
|
---|
[16334] | 260 | break;
|
---|
| 261 | }
|
---|
[16356] | 262 | case OpCodes::CubeRoot:
|
---|
| 263 | {
|
---|
| 264 | pow(in.buf, code[in.childIndex].buf, 1./3.);
|
---|
| 265 | break;
|
---|
| 266 | }
|
---|
| 267 | case OpCodes::Cube:
|
---|
| 268 | {
|
---|
| 269 | pow(in.buf, code[in.childIndex].buf, 3.);
|
---|
| 270 | break;
|
---|
| 271 | }
|
---|
| 272 | case OpCodes::Absolute:
|
---|
| 273 | {
|
---|
| 274 | abs(in.buf, code[in.childIndex].buf);
|
---|
| 275 | break;
|
---|
| 276 | }
|
---|
| 277 | case OpCodes::AnalyticalQuotient:
|
---|
| 278 | {
|
---|
| 279 | load(in.buf, code[in.childIndex].buf);
|
---|
| 280 | analytical_quotient(in.buf, code[in.childIndex + 1].buf);
|
---|
| 281 | break;
|
---|
| 282 | }
|
---|
| 283 | default: load(in.buf, NAN);
|
---|
| 284 | }
|
---|
[16269] | 285 | }
|
---|
| 286 | }
|
---|
| 287 |
|
---|
| 288 | #endif
|
---|