Switch to SIMD accelerated base64 decoding for clipboard and notification requests

This commit is contained in:
Kovid Goyal
2024-07-26 14:31:52 +05:30
parent ea112a6592
commit 3d0747e713
6 changed files with 188 additions and 79 deletions

View File

@@ -14,6 +14,7 @@ from .fast_data_types import (
ESC_OSC,
GLFW_CLIPBOARD,
GLFW_PRIMARY_SELECTION,
StreamingBase64Decoder,
find_in_memoryview,
get_boss,
get_clipboard_mime,
@@ -236,6 +237,7 @@ class WriteRequest:
self, is_primary_selection: bool = False, protocol_type: ProtocolType = ProtocolType.osc_52, id: str = '',
rollover_size: int = 16 * 1024 * 1024, max_size: int = -1,
) -> None:
self.decoder = StreamingBase64Decoder(8 * 1024)
self.id = id
self.is_primary_selection = is_primary_selection
self.protocol_type = protocol_type
@@ -243,7 +245,6 @@ class WriteRequest:
self.tempfile = Tempfile(max_size=rollover_size)
self.mime_map: Dict[str, MimePos] = {}
self.currently_writing_mime = ''
self.current_leftover_bytes = memoryview(b'')
self.max_size = (get_options().clipboard_max_size * 1024 * 1024) if max_size < 0 else max_size
self.aliases: Dict[str, str] = {}
self.committed = False
@@ -276,51 +277,29 @@ class WriteRequest:
if not self.currently_writing_mime:
self.mime_map[mime] = MimePos(self.tempfile.tell(), -1)
self.currently_writing_mime = mime
self.write_base64_data(data)
def write_saving_leftover_bytes(data: bytes) -> None:
if len(data) == 0:
return
extra = len(data) % 4
if extra > 0:
mv = memoryview(data)
self.current_leftover_bytes = memoryview(bytes(mv[-extra:]))
mv = mv[:-extra]
if len(mv) > 0:
self.write_base64_data(mv)
else:
self.write_base64_data(data)
if len(self.current_leftover_bytes) > 0:
extra = 4 - len(self.current_leftover_bytes)
if len(data) >= extra:
self.write_base64_data(memoryview(bytes(self.current_leftover_bytes) + data[:extra]))
self.current_leftover_bytes = memoryview(b'')
data = memoryview(data)[extra:]
write_saving_leftover_bytes(data)
else:
self.current_leftover_bytes = memoryview(bytes(self.current_leftover_bytes) + data)
else:
write_saving_leftover_bytes(data)
@property
def current_leftover_bytes(self) -> memoryview:
return self.decoder.leftover_bytes()
def flush_base64_data(self) -> None:
if self.currently_writing_mime:
b = self.current_leftover_bytes
padding = 4 - len(b)
if padding in (1, 2):
self.write_base64_data(memoryview(bytes(b) + b'=' * padding))
self.decoder.flush()
if len(self.decoder):
self.write_base64_data(b'')
start = self.mime_map[self.currently_writing_mime][0]
self.mime_map[self.currently_writing_mime] = MimePos(start, self.tempfile.tell() - start)
self.currently_writing_mime = ''
self.current_leftover_bytes = memoryview(b'')
def write_base64_data(self, b: bytes) -> None:
from base64 import standard_b64decode
if not self.max_size_exceeded:
d = standard_b64decode(b)
self.tempfile.write(d)
if self.max_size > 0 and self.tempfile.tell() > (self.max_size * 1024 * 1024):
log_error(f'Clipboard write request has more data than allowed by clipboard_max_size ({self.max_size}), truncating')
self.max_size_exceeded = True
self.decoder.add(b)
if len(self.decoder):
self.tempfile.write(self.decoder.take_output())
if self.max_size > 0 and self.tempfile.tell() > (self.max_size * 1024 * 1024):
log_error(f'Clipboard write request has more data than allowed by clipboard_max_size ({self.max_size}), truncating')
self.max_size_exceeded = True
def data_for(self, mime: str = 'text/plain', offset: int = 0, size: int = -1) -> bytes:
start, full_size = self.mime_map[mime]

View File

