mirror of
https://github.com/kovidgoyal/kitty
synced 2026-06-11 19:19:35 +02:00
Switch to same algorithm for 128bit SIMD as used for 256 bit SIMD
Avoids needing to write to the haystack and also less chance of a bug in the never tested simd since all CPUs I have access to have AVX2
This commit is contained in:
@@ -66,8 +66,8 @@ byte_loader_skip(ByteLoader *self) {
|
||||
#define prepare_for_hasvalue(n) (~0ULL/255 * (n))
|
||||
#define hasvalue(x,n) (haszero((x) ^ (n)))
|
||||
|
||||
static uint8_t*
|
||||
find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_t x, const uint8_t y) {
|
||||
static const uint8_t*
|
||||
find_either_of_two_bytes_simple(const uint8_t *haystack, const size_t sz, const uint8_t x, const uint8_t y) {
|
||||
ByteLoader it; byte_loader_init(&it, (uint8_t*)haystack, sz);
|
||||
|
||||
// first align by testing the first few bytes one at a time
|
||||
@@ -79,7 +79,7 @@ find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_
|
||||
const BYTE_LOADER_T a = prepare_for_hasvalue(x), b = prepare_for_hasvalue(y);
|
||||
while (it.num_left) {
|
||||
if (hasvalue(it.m, a) || hasvalue(it.m, b)) {
|
||||
uint8_t *ans = haystack + sz - it.num_left, q = hasvalue(it.m, a) ? x : y;
|
||||
const uint8_t *ans = haystack + sz - it.num_left, q = hasvalue(it.m, a) ? x : y;
|
||||
while (it.num_left) {
|
||||
if (byte_loader_next(&it) == q) return ans;
|
||||
ans++;
|
||||
@@ -92,43 +92,19 @@ find_either_of_two_bytes_simple(uint8_t *haystack, const size_t sz, const uint8_
|
||||
}
|
||||
#undef SHIFT_OP
|
||||
|
||||
static uint8_t*
|
||||
find_either_of_two_bytes_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) {
|
||||
const size_t extra = (uintptr_t)haystack % sizeof(__m128i);
|
||||
if (extra) { // need aligned loads for performance so search first few bytes by hand
|
||||
const size_t es = MIN(sz, sizeof(__m128i) - extra);
|
||||
uint8_t *ans = find_either_of_two_bytes_simple(haystack, es, needle_[0], needle_[1]);
|
||||
if (ans) return ans;
|
||||
sz -= es;
|
||||
haystack += es;
|
||||
if (!sz) return NULL;
|
||||
}
|
||||
const __m128i needle = _mm_load_si128((const __m128i *)needle_);
|
||||
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += 16) {
|
||||
const __m128i h = _mm_load_si128((const __m128i *)haystack);
|
||||
int c = _mm_cmpistri(needle, h, _SIDD_CMP_EQUAL_ANY);
|
||||
if (c != 16 && haystack + c < limit) {
|
||||
return haystack + c;
|
||||
}
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static uint8_t*
|
||||
find_either_of_two_bytes_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t x, const uint8_t y) {
|
||||
uint8_t before = haystack[sz];
|
||||
haystack[sz] = 0;
|
||||
uint8_t needle[16] = {x, y, 0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
uint8_t *ans = find_either_of_two_bytes_sse4_2_impl(haystack, needle, sz);
|
||||
haystack[sz] = before;
|
||||
return ans;
|
||||
}
|
||||
#define _mm128_set1_epi8 _mm_set1_epi8
|
||||
#define _mm128_load_si128 _mm_load_si128
|
||||
#define _mm128_cmpeq_epi8 _mm_cmpeq_epi8
|
||||
#define _mm128_or_si128 _mm_or_si128
|
||||
#define _mm128_movemask_epi8 _mm_movemask_epi8
|
||||
#define _mm128_cmpgt_epi8 _mm_cmpgt_epi8
|
||||
#define _mm128_and_si128 _mm_and_si128
|
||||
|
||||
#define start_simd2(bits, aligner) \
|
||||
const size_t extra = (uintptr_t)haystack % sizeof(__m##bits##i); \
|
||||
if (extra) { /* do aligned loading */ \
|
||||
size_t es = MIN(sz, sizeof(__m##bits##i) - extra); \
|
||||
uint8_t *ans = aligner; \
|
||||
const uint8_t *ans = aligner; \
|
||||
if (ans) return ans; \
|
||||
sz -= es; \
|
||||
haystack += es; \
|
||||
@@ -154,23 +130,29 @@ find_either_of_two_bytes_sse4_2(uint8_t *haystack, const size_t sz, const uint8_
|
||||
end_simd2; \
|
||||
} return NULL;
|
||||
|
||||
static uint8_t*
|
||||
find_either_of_two_bytes_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
|
||||
static const uint8_t*
|
||||
find_either_of_two_bytes_sse4_2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
|
||||
either_of_two(128, find_either_of_two_bytes_simple(haystack, es, a, b));
|
||||
}
|
||||
|
||||
|
||||
static const uint8_t*
|
||||
find_either_of_two_bytes_avx2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
|
||||
either_of_two(256, (has_sse4_2 && es > 15) ? find_either_of_two_bytes_sse4_2(haystack, es, a, b) : find_either_of_two_bytes_simple(haystack, es, a, b));
|
||||
}
|
||||
|
||||
|
||||
static uint8_t* (*find_either_of_two_bytes_impl)(uint8_t*, const size_t, const uint8_t, const uint8_t) = find_either_of_two_bytes_simple;
|
||||
static const uint8_t* (*find_either_of_two_bytes_impl)(const uint8_t*, const size_t, const uint8_t, const uint8_t) = find_either_of_two_bytes_simple;
|
||||
|
||||
uint8_t*
|
||||
find_either_of_two_bytes(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
|
||||
const uint8_t*
|
||||
find_either_of_two_bytes(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
|
||||
return (uint8_t*)find_either_of_two_bytes_impl(haystack, sz, a, b);
|
||||
}
|
||||
// }}}
|
||||
|
||||
// find_byte_not_in_range {{{
|
||||
static uint8_t*
|
||||
find_byte_not_in_range_simple(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
|
||||
static const uint8_t*
|
||||
find_byte_not_in_range_simple(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
|
||||
ByteLoader it; byte_loader_init(&it, haystack, sz);
|
||||
while (it.num_left) {
|
||||
const uint8_t ch = byte_loader_next(&it);
|
||||
@@ -179,59 +161,31 @@ find_byte_not_in_range_simple(uint8_t *haystack, const size_t sz, const uint8_t
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static uint8_t*
|
||||
find_byte_not_in_range_sse4_2_impl(uint8_t *haystack, const uint8_t* needle_, size_t sz) {
|
||||
const size_t extra = (uintptr_t)haystack % sizeof(__m128i);
|
||||
if (extra) { // need aligned loads for performance so search first few bytes by hand
|
||||
size_t es = MIN(sz, sizeof(__m128i) - extra);
|
||||
uint8_t *ans = find_byte_not_in_range_simple(haystack, es, needle_[0], needle_[1]);
|
||||
if (ans) return ans;
|
||||
sz -= es;
|
||||
haystack += es;
|
||||
if (!sz) return NULL;
|
||||
}
|
||||
const __m128i needle = _mm_load_si128((const __m128i *)needle_);
|
||||
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m128i)) {
|
||||
const __m128i h = _mm_load_si128((const __m128i *)haystack);
|
||||
int c = _mm_cmpistri(needle, h, _SIDD_CMP_RANGES | _SIDD_NEGATIVE_POLARITY);
|
||||
if (c != 16 && haystack + c < limit) {
|
||||
return haystack + c;
|
||||
}
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
static uint8_t*
|
||||
find_byte_not_in_range_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
|
||||
uint8_t before = haystack[sz];
|
||||
haystack[sz] = 0;
|
||||
uint8_t needle[16] = {a, b, 0, 0, 0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
uint8_t *ans = (uint8_t*)find_byte_not_in_range_sse4_2_impl((uint8_t*)haystack, needle, sz);
|
||||
haystack[sz] = before;
|
||||
return ans;
|
||||
|
||||
}
|
||||
|
||||
#define not_in_range(bits, aligner) \
|
||||
start_simd2(bits, aligner) { \
|
||||
__m256i chunk = _mm256_load_si256((__m256i*)(haystack)); \
|
||||
__m256i above_lower = _mm256_cmpgt_epi8(chunk, a_vec); \
|
||||
__m256i below_upper = _mm256_cmpgt_epi8(b_vec, chunk); \
|
||||
__m256i in_range = _mm256_and_si256(above_lower, below_upper); \
|
||||
const int mask = ~_mm256_movemask_epi8(in_range); /* ~ as we want not in range */ \
|
||||
__m##bits##i chunk = _mm##bits##_load_si##bits((__m##bits##i*)(haystack)); \
|
||||
__m##bits##i above_lower = _mm##bits##_cmpgt_epi8(chunk, a_vec); \
|
||||
__m##bits##i below_upper = _mm##bits##_cmpgt_epi8(b_vec, chunk); \
|
||||
__m##bits##i in_range = _mm##bits##_and_si##bits(above_lower, below_upper); \
|
||||
const int mask = ~_mm##bits##_movemask_epi8(in_range); /* ~ as we want not in range */ \
|
||||
end_simd2; \
|
||||
} return NULL;
|
||||
|
||||
static uint8_t*
|
||||
find_byte_not_in_range_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
|
||||
static const uint8_t*
|
||||
find_byte_not_in_range_sse4_2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
|
||||
not_in_range(128, find_byte_not_in_range_simple(haystack, es, a, b));
|
||||
}
|
||||
|
||||
|
||||
static const uint8_t*
|
||||
find_byte_not_in_range_avx2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
|
||||
not_in_range(256, (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b));
|
||||
}
|
||||
|
||||
static uint8_t* (*find_byte_not_in_range_impl)(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple;
|
||||
static const uint8_t* (*find_byte_not_in_range_impl)(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) = find_byte_not_in_range_simple;
|
||||
|
||||
uint8_t*
|
||||
find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
|
||||
const uint8_t*
|
||||
find_byte_not_in_range(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
|
||||
return (uint8_t*)find_byte_not_in_range_impl(haystack, sz, a, b);
|
||||
}
|
||||
// }}}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#define BYTE_LOADER_T unsigned long long
|
||||
typedef struct ByteLoader {
|
||||
@@ -24,13 +25,12 @@ uint8_t byte_loader_next(ByteLoader *self);
|
||||
// Pass a PyModule PyObject* as the argument. Must be called once at application startup
|
||||
bool init_simd(void* module);
|
||||
|
||||
// Requires haystack[sz] to be writable and 7 bytes to the left of haystack to
|
||||
// be readable. Returns pointer to first position in haystack that contains
|
||||
// either of the two chars or NULL if not found.
|
||||
uint8_t* find_either_of_two_bytes(uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b);
|
||||
// Requires 7 bytes to the left of haystack to be readable. Returns pointer to
|
||||
// first position in haystack that contains either of the two chars or NULL if
|
||||
// not found.
|
||||
const uint8_t* find_either_of_two_bytes(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b);
|
||||
|
||||
// Requires haystack[sz] to be writable and 7 bytes to the left of haystack to
|
||||
// be readable. Returns pointer to first position in haystack that contains
|
||||
// a char that is not in [a, b]. a must be <= b
|
||||
uint8_t*
|
||||
find_byte_not_in_range(uint8_t *haystack, const size_t sz, const uint8_t a1, const uint8_t b);
|
||||
// Requires 7 bytes to the left of haystack to be readable. Returns pointer to
|
||||
// first position in haystack that contains a char that is not in [a, b].
|
||||
// a must be <= b
|
||||
const uint8_t* find_byte_not_in_range(const uint8_t *haystack, const size_t sz, const uint8_t a1, const uint8_t b);
|
||||
|
||||
@@ -284,7 +284,7 @@ consume_normal(PS *self) {
|
||||
do {
|
||||
if (self->utf8.state == UTF8_ACCEPT) {
|
||||
size_t sz = self->read.sz - self->read.pos;
|
||||
uint8_t *p = find_byte_not_in_range(self->buf + self->read.pos, sz, 32, 126);
|
||||
const uint8_t *p = find_byte_not_in_range(self->buf + self->read.pos, sz, 32, 126);
|
||||
if (p != NULL) sz = p - (self->buf + self->read.pos);
|
||||
if (sz) dispatch_printable_ascii(self, sz);
|
||||
else dispatch_normal_mode_byte(self, self->buf[self->read.pos++]);
|
||||
@@ -408,7 +408,7 @@ consume_esc(PS *self) {
|
||||
static bool
|
||||
find_st_terminator(PS *self, size_t *end_pos) {
|
||||
const size_t sz = self->read.sz - self->read.pos;
|
||||
uint8_t *q = find_either_of_two_bytes(self->buf + self->read.pos, sz, BEL, ESC_ST);
|
||||
const uint8_t *q = find_either_of_two_bytes(self->buf + self->read.pos, sz, BEL, ESC_ST);
|
||||
if (q == NULL) {
|
||||
self->read.pos += sz;
|
||||
return false;
|
||||
@@ -1459,7 +1459,7 @@ consume_input(PS *self) {
|
||||
static bool
|
||||
find_pending_stop_csi(PS *self) {
|
||||
const size_t sz = self->read.sz - self->read.pos;
|
||||
uint8_t *q = find_either_of_two_bytes(self->buf + self->read.pos, sz, ESC, 'l');
|
||||
const uint8_t *q = find_either_of_two_bytes(self->buf + self->read.pos, sz, ESC, 'l');
|
||||
if (q == NULL) {
|
||||
self->read.pos += sz;
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user