From aacdffd539e89b5ae4d35bb7dac05104fa8f6722 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Fri, 10 Nov 2023 20:45:46 +0530 Subject: [PATCH] DRYer --- kitty/simd-string.c | 68 +++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/kitty/simd-string.c b/kitty/simd-string.c index e8cef9e5d..a95ceecc9 100644 --- a/kitty/simd-string.c +++ b/kitty/simd-string.c @@ -124,30 +124,35 @@ find_either_of_two_bytes_sse4_2(uint8_t *haystack, const size_t sz, const uint8_ return ans; } +#define start_simd2(bits, aligner) \ + const size_t extra = (uintptr_t)haystack % sizeof(__m##bits##i); \ + if (extra) { \ + size_t es = MIN(sz, sizeof(__m##bits##i) - extra); \ + uint8_t *ans = aligner; \ + if (ans) return ans; \ + sz -= es; \ + haystack += es; \ + if (!sz) return NULL; \ + } \ + __m##bits##i a_vec = _mm##bits##_set1_epi8(a); \ + __m##bits##i b_vec = _mm##bits##_set1_epi8(b); \ + for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m##bits##i)) + +#define end_simd2 \ + if (mask != 0) { \ + size_t pos = __builtin_ctz(mask); \ + if (haystack + pos < limit) return haystack + pos; \ + } + static uint8_t* find_either_of_two_bytes_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) { - const size_t extra = (uintptr_t)haystack % sizeof(__m256i); - if (extra) { // need aligned loads for performance so search first few bytes by hand - size_t es = MIN(sz, sizeof(__m256i) - extra); - uint8_t *ans = (has_sse4_2 && es > 15) ? find_either_of_two_bytes_sse4_2(haystack, es, a, b) : find_either_of_two_bytes_simple(haystack, es, a, b); - if (ans) return ans; - sz -= es; - haystack += es; - if (!sz) return NULL; - } - __m256i x = _mm256_set1_epi8(a); - __m256i y = _mm256_set1_epi8(b); - for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m256i)) { + start_simd2(256, (has_sse4_2 && es > 15) ? find_either_of_two_bytes_sse4_2(haystack, es, a, b) : find_either_of_two_bytes_simple(haystack, es, a, b)) { __m256i chunk = _mm256_load_si256((__m256i*)(haystack)); - __m256i x_cmp = _mm256_cmpeq_epi8(chunk, x); - __m256i y_cmp = _mm256_cmpeq_epi8(chunk, y); - __m256i matches = _mm256_or_si256(x_cmp, y_cmp); + __m256i a_cmp = _mm256_cmpeq_epi8(chunk, a_vec); + __m256i b_cmp = _mm256_cmpeq_epi8(chunk, b_vec); + __m256i matches = _mm256_or_si256(a_cmp, b_cmp); const int mask = _mm256_movemask_epi8(matches); - if (mask != 0) { - // The trailing zeroes in mask give us the position of the first match - size_t pos = __builtin_ctz(mask); - if (haystack + pos < limit) return haystack + pos; - } + end_simd2; } return NULL; } @@ -208,28 +213,13 @@ find_byte_not_in_range_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t static uint8_t* find_byte_not_in_range_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) { - const size_t extra = (uintptr_t)haystack % sizeof(__m256i); - if (extra) { // need aligned loads for performance so search first few bytes by hand - size_t es = MIN(sz, sizeof(__m256i) - extra); - uint8_t *ans = (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b); - if (ans) return ans; - sz -= es; - haystack += es; - if (!sz) return NULL; - } - __m256i lower_bound = _mm256_set1_epi8(a); - __m256i upper_bound = _mm256_set1_epi8(b); - for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m256i)) { + start_simd2(256, (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b)) { __m256i chunk = _mm256_load_si256((__m256i*)(haystack)); - __m256i above_lower = _mm256_cmpgt_epi8(chunk, lower_bound); - __m256i below_upper = _mm256_cmpgt_epi8(upper_bound, chunk); + __m256i above_lower = _mm256_cmpgt_epi8(chunk, a_vec); + __m256i below_upper = _mm256_cmpgt_epi8(b_vec, chunk); __m256i in_range = _mm256_and_si256(above_lower, below_upper); const int mask = ~_mm256_movemask_epi8(in_range); // ~ as we want not in range - if (mask != 0) { - // The trailing zeroes in mask give us the position of the first match on little endian - const int pos = __builtin_ctz(mask); - if (haystack + pos < limit) return haystack + pos; - } + end_simd2; } return NULL; }