mirror of
https://github.com/kovidgoyal/kitty
synced 2026-06-11 11:09:16 +02:00
Add AVX2 implementation of find byte not in range
Also fix alignment bug and ensure the simd finders dont return a pointer beyond the end
This commit is contained in:
@@ -9,6 +9,8 @@
|
||||
#include "simd-string.h"
|
||||
#include <immintrin.h>
|
||||
|
||||
static bool has_sse4_2 = false, has_avx2 = false;
|
||||
|
||||
// ByteLoader {{{
|
||||
uint8_t
|
||||
byte_loader_peek(const ByteLoader *self) {
|
||||
@@ -92,20 +94,20 @@ find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_
|
||||
|
||||
static uint8_t*
|
||||
find_either_of_two_bytes_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) {
|
||||
size_t extra = (uintptr_t)haystack % sizeof(__m128i);
|
||||
const size_t extra = (uintptr_t)haystack % sizeof(__m128i);
|
||||
if (extra) { // need aligned loads for performance so search first few bytes by hand
|
||||
uint8_t *ans = find_either_of_two_bytes_simple(haystack, MIN(sz, extra), needle_[0], needle_[1]);
|
||||
const size_t es = MIN(sz, sizeof(__m128i) - extra);
|
||||
uint8_t *ans = find_either_of_two_bytes_simple(haystack, es, needle_[0], needle_[1]);
|
||||
if (ans) return ans;
|
||||
extra = MIN(extra, sz);
|
||||
sz -= extra;
|
||||
haystack += extra;
|
||||
sz -= es;
|
||||
haystack += es;
|
||||
if (!sz) return NULL;
|
||||
}
|
||||
const __m128i needle = _mm_load_si128((const __m128i *)needle_);
|
||||
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += 16) {
|
||||
const __m128i h = _mm_load_si128((const __m128i *)haystack);
|
||||
int c = _mm_cmpistri(needle, h, _SIDD_CMP_EQUAL_ANY);
|
||||
if (c != 16) {
|
||||
if (c != 16 && haystack + c < limit) {
|
||||
return haystack + c;
|
||||
}
|
||||
}
|
||||
@@ -143,20 +145,20 @@ find_byte_not_in_range_simple(uint8_t *haystack, const size_t sz, const uint8_t
|
||||
|
||||
static uint8_t*
|
||||
find_byte_not_in_range_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) {
|
||||
size_t extra = (uintptr_t)haystack % sizeof(__m128i);
|
||||
const size_t extra = (uintptr_t)haystack % sizeof(__m128i);
|
||||
if (extra) { // need aligned loads for performance so search first few bytes by hand
|
||||
uint8_t *ans = find_byte_not_in_range_simple(haystack, MIN(sz, extra), needle_[0], needle_[1]);
|
||||
size_t es = MIN(sz, sizeof(__m128i) - extra);
|
||||
uint8_t *ans = find_byte_not_in_range_simple(haystack, es, needle_[0], needle_[1]);
|
||||
if (ans) return ans;
|
||||
extra = MIN(extra, sz);
|
||||
sz -= extra;
|
||||
haystack += extra;
|
||||
sz -= es;
|
||||
haystack += es;
|
||||
if (!sz) return NULL;
|
||||
}
|
||||
const __m128i needle = _mm_load_si128((const __m128i *)needle_);
|
||||
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += 16) {
|
||||
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m128i)) {
|
||||
const __m128i h = _mm_load_si128((const __m128i *)haystack);
|
||||
int c = _mm_cmpistri(needle, h, _SIDD_CMP_RANGES | _SIDD_NEGATIVE_POLARITY);
|
||||
if (c != 16) {
|
||||
if (c != 16 && haystack + c < limit) {
|
||||
return haystack + c;
|
||||
}
|
||||
}
|
||||
@@ -175,7 +177,35 @@ find_byte_not_in_range_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t
|
||||
|
||||
}
|
||||
|
||||
static uint8_t* (*find_byte_not_in_range_impl)(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple;
|
||||
static uint8_t*
|
||||
find_byte_not_in_range_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
|
||||
const size_t extra = (uintptr_t)haystack % sizeof(__m256i);
|
||||
if (extra) { // need aligned loads for performance so search first few bytes by hand
|
||||
size_t es = MIN(sz, sizeof(__m256i) - extra);
|
||||
uint8_t *ans = (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b);
|
||||
if (ans) return ans;
|
||||
sz -= es;
|
||||
haystack += es;
|
||||
if (!sz) return NULL;
|
||||
}
|
||||
__m256i lower_bound = _mm256_set1_epi8(a);
|
||||
__m256i upper_bound = _mm256_set1_epi8(b);
|
||||
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m256i)) {
|
||||
__m256i chunk = _mm256_load_si256((__m256i*)(haystack));
|
||||
__m256i above_lower = _mm256_cmpgt_epi8(chunk, lower_bound);
|
||||
__m256i below_upper = _mm256_cmpgt_epi8(upper_bound, chunk);
|
||||
__m256i in_range = _mm256_and_si256(above_lower, below_upper);
|
||||
int mask = _mm256_movemask_epi8(in_range);
|
||||
if (mask != (int)0xFFFFFFFF) {
|
||||
// The trailing zeroes in ~mask give us the position
|
||||
const size_t pos = __builtin_ctz(~mask);
|
||||
if (haystack + pos < limit) return haystack + pos;
|
||||
}
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static uint8_t* (*find_byte_not_in_range_impl)(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple;
|
||||
|
||||
uint8_t*
|
||||
find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
|
||||
@@ -186,12 +216,21 @@ find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a, cons
|
||||
bool
|
||||
init_simd(void *x) {
|
||||
PyObject *module = (PyObject*)x;
|
||||
bool has_sse4_2 = __builtin_cpu_supports("sse4.2") != 0; bool has_avx2 = __builtin_cpu_supports("avx2");
|
||||
if (0 != PyModule_AddObjectRef(module, "has_sse4_2", has_sse4_2 ? Py_True : Py_False)) return false;
|
||||
if (0 != PyModule_AddObjectRef(module, "has_avx2", has_avx2 ? Py_True : Py_False)) return false;
|
||||
if (has_sse4_2) {
|
||||
find_byte_not_in_range_impl = find_byte_not_in_range_sse4_2;
|
||||
find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2;
|
||||
#define A(x, val) { Py_INCREF(Py_##val); if (0 != PyModule_AddObject(module, #x, Py_##val)) return false; }
|
||||
has_sse4_2 = __builtin_cpu_supports("sse4.2") != 0; has_avx2 = __builtin_cpu_supports("avx2");
|
||||
if (has_avx2) {
|
||||
A(has_avx2, True);
|
||||
find_byte_not_in_range_impl = find_byte_not_in_range_avx2;
|
||||
} else {
|
||||
A(has_avx2, False);
|
||||
}
|
||||
if (has_sse4_2) {
|
||||
A(has_sse4_2, True);
|
||||
if (find_byte_not_in_range == find_byte_not_in_range_simple) find_byte_not_in_range_impl = find_byte_not_in_range_sse4_2;
|
||||
find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2;
|
||||
} else {
|
||||
A(has_sse4_2, False);
|
||||
}
|
||||
#undef A
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user