Switch to same algorithm for 128bit SIMD as used for 256 bit SIMD

Avoids needing to write to the haystack and also less chance of a bug in
the never tested simd since all CPUs I have access to have AVX2
This commit is contained in:
Kovid Goyal
2023-11-10 21:16:55 +05:30
parent 1925d5ea65
commit fe2cd543ba
3 changed files with 52 additions and 98 deletions

View File

@@ -66,8 +66,8 @@ byte_loader_skip(ByteLoader *self) {
#define prepare_for_hasvalue(n) (~0ULL/255 * (n))
#define hasvalue(x,n) (haszero((x) ^ (n)))
static uint8_t*
find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_t x, const uint8_t y) {
static const uint8_t*
find_either_of_two_bytes_simple(const uint8_t *haystack, const size_t sz, const uint8_t x, const uint8_t y) {
ByteLoader it; byte_loader_init(&it, (uint8_t*)haystack, sz);
// first align by testing the first few bytes one at a time
@@ -79,7 +79,7 @@ find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_
const BYTE_LOADER_T a = prepare_for_hasvalue(x), b = prepare_for_hasvalue(y);
while (it.num_left) {
if (hasvalue(it.m, a) || hasvalue(it.m, b)) {
uint8_t *ans = haystack + sz - it.num_left, q = hasvalue(it.m, a) ? x : y;
const uint8_t *ans = haystack + sz - it.num_left, q = hasvalue(it.m, a) ? x : y;
while (it.num_left) {
if (byte_loader_next(&it) == q) return ans;
ans++;
@@ -92,43 +92,19 @@ find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_
}
#undef SHIFT_OP
static uint8_t*
find_either_of_two_bytes_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) {
const size_t extra = (uintptr_t)haystack % sizeof(__m128i);
if (extra) { // need aligned loads for performance so search first few bytes by hand
const size_t es = MIN(sz, sizeof(__m128i) - extra);
uint8_t *ans = find_either_of_two_bytes_simple(haystack, es, needle_[0], needle_[1]);
if (ans) return ans;
sz -= es;
haystack += es;
if (!sz) return NULL;
}
const __m128i needle = _mm_load_si128((const __m128i *)needle_);
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += 16) {
const __m128i h = _mm_load_si128((const __m128i *)haystack);
int c = _mm_cmpistri(needle, h, _SIDD_CMP_EQUAL_ANY);
if (c != 16 && haystack + c < limit) {
return haystack + c;
}
}
return NULL;
}
static uint8_t*
find_either_of_two_bytes_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t x, const uint8_t y) {
uint8_t before = haystack[sz];
haystack[sz] = 0;
uint8_t needle[16] = {x, y, 0,0,0,0,0,0,0,0,0,0,0,0,0,0};
uint8_t *ans = find_either_of_two_bytes_sse4_2_impl(haystack, needle, sz);
haystack[sz] = before;
return ans;
}
#define _mm128_set1_epi8 _mm_set1_epi8
#define _mm128_load_si128 _mm_load_si128
#define _mm128_cmpeq_epi8 _mm_cmpeq_epi8
#define _mm128_or_si128 _mm_or_si128
#define _mm128_movemask_epi8 _mm_movemask_epi8
#define _mm128_cmpgt_epi8 _mm_cmpgt_epi8
#define _mm128_and_si128 _mm_and_si128
#define start_simd2(bits, aligner) \
const size_t extra = (uintptr_t)haystack % sizeof(__m##bits##i); \
if (extra) { /* do aligned loading */ \
size_t es = MIN(sz, sizeof(__m##bits##i) - extra); \
uint8_t *ans = aligner; \
const uint8_t *ans = aligner; \
if (ans) return ans; \
sz -= es; \
haystack += es; \
@@ -154,23 +130,29 @@ find_either_of_two_bytes_sse4_2(uint8_t *haystack, const size_t sz, const uint8_
end_simd2; \
} return NULL;
static uint8_t*
find_either_of_two_bytes_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
static const uint8_t*
find_either_of_two_bytes_sse4_2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
either_of_two(128, find_either_of_two_bytes_simple(haystack, es, a, b));
}
static const uint8_t*
find_either_of_two_bytes_avx2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
either_of_two(256, (has_sse4_2 && es > 15) ? find_either_of_two_bytes_sse4_2(haystack, es, a, b) : find_either_of_two_bytes_simple(haystack, es, a, b));
}
static uint8_t* (*find_either_of_two_bytes_impl)(uint8_t*, const size_t, const uint8_t, const uint8_t) = find_either_of_two_bytes_simple;
static const uint8_t* (*find_either_of_two_bytes_impl)(const uint8_t*, const size_t, const uint8_t, const uint8_t) = find_either_of_two_bytes_simple;
uint8_t*
find_either_of_two_bytes(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
const uint8_t*
find_either_of_two_bytes(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
return (uint8_t*)find_either_of_two_bytes_impl(haystack, sz, a, b);
}
// }}}
// find_byte_not_in_range {{{
static uint8_t*
find_byte_not_in_range_simple(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
static const uint8_t*
find_byte_not_in_range_simple(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
ByteLoader it; byte_loader_init(&it, haystack, sz);
while (it.num_left) {
const uint8_t ch = byte_loader_next(&it);
@@ -179,59 +161,31 @@ find_byte_not_in_range_simple(uint8_t *haystack, const size_t sz, const uint8_t
return NULL;
}
static uint8_t*
find_byte_not_in_range_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) {
const size_t extra = (uintptr_t)haystack % sizeof(__m128i);
if (extra) { // need aligned loads for performance so search first few bytes by hand
size_t es = MIN(sz, sizeof(__m128i) - extra);
uint8_t *ans = find_byte_not_in_range_simple(haystack, es, needle_[0], needle_[1]);
if (ans) return ans;
sz -= es;
haystack += es;
if (!sz) return NULL;
}
const __m128i needle = _mm_load_si128((const __m128i *)needle_);
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m128i)) {
const __m128i h = _mm_load_si128((const __m128i *)haystack);
int c = _mm_cmpistri(needle, h, _SIDD_CMP_RANGES | _SIDD_NEGATIVE_POLARITY);
if (c != 16 && haystack + c < limit) {
return haystack + c;
}
}
return NULL;
}
static uint8_t*
find_byte_not_in_range_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
uint8_t before = haystack[sz];
haystack[sz] = 0;
uint8_t needle[16] = {a, b, 0, 0, 0,0,0,0,0,0,0,0,0,0,0,0};
uint8_t *ans = (uint8_t*)find_byte_not_in_range_sse4_2_impl((uint8_t*)haystack, needle, sz);
haystack[sz] = before;
return ans;
}
#define not_in_range(bits, aligner) \
start_simd2(bits, aligner) { \
__m256i chunk = _mm256_load_si256((__m256i*)(haystack)); \
__m256i above_lower = _mm256_cmpgt_epi8(chunk, a_vec); \
__m256i below_upper = _mm256_cmpgt_epi8(b_vec, chunk); \
__m256i in_range = _mm256_and_si256(above_lower, below_upper); \
const int mask = ~_mm256_movemask_epi8(in_range); /* ~ as we want not in range */ \
__m##bits##i chunk = _mm##bits##_load_si##bits((__m##bits##i*)(haystack)); \
__m##bits##i above_lower = _mm##bits##_cmpgt_epi8(chunk, a_vec); \
__m##bits##i below_upper = _mm##bits##_cmpgt_epi8(b_vec, chunk); \
__m##bits##i in_range = _mm##bits##_and_si##bits(above_lower, below_upper); \
const int mask = ~_mm##bits##_movemask_epi8(in_range); /* ~ as we want not in range */ \
end_simd2; \
} return NULL;
static uint8_t*
find_byte_not_in_range_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
static const uint8_t*
find_byte_not_in_range_sse4_2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
not_in_range(128, find_byte_not_in_range_simple(haystack, es, a, b));
}
static const uint8_t*
find_byte_not_in_range_avx2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
not_in_range(256, (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b));
}
static uint8_t* (*find_byte_not_in_range_impl)(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple;
static const uint8_t* (*find_byte_not_in_range_impl)(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple;
uint8_t*
find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
const uint8_t*
find_byte_not_in_range(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
return (uint8_t*)find_byte_not_in_range_impl(haystack, sz, a, b);
}
// }}}

View File

@@ -8,6 +8,7 @@
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#define BYTE_LOADER_T unsigned long long
typedef struct ByteLoader {
@@ -24,13 +25,12 @@ uint8_t byte_loader_next(ByteLoader *self);
// Pass a PyModule PyObject* as the argument. Must be called once at application startup
bool init_simd(void* module);
// Requires haystack[sz] to be writable and 7 bytes to the left of haystack to
// be readable. Returns pointer to first position in haystack that contains
// either of the two chars or NULL if not found.
uint8_t* find_either_of_two_bytes(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b);
// Requires 7 bytes to the left of haystack to be readable. Returns pointer to
// first position in haystack that contains either of the two chars or NULL if
// not found.
const uint8_t* find_either_of_two_bytes(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b);
// Requires haystack[sz] to be writable and 7 bytes to the left of haystack to
// be readable. Returns pointer to first position in haystack that contains
// a char that is not in [a, b]. a must be <= b
uint8_t*
find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a1, const uint8_t b);
// Requires 7 bytes to the left of haystack to be readable. Returns pointer to
// first position in haystack that contains a char that is not in [a, b].
// a must be <= b
const uint8_t* find_byte_not_in_range(const uint8_t *haystack, const size_t sz, const uint8_t a1, const uint8_t b);

View File

@@ -284,7 +284,7 @@ consume_normal(PS *self) {
do {
if (self->utf8.state == UTF8_ACCEPT) {
size_t sz = self->read.sz - self->read.pos;
uint8_t *p = find_byte_not_in_range(self->buf + self->read.pos, sz, 32, 126);
const uint8_t *p = find_byte_not_in_range(self->buf + self->read.pos, sz, 32, 126);
if (p != NULL) sz = p - (self->buf + self->read.pos);
if (sz) dispatch_printable_ascii(self, sz);
else dispatch_normal_mode_byte(self, self->buf[self->read.pos++]);
@@ -408,7 +408,7 @@ consume_esc(PS *self) {
static bool
find_st_terminator(PS *self, size_t *end_pos) {
const size_t sz = self->read.sz - self->read.pos;
uint8_t *q = find_either_of_two_bytes(self->buf + self->read.pos, sz, BEL, ESC_ST);
const uint8_t *q = find_either_of_two_bytes(self->buf + self->read.pos, sz, BEL, ESC_ST);
if (q == NULL) {
self->read.pos += sz;
return false;
@@ -1459,7 +1459,7 @@ consume_input(PS *self) {
static bool
find_pending_stop_csi(PS *self) {
const size_t sz = self->read.sz - self->read.pos;
uint8_t *q = find_either_of_two_bytes(self->buf + self->read.pos, sz, ESC, 'l');
const uint8_t *q = find_either_of_two_bytes(self->buf + self->read.pos, sz, ESC, 'l');
if (q == NULL) {
self->read.pos += sz;
return false;