diff --git a/kitty/simd-string.c b/kitty/simd-string.c index 1ed54d9ec..e8cef9e5d 100644 --- a/kitty/simd-string.c +++ b/kitty/simd-string.c @@ -124,6 +124,35 @@ find_either_of_two_bytes_sse4_2(uint8_t *haystack, const size_t sz, const uint8_ return ans; } +static uint8_t* +find_either_of_two_bytes_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) { + const size_t extra = (uintptr_t)haystack % sizeof(__m256i); + if (extra) { // need aligned loads for performance so search first few bytes by hand + size_t es = MIN(sz, sizeof(__m256i) - extra); + uint8_t *ans = (has_sse4_2 && es > 15) ? find_either_of_two_bytes_sse4_2(haystack, es, a, b) : find_either_of_two_bytes_simple(haystack, es, a, b); + if (ans) return ans; + sz -= es; + haystack += es; + if (!sz) return NULL; + } + __m256i x = _mm256_set1_epi8(a); + __m256i y = _mm256_set1_epi8(b); + for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m256i)) { + __m256i chunk = _mm256_load_si256((__m256i*)(haystack)); + __m256i x_cmp = _mm256_cmpeq_epi8(chunk, x); + __m256i y_cmp = _mm256_cmpeq_epi8(chunk, y); + __m256i matches = _mm256_or_si256(x_cmp, y_cmp); + const int mask = _mm256_movemask_epi8(matches); + if (mask != 0) { + // The trailing zeroes in mask give us the position of the first match + size_t pos = __builtin_ctz(mask); + if (haystack + pos < limit) return haystack + pos; + } + } + return NULL; +} + + static uint8_t* (*find_either_of_two_bytes_impl)(uint8_t*, const size_t, const uint8_t, const uint8_t) = find_either_of_two_bytes_simple; uint8_t* @@ -195,10 +224,10 @@ find_byte_not_in_range_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const __m256i above_lower = _mm256_cmpgt_epi8(chunk, lower_bound); __m256i below_upper = _mm256_cmpgt_epi8(upper_bound, chunk); __m256i in_range = _mm256_and_si256(above_lower, below_upper); - int mask = _mm256_movemask_epi8(in_range); - if (mask != (int)0xFFFFFFFF) { - // The trailing zeroes in ~mask give us the position - const size_t pos = __builtin_ctz(~mask); + const int mask = ~_mm256_movemask_epi8(in_range); // ~ as we want not in range + if (mask != 0) { + // The trailing zeroes in mask give us the position of the first match on little endian + const int pos = __builtin_ctz(mask); if (haystack + pos < limit) return haystack + pos; } } @@ -221,13 +250,14 @@ init_simd(void *x) { if (has_avx2) { A(has_avx2, True); find_byte_not_in_range_impl = find_byte_not_in_range_avx2; + find_either_of_two_bytes_impl = find_either_of_two_bytes_avx2; } else { A(has_avx2, False); } if (has_sse4_2) { A(has_sse4_2, True); if (find_byte_not_in_range == find_byte_not_in_range_simple) find_byte_not_in_range_impl = find_byte_not_in_range_sse4_2; - find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2; + if (find_either_of_two_bytes_impl == find_either_of_two_bytes_simple) find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2; } else { A(has_sse4_2, False); }