Use a custom movmask for ARM rather than the one from simde

Supposedly faster, not that I can measure it, but...
Also gives neater code, so keep it.
This commit is contained in:
Kovid Goyal
2024-01-30 22:09:10 +05:30
parent 3b65c1a58a
commit 4d35fc2928

View File

@@ -30,6 +30,7 @@ _Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wbitwise-instead-of-logical\"")
#endif
#include <simde/x86/avx2.h>
#include <simde/arm/neon.h>
#if defined(__clang__) && __clang_major__ > 12
_Pragma("clang diagnostic pop")
#endif
@@ -203,13 +204,42 @@ static inline integer_t shuffle_impl256(const integer_t value, const integer_t s
#define debug(...)
#endif
#if defined(SIMDE_ARCH_AARCH64)
// See https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
typedef int32_t find_mask_t;
static inline find_mask_t
mask_for_find(const integer_t a) { return movemask_epi8(a); }
static inline uint64_t
movemask_arm128(const simde__m128i vec) {
simde_uint8x8_t res = simde_vshrn_n_u16(simde_vreinterpretq_u16_u8(vec), 4);
return simde_vget_lane_u64(simde_vreinterpret_u64_u8(res), 0);
}
static inline unsigned
bytes_to_first_match(const find_mask_t m) { return __builtin_ctz(m); }
#if KITTY_SIMD_LEVEL == 128
static inline int
bytes_to_first_match(const integer_t vec) { const uint64_t m = movemask_arm128(vec); return m ? (__builtin_ctzll(m) >> 2) : -1; }
#else
static inline int
bytes_to_first_match(const integer_t vec) {
if (is_zero(vec)) return -1;
simde__m128i v = simde_mm256_extracti128_si256(vec, 0);
if (!simde_mm_testz_si128(v, v)) return __builtin_ctzll(movemask_arm128(v)) >> 2;
v = simde_mm256_extracti128_si256(vec, 1);
return 16 + (__builtin_ctzll(movemask_arm128(v)) >> 2);
}
#endif
#else
static inline int
bytes_to_first_match(const integer_t vec) {
return is_zero(vec) ? -1 : __builtin_ctz(movemask_epi8(vec));
}
#endif
// }}}
@@ -226,12 +256,12 @@ const uint8_t*
FUNC(find_either_of_two_bytes)(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
const integer_t a_vec = set1_epi8(a), b_vec = set1_epi8(b);
const uint8_t* limit = haystack + sz;
integer_t chunk; find_mask_t mask;
integer_t chunk;
#define check_chunk() { \
const integer_t matches = or_si(cmpeq_epi8(chunk, a_vec), cmpeq_epi8(chunk, b_vec)); \
if ((mask = mask_for_find(matches))) { \
const uint8_t *ans = haystack + bytes_to_first_match(mask); \
const int n = bytes_to_first_match(or_si(cmpeq_epi8(chunk, a_vec), cmpeq_epi8(chunk, b_vec))); \
if (n > -1) { \
const uint8_t *ans = haystack + n; \
return ans < limit ? ans : NULL; \
}}
@@ -396,10 +426,9 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
const integer_t esc_vec = set1_epi8(0x1b);
const integer_t esc_cmp = cmpeq_epi8(vec, esc_vec);
const find_mask_t esc_test_mask = mask_for_find(esc_cmp);
bool sentinel_found = false;
unsigned short num_of_bytes_to_first_esc;
if (esc_test_mask && (num_of_bytes_to_first_esc = bytes_to_first_match(esc_test_mask)) < src_sz) {
int num_of_bytes_to_first_esc = bytes_to_first_match(esc_cmp);
if (num_of_bytes_to_first_esc > -1 && (unsigned)num_of_bytes_to_first_esc < src_sz) {
sentinel_found = true;
src_sz = num_of_bytes_to_first_esc;
d->num_consumed += src_sz + 1; // esc is also consumed
@@ -416,9 +445,9 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
// Check if we have pure ASCII and use fast path
debug_register(vec);
find_mask_t ascii_mask;
int32_t ascii_mask;
start_classification:
ascii_mask = mask_for_find(vec);
ascii_mask = movemask_epi8(vec);
if (!ascii_mask) { // no bytes with high bit (0x80) set, so just plain ASCII
FUNC(output_plain_ascii)(d, vec, src_sz);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
@@ -451,7 +480,7 @@ start_classification:
// counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes
debug_register(counts);
// check for an incomplete trailing utf8 sequence
if (check_for_trailing_bytes && mask_for_find(cmplt_epi8(one, and_si(counts, cmpeq_epi8(numbered_bytes(), set1_epi8(src_sz - 1)))))) {
if (check_for_trailing_bytes && !is_zero(cmplt_epi8(one, and_si(counts, cmpeq_epi8(numbered_bytes(), set1_epi8(src_sz - 1)))))) {
// The value of counts at the last byte is > 1 indicating we have a trailing incomplete sequence
check_for_trailing_bytes = false;
if (src[src_sz-1] >= 0xc0) num_of_trailing_bytes = 1; // 2-, 3- and 4-byte characters with only 1 byte left
@@ -462,10 +491,10 @@ start_classification:
goto start_classification;
}
// Only ASCII chars should have corresponding byte of counts == 0
if (ascii_mask != mask_for_find(cmpgt_epi8(counts, zero))) goto invalid_utf8;
if (ascii_mask != movemask_epi8(cmpgt_epi8(counts, zero))) goto invalid_utf8;
// The difference between a byte in counts and the next one should be negative,
// zero, or one. Any other value means there is not enough continuation bytes.
if (mask_for_find(cmpgt_epi8(subtract_epi8(shift_right_by_one_byte(counts), counts), one))) goto invalid_utf8;
if (!is_zero(cmpgt_epi8(subtract_epi8(shift_right_by_one_byte(counts), counts), one))) goto invalid_utf8;
// Process the bytes storing the three resulting bytes that make up the unicode codepoint
// mask all control bits so that we have only useful bits left