This commit is contained in:
Kovid Goyal
2023-11-10 20:45:46 +05:30
parent a0e1eb4985
commit aacdffd539

View File

@@ -124,30 +124,35 @@ find_either_of_two_bytes_sse4_2(uint8_t *haystack, const size_t sz, const uint8_
return ans;
}
#define start_simd2(bits, aligner) \
const size_t extra = (uintptr_t)haystack % sizeof(__m##bits##i); \
if (extra) { \
size_t es = MIN(sz, sizeof(__m##bits##i) - extra); \
uint8_t *ans = aligner; \
if (ans) return ans; \
sz -= es; \
haystack += es; \
if (!sz) return NULL; \
} \
__m##bits##i a_vec = _mm##bits##_set1_epi8(a); \
__m##bits##i b_vec = _mm##bits##_set1_epi8(b); \
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m##bits##i))
#define end_simd2 \
if (mask != 0) { \
size_t pos = __builtin_ctz(mask); \
if (haystack + pos < limit) return haystack + pos; \
}
static uint8_t*
find_either_of_two_bytes_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
const size_t extra = (uintptr_t)haystack % sizeof(__m256i);
if (extra) { // need aligned loads for performance so search first few bytes by hand
size_t es = MIN(sz, sizeof(__m256i) - extra);
uint8_t *ans = (has_sse4_2 && es > 15) ? find_either_of_two_bytes_sse4_2(haystack, es, a, b) : find_either_of_two_bytes_simple(haystack, es, a, b);
if (ans) return ans;
sz -= es;
haystack += es;
if (!sz) return NULL;
}
__m256i x = _mm256_set1_epi8(a);
__m256i y = _mm256_set1_epi8(b);
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m256i)) {
start_simd2(256, (has_sse4_2 && es > 15) ? find_either_of_two_bytes_sse4_2(haystack, es, a, b) : find_either_of_two_bytes_simple(haystack, es, a, b)) {
__m256i chunk = _mm256_load_si256((__m256i*)(haystack));
__m256i x_cmp = _mm256_cmpeq_epi8(chunk, x);
__m256i y_cmp = _mm256_cmpeq_epi8(chunk, y);
__m256i matches = _mm256_or_si256(x_cmp, y_cmp);
__m256i a_cmp = _mm256_cmpeq_epi8(chunk, a_vec);
__m256i b_cmp = _mm256_cmpeq_epi8(chunk, b_vec);
__m256i matches = _mm256_or_si256(a_cmp, b_cmp);
const int mask = _mm256_movemask_epi8(matches);
if (mask != 0) {
// The trailing zeroes in mask give us the position of the first match
size_t pos = __builtin_ctz(mask);
if (haystack + pos < limit) return haystack + pos;
}
end_simd2;
}
return NULL;
}
@@ -208,28 +213,13 @@ find_byte_not_in_range_sse4_2(uint8_t *haystack, const size_t sz, const uint8_t
static uint8_t*
find_byte_not_in_range_avx2(uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
const size_t extra = (uintptr_t)haystack % sizeof(__m256i);
if (extra) { // need aligned loads for performance so search first few bytes by hand
size_t es = MIN(sz, sizeof(__m256i) - extra);
uint8_t *ans = (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b);
if (ans) return ans;
sz -= es;
haystack += es;
if (!sz) return NULL;
}
__m256i lower_bound = _mm256_set1_epi8(a);
__m256i upper_bound = _mm256_set1_epi8(b);
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m256i)) {
start_simd2(256, (has_sse4_2 && extra > 15) ? find_byte_not_in_range_sse4_2(haystack, es, a, b) : find_byte_not_in_range_simple(haystack, es, a, b)) {
__m256i chunk = _mm256_load_si256((__m256i*)(haystack));
__m256i above_lower = _mm256_cmpgt_epi8(chunk, lower_bound);
__m256i below_upper = _mm256_cmpgt_epi8(upper_bound, chunk);
__m256i above_lower = _mm256_cmpgt_epi8(chunk, a_vec);
__m256i below_upper = _mm256_cmpgt_epi8(b_vec, chunk);
__m256i in_range = _mm256_and_si256(above_lower, below_upper);
const int mask = ~_mm256_movemask_epi8(in_range); // ~ as we want not in range
if (mask != 0) {
// The trailing zeroes in mask give us the position of the first match on little endian
const int pos = __builtin_ctz(mask);
if (haystack + pos < limit) return haystack + pos;
}
end_simd2;
}
return NULL;
}