[16269] | 1 | #ifndef VECTOR_OPERATIONS_H
|
---|
| 2 | #define VECTOR_OPERATIONS_H
|
---|
| 3 |
|
---|
[16274] | 4 | #define _USE_MATH_DEFINES
|
---|
| 5 | #include <cmath>
|
---|
[16269] | 6 | #include <cstring>
|
---|
| 7 |
|
---|
| 8 | #ifdef USE_VDT
|
---|
| 9 | #include "vdt/vdtMath.h"
|
---|
| 10 | #include "vdt/stdwrap.h"
|
---|
| 11 | #define hl_exp vdt::fast_exp
|
---|
| 12 | #define hl_log vdt::fast_log
|
---|
| 13 | #define hl_sin vdt::fast_sin
|
---|
| 14 | #define hl_cos vdt::fast_cos
|
---|
| 15 | #define hl_tan vdt::fast_tan
|
---|
[16892] | 16 | #define hl_tanh vdt::fast_tanh
|
---|
[16269] | 17 | #define hl_sqrt vdt::fast_sqrt
|
---|
| 18 | #define hl_pow vdt::fast_pow
|
---|
[16988] | 19 | #define hl_cbrt std::cbrt
|
---|
[16269] | 20 | #define hl_round vdt::fast_round
|
---|
[16892] | 21 | #define hl_inv vdt::fast_inv
|
---|
[16269] | 22 | #else
|
---|
| 23 | #define hl_exp std::exp
|
---|
| 24 | #define hl_log std::log
|
---|
| 25 | #define hl_sin std::sin
|
---|
| 26 | #define hl_cos std::cos
|
---|
| 27 | #define hl_tan std::tan
|
---|
[16892] | 28 | #define hl_tanh std::tanh
|
---|
[16269] | 29 | #define hl_sqrt std::sqrt
|
---|
| 30 | #define hl_pow std::pow
|
---|
[16988] | 31 | #define hl_cbrt std::cbrt
|
---|
[16269] | 32 | #define hl_round std::round
|
---|
[16892] | 33 | #define hl_inv(x) 1. / x;
|
---|
| 34 |
|
---|
[16269] | 35 | #endif
|
---|
| 36 |
|
---|
[16274] | 37 | constexpr int BATCHSIZE = 64;
|
---|
[16269] | 38 |
|
---|
[16274] | 39 | #define FOR(i) for(int i = 0; i < BATCHSIZE; ++i)
|
---|
[16269] | 40 |
|
---|
| 41 | // When auto-vectorizing without __restrict,
|
---|
| 42 | // gcc and clang check for overlap (with a bunch of integer code)
|
---|
| 43 | // before running the vectorized loop
|
---|
| 44 |
|
---|
| 45 | // vector - vector operations
|
---|
[16274] | 46 | inline void load(double* __restrict a, double const * __restrict b) noexcept { std::memcpy(a, b, BATCHSIZE * sizeof(double)); }
|
---|
| 47 | inline void add(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] += b[i]; }
|
---|
| 48 | inline void sub(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] -= b[i]; }
|
---|
| 49 | inline void mul(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] *= b[i]; }
|
---|
| 50 | inline void div(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] /= b[i]; }
|
---|
| 51 | inline void exp(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_exp(b[i]); }
|
---|
| 52 | inline void log(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_log(b[i]); }
|
---|
| 53 | inline void sin(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_sin(b[i]); }
|
---|
| 54 | inline void cos(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_cos(b[i]); }
|
---|
| 55 | inline void tan(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_tan(b[i]); }
|
---|
[16892] | 56 | inline void tanh(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_tanh(b[i]); }
|
---|
[16274] | 57 | inline void sqrt(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_sqrt(b[i]); }
|
---|
| 58 | inline void pow(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_pow(a[i], hl_round(b[i])); };
|
---|
| 59 | inline void root(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_pow(a[i], 1. / hl_round(b[i])); };
|
---|
[16988] | 60 | inline void cbrt(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_cbrt(b[i]); };
|
---|
[16274] | 61 | inline void square(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_pow(b[i], 2.); };
|
---|
[16892] | 62 | inline void inv(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = hl_inv(b[i]); }
|
---|
[16274] | 63 | inline void neg(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = -b[i]; }
|
---|
[16356] | 64 | inline void abs(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] = std::fabs(b[i]); }
|
---|
| 65 | inline void analytical_quotient(double* __restrict a, double const * __restrict b) noexcept { FOR(i) a[i] /= hl_sqrt(b[i]*b[i] + 1.); }
|
---|
[16269] | 66 |
|
---|
| 67 | // vector - scalar operations
|
---|
| 68 | inline void load(double* __restrict a, double s) noexcept { FOR(i) a[i] = s; }
|
---|
| 69 | inline void add(double* __restrict a, double s) noexcept { FOR(i) a[i] += s; }
|
---|
| 70 | inline void sub(double* __restrict a, double s) noexcept { FOR(i) a[i] -= s; }
|
---|
| 71 | inline void mul(double* __restrict a, double s) noexcept { FOR(i) a[i] *= s; }
|
---|
| 72 | inline void div(double* __restrict a, double s) noexcept { FOR(i) a[i] /= s; }
|
---|
[16356] | 73 | inline void pow(double* __restrict dst, double const * __restrict src, double s) noexcept { FOR(i) dst[i] = hl_pow(src[i], s); }
|
---|
[16269] | 74 |
|
---|
| 75 | // vector operations
|
---|
| 76 | inline void neg(double* __restrict a) noexcept { FOR(i) a[i] = -a[i]; }
|
---|
[16892] | 77 | inline void inv(double* __restrict a) noexcept { FOR(i) a[i] = hl_inv(a[i]); }
|
---|
[16269] | 78 | inline void exp(double* __restrict a) noexcept { FOR(i) a[i] = hl_exp(a[i]); }
|
---|
| 79 | inline void log(double* __restrict a) noexcept { FOR(i) a[i] = hl_log(a[i]); }
|
---|
| 80 | inline void sin(double* __restrict a) noexcept { FOR(i) a[i] = hl_sin(a[i]); }
|
---|
| 81 | inline void cos(double* __restrict a) noexcept { FOR(i) a[i] = hl_cos(a[i]); }
|
---|
[16892] | 82 | inline void tan(double* __restrict a) noexcept { FOR(i) a[i] = hl_tan(a[i]); }
|
---|
| 83 | inline void tanh(double* __restrict a) noexcept { FOR(i) a[i] = hl_tanh(a[i]); }
|
---|
[16269] | 84 | inline void round(double* __restrict a) noexcept { FOR(i) a[i] = hl_round(a[i]); }
|
---|
[16274] | 85 | inline void square(double* __restrict a) noexcept { FOR(i) a[i] = hl_pow(a[i], 2.); }
|
---|
[16988] | 86 | inline void cbrt(double* __restrict a) noexcept { FOR(i) a[i] = hl_cbrt(a[i]); }
|
---|
[16269] | 87 |
|
---|
| 88 | #undef FOR
|
---|
| 89 | #endif
|
---|