Add AVX2 implementation of find byte not in range

Also fix alignment bug and ensure the simd finders dont return a pointer
beyond the end
This commit is contained in:
Kovid Goyal
2023-11-10 19:34:26 +05:30
parent 021dd168e5
commit e4c48a5f17

View File

@@ -9,6 +9,8 @@
#include "simd-string.h"
#include <immintrin.h>
static bool has_sse4_2 = false, has_avx2 = false;
// ByteLoader {{{
uint8_t
byte_loader_peek(const ByteLoader *self) {
@@ -92,20 +94,20 @@ find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_
static uint8_t*
find_either_of_two_bytes_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) {
size_t extra = (uintptr_t)haystack % sizeof(__m128i);
const size_t extra = (uintptr_t)haystack % sizeof(__m128i);
if (extra) { // need aligned loads for performance so search first few bytes by hand
uint8_t *ans = find_either_of_two_bytes_simple(haystack, MIN(sz, extra), needle_[0], needle_[1]);
const size_t es = MIN(sz, sizeof(__m128i) - extra);
uint8_t *ans = find_either_of_two_bytes_simple(haystack, es, needle_[0], needle_[1]);
if (ans) return ans;
extra = MIN(extra, sz);
sz -= extra;
haystack += extra;
sz -= es;
haystack += es;
if (!sz) return NULL;
}
const __m128i needle = _mm_load_si128((const __m128i *)needle_);
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += 16) {
const __m128i h = _mm_load_si128((const __m128i *)haystack);
int c = _mm_cmpistri(needle, h, _SIDD_CMP_EQUAL_ANY);
if (c != 16) {
if (c != 16 && haystack + c < limit) {
return haystack + c;
}
}
@@ -143,20 +145,20 @@ find_byte_not_in_range_simple(uint8_t *haystack, const size_t sz, const uint8_t
static uint8_t*
find_byte_not_in_range_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) {
size_t extra = (uintptr_t)haystack % sizeof(__m128i);
const size_t extra = (uintptr_t)haystack % sizeof(__m128i);
if (extra) { // need aligned loads for performance so search first few bytes by hand
uint8_t *ans = find_byte_not_in_range_simple(haystack, MIN(sz, extra), needle_[0], needle_[1]);
size_t es = MIN(sz, sizeof(__m128i) - extra);
uint8_t *ans = find_byte_not_in_range_simple(haystack, es, needle_[0], needle_[1]);
if (ans) return ans;
extra = MIN(extra, sz);
sz -= extra;
haystack += extra;
sz -= es;
haystack += es;
if (!sz) return NULL;
}
const __m128i needle = _mm_load_si128((const __m128i *)needle_);
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += 16) {
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m128i)) {
const __m128i h = _mm_load_si128((const __m128i *)haystack);
int c = _mm_cmpistri(needle, h, _SIDD_CMP_RANGES | _SIDD_NEGATIVE_POLARITY);
if (c != 16) {
if (c != 16 && haystack + c < limit) {
return haystack + c;
}
}
@@ -175,7 +177,35 @@ find_byte_not_in_range_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t
}
static uint8_t* (*find_byte_not_in_range_impl)(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple;
static uint8_t*
find_byte_not_in_range_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
const size_t extra = (uintptr_t)haystack % sizeof(__m256i);
if (extra) { // need aligned loads for performance so search first few bytes by hand
size_t es = MIN(sz, sizeof(__m256i) - extra);
uint8_t *ans = (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b);
if (ans) return ans;
sz -= es;
haystack += es;
if (!sz) return NULL;
}
__m256i lower_bound = _mm256_set1_epi8(a);
__m256i upper_bound = _mm256_set1_epi8(b);
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m256i)) {
__m256i chunk = _mm256_load_si256((__m256i*)(haystack));
__m256i above_lower = _mm256_cmpgt_epi8(chunk, lower_bound);
__m256i below_upper = _mm256_cmpgt_epi8(upper_bound, chunk);
__m256i in_range = _mm256_and_si256(above_lower, below_upper);
int mask = _mm256_movemask_epi8(in_range);
if (mask != (int)0xFFFFFFFF) {
// The trailing zeroes in ~mask give us the position
const size_t pos = __builtin_ctz(~mask);
if (haystack + pos < limit) return haystack + pos;
}
}
return NULL;
}
static uint8_t* (*find_byte_not_in_range_impl)(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple;
uint8_t*
find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
@@ -186,12 +216,21 @@ find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a, cons
bool
init_simd(void *x) {
PyObject *module = (PyObject*)x;
bool has_sse4_2 = __builtin_cpu_supports("sse4.2") != 0; bool has_avx2 = __builtin_cpu_supports("avx2");
if (0 != PyModule_AddObjectRef(module, "has_sse4_2", has_sse4_2 ? Py_True : Py_False)) return false;
if (0 != PyModule_AddObjectRef(module, "has_avx2", has_avx2 ? Py_True : Py_False)) return false;
if (has_sse4_2) {
find_byte_not_in_range_impl = find_byte_not_in_range_sse4_2;
find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2;
#define A(x, val) { Py_INCREF(Py_##val); if (0 != PyModule_AddObject(module, #x, Py_##val)) return false; }
has_sse4_2 = __builtin_cpu_supports("sse4.2") != 0; has_avx2 = __builtin_cpu_supports("avx2");
if (has_avx2) {
A(has_avx2, True);
find_byte_not_in_range_impl = find_byte_not_in_range_avx2;
} else {
A(has_avx2, False);
}
if (has_sse4_2) {
A(has_sse4_2, True);
if (find_byte_not_in_range == find_byte_not_in_range_simple) find_byte_not_in_range_impl = find_byte_not_in_range_sse4_2;
find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2;
} else {
A(has_sse4_2, False);
}
#undef A
return true;
}