A further 5% speedup for UTF-8 decoding

Achieved by decoding in larger chunks thereby amortizing the cost
of creating various constant vectors over larger chunks.
This commit is contained in:
Kovid Goyal
2024-02-01 17:09:12 +05:30
parent 0bccada9d1
commit 6cdc7ac91d
4 changed files with 201 additions and 165 deletions

View File

@@ -413,187 +413,210 @@ scalar_decode_all(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
}
return pos;
}
#undef do_one_byte
bool
FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src_data, size_t src_len) {
// Based on the algorithm described in: https://woboq.com/blog/utf-8-processing-using-simd.html
d->output.pos = 0; d->num_consumed = 0;
if (d->state.cur != UTF8_ACCEPT) {
// Finish the trailing sequence only
d->num_consumed += scalar_decode_to_accept(d, src, src_sz);
src += d->num_consumed; src_sz -= d->num_consumed;
d->num_consumed = scalar_decode_to_accept(d, src_data, src_len);
src_data += d->num_consumed; src_len -= d->num_consumed;
}
src_sz = MIN(src_sz, sizeof(integer_t));
integer_t vec = load_unaligned((integer_t*)src);
const integer_t esc_vec = set1_epi8(0x1b);
const integer_t esc_cmp = cmpeq_epi8(vec, esc_vec);
const integer_t zero = create_zero_integer(), one = set1_epi8(1), two = set1_epi8(2), three = set1_epi8(3), four = set1_epi8(4);
const integer_t vec_c2 = set1_epi8(0xc2), vec_e3 = set1_epi8(0xe3), vec_f4 = set1_epi8(0xf4);
const uint8_t *limit = src_data + src_len, *p = src_data, *start_of_current_chunk = src_data;
bool sentinel_found = false;
int num_of_bytes_to_first_esc = bytes_to_first_match(esc_cmp);
if (num_of_bytes_to_first_esc > -1 && (unsigned)num_of_bytes_to_first_esc < src_sz) {
sentinel_found = true;
src_sz = num_of_bytes_to_first_esc;
d->num_consumed += src_sz + 1; // esc is also consumed
} else d->num_consumed += src_sz;
// use scalar decode for short input
if (src_sz < 4) {
scalar_decode_all(d, src, src_sz); return sentinel_found;
}
if (src_sz < sizeof(integer_t)) vec = zero_last_n_bytes(vec, sizeof(integer_t) - src_sz);
unsigned chunk_src_sz = 0;
unsigned num_of_trailing_bytes = 0;
bool check_for_trailing_bytes = true;
// Check if we have pure ASCII and use fast path
debug_register(vec);
int32_t ascii_mask;
while (p < limit && !sentinel_found) {
chunk_src_sz = MIN((size_t)(limit - p), sizeof(integer_t));
integer_t vec = load_unaligned((integer_t*)p);
start_of_current_chunk = p;
p += chunk_src_sz;
const integer_t esc_cmp = cmpeq_epi8(vec, esc_vec);
int num_of_bytes_to_first_esc = bytes_to_first_match(esc_cmp);
if (num_of_bytes_to_first_esc > -1 && (unsigned)num_of_bytes_to_first_esc < chunk_src_sz) {
sentinel_found = true;
chunk_src_sz = num_of_bytes_to_first_esc;
d->num_consumed += chunk_src_sz + 1; // esc is also consumed
if (!chunk_src_sz) continue;
} else d->num_consumed += chunk_src_sz;
if (chunk_src_sz < sizeof(integer_t)) vec = zero_last_n_bytes(vec, sizeof(integer_t) - chunk_src_sz);
num_of_trailing_bytes = 0;
bool check_for_trailing_bytes = !sentinel_found;
debug_register(vec);
int32_t ascii_mask;
#define abort_with_invalid_utf8() { \
scalar_decode_all(d, start_of_current_chunk, chunk_src_sz + num_of_trailing_bytes); \
d->num_consumed += num_of_trailing_bytes; \
break; \
}
#define handle_trailing_bytes() if (num_of_trailing_bytes) { \
if (p >= limit) { \
scalar_decode_all(d, p - num_of_trailing_bytes, num_of_trailing_bytes); \
d->num_consumed += num_of_trailing_bytes; \
break; \
} \
p -= num_of_trailing_bytes; \
}
start_classification:
ascii_mask = movemask_epi8(vec);
if (!ascii_mask) { // no bytes with high bit (0x80) set, so just plain ASCII
FUNC(output_plain_ascii)(d, vec, src_sz);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
return sentinel_found;
}
// Classify the bytes
integer_t state = set1_epi8(0x80);
const integer_t vec_signed = add_epi8(vec, state); // needed because cmplt_epi8 works only on signed chars
// Check if we have pure ASCII and use fast path
ascii_mask = movemask_epi8(vec);
if (!ascii_mask) { // no bytes with high bit (0x80) set, so just plain ASCII
FUNC(output_plain_ascii)(d, vec, chunk_src_sz);
handle_trailing_bytes();
continue;
}
// Classify the bytes
integer_t state = set1_epi8(0x80);
const integer_t vec_signed = add_epi8(vec, state); // needed because cmplt_epi8 works only on signed chars
const integer_t bytes_indicating_start_of_two_byte_sequence = cmplt_epi8(set1_epi8(0xc0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state, set1_epi8(0xc2), bytes_indicating_start_of_two_byte_sequence);
// state now has 0xc2 on all bytes that start a 2 or more byte sequence and 0x80 on the rest
const integer_t bytes_indicating_start_of_three_byte_sequence = cmplt_epi8(set1_epi8(0xe0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state, set1_epi8(0xe3), bytes_indicating_start_of_three_byte_sequence);
const integer_t bytes_indicating_start_of_four_byte_sequence = cmplt_epi8(set1_epi8(0xf0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state, set1_epi8(0xf4), bytes_indicating_start_of_four_byte_sequence);
// state now has 0xc2 on all bytes that start a 2 byte sequence, 0xe3 on start of 3-byte sequence, 0xf4 on 4-byte start and 0x80 on rest
debug_register(state);
integer_t mask = and_si(state, set1_epi8(0xf8)); // keep upper 5 bits of state
debug_register(mask);
integer_t count = and_si(state, set1_epi8(0x7)); // keep lower 3 bits of state
debug_register(count);
const integer_t zero = create_zero_integer(), one = set1_epi8(1), two = set1_epi8(2), three = set1_epi8(3);
// count contains the number of bytes in the sequence for the start byte of every sequence and zero elsewhere
// shift 02 bytes by 1 and subtract 1
integer_t count_subs1 = subtract_saturate_epu8(count, one);
integer_t counts = add_epi8(count, shift_right_by_one_byte(count_subs1));
// shift 03 and 04 bytes by 2 and subtract 2
counts = add_epi8(counts, shift_right_by_two_bytes(subtract_saturate_epu8(counts, two)));
// counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes
debug_register(counts);
// check for an incomplete trailing utf8 sequence
if (check_for_trailing_bytes && !is_zero(cmplt_epi8(one, and_si(counts, cmpeq_epi8(numbered_bytes(), set1_epi8(src_sz - 1)))))) {
// The value of counts at the last byte is > 1 indicating we have a trailing incomplete sequence
check_for_trailing_bytes = false;
if (src[src_sz-1] >= 0xc0) num_of_trailing_bytes = 1; // 2-, 3- and 4-byte characters with only 1 byte left
else if (src_sz > 1 && src[src_sz-2] >= 0xe0) num_of_trailing_bytes = 2; // 3- and 4-byte characters with only 1 byte left
else if (src_sz > 2 && src[src_sz-3] >= 0xf0) num_of_trailing_bytes = 3; // 4-byte characters with only 3 bytes left
src_sz -= num_of_trailing_bytes;
vec = zero_last_n_bytes(vec, sizeof(integer_t) - src_sz);
goto start_classification;
}
// Only ASCII chars should have corresponding byte of counts == 0
if (ascii_mask != movemask_epi8(cmpgt_epi8(counts, zero))) goto invalid_utf8;
// The difference between a byte in counts and the next one should be negative,
// zero, or one. Any other value means there is not enough continuation bytes.
if (!is_zero(cmpgt_epi8(subtract_epi8(shift_right_by_one_byte(counts), counts), one))) goto invalid_utf8;
const integer_t bytes_indicating_start_of_two_byte_sequence = cmplt_epi8(set1_epi8(0xc0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state, vec_c2, bytes_indicating_start_of_two_byte_sequence);
// state now has 0xc2 on all bytes that start a 2 or more byte sequence and 0x80 on the rest
const integer_t bytes_indicating_start_of_three_byte_sequence = cmplt_epi8(set1_epi8(0xe0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state, vec_e3, bytes_indicating_start_of_three_byte_sequence);
const integer_t bytes_indicating_start_of_four_byte_sequence = cmplt_epi8(set1_epi8(0xf0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state, vec_f4, bytes_indicating_start_of_four_byte_sequence);
// state now has 0xc2 on all bytes that start a 2 byte sequence, 0xe3 on start of 3-byte, 0xf4 on 4-byte start and 0x80 on rest
debug_register(state);
integer_t mask = and_si(state, set1_epi8(0xf8)); // keep upper 5 bits of state
debug_register(mask);
integer_t count = and_si(state, set1_epi8(0x7)); // keep lower 3 bits of state
debug_register(count);
// count contains the number of bytes in the sequence for the start byte of every sequence and zero elsewhere
// shift 02 bytes by 1 and subtract 1
integer_t count_subs1 = subtract_saturate_epu8(count, one);
integer_t counts = add_epi8(count, shift_right_by_one_byte(count_subs1));
// shift 03 and 04 bytes by 2 and subtract 2
counts = add_epi8(counts, shift_right_by_two_bytes(subtract_saturate_epu8(counts, two)));
// counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes
debug_register(counts);
// check for an incomplete trailing utf8 sequence
if (check_for_trailing_bytes && !is_zero(cmplt_epi8(one, and_si(counts, cmpeq_epi8(numbered_bytes(), set1_epi8(chunk_src_sz - 1)))))) {
// The value of counts at the last byte is > 1 indicating we have a trailing incomplete sequence
check_for_trailing_bytes = false;
if (start_of_current_chunk[chunk_src_sz-1] >= 0xc0) num_of_trailing_bytes = 1; // 2-, 3- and 4-byte characters with only 1 byte left
else if (chunk_src_sz > 1 && start_of_current_chunk[chunk_src_sz-2] >= 0xe0) num_of_trailing_bytes = 2; // 3- and 4-byte characters with only 1 byte left
else if (chunk_src_sz > 2 && start_of_current_chunk[chunk_src_sz-3] >= 0xf0) num_of_trailing_bytes = 3; // 4-byte characters with only 3 bytes left
chunk_src_sz -= num_of_trailing_bytes;
d->num_consumed -= num_of_trailing_bytes;
if (!chunk_src_sz) { abort_with_invalid_utf8(); }
vec = zero_last_n_bytes(vec, sizeof(integer_t) - chunk_src_sz);
goto start_classification;
}
// Only ASCII chars should have corresponding byte of counts == 0
if (ascii_mask != movemask_epi8(cmpgt_epi8(counts, zero))) { abort_with_invalid_utf8(); }
// The difference between a byte in counts and the next one should be negative,
// zero, or one. Any other value means there is not enough continuation bytes.
if (!is_zero(cmpgt_epi8(subtract_epi8(shift_right_by_one_byte(counts), counts), one))) { abort_with_invalid_utf8(); }
// Process the bytes storing the three resulting bytes that make up the unicode codepoint
// mask all control bits so that we have only useful bits left
vec = andnot_si(mask, vec);
debug_register(vec);
// Process the bytes storing the three resulting bytes that make up the unicode codepoint
// mask all control bits so that we have only useful bits left
vec = andnot_si(mask, vec);
debug_register(vec);
// Now calculate the three output vectors
// Now calculate the three output vectors
// The lowest byte is made up of 6 bits from locations with counts == 1 and the lowest two bits from locations with count == 2
// In addition, the ASCII bytes are copied unchanged from vec
integer_t vec_non_ascii = andnot_si(cmpeq_epi8(counts, zero), vec);
debug_register(vec_non_ascii);
integer_t output1 = blendv_epi8(vec,
or_si(
// there are no count == 1 locations without a count == 2 location to its left so we dont need to AND with count2_locations
vec, and_si(shift_left_by_bits16(shift_right_by_one_byte(vec_non_ascii), 6), set1_epi8(0xc0))
),
cmpeq_epi8(counts, one)
);
debug_register(output1);
// The lowest byte is made up of 6 bits from locations with counts == 1 and the lowest two bits from locations with count == 2
// In addition, the ASCII bytes are copied unchanged from vec
integer_t vec_non_ascii = andnot_si(cmpeq_epi8(counts, zero), vec);
debug_register(vec_non_ascii);
integer_t output1 = blendv_epi8(vec,
or_si(
// there are no count == 1 locations without a count == 2 location to its left so we dont need to AND with count2_locations
vec, and_si(shift_left_by_bits16(shift_right_by_one_byte(vec_non_ascii), 6), set1_epi8(0xc0))
),
cmpeq_epi8(counts, one)
);
debug_register(output1);
// The next byte is made up of 4 bits (5, 4, 3, 2) from locations with count == 2 and the first 4 bits from locations with count == 3
integer_t count2_locations = cmpeq_epi8(counts, two), count3_locations = cmpeq_epi8(counts, three);
integer_t output2 = and_si(vec, count2_locations);
output2 = shift_right_by_bits32(output2, 2); // selects the bits 5, 4, 3, 2
// select the first 4 bits from locs with count == 3 by shifting count 3 locations right by one byte and left by 4 bits
output2 = or_si(output2,
and_si(set1_epi8(0xf0),
shift_left_by_bits16(shift_right_by_one_byte(and_si(count3_locations, vec_non_ascii)), 4)
)
);
output2 = and_si(output2, count2_locations); // keep only the count2 bytes
output2 = shift_right_by_one_byte(output2);
debug_register(output2);
// The next byte is made up of 4 bits (5, 4, 3, 2) from locations with count == 2 and the first 4 bits from locations with count == 3
integer_t count2_locations = cmpeq_epi8(counts, two), count3_locations = cmpeq_epi8(counts, three);
integer_t output2 = and_si(vec, count2_locations);
output2 = shift_right_by_bits32(output2, 2); // selects the bits 5, 4, 3, 2
// select the first 4 bits from locs with count == 3 by shifting count 3 locations right by one byte and left by 4 bits
output2 = or_si(output2,
and_si(set1_epi8(0xf0),
shift_left_by_bits16(shift_right_by_one_byte(and_si(count3_locations, vec_non_ascii)), 4)
)
);
output2 = and_si(output2, count2_locations); // keep only the count2 bytes
output2 = shift_right_by_one_byte(output2);
debug_register(output2);
// The last byte is made up of bits 5 and 6 from count == 3 and 3 bits from count == 4
integer_t output3 = and_si(three, shift_right_by_bits32(vec, 4)); // bits 5 and 6 from count == 3
integer_t count4_locations = cmpeq_epi8(counts, set1_epi8(4));
// 3 bits from count == 4 locations, placed at count == 3 locations shifted left by 2 bits
output3 = or_si(output3,
and_si(set1_epi8(0xfc),
shift_left_by_bits16(shift_right_by_one_byte(and_si(count4_locations, vec_non_ascii)), 2)
)
);
output3 = and_si(output3, count3_locations); // keep only count3 bytes
output3 = shift_right_by_two_bytes(output3);
debug_register(output3);
// The last byte is made up of bits 5 and 6 from count == 3 and 3 bits from count == 4
integer_t output3 = and_si(three, shift_right_by_bits32(vec, 4)); // bits 5 and 6 from count == 3
integer_t count4_locations = cmpeq_epi8(counts, four);
// 3 bits from count == 4 locations, placed at count == 3 locations shifted left by 2 bits
output3 = or_si(output3,
and_si(set1_epi8(0xfc),
shift_left_by_bits16(shift_right_by_one_byte(and_si(count4_locations, vec_non_ascii)), 2)
)
);
output3 = and_si(output3, count3_locations); // keep only count3 bytes
output3 = shift_right_by_two_bytes(output3);
debug_register(output3);
// Shuffle bytes to remove continuation bytes
integer_t shifts = count_subs1; // number of bytes we need to skip for each UTF-8 sequence
// propagate the shifts to all subsequent bytes by shift and add
shifts = add_epi8(shifts, shift_right_by_one_byte(shifts));
shifts = add_epi8(shifts, shift_right_by_two_bytes(shifts));
shifts = add_epi8(shifts, shift_right_by_four_bytes(shifts));
shifts = add_epi8(shifts, shift_right_by_eight_bytes(shifts));
// Shuffle bytes to remove continuation bytes
integer_t shifts = count_subs1; // number of bytes we need to skip for each UTF-8 sequence
// propagate the shifts to all subsequent bytes by shift and add
shifts = add_epi8(shifts, shift_right_by_one_byte(shifts));
shifts = add_epi8(shifts, shift_right_by_two_bytes(shifts));
shifts = add_epi8(shifts, shift_right_by_four_bytes(shifts));
shifts = add_epi8(shifts, shift_right_by_eight_bytes(shifts));
#if KITTY_SIMD_LEVEL == 256
shifts = add_epi8(shifts, shift_right_by_sixteen_bytes(shifts));
shifts = add_epi8(shifts, shift_right_by_sixteen_bytes(shifts));
#endif
// zero the shifts for discarded continuation bytes
shifts = and_si(shifts, cmplt_epi8(counts, two));
// now we need to convert shifts into a mask for the shuffle. The mask has each byte of the
// form 0000xxxx the lower four bits indicating the destination location for the byte. For 256 bit shuffle we use lower 5 bits.
// First we move the numbers in shifts to discard the unwanted UTF-8 sequence bytes. We note that the numbers
// are bounded by sizeof(integer_t) and so we need at most 4 (for 128 bit) or 5 (for 256 bit) moves. The numbers are
// monotonic from left to right and change value only at the end of a UTF-8 sequence. We move them leftwards, accumulating the
// moves bit-by-bit.
// zero the shifts for discarded continuation bytes
shifts = and_si(shifts, cmplt_epi8(counts, two));
// now we need to convert shifts into a mask for the shuffle. The mask has each byte of the
// form 0000xxxx the lower four bits indicating the destination location for the byte. For 256 bit shuffle we use lower 5 bits.
// First we move the numbers in shifts to discard the unwanted UTF-8 sequence bytes. We note that the numbers
// are bounded by sizeof(integer_t) and so we need at most 4 (for 128 bit) or 5 (for 256 bit) moves. The numbers are
// monotonic from left to right and change value only at the end of a UTF-8 sequence. We move them leftwards, accumulating the
// moves bit-by-bit.
#define move(shifts, amt, which_bit) blendv_epi8(shifts, shift_left_by_##amt(shifts), shift_left_by_##amt(shift_left_by_bits16(shifts, 8 - which_bit)))
shifts = move(shifts, one_byte, 1);
shifts = move(shifts, two_bytes, 2);
shifts = move(shifts, four_bytes, 3);
shifts = move(shifts, eight_bytes, 4);
shifts = move(shifts, one_byte, 1);
shifts = move(shifts, two_bytes, 2);
shifts = move(shifts, four_bytes, 3);
shifts = move(shifts, eight_bytes, 4);
#if KITTY_SIMD_LEVEL == 256
shifts = move(shifts, sixteen_bytes, 5);
shifts = move(shifts, sixteen_bytes, 5);
#endif
#undef move
// convert the shifts into a suitable mask for shuffle by adding the byte number to each byte
shifts = add_epi8(shifts, numbered_bytes());
debug_register(shifts);
// convert the shifts into a suitable mask for shuffle by adding the byte number to each byte
shifts = add_epi8(shifts, numbered_bytes());
debug_register(shifts);
output1 = shuffle_epi8(output1, shifts);
output2 = shuffle_epi8(output2, shifts);
output3 = shuffle_epi8(output3, shifts);
debug_register(output1);
debug_register(output2);
debug_register(output3);
output1 = shuffle_epi8(output1, shifts);
output2 = shuffle_epi8(output2, shifts);
output3 = shuffle_epi8(output3, shifts);
debug_register(output1);
debug_register(output2);
debug_register(output3);
const unsigned num_of_discarded_bytes = sum_bytes(count_subs1);
const unsigned num_codepoints = src_sz - num_of_discarded_bytes;
debug("num_of_discarded_bytes: %u num_codepoints: %u\n", num_of_discarded_bytes, num_codepoints);
FUNC(output_unicode)(d, output1, output2, output3, num_codepoints);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
return sentinel_found;
invalid_utf8:
scalar_decode_all(d, src, src_sz + num_of_trailing_bytes);
const unsigned num_of_discarded_bytes = sum_bytes(count_subs1);
const unsigned num_codepoints = chunk_src_sz - num_of_discarded_bytes;
debug("num_of_discarded_bytes: %u num_codepoints: %u\n", num_of_discarded_bytes, num_codepoints);
FUNC(output_unicode)(d, output1, output2, output3, num_codepoints);
handle_trailing_bytes();
}
return sentinel_found;
#undef abort_with_invalid_utf8
#undef handle_trailing_bytes
}