From b83e4d88f45cf04c525b129fe0dea8a1a2b573e4 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Wed, 28 Jun 2023 20:52:35 +0530 Subject: [PATCH] Implement fast padding-less base64 encode/decode for python --- gen-apc-parsers.py | 5 +- kittens/transfer/ftc.go | 4 +- kitty/base64.h | 125 +++++++++++++++++++++++++++++++++ kitty/charsets.c | 33 ++------- kitty/data-types.c | 29 ++++++++ kitty/data-types.h | 1 - kitty/fast_data_types.pyi | 2 + kitty/file_transmission.py | 20 ++++-- kitty/parse-graphics-command.h | 8 +-- kitty/parser.c | 2 + kitty_tests/parser.py | 12 +++- 11 files changed, 197 insertions(+), 44 deletions(-) create mode 100644 kitty/base64.h diff --git a/gen-apc-parsers.py b/gen-apc-parsers.py index bf98c53a9..1df9beb07 100755 --- a/gen-apc-parsers.py +++ b/gen-apc-parsers.py @@ -111,8 +111,9 @@ def generate( payload_case = f''' case PAYLOAD: {{ sz = screen->parser_buf_pos - pos; - const char *err = base64_decode(screen->parser_buf + pos, sz, payload, sizeof(payload), &g.payload_sz); - if (err != NULL) {{ REPORT_ERROR("Failed to parse {command_class} command payload with error: %s", err); return; }} + g.payload_sz = sizeof(payload); + if (!base64_decode32(screen->parser_buf + pos, sz, payload, &g.payload_sz)) {{ + REPORT_ERROR("Failed to parse {command_class} command payload with error: %s", "output buffer for base64_decode too small"); return; }} pos = screen->parser_buf_pos; }} break; diff --git a/kittens/transfer/ftc.go b/kittens/transfer/ftc.go index 668ef6bae..7166ed303 100644 --- a/kittens/transfer/ftc.go +++ b/kittens/transfer/ftc.go @@ -249,7 +249,7 @@ func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand case reflect.String: switch field.Tag.Get("encoding") { case "base64": - b, err := base64.StdEncoding.DecodeString(serialized_val) + b, err := base64.RawStdEncoding.DecodeString(serialized_val) if err != nil { return fmt.Errorf("The field %#v has invalid base64 encoded value with error: %w", key, err) } @@ -260,7 +260,7 @@ func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand case reflect.Slice: switch val.Type().Elem().Kind() { case reflect.Uint8: - b, err := base64.StdEncoding.DecodeString(serialized_val) + b, err := base64.RawStdEncoding.DecodeString(serialized_val) if err != nil { return fmt.Errorf("The field %#v has invalid base64 encoded value with error: %w", key, err) } diff --git a/kitty/base64.h b/kitty/base64.h new file mode 100644 index 000000000..7417fbed4 --- /dev/null +++ b/kitty/base64.h @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2023 Kovid Goyal + * + * Distributed under terms of the GPL3 license. + */ + +#include +#include +#include + +#ifndef B64_INPUT_BITSIZE +#define B64_INPUT_BITSIZE 8 +#endif + +#if B64_INPUT_BITSIZE == 8 +#define INPUT_T uint8_t +#define inner_func base64_decode_inner8 +#define decode_func base64_decode8 +#define encode_func base64_encode8 +#else +#define INPUT_T uint32_t +#define inner_func base64_decode_inner32 +#define decode_func base64_decode32 +#define encode_func base64_encode32 +#endif + +bool decode_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, size_t *dest_sz); +bool encode_func(const unsigned char *src, size_t src_len, unsigned char *out, size_t *out_len, bool add_padding); +#ifndef B64_INCLUDED_ONCE +static inline size_t required_buffer_size_for_base64_decode(size_t src_sz) { return (src_sz / 4) * 3 + 4; } +static inline size_t required_buffer_size_for_base64_encode(size_t src_sz) { return (src_sz / 3) * 4 + 5; } +#endif + +#ifndef B64_INCLUDED_ONCE +#define B64_INCLUDED_ONCE +#endif + +#ifdef INCLUDE_BASE64_DEFINITIONS +#if B64_INPUT_BITSIZE == 8 +// standard decoding using + and / with = being the padding character +static uint8_t b64_decoding_table[256] = { +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 62, 0, 0, 0, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 +}; +#endif + +static void +inner_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, const size_t dest_sz) { + for (size_t i = 0, j = 0; i < src_sz;) { + uint32_t sextet_a = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; + uint32_t sextet_b = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; + uint32_t sextet_c = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; + uint32_t sextet_d = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; + uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); + + if (j < dest_sz) dest[j++] = (triple >> 2 * 8) & 0xFF; + if (j < dest_sz) dest[j++] = (triple >> 1 * 8) & 0xFF; + if (j < dest_sz) dest[j++] = (triple >> 0 * 8) & 0xFF; + } +} + +bool +decode_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, size_t *dest_sz) { + while (src_sz && src[src_sz-1] == '=') src_sz--; // remove trailing padding + if (!src_sz) { *dest_sz = 0; return true; } + const size_t dest_capacity = *dest_sz; + size_t extra = src_sz % 4; + src_sz -= extra; + *dest_sz = (src_sz / 4) * 3; + if (*dest_sz > dest_capacity) return false; + if (src_sz) inner_func(src, src_sz, dest, *dest_sz); + if (extra > 1) { + INPUT_T buf[4] = {0}; + for (size_t i = 0; i < extra; i++) buf[i] = src[src_sz+i]; + dest += *dest_sz; + *dest_sz += extra - 1; + if (*dest_sz > dest_capacity) return false; + inner_func(buf, extra, dest, extra-1); + } + if (*dest_sz + 1 > dest_capacity) return false; + dest[*dest_sz] = 0; // ensure zero-terminated + return true; +} + +#if B64_INPUT_BITSIZE == 8 +static const unsigned char base64_table[65] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +#endif + +bool +encode_func(const unsigned char *src, size_t src_len, unsigned char *out, size_t *out_len, bool add_padding) { + size_t required_len = required_buffer_size_for_base64_encode(src_len); + if (*out_len < required_len) return false; + + const unsigned char *end = src + src_len, *in = src; + unsigned char *pos = out; + while (end - in >= 3) { + *pos++ = base64_table[in[0] >> 2]; + *pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)]; + *pos++ = base64_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)]; + *pos++ = base64_table[in[2] & 0x3f]; + in += 3; + } + + if (end - in) { + *pos++ = base64_table[in[0] >> 2]; + if (end - in == 1) { + *pos++ = base64_table[(in[0] & 0x03) << 4]; + if (add_padding) *pos++ = '='; + } else { + *pos++ = base64_table[((in[0] & 0x03) << 4) | + (in[1] >> 4)]; + *pos++ = base64_table[(in[1] & 0x0f) << 2]; + } + if (add_padding) *pos++ = '='; + } + *pos = '\0'; + *out_len = pos - out; + return true; +} +#undef encode_func +#undef decode_func +#undef inner_func +#undef INPUT_T +#undef B64_INPUT_BITSIZE +#endif diff --git a/kitty/charsets.c b/kitty/charsets.c index ddcd80df6..9dbf664e0 100644 --- a/kitty/charsets.c +++ b/kitty/charsets.c @@ -290,31 +290,8 @@ encode_utf8(uint32_t ch, char* dest) { // Base64 -// standard decoding using + and / with = being the padding character -static uint8_t b64_decoding_table[256] = { -0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 62, 0, 0, 0, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 -}; - -static char length_error[96] = {0}; - -const char* -base64_decode(const uint32_t *src, size_t src_sz, uint8_t *dest, size_t dest_capacity, size_t *dest_sz) { - if (!src_sz) { *dest_sz = 0; return NULL; } - if (src_sz % 4 != 0) { snprintf(length_error, sizeof(length_error)-1, "base64 encoded data must have a length that is a multiple of four not: %zd", src_sz); return length_error;} - *dest_sz = (src_sz / 4) * 3; - if (src[src_sz - 1] == '=') (*dest_sz)--; - if (src[src_sz - 2] == '=') (*dest_sz)--; - if (*dest_sz > dest_capacity) return "output buffer too small"; - for (size_t i = 0, j = 0; i < src_sz;) { - uint32_t sextet_a = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; - uint32_t sextet_b = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; - uint32_t sextet_c = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; - uint32_t sextet_d = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; - uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); - - if (j < *dest_sz) dest[j++] = (triple >> 2 * 8) & 0xFF; - if (j < *dest_sz) dest[j++] = (triple >> 1 * 8) & 0xFF; - if (j < *dest_sz) dest[j++] = (triple >> 0 * 8) & 0xFF; - } - return NULL; -} +#define B64_INPUT_BITSIZE 8 +#define INCLUDE_BASE64_DEFINITIONS +#include "base64.h" +#define B64_INPUT_BITSIZE 32 +#include "base64.h" diff --git a/kitty/data-types.c b/kitty/data-types.c index 46a9542b5..a29b80832 100644 --- a/kitty/data-types.c +++ b/kitty/data-types.c @@ -13,6 +13,7 @@ #endif #include "data-types.h" +#include "base64.h" #include #include #include @@ -75,6 +76,32 @@ redirect_std_streams(PyObject UNUSED *self, PyObject *args) { Py_RETURN_NONE; } +static PyObject* +pybase64_encode(PyObject UNUSED *self, PyObject *args) { + int add_padding = 0; + const char *src; Py_ssize_t src_len; + if (!PyArg_ParseTuple(args, "y#|p", &src, &src_len, &add_padding)) return NULL; + size_t sz = required_buffer_size_for_base64_encode(src_len); + PyObject *ans = PyBytes_FromStringAndSize(NULL, sz); + if (!ans) return NULL; + base64_encode8((const unsigned char*)src, src_len, (unsigned char*)PyBytes_AS_STRING(ans), &sz, add_padding); + if (_PyBytes_Resize(&ans, sz) != 0) return NULL; + return ans; +} + +static PyObject* +pybase64_decode(PyObject UNUSED *self, PyObject *args) { + const char *src; Py_ssize_t src_len; + if (!PyArg_ParseTuple(args, "y#", &src, &src_len)) return NULL; + size_t sz = required_buffer_size_for_base64_decode(src_len); + PyObject *ans = PyBytes_FromStringAndSize(NULL, sz); + if (!ans) return NULL; + base64_decode8((const unsigned char*)src, src_len, (unsigned char*)PyBytes_AS_STRING(ans), &sz); + if (_PyBytes_Resize(&ans, sz) != 0) return NULL; + return ans; +} + + static PyObject* pyset_iutf8(PyObject UNUSED *self, PyObject *args) { int fd, on; @@ -306,6 +333,8 @@ static PyMethodDef module_methods[] = { {"raw_tty", raw_tty, METH_VARARGS, ""}, {"close_tty", close_tty, METH_VARARGS, ""}, {"set_iutf8_fd", (PyCFunction)pyset_iutf8, METH_VARARGS, ""}, + {"base64_encode", (PyCFunction)pybase64_encode, METH_VARARGS, ""}, + {"base64_decode", (PyCFunction)pybase64_decode, METH_VARARGS, ""}, {"thread_write", (PyCFunction)cm_thread_write, METH_VARARGS, ""}, {"parse_bytes", (PyCFunction)parse_bytes, METH_VARARGS, ""}, {"parse_bytes_dump", (PyCFunction)parse_bytes_dump, METH_VARARGS, ""}, diff --git a/kitty/data-types.h b/kitty/data-types.h index 44b31e0b0..3a99a6355 100644 --- a/kitty/data-types.h +++ b/kitty/data-types.h @@ -346,7 +346,6 @@ attrs_to_cursor(const CellAttrs attrs, Cursor *c) { // Global functions -const char* base64_decode(const uint32_t *src, size_t src_sz, uint8_t *dest, size_t dest_capacity, size_t *dest_sz); Line* alloc_line(void); Cursor* alloc_cursor(void); LineBuf* alloc_linebuf(unsigned int, unsigned int); diff --git a/kitty/fast_data_types.pyi b/kitty/fast_data_types.pyi index 3dd584199..5f0b37a45 100644 --- a/kitty/fast_data_types.pyi +++ b/kitty/fast_data_types.pyi @@ -1534,3 +1534,5 @@ def expand_ansi_c_escapes(test: str) -> str: ... def update_tab_bar_edge_colors(os_window_id: int) -> bool: ... def mask_kitty_signals_process_wide() -> None: ... def is_modifier_key(key: int) -> bool: ... +def base64_encode(src: bytes, add_padding: bool = False) -> bytes: ... +def base64_decode(src: bytes) -> bytes: ... diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index 97b781e62..135086669 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -6,7 +6,7 @@ import os import re import stat import tempfile -from base64 import standard_b64decode, standard_b64encode +from base64 import standard_b64decode from collections import defaultdict, deque from contextlib import suppress from dataclasses import Field, dataclass, field, fields @@ -19,7 +19,7 @@ from typing import IO, Any, Callable, DefaultDict, Deque, Dict, Iterable, Iterat from kittens.transfer.librsync import LoadSignature, PatchFile, delta_for_file, signature_of_file from kittens.transfer.utils import IdentityCompressor, ZlibCompressor, abspath, expand_home, home_path -from kitty.fast_data_types import FILE_TRANSFER_CODE, OSC, add_timer, get_boss, get_options +from kitty.fast_data_types import FILE_TRANSFER_CODE, OSC, add_timer, base64_encode, get_boss, get_options from kitty.types import run_once from .utils import log_error @@ -247,6 +247,14 @@ def serialized_to_field_map() -> Dict[bytes, 'Field[Any]']: return ans +def b64decode(val: memoryview) -> bytes: + extra = len(val) % 4 + if extra != 0: + padding = b'=' * (4 - extra) + val = memoryview(bytes(val) + padding) + return standard_b64decode(val) + + @dataclass class FileTransmissionCommand: @@ -308,10 +316,10 @@ class FileTransmissionCommand: if issubclass(k.type, Enum): yield val.name elif k.type is bytes: - yield standard_b64encode(val) + yield base64_encode(val) elif k.type is str: if k.metadata.get('base64'): - yield standard_b64encode(val.encode('utf-8')) + yield base64_encode(val.encode('utf-8')) else: yield safe_string(val) elif k.type is int: @@ -335,12 +343,12 @@ class FileTransmissionCommand: if issubclass(field.type, Enum): setattr(ans, field.name, field.type[decode_utf8_buffer(val)]) elif field.type is bytes: - setattr(ans, field.name, standard_b64decode(val)) + setattr(ans, field.name, b64decode(val)) elif field.type is int: setattr(ans, field.name, int(val)) elif field.type is str: if field.metadata.get('base64'): - sval = standard_b64decode(val).decode('utf-8') + sval = b64decode(val).decode('utf-8') else: sval = safe_string(decode_utf8_buffer(val)) setattr(ans, field.name, safe_string(sval)) diff --git a/kitty/parse-graphics-command.h b/kitty/parse-graphics-command.h index d28a41fd8..62b55a44a 100644 --- a/kitty/parse-graphics-command.h +++ b/kitty/parse-graphics-command.h @@ -299,12 +299,12 @@ static inline void parse_graphics_code(Screen *screen, case PAYLOAD: { sz = screen->parser_buf_pos - pos; - const char *err = base64_decode(screen->parser_buf + pos, sz, payload, - sizeof(payload), &g.payload_sz); - if (err != NULL) { + g.payload_sz = sizeof(payload); + if (!base64_decode32(screen->parser_buf + pos, sz, payload, + &g.payload_sz)) { REPORT_ERROR( "Failed to parse GraphicsCommand command payload with error: %s", - err); + "output buffer for base64_decode too small"); return; } pos = screen->parser_buf_pos; diff --git a/kitty/parser.c b/kitty/parser.c index 140bcf5c1..85f6f12c3 100644 --- a/kitty/parser.c +++ b/kitty/parser.c @@ -8,6 +8,8 @@ #define _POSIX_C_SOURCE 200809L #include "data-types.h" +#define B64_INPUT_BITSIZE 32 +#include "base64.h" #include "control-codes.h" #include "screen.h" #include "graphics.h" diff --git a/kitty_tests/parser.py b/kitty_tests/parser.py index c6d3486d9..330acc85f 100644 --- a/kitty_tests/parser.py +++ b/kitty_tests/parser.py @@ -6,7 +6,7 @@ from base64 import standard_b64encode from binascii import hexlify from functools import partial -from kitty.fast_data_types import CURSOR_BLOCK, parse_bytes, parse_bytes_dump +from kitty.fast_data_types import CURSOR_BLOCK, parse_bytes, parse_bytes_dump, base64_decode, base64_encode from kitty.notify import NotificationCommand, handle_notification_cmd, notification_activated, reset_registry from . import BaseTest @@ -41,6 +41,16 @@ class TestParser(BaseTest): q.append(('draw', current)) self.ae(tuple(q), cmds) + def test_base64(self): + for src, expected in { + 'bGlnaHQgdw==': 'light w', + 'bGlnaHQgd28=': 'light wo', + 'bGlnaHQgd29y': 'light wor', + }.items(): + self.ae(base64_decode(src.encode()), expected.encode(), f'Decoding of {src} failed') + self.ae(base64_decode(src.replace('=', '').encode()), expected.encode(), f'Decoding of {src} failed') + self.ae(base64_encode(expected.encode()), src.replace('=', '').encode(), f'Encoding of {expected} failed') + def test_simple_parsing(self): s = self.create_screen() pb = partial(self.parse_bytes_dump, s)