diff --git a/kitty/clipboard.py b/kitty/clipboard.py index d9f94f2bd..8776244f6 100644 --- a/kitty/clipboard.py +++ b/kitty/clipboard.py @@ -14,6 +14,7 @@ from .fast_data_types import ( ESC_OSC, GLFW_CLIPBOARD, GLFW_PRIMARY_SELECTION, + StreamingBase64Decoder, find_in_memoryview, get_boss, get_clipboard_mime, @@ -236,6 +237,7 @@ class WriteRequest: self, is_primary_selection: bool = False, protocol_type: ProtocolType = ProtocolType.osc_52, id: str = '', rollover_size: int = 16 * 1024 * 1024, max_size: int = -1, ) -> None: + self.decoder = StreamingBase64Decoder(8 * 1024) self.id = id self.is_primary_selection = is_primary_selection self.protocol_type = protocol_type @@ -243,7 +245,6 @@ class WriteRequest: self.tempfile = Tempfile(max_size=rollover_size) self.mime_map: Dict[str, MimePos] = {} self.currently_writing_mime = '' - self.current_leftover_bytes = memoryview(b'') self.max_size = (get_options().clipboard_max_size * 1024 * 1024) if max_size < 0 else max_size self.aliases: Dict[str, str] = {} self.committed = False @@ -276,51 +277,29 @@ class WriteRequest: if not self.currently_writing_mime: self.mime_map[mime] = MimePos(self.tempfile.tell(), -1) self.currently_writing_mime = mime + self.write_base64_data(data) - def write_saving_leftover_bytes(data: bytes) -> None: - if len(data) == 0: - return - extra = len(data) % 4 - if extra > 0: - mv = memoryview(data) - self.current_leftover_bytes = memoryview(bytes(mv[-extra:])) - mv = mv[:-extra] - if len(mv) > 0: - self.write_base64_data(mv) - else: - self.write_base64_data(data) - - if len(self.current_leftover_bytes) > 0: - extra = 4 - len(self.current_leftover_bytes) - if len(data) >= extra: - self.write_base64_data(memoryview(bytes(self.current_leftover_bytes) + data[:extra])) - self.current_leftover_bytes = memoryview(b'') - data = memoryview(data)[extra:] - write_saving_leftover_bytes(data) - else: - self.current_leftover_bytes = memoryview(bytes(self.current_leftover_bytes) + data) - else: - write_saving_leftover_bytes(data) + @property + def current_leftover_bytes(self) -> memoryview: + return self.decoder.leftover_bytes() def flush_base64_data(self) -> None: if self.currently_writing_mime: - b = self.current_leftover_bytes - padding = 4 - len(b) - if padding in (1, 2): - self.write_base64_data(memoryview(bytes(b) + b'=' * padding)) + self.decoder.flush() + if len(self.decoder): + self.write_base64_data(b'') start = self.mime_map[self.currently_writing_mime][0] self.mime_map[self.currently_writing_mime] = MimePos(start, self.tempfile.tell() - start) self.currently_writing_mime = '' - self.current_leftover_bytes = memoryview(b'') def write_base64_data(self, b: bytes) -> None: - from base64 import standard_b64decode if not self.max_size_exceeded: - d = standard_b64decode(b) - self.tempfile.write(d) - if self.max_size > 0 and self.tempfile.tell() > (self.max_size * 1024 * 1024): - log_error(f'Clipboard write request has more data than allowed by clipboard_max_size ({self.max_size}), truncating') - self.max_size_exceeded = True + self.decoder.add(b) + if len(self.decoder): + self.tempfile.write(self.decoder.take_output()) + if self.max_size > 0 and self.tempfile.tell() > (self.max_size * 1024 * 1024): + log_error(f'Clipboard write request has more data than allowed by clipboard_max_size ({self.max_size}), truncating') + self.max_size_exceeded = True def data_for(self, mime: str = 'text/plain', offset: int = 0, size: int = -1) -> bytes: start, full_size = self.mime_map[mime] diff --git a/kitty/data-types.c b/kitty/data-types.c index f23ebb0bd..8dcc3b332 100644 --- a/kitty/data-types.c +++ b/kitty/data-types.c @@ -105,6 +105,138 @@ pybase64_decode(PyObject UNUSED *self, PyObject *args) { return ans; } +typedef struct StreamingBase64Decoder { + PyObject_HEAD + PyObject *output; + size_t output_sz, output_capacity, num_leftover_bytes, initial_capacity; + unsigned char leftover_bytes[8]; +} StreamingBase64Decoder; + +static int +StreamingBase64Decoder_init(PyObject *s, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"initial_capacity", NULL}; + unsigned long initial_capacity = 8 * 1024; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|k", kwlist, &initial_capacity)) return -1; + StreamingBase64Decoder *self = (StreamingBase64Decoder*)s; + self->output = PyBytes_FromStringAndSize(NULL, initial_capacity); + if (!self->output) return -1; + self->output_capacity = initial_capacity; + self->initial_capacity = initial_capacity; + return 0; +} + +static void +StreamingBase64Decoder_dealloc(PyObject *self) { + StreamingBase64Decoder *h = (StreamingBase64Decoder*)self; + Py_CLEAR(h->output); + Py_TYPE(self)->tp_free(self); +} + +static bool +write_base64_data(StreamingBase64Decoder *self, const void *data, size_t len) { + if (!len) return true; + size_t sz = required_buffer_size_for_base64_decode(len); + if ((self->output_sz + sz) > self->output_capacity) { + size_t cap = MAX(self->output_capacity * 2, self->output_sz + sz + self->initial_capacity); + if (_PyBytes_Resize(&self->output, cap) != 0) return false; + self->output_capacity = cap; + } + if (!base64_decode8(data, len, (unsigned char*)(PyBytes_AS_STRING(self->output) + self->output_sz), &sz)) { + PyErr_SetString(PyExc_ValueError, "Invalid base64 input data"); + return false; + } + self->output_sz += sz; + return true; +} + +static bool +write_saving_leftover_bytes(StreamingBase64Decoder *self, const unsigned char *data, size_t len) { + size_t extra = len % 4; + if (!write_base64_data(self, data, len - extra)) return false; + self->num_leftover_bytes = extra; + if (extra) memcpy(self->leftover_bytes, data + len - extra, extra); + return true; +} + +static PyObject* +StreamingBase64Decoder_add(StreamingBase64Decoder *self, PyObject *a) { + RAII_PY_BUFFER(data); + if (PyObject_GetBuffer(a, &data, PyBUF_SIMPLE) != 0) return NULL; + if (!data.buf || !data.len) return PyLong_FromLong(0); + unsigned char *d = data.buf; size_t dlen = data.len; + size_t before = self->output_sz; + if (self->num_leftover_bytes) { + size_t extra = 4 - self->num_leftover_bytes; + if (dlen >= extra) { + memcpy(self->leftover_bytes + self->num_leftover_bytes, d, extra); + if (!write_base64_data(self, self->leftover_bytes, self->num_leftover_bytes + extra)) return NULL; + self->num_leftover_bytes = 0; + d += extra; dlen -= extra; + if (!write_saving_leftover_bytes(self, d, dlen)) return NULL; + } else { + memcpy(self->leftover_bytes + self->num_leftover_bytes, d, dlen); + self->num_leftover_bytes += dlen; + } + } else if (!write_saving_leftover_bytes(self, d, dlen)) return NULL; + return PyLong_FromSize_t(self->output_sz - before); +} + +static Py_ssize_t +StreamingBase64Decoder_len(PyObject *s) { return ((StreamingBase64Decoder*)s)->output_sz; } + +static PyObject* +StreamingBase64Decoder_leftover_bytes(StreamingBase64Decoder *self, PyObject *a UNUSED) { + return PyMemoryView_FromMemory((char*)self->leftover_bytes, self->num_leftover_bytes, PyBUF_READ); +} + +static PyObject* +StreamingBase64Decoder_flush(StreamingBase64Decoder *self, PyObject *args UNUSED) { + size_t padding = 4 - self->num_leftover_bytes; + switch(padding) { + case 1: self->leftover_bytes[self->num_leftover_bytes++] = '='; break; + case 2: self->leftover_bytes[self->num_leftover_bytes++] = '='; self->leftover_bytes[self->num_leftover_bytes++] = '='; break; + } + write_base64_data(self, self->leftover_bytes, self->num_leftover_bytes); + self->num_leftover_bytes = 0; + Py_RETURN_NONE; +} + +static PyObject* +StreamingBase64Decoder_copy_output(StreamingBase64Decoder *self, PyObject *args UNUSED) { + return PyBytes_FromStringAndSize(PyBytes_AS_STRING(self->output), self->output_sz); +} + +static PyObject* +StreamingBase64Decoder_take_output(StreamingBase64Decoder *self, PyObject *args UNUSED) { + RAII_PyObject(newbuf, PyBytes_FromStringAndSize(NULL, self->initial_capacity)); + if (!newbuf) return NULL; + if (_PyBytes_Resize(&self->output, self->output_sz) != 0) return NULL; + PyObject *ans = self->output; + self->output = Py_NewRef(newbuf); self->output_sz = 0; self->output_capacity = self->initial_capacity; + return ans; +} + +static PyTypeObject StreamingBase64Decoder_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "kitty.fast_data_types.StreamingBase64Decoder", + .tp_basicsize = sizeof(StreamingBase64Decoder), + .tp_dealloc = StreamingBase64Decoder_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "StreamingBase64Decoder", + .tp_methods = (PyMethodDef[]){ + {"add", (PyCFunction)StreamingBase64Decoder_add, METH_O, ""}, + {"flush", (PyCFunction)StreamingBase64Decoder_flush, METH_NOARGS, ""}, + {"take_output", (PyCFunction)StreamingBase64Decoder_take_output, METH_NOARGS, ""}, + {"copy_output", (PyCFunction)StreamingBase64Decoder_copy_output, METH_NOARGS, ""}, + {"leftover_bytes", (PyCFunction)StreamingBase64Decoder_leftover_bytes, METH_NOARGS, ""}, + {NULL, NULL, 0, NULL}, + }, + .tp_new = PyType_GenericNew, + .tp_init = StreamingBase64Decoder_init, + .tp_as_sequence = &(PySequenceMethods){ + .sq_length = StreamingBase64Decoder_len, + }, +}; static PyObject* pyset_iutf8(PyObject UNUSED *self, PyObject *args) { @@ -611,5 +743,8 @@ PyInit_fast_data_types(void) { PyModule_AddIntConstant(m, "SHM_NAME_MAX", MIN(1023, PATH_MAX)); #endif + if (PyType_Ready(&StreamingBase64Decoder_Type) < 0) return NULL; + if (PyModule_AddObject(m, "StreamingBase64Decoder", (PyObject *) &StreamingBase64Decoder_Type) < 0) return NULL; + return m; } diff --git a/kitty/data-types.h b/kitty/data-types.h index 2eac7cea2..f4e856082 100644 --- a/kitty/data-types.h +++ b/kitty/data-types.h @@ -52,6 +52,10 @@ static inline void cleanup_free(void *p) { free(*(void**)p); } static inline void cleanup_decref(PyObject **p) { Py_CLEAR(*p); } #define RAII_PyObject(name, initializer) __attribute__((cleanup(cleanup_decref))) PyObject *name = initializer #define RAII_PY_BUFFER(name) __attribute__((cleanup(PyBuffer_Release))) Py_buffer name = {0} +#if PY_VERSION_HEX < 0x030a0000 +static inline PyObject* Py_NewRef(PyObject *o) { Py_INCREF(o); return o; } +static inline PyObject* Py_XNewRef(PyObject *o) { Py_XINCREF(o); return o; } +#endif typedef unsigned long long id_type; typedef uint32_t char_type; diff --git a/kitty/fast_data_types.pyi b/kitty/fast_data_types.pyi index 1e3e87519..4b748ec8c 100644 --- a/kitty/fast_data_types.pyi +++ b/kitty/fast_data_types.pyi @@ -305,6 +305,9 @@ WINDOW_MINIMIZED: int # }}} +ReadOnlyBuffer = Union[bytes, bytearray, memoryview] + + def encode_key_for_tty( key: int = 0, shifted_key: int = 0, @@ -1700,6 +1703,16 @@ class MousePosition(TypedDict): def get_mouse_data_for_window(os_window_id: int, tab_id: int, window_id: int) -> Optional[MousePosition]: ... +class StreamingBase64Decoder: + def __init__(self, initial_capacity: int = 8 *1024) -> None: ... # set the initial output buffer capacity + def add(self, data: ReadOnlyBuffer) -> int: ... # add the base64 data + def flush(self) -> None: ... # indicate end of base64 data, left over bytes are processed as if they were followed by padding + def take_output(self) -> bytes: ... # take the output so far. The decoder no longer references this output + def copy_output(self) -> bytes: ... # copy the output so far + def __len__(self) -> int: ... # return the length of the current output + def leftover_bytes(self) -> memoryview: ... # return the currently leftover bytes that will be consumed by flush() + + class DiskCache: small_hole_threshold: int defrag_factor: int diff --git a/kitty/notifications.py b/kitty/notifications.py index 5f42c29d5..c64a5fce4 100644 --- a/kitty/notifications.py +++ b/kitty/notifications.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, FrozenSet, Iterator, List, NamedTuple, O from weakref import ReferenceType, ref from .constants import cache_dir, is_macos, logo_png_file -from .fast_data_types import ESC_OSC, base64_decode, current_focused_os_window_id, get_boss +from .fast_data_types import ESC_OSC, StreamingBase64Decoder, current_focused_os_window_id, get_boss from .types import run_once from .typing import WindowType from .utils import get_custom_window_icon, log_error, sanitize_control_codes @@ -146,7 +146,7 @@ class DataStore: class EncodedDataStore: def __init__(self, data_store: DataStore) -> None: - self.current_leftover_bytes = memoryview(b'') + self.decoder = StreamingBase64Decoder(initial_capacity=4096) self.data_store = data_store @property @@ -162,41 +162,14 @@ class EncodedDataStore: def add_base64_data(self, data: Union[str, bytes]) -> None: if isinstance(data, str): data = data.encode('ascii') - - def write_saving_leftover_bytes(data: bytes) -> None: - if len(data) == 0: - return - extra = len(data) % 4 - if extra > 0: - mv = memoryview(data) - self.current_leftover_bytes = memoryview(bytes(mv[-extra:])) - mv = mv[:-extra] - if len(mv) > 0: - self._write_base64_data(mv) - else: - self._write_base64_data(data) - - if len(self.current_leftover_bytes) > 0: - extra = 4 - len(self.current_leftover_bytes) - if len(data) >= extra: - self._write_base64_data(memoryview(bytes(self.current_leftover_bytes) + data[:extra])) - self.current_leftover_bytes = memoryview(b'') - data = memoryview(data)[extra:] - write_saving_leftover_bytes(data) - else: - self.current_leftover_bytes = memoryview(bytes(self.current_leftover_bytes) + data) - else: - write_saving_leftover_bytes(data) - - def _write_base64_data(self, b: bytes) -> None: - self.data_store(base64_decode(b)) + self.decoder.add(data) + if len(self.decoder) >= self.data_store.max_size: + self.data_store(self.decoder.take_output()) def flush_encoded_data(self) -> None: - b = self.current_leftover_bytes - self.current_leftover_bytes = memoryview(b'') - padding = 4 - len(b) - if padding in (1, 2): - self._write_base64_data(memoryview(bytes(b) + b'=' * padding)) + self.decoder.flush() + if len(self.decoder): + self.data_store(self.decoder.take_output()) def finalise(self) -> bytes: self.flush_encoded_data() diff --git a/kitty_tests/notifications.py b/kitty_tests/notifications.py index 6ed9c9b9e..e039c31b4 100644 --- a/kitty_tests/notifications.py +++ b/kitty_tests/notifications.py @@ -214,14 +214,19 @@ def do_test(self: 'TestNotifications', tdir: str) -> None: self.ae(ch.responses, [f'99;i=0:p=?;{qr}']) # Test MIME streaming - text = 'some reasonably long text to test MIME streaming with' - encoded = standard_b64encode(text.encode()).decode() - for ch in encoded: - h(f'i=s:e=1:d=0;{ch}') - h(f'i=s:e=1:d=0:p=body;{encoded[:13]}') - h(f'i=s:e=1:d=0:p=body;{encoded[13:]}') - h('i=s') - self.ae(di.notifications, [n(text, text)]) + for padding in (True, False): + for extra in ('a', 'ab', 'abc', 'abcd'): + text = 'some reasonably long text to test MIME streaming with: ' + encoded = standard_b64encode(text.encode()).decode() + if not padding: + encoded = encoded.rstrip('=') + for t in encoded: + h(f'i=s:e=1:d=0;{t}') + h(f'i=s:e=1:d=0:p=body;{encoded[:13]}') + h(f'i=s:e=1:d=0:p=body;{encoded[13:]}') + h('i=s') + self.ae(di.notifications, [n(text, text)]) + reset() # Test Disk Cache dc = IconDataCache(base_cache_dir=tdir, max_cache_size=4)