@@ -105,6 +105,138 @@ pybase64_decode(PyObject UNUSED *self, PyObject *args) {
return ans;
}
typedef struct StreamingBase64Decoder {
PyObject_HEAD
PyObject *output;
size_t output_sz, output_capacity, num_leftover_bytes, initial_capacity;
unsigned char leftover_bytes[8];
} StreamingBase64Decoder;
static int
StreamingBase64Decoder_init(PyObject *s, PyObject *args, PyObject *kwds) {
static char *kwlist[] = {"initial_capacity", NULL};
unsigned long initial_capacity = 8 * 1024;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|k", kwlist, &initial_capacity)) return -1;
StreamingBase64Decoder *self = (StreamingBase64Decoder*)s;
self->output = PyBytes_FromStringAndSize(NULL, initial_capacity);
if (!self->output) return -1;
self->output_capacity = initial_capacity;
self->initial_capacity = initial_capacity;
return 0;
}
static void
StreamingBase64Decoder_dealloc(PyObject *self) {
StreamingBase64Decoder *h = (StreamingBase64Decoder*)self;
Py_CLEAR(h->output);
Py_TYPE(self)->tp_free(self);
}
static bool
write_base64_data(StreamingBase64Decoder *self, const void *data, size_t len) {
if (!len) return true;
size_t sz = required_buffer_size_for_base64_decode(len);
if ((self->output_sz + sz) > self->output_capacity) {
size_t cap = MAX(self->output_capacity * 2, self->output_sz + sz + self->initial_capacity);
if (_PyBytes_Resize(&self->output, cap) != 0) return false;
self->output_capacity = cap;
}
if (!base64_decode8(data, len, (unsigned char*)(PyBytes_AS_STRING(self->output) + self->output_sz), &sz)) {
PyErr_SetString(PyExc_ValueError, "Invalid base64 input data");
return false;
}
self->output_sz += sz;
return true;
}
static bool
write_saving_leftover_bytes(StreamingBase64Decoder *self, const unsigned char *data, size_t len) {
size_t extra = len % 4;
if (!write_base64_data(self, data, len - extra)) return false;
self->num_leftover_bytes = extra;
if (extra) memcpy(self->leftover_bytes, data + len - extra, extra);
return true;
}
static PyObject*
StreamingBase64Decoder_add(StreamingBase64Decoder *self, PyObject *a) {
RAII_PY_BUFFER(data);
if (PyObject_GetBuffer(a, &data, PyBUF_SIMPLE) != 0) return NULL;
if (!data.buf || !data.len) return PyLong_FromLong(0);
unsigned char *d = data.buf; size_t dlen = data.len;
size_t before = self->output_sz;
if (self->num_leftover_bytes) {
size_t extra = 4 - self->num_leftover_bytes;
if (dlen >= extra) {
memcpy(self->leftover_bytes + self->num_leftover_bytes, d, extra);
if (!write_base64_data(self, self->leftover_bytes, self->num_leftover_bytes + extra)) return NULL;
self->num_leftover_bytes = 0;
d += extra; dlen -= extra;
if (!write_saving_leftover_bytes(self, d, dlen)) return NULL;
} else {
memcpy(self->leftover_bytes + self->num_leftover_bytes, d, dlen);
self->num_leftover_bytes += dlen;
}
} else if (!write_saving_leftover_bytes(self, d, dlen)) return NULL;
return PyLong_FromSize_t(self->output_sz - before);
}
static Py_ssize_t
StreamingBase64Decoder_len(PyObject *s) { return ((StreamingBase64Decoder*)s)->output_sz; }
static PyObject*
StreamingBase64Decoder_leftover_bytes(StreamingBase64Decoder *self, PyObject *a UNUSED) {
return PyMemoryView_FromMemory((char*)self->leftover_bytes, self->num_leftover_bytes, PyBUF_READ);
}
static PyObject*
StreamingBase64Decoder_flush(StreamingBase64Decoder *self, PyObject *args UNUSED) {
size_t padding = 4 - self->num_leftover_bytes;
switch(padding) {
case 1: self->leftover_bytes[self->num_leftover_bytes++] = '='; break;
case 2: self->leftover_bytes[self->num_leftover_bytes++] = '='; self->leftover_bytes[self->num_leftover_bytes++] = '='; break;
}
write_base64_data(self, self->leftover_bytes, self->num_leftover_bytes);
self->num_leftover_bytes = 0;
Py_RETURN_NONE;
}
static PyObject*
StreamingBase64Decoder_copy_output(StreamingBase64Decoder *self, PyObject *args UNUSED) {
return PyBytes_FromStringAndSize(PyBytes_AS_STRING(self->output), self->output_sz);
}
static PyObject*
StreamingBase64Decoder_take_output(StreamingBase64Decoder *self, PyObject *args UNUSED) {
RAII_PyObject(newbuf, PyBytes_FromStringAndSize(NULL, self->initial_capacity));
if (!newbuf) return NULL;
if (_PyBytes_Resize(&self->output, self->output_sz) != 0) return NULL;
PyObject *ans = self->output;
self->output = Py_NewRef(newbuf); self->output_sz = 0; self->output_capacity = self->initial_capacity;
return ans;
}
static PyTypeObject StreamingBase64Decoder_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "kitty.fast_data_types.StreamingBase64Decoder",
.tp_basicsize = sizeof(StreamingBase64Decoder),
.tp_dealloc = StreamingBase64Decoder_dealloc,
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_doc = "StreamingBase64Decoder",
.tp_methods = (PyMethodDef[]){
{"add", (PyCFunction)StreamingBase64Decoder_add, METH_O, ""},
{"flush", (PyCFunction)StreamingBase64Decoder_flush, METH_NOARGS, ""},
{"take_output", (PyCFunction)StreamingBase64Decoder_take_output, METH_NOARGS, ""},
{"copy_output", (PyCFunction)StreamingBase64Decoder_copy_output, METH_NOARGS, ""},
{"leftover_bytes", (PyCFunction)StreamingBase64Decoder_leftover_bytes, METH_NOARGS, ""},
{NULL, NULL, 0, NULL},
},
.tp_new = PyType_GenericNew,
.tp_init = StreamingBase64Decoder_init,
.tp_as_sequence = &(PySequenceMethods){
.sq_length = StreamingBase64Decoder_len,
},
};
static PyObject*
pyset_iutf8(PyObject UNUSED *self, PyObject *args) {
@@ -611,5 +743,8 @@ PyInit_fast_data_types(void) {
PyModule_AddIntConstant(m, "SHM_NAME_MAX", MIN(1023, PATH_MAX));
#endif
if (PyType_Ready(&StreamingBase64Decoder_Type) < 0) return NULL;
if (PyModule_AddObject(m, "StreamingBase64Decoder", (PyObject *) &StreamingBase64Decoder_Type) < 0) return NULL;
return m;
}

View File

@@ -52,6 +52,10 @@ static inline void cleanup_free(void *p) { free(*(void**)p); }
static inline void cleanup_decref(PyObject **p) { Py_CLEAR(*p); }
#define RAII_PyObject(name, initializer) __attribute__((cleanup(cleanup_decref))) PyObject *name = initializer
#define RAII_PY_BUFFER(name) __attribute__((cleanup(PyBuffer_Release))) Py_buffer name = {0}
#if PY_VERSION_HEX < 0x030a0000
static inline PyObject* Py_NewRef(PyObject *o) { Py_INCREF(o); return o; }
static inline PyObject* Py_XNewRef(PyObject *o) { Py_XINCREF(o); return o; }
#endif
typedef unsigned long long id_type;
typedef uint32_t char_type;

View File

@@ -305,6 +305,9 @@ WINDOW_MINIMIZED: int
# }}}
ReadOnlyBuffer = Union[bytes, bytearray, memoryview]
def encode_key_for_tty(
key: int = 0,
shifted_key: int = 0,
@@ -1700,6 +1703,16 @@ class MousePosition(TypedDict):
def get_mouse_data_for_window(os_window_id: int, tab_id: int, window_id: int) -> Optional[MousePosition]: ...
class StreamingBase64Decoder:
def __init__(self, initial_capacity: int = 8 *1024) -> None: ... # set the initial output buffer capacity
def add(self, data: ReadOnlyBuffer) -> int: ... # add the base64 data
def flush(self) -> None: ... # indicate end of base64 data, left over bytes are processed as if they were followed by padding
def take_output(self) -> bytes: ... # take the output so far. The decoder no longer references this output
def copy_output(self) -> bytes: ... # copy the output so far
def __len__(self) -> int: ... # return the length of the current output
def leftover_bytes(self) -> memoryview: ... # return the currently leftover bytes that will be consumed by flush()
class DiskCache:
small_hole_threshold: int
defrag_factor: int

View File

@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, FrozenSet, Iterator, List, NamedTuple, O
from weakref import ReferenceType, ref
from .constants import cache_dir, is_macos, logo_png_file
from .fast_data_types import ESC_OSC, base64_decode, current_focused_os_window_id, get_boss
from .fast_data_types import ESC_OSC, StreamingBase64Decoder, current_focused_os_window_id, get_boss
from .types import run_once
from .typing import WindowType
from .utils import get_custom_window_icon, log_error, sanitize_control_codes
@@ -146,7 +146,7 @@ class DataStore:
class EncodedDataStore:
def __init__(self, data_store: DataStore) -> None:
self.current_leftover_bytes = memoryview(b'')
self.decoder = StreamingBase64Decoder(initial_capacity=4096)
self.data_store = data_store
@property
@@ -162,41 +162,14 @@ class EncodedDataStore:
def add_base64_data(self, data: Union[str, bytes]) -> None:
if isinstance(data, str):
data = data.encode('ascii')
def write_saving_leftover_bytes(data: bytes) -> None:
if len(data) == 0:
return
extra = len(data) % 4
if extra > 0:
mv = memoryview(data)
self.current_leftover_bytes = memoryview(bytes(mv[-extra:]))
mv = mv[:-extra]
if len(mv) > 0:
self._write_base64_data(mv)
else:
self._write_base64_data(data)
if len(self.current_leftover_bytes) > 0:
extra = 4 - len(self.current_leftover_bytes)
if len(data) >= extra:
self._write_base64_data(memoryview(bytes(self.current_leftover_bytes) + data[:extra]))
self.current_leftover_bytes = memoryview(b'')
data = memoryview(data)[extra:]
write_saving_leftover_bytes(data)
else:
self.current_leftover_bytes = memoryview(bytes(self.current_leftover_bytes) + data)
else:
write_saving_leftover_bytes(data)
def _write_base64_data(self, b: bytes) -> None:
self.data_store(base64_decode(b))
self.decoder.add(data)
if len(self.decoder) >= self.data_store.max_size:
self.data_store(self.decoder.take_output())
def flush_encoded_data(self) -> None:
b = self.current_leftover_bytes
self.current_leftover_bytes = memoryview(b'')
padding = 4 - len(b)
if padding in (1, 2):
self._write_base64_data(memoryview(bytes(b) + b'=' * padding))
self.decoder.flush()
if len(self.decoder):
self.data_store(self.decoder.take_output())
def finalise(self) -> bytes:
self.flush_encoded_data()

View File

@@ -214,14 +214,19 @@ def do_test(self: 'TestNotifications', tdir: str) -> None:
self.ae(ch.responses, [f'99;i=0:p=?;{qr}'])
# Test MIME streaming
text = 'some reasonably long text to test MIME streaming with'
encoded = standard_b64encode(text.encode()).decode()
for ch in encoded:
h(f'i=s:e=1:d=0;{ch}')
h(f'i=s:e=1:d=0:p=body;{encoded[:13]}')
h(f'i=s:e=1:d=0:p=body;{encoded[13:]}')
h('i=s')
self.ae(di.notifications, [n(text, text)])
for padding in (True, False):
for extra in ('a', 'ab', 'abc', 'abcd'):
text = 'some reasonably long text to test MIME streaming with: '
encoded = standard_b64encode(text.encode()).decode()
if not padding:
encoded = encoded.rstrip('=')
for t in encoded:
h(f'i=s:e=1:d=0;{t}')
h(f'i=s:e=1:d=0:p=body;{encoded[:13]}')
h(f'i=s:e=1:d=0:p=body;{encoded[13:]}')
h('i=s')
self.ae(di.notifications, [n(text, text)])
reset()
# Test Disk Cache
dc = IconDataCache(base_cache_dir=tdir, max_cache_size=4)