diff --git a/kitty/simd-string.c b/kitty/simd-string.c index 0986b34ac..1ed54d9ec 100644 --- a/kitty/simd-string.c +++ b/kitty/simd-string.c @@ -9,6 +9,8 @@ #include "simd-string.h" #include +static bool has_sse4_2 = false, has_avx2 = false; + // ByteLoader {{{ uint8_t byte_loader_peek(const ByteLoader *self) { @@ -92,20 +94,20 @@ find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_ static uint8_t* find_either_of_two_bytes_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) { - size_t extra = (uintptr_t)haystack % sizeof(__m128i); + const size_t extra = (uintptr_t)haystack % sizeof(__m128i); if (extra) { // need aligned loads for performance so search first few bytes by hand - uint8_t *ans = find_either_of_two_bytes_simple(haystack, MIN(sz, extra), needle_[0], needle_[1]); + const size_t es = MIN(sz, sizeof(__m128i) - extra); + uint8_t *ans = find_either_of_two_bytes_simple(haystack, es, needle_[0], needle_[1]); if (ans) return ans; - extra = MIN(extra, sz); - sz -= extra; - haystack += extra; + sz -= es; + haystack += es; if (!sz) return NULL; } const __m128i needle = _mm_load_si128((const __m128i *)needle_); for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += 16) { const __m128i h = _mm_load_si128((const __m128i *)haystack); int c = _mm_cmpistri(needle, h, _SIDD_CMP_EQUAL_ANY); - if (c != 16) { + if (c != 16 && haystack + c < limit) { return haystack + c; } } @@ -143,20 +145,20 @@ find_byte_not_in_range_simple(uint8_t *haystack, const size_t sz, const uint8_t static uint8_t* find_byte_not_in_range_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) { - size_t extra = (uintptr_t)haystack % sizeof(__m128i); + const size_t extra = (uintptr_t)haystack % sizeof(__m128i); if (extra) { // need aligned loads for performance so search first few bytes by hand - uint8_t *ans = find_byte_not_in_range_simple(haystack, MIN(sz, extra), needle_[0], needle_[1]); + size_t es = MIN(sz, sizeof(__m128i) - extra); + uint8_t *ans = find_byte_not_in_range_simple(haystack, es, needle_[0], needle_[1]); if (ans) return ans; - extra = MIN(extra, sz); - sz -= extra; - haystack += extra; + sz -= es; + haystack += es; if (!sz) return NULL; } const __m128i needle = _mm_load_si128((const __m128i *)needle_); - for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += 16) { + for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m128i)) { const __m128i h = _mm_load_si128((const __m128i *)haystack); int c = _mm_cmpistri(needle, h, _SIDD_CMP_RANGES | _SIDD_NEGATIVE_POLARITY); - if (c != 16) { + if (c != 16 && haystack + c < limit) { return haystack + c; } } @@ -175,7 +177,35 @@ find_byte_not_in_range_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t } -static uint8_t* (*find_byte_not_in_range_impl)(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple; +static uint8_t* +find_byte_not_in_range_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) { + const size_t extra = (uintptr_t)haystack % sizeof(__m256i); + if (extra) { // need aligned loads for performance so search first few bytes by hand + size_t es = MIN(sz, sizeof(__m256i) - extra); + uint8_t *ans = (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b); + if (ans) return ans; + sz -= es; + haystack += es; + if (!sz) return NULL; + } + __m256i lower_bound = _mm256_set1_epi8(a); + __m256i upper_bound = _mm256_set1_epi8(b); + for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m256i)) { + __m256i chunk = _mm256_load_si256((__m256i*)(haystack)); + __m256i above_lower = _mm256_cmpgt_epi8(chunk, lower_bound); + __m256i below_upper = _mm256_cmpgt_epi8(upper_bound, chunk); + __m256i in_range = _mm256_and_si256(above_lower, below_upper); + int mask = _mm256_movemask_epi8(in_range); + if (mask != (int)0xFFFFFFFF) { + // The trailing zeroes in ~mask give us the position + const size_t pos = __builtin_ctz(~mask); + if (haystack + pos < limit) return haystack + pos; + } + } + return NULL; +} + +static uint8_t* (*find_byte_not_in_range_impl)(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple; uint8_t* find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) { @@ -186,12 +216,21 @@ find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a, cons bool init_simd(void *x) { PyObject *module = (PyObject*)x; - bool has_sse4_2 = __builtin_cpu_supports("sse4.2") != 0; bool has_avx2 = __builtin_cpu_supports("avx2"); - if (0 != PyModule_AddObjectRef(module, "has_sse4_2", has_sse4_2 ? Py_True : Py_False)) return false; - if (0 != PyModule_AddObjectRef(module, "has_avx2", has_avx2 ? Py_True : Py_False)) return false; - if (has_sse4_2) { - find_byte_not_in_range_impl = find_byte_not_in_range_sse4_2; - find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2; +#define A(x, val) { Py_INCREF(Py_##val); if (0 != PyModule_AddObject(module, #x, Py_##val)) return false; } + has_sse4_2 = __builtin_cpu_supports("sse4.2") != 0; has_avx2 = __builtin_cpu_supports("avx2"); + if (has_avx2) { + A(has_avx2, True); + find_byte_not_in_range_impl = find_byte_not_in_range_avx2; + } else { + A(has_avx2, False); } + if (has_sse4_2) { + A(has_sse4_2, True); + if (find_byte_not_in_range == find_byte_not_in_range_simple) find_byte_not_in_range_impl = find_byte_not_in_range_sse4_2; + find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2; + } else { + A(has_sse4_2, False); + } +#undef A return true; }