From 4d35fc2928ff768084bc1045cd874ec068d4ee6e Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Tue, 30 Jan 2024 22:09:10 +0530 Subject: [PATCH] Use a custom movmask for ARM rather than the one from simde Supposedly faster, not that I can measure it, but... Also gives neater code, so keep it. --- kitty/simd-string-impl.h | 63 +++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/kitty/simd-string-impl.h b/kitty/simd-string-impl.h index 22bb64d67..416338726 100644 --- a/kitty/simd-string-impl.h +++ b/kitty/simd-string-impl.h @@ -30,6 +30,7 @@ _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wbitwise-instead-of-logical\"") #endif #include +#include #if defined(__clang__) && __clang_major__ > 12 _Pragma("clang diagnostic pop") #endif @@ -203,13 +204,42 @@ static inline integer_t shuffle_impl256(const integer_t value, const integer_t s #define debug(...) #endif +#if defined(SIMDE_ARCH_AARCH64) +// See https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon -typedef int32_t find_mask_t; -static inline find_mask_t -mask_for_find(const integer_t a) { return movemask_epi8(a); } +static inline uint64_t +movemask_arm128(const simde__m128i vec) { + simde_uint8x8_t res = simde_vshrn_n_u16(simde_vreinterpretq_u16_u8(vec), 4); + return simde_vget_lane_u64(simde_vreinterpret_u64_u8(res), 0); +} -static inline unsigned -bytes_to_first_match(const find_mask_t m) { return __builtin_ctz(m); } +#if KITTY_SIMD_LEVEL == 128 + +static inline int +bytes_to_first_match(const integer_t vec) { const uint64_t m = movemask_arm128(vec); return m ? (__builtin_ctzll(m) >> 2) : -1; } + +#else + +static inline int +bytes_to_first_match(const integer_t vec) { + if (is_zero(vec)) return -1; + simde__m128i v = simde_mm256_extracti128_si256(vec, 0); + if (!simde_mm_testz_si128(v, v)) return __builtin_ctzll(movemask_arm128(v)) >> 2; + v = simde_mm256_extracti128_si256(vec, 1); + return 16 + (__builtin_ctzll(movemask_arm128(v)) >> 2); +} + +#endif + +#else + +static inline int +bytes_to_first_match(const integer_t vec) { + return is_zero(vec) ? -1 : __builtin_ctz(movemask_epi8(vec)); +} + + +#endif // }}} @@ -226,12 +256,12 @@ const uint8_t* FUNC(find_either_of_two_bytes)(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) { const integer_t a_vec = set1_epi8(a), b_vec = set1_epi8(b); const uint8_t* limit = haystack + sz; - integer_t chunk; find_mask_t mask; + integer_t chunk; #define check_chunk() { \ - const integer_t matches = or_si(cmpeq_epi8(chunk, a_vec), cmpeq_epi8(chunk, b_vec)); \ - if ((mask = mask_for_find(matches))) { \ - const uint8_t *ans = haystack + bytes_to_first_match(mask); \ + const int n = bytes_to_first_match(or_si(cmpeq_epi8(chunk, a_vec), cmpeq_epi8(chunk, b_vec))); \ + if (n > -1) { \ + const uint8_t *ans = haystack + n; \ return ans < limit ? ans : NULL; \ }} @@ -396,10 +426,9 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { const integer_t esc_vec = set1_epi8(0x1b); const integer_t esc_cmp = cmpeq_epi8(vec, esc_vec); - const find_mask_t esc_test_mask = mask_for_find(esc_cmp); bool sentinel_found = false; - unsigned short num_of_bytes_to_first_esc; - if (esc_test_mask && (num_of_bytes_to_first_esc = bytes_to_first_match(esc_test_mask)) < src_sz) { + int num_of_bytes_to_first_esc = bytes_to_first_match(esc_cmp); + if (num_of_bytes_to_first_esc > -1 && (unsigned)num_of_bytes_to_first_esc < src_sz) { sentinel_found = true; src_sz = num_of_bytes_to_first_esc; d->num_consumed += src_sz + 1; // esc is also consumed @@ -416,9 +445,9 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { // Check if we have pure ASCII and use fast path debug_register(vec); - find_mask_t ascii_mask; + int32_t ascii_mask; start_classification: - ascii_mask = mask_for_find(vec); + ascii_mask = movemask_epi8(vec); if (!ascii_mask) { // no bytes with high bit (0x80) set, so just plain ASCII FUNC(output_plain_ascii)(d, vec, src_sz); if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes); @@ -451,7 +480,7 @@ start_classification: // counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes debug_register(counts); // check for an incomplete trailing utf8 sequence - if (check_for_trailing_bytes && mask_for_find(cmplt_epi8(one, and_si(counts, cmpeq_epi8(numbered_bytes(), set1_epi8(src_sz - 1)))))) { + if (check_for_trailing_bytes && !is_zero(cmplt_epi8(one, and_si(counts, cmpeq_epi8(numbered_bytes(), set1_epi8(src_sz - 1)))))) { // The value of counts at the last byte is > 1 indicating we have a trailing incomplete sequence check_for_trailing_bytes = false; if (src[src_sz-1] >= 0xc0) num_of_trailing_bytes = 1; // 2-, 3- and 4-byte characters with only 1 byte left @@ -462,10 +491,10 @@ start_classification: goto start_classification; } // Only ASCII chars should have corresponding byte of counts == 0 - if (ascii_mask != mask_for_find(cmpgt_epi8(counts, zero))) goto invalid_utf8; + if (ascii_mask != movemask_epi8(cmpgt_epi8(counts, zero))) goto invalid_utf8; // The difference between a byte in counts and the next one should be negative, // zero, or one. Any other value means there is not enough continuation bytes. - if (mask_for_find(cmpgt_epi8(subtract_epi8(shift_right_by_one_byte(counts), counts), one))) goto invalid_utf8; + if (!is_zero(cmpgt_epi8(subtract_epi8(shift_right_by_one_byte(counts), counts), one))) goto invalid_utf8; // Process the bytes storing the three resulting bytes that make up the unicode codepoint // mask all control bits so that we have only useful bits left