mirror of
https://github.com/kovidgoyal/kitty
synced 2026-07-02 12:44:01 +02:00
Switch to SIMD accelerated base64 decoding for clipboard and notification requests
This commit is contained in:
@@ -14,6 +14,7 @@ from .fast_data_types import (
|
||||
ESC_OSC,
|
||||
GLFW_CLIPBOARD,
|
||||
GLFW_PRIMARY_SELECTION,
|
||||
StreamingBase64Decoder,
|
||||
find_in_memoryview,
|
||||
get_boss,
|
||||
get_clipboard_mime,
|
||||
@@ -236,6 +237,7 @@ class WriteRequest:
|
||||
self, is_primary_selection: bool = False, protocol_type: ProtocolType = ProtocolType.osc_52, id: str = '',
|
||||
rollover_size: int = 16 * 1024 * 1024, max_size: int = -1,
|
||||
) -> None:
|
||||
self.decoder = StreamingBase64Decoder(8 * 1024)
|
||||
self.id = id
|
||||
self.is_primary_selection = is_primary_selection
|
||||
self.protocol_type = protocol_type
|
||||
@@ -243,7 +245,6 @@ class WriteRequest:
|
||||
self.tempfile = Tempfile(max_size=rollover_size)
|
||||
self.mime_map: Dict[str, MimePos] = {}
|
||||
self.currently_writing_mime = ''
|
||||
self.current_leftover_bytes = memoryview(b'')
|
||||
self.max_size = (get_options().clipboard_max_size * 1024 * 1024) if max_size < 0 else max_size
|
||||
self.aliases: Dict[str, str] = {}
|
||||
self.committed = False
|
||||
@@ -276,51 +277,29 @@ class WriteRequest:
|
||||
if not self.currently_writing_mime:
|
||||
self.mime_map[mime] = MimePos(self.tempfile.tell(), -1)
|
||||
self.currently_writing_mime = mime
|
||||
self.write_base64_data(data)
|
||||
|
||||
def write_saving_leftover_bytes(data: bytes) -> None:
|
||||
if len(data) == 0:
|
||||
return
|
||||
extra = len(data) % 4
|
||||
if extra > 0:
|
||||
mv = memoryview(data)
|
||||
self.current_leftover_bytes = memoryview(bytes(mv[-extra:]))
|
||||
mv = mv[:-extra]
|
||||
if len(mv) > 0:
|
||||
self.write_base64_data(mv)
|
||||
else:
|
||||
self.write_base64_data(data)
|
||||
|
||||
if len(self.current_leftover_bytes) > 0:
|
||||
extra = 4 - len(self.current_leftover_bytes)
|
||||
if len(data) >= extra:
|
||||
self.write_base64_data(memoryview(bytes(self.current_leftover_bytes) + data[:extra]))
|
||||
self.current_leftover_bytes = memoryview(b'')
|
||||
data = memoryview(data)[extra:]
|
||||
write_saving_leftover_bytes(data)
|
||||
else:
|
||||
self.current_leftover_bytes = memoryview(bytes(self.current_leftover_bytes) + data)
|
||||
else:
|
||||
write_saving_leftover_bytes(data)
|
||||
@property
|
||||
def current_leftover_bytes(self) -> memoryview:
|
||||
return self.decoder.leftover_bytes()
|
||||
|
||||
def flush_base64_data(self) -> None:
|
||||
if self.currently_writing_mime:
|
||||
b = self.current_leftover_bytes
|
||||
padding = 4 - len(b)
|
||||
if padding in (1, 2):
|
||||
self.write_base64_data(memoryview(bytes(b) + b'=' * padding))
|
||||
self.decoder.flush()
|
||||
if len(self.decoder):
|
||||
self.write_base64_data(b'')
|
||||
start = self.mime_map[self.currently_writing_mime][0]
|
||||
self.mime_map[self.currently_writing_mime] = MimePos(start, self.tempfile.tell() - start)
|
||||
self.currently_writing_mime = ''
|
||||
self.current_leftover_bytes = memoryview(b'')
|
||||
|
||||
def write_base64_data(self, b: bytes) -> None:
|
||||
from base64 import standard_b64decode
|
||||
if not self.max_size_exceeded:
|
||||
d = standard_b64decode(b)
|
||||
self.tempfile.write(d)
|
||||
if self.max_size > 0 and self.tempfile.tell() > (self.max_size * 1024 * 1024):
|
||||
log_error(f'Clipboard write request has more data than allowed by clipboard_max_size ({self.max_size}), truncating')
|
||||
self.max_size_exceeded = True
|
||||
self.decoder.add(b)
|
||||
if len(self.decoder):
|
||||
self.tempfile.write(self.decoder.take_output())
|
||||
if self.max_size > 0 and self.tempfile.tell() > (self.max_size * 1024 * 1024):
|
||||
log_error(f'Clipboard write request has more data than allowed by clipboard_max_size ({self.max_size}), truncating')
|
||||
self.max_size_exceeded = True
|
||||
|
||||
def data_for(self, mime: str = 'text/plain', offset: int = 0, size: int = -1) -> bytes:
|
||||
start, full_size = self.mime_map[mime]
|
||||
|
||||
@@ -105,6 +105,138 @@ pybase64_decode(PyObject UNUSED *self, PyObject *args) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
typedef struct StreamingBase64Decoder {
|
||||
PyObject_HEAD
|
||||
PyObject *output;
|
||||
size_t output_sz, output_capacity, num_leftover_bytes, initial_capacity;
|
||||
unsigned char leftover_bytes[8];
|
||||
} StreamingBase64Decoder;
|
||||
|
||||
static int
|
||||
StreamingBase64Decoder_init(PyObject *s, PyObject *args, PyObject *kwds) {
|
||||
static char *kwlist[] = {"initial_capacity", NULL};
|
||||
unsigned long initial_capacity = 8 * 1024;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|k", kwlist, &initial_capacity)) return -1;
|
||||
StreamingBase64Decoder *self = (StreamingBase64Decoder*)s;
|
||||
self->output = PyBytes_FromStringAndSize(NULL, initial_capacity);
|
||||
if (!self->output) return -1;
|
||||
self->output_capacity = initial_capacity;
|
||||
self->initial_capacity = initial_capacity;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void
|
||||
StreamingBase64Decoder_dealloc(PyObject *self) {
|
||||
StreamingBase64Decoder *h = (StreamingBase64Decoder*)self;
|
||||
Py_CLEAR(h->output);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
static bool
|
||||
write_base64_data(StreamingBase64Decoder *self, const void *data, size_t len) {
|
||||
if (!len) return true;
|
||||
size_t sz = required_buffer_size_for_base64_decode(len);
|
||||
if ((self->output_sz + sz) > self->output_capacity) {
|
||||
size_t cap = MAX(self->output_capacity * 2, self->output_sz + sz + self->initial_capacity);
|
||||
if (_PyBytes_Resize(&self->output, cap) != 0) return false;
|
||||
self->output_capacity = cap;
|
||||
}
|
||||
if (!base64_decode8(data, len, (unsigned char*)(PyBytes_AS_STRING(self->output) + self->output_sz), &sz)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Invalid base64 input data");
|
||||
return false;
|
||||
}
|
||||
self->output_sz += sz;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
write_saving_leftover_bytes(StreamingBase64Decoder *self, const unsigned char *data, size_t len) {
|
||||
size_t extra = len % 4;
|
||||
if (!write_base64_data(self, data, len - extra)) return false;
|
||||
self->num_leftover_bytes = extra;
|
||||
if (extra) memcpy(self->leftover_bytes, data + len - extra, extra);
|
||||
return true;
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
StreamingBase64Decoder_add(StreamingBase64Decoder *self, PyObject *a) {
|
||||
RAII_PY_BUFFER(data);
|
||||
if (PyObject_GetBuffer(a, &data, PyBUF_SIMPLE) != 0) return NULL;
|
||||
if (!data.buf || !data.len) return PyLong_FromLong(0);
|
||||
unsigned char *d = data.buf; size_t dlen = data.len;
|
||||
size_t before = self->output_sz;
|
||||
if (self->num_leftover_bytes) {
|
||||
size_t extra = 4 - self->num_leftover_bytes;
|
||||
if (dlen >= extra) {
|
||||
memcpy(self->leftover_bytes + self->num_leftover_bytes, d, extra);
|
||||
if (!write_base64_data(self, self->leftover_bytes, self->num_leftover_bytes + extra)) return NULL;
|
||||
self->num_leftover_bytes = 0;
|
||||
d += extra; dlen -= extra;
|
||||
if (!write_saving_leftover_bytes(self, d, dlen)) return NULL;
|
||||
} else {
|
||||
memcpy(self->leftover_bytes + self->num_leftover_bytes, d, dlen);
|
||||
self->num_leftover_bytes += dlen;
|
||||
}
|
||||
} else if (!write_saving_leftover_bytes(self, d, dlen)) return NULL;
|
||||
return PyLong_FromSize_t(self->output_sz - before);
|
||||
}
|
||||
|
||||
static Py_ssize_t
|
||||
StreamingBase64Decoder_len(PyObject *s) { return ((StreamingBase64Decoder*)s)->output_sz; }
|
||||
|
||||
static PyObject*
|
||||
StreamingBase64Decoder_leftover_bytes(StreamingBase64Decoder *self, PyObject *a UNUSED) {
|
||||
return PyMemoryView_FromMemory((char*)self->leftover_bytes, self->num_leftover_bytes, PyBUF_READ);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
StreamingBase64Decoder_flush(StreamingBase64Decoder *self, PyObject *args UNUSED) {
|
||||
size_t padding = 4 - self->num_leftover_bytes;
|
||||
switch(padding) {
|
||||
case 1: self->leftover_bytes[self->num_leftover_bytes++] = '='; break;
|
||||
case 2: self->leftover_bytes[self->num_leftover_bytes++] = '='; self->leftover_bytes[self->num_leftover_bytes++] = '='; break;
|
||||
}
|
||||
write_base64_data(self, self->leftover_bytes, self->num_leftover_bytes);
|
||||
self->num_leftover_bytes = 0;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
StreamingBase64Decoder_copy_output(StreamingBase64Decoder *self, PyObject *args UNUSED) {
|
||||
return PyBytes_FromStringAndSize(PyBytes_AS_STRING(self->output), self->output_sz);
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
StreamingBase64Decoder_take_output(StreamingBase64Decoder *self, PyObject *args UNUSED) {
|
||||
RAII_PyObject(newbuf, PyBytes_FromStringAndSize(NULL, self->initial_capacity));
|
||||
if (!newbuf) return NULL;
|
||||
if (_PyBytes_Resize(&self->output, self->output_sz) != 0) return NULL;
|
||||
PyObject *ans = self->output;
|
||||
self->output = Py_NewRef(newbuf); self->output_sz = 0; self->output_capacity = self->initial_capacity;
|
||||
return ans;
|
||||
}
|
||||
|
||||
static PyTypeObject StreamingBase64Decoder_Type = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
.tp_name = "kitty.fast_data_types.StreamingBase64Decoder",
|
||||
.tp_basicsize = sizeof(StreamingBase64Decoder),
|
||||
.tp_dealloc = StreamingBase64Decoder_dealloc,
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
.tp_doc = "StreamingBase64Decoder",
|
||||
.tp_methods = (PyMethodDef[]){
|
||||
{"add", (PyCFunction)StreamingBase64Decoder_add, METH_O, ""},
|
||||
{"flush", (PyCFunction)StreamingBase64Decoder_flush, METH_NOARGS, ""},
|
||||
{"take_output", (PyCFunction)StreamingBase64Decoder_take_output, METH_NOARGS, ""},
|
||||
{"copy_output", (PyCFunction)StreamingBase64Decoder_copy_output, METH_NOARGS, ""},
|
||||
{"leftover_bytes", (PyCFunction)StreamingBase64Decoder_leftover_bytes, METH_NOARGS, ""},
|
||||
{NULL, NULL, 0, NULL},
|
||||
},
|
||||
.tp_new = PyType_GenericNew,
|
||||
.tp_init = StreamingBase64Decoder_init,
|
||||
.tp_as_sequence = &(PySequenceMethods){
|
||||
.sq_length = StreamingBase64Decoder_len,
|
||||
},
|
||||
};
|
||||
|
||||
static PyObject*
|
||||
pyset_iutf8(PyObject UNUSED *self, PyObject *args) {
|
||||
@@ -611,5 +743,8 @@ PyInit_fast_data_types(void) {
|
||||
PyModule_AddIntConstant(m, "SHM_NAME_MAX", MIN(1023, PATH_MAX));
|
||||
#endif
|
||||
|
||||
if (PyType_Ready(&StreamingBase64Decoder_Type) < 0) return NULL;
|
||||
if (PyModule_AddObject(m, "StreamingBase64Decoder", (PyObject *) &StreamingBase64Decoder_Type) < 0) return NULL;
|
||||
|
||||
return m;
|
||||
}
|
||||
|
||||
@@ -52,6 +52,10 @@ static inline void cleanup_free(void *p) { free(*(void**)p); }
|
||||
static inline void cleanup_decref(PyObject **p) { Py_CLEAR(*p); }
|
||||
#define RAII_PyObject(name, initializer) __attribute__((cleanup(cleanup_decref))) PyObject *name = initializer
|
||||
#define RAII_PY_BUFFER(name) __attribute__((cleanup(PyBuffer_Release))) Py_buffer name = {0}
|
||||
#if PY_VERSION_HEX < 0x030a0000
|
||||
static inline PyObject* Py_NewRef(PyObject *o) { Py_INCREF(o); return o; }
|
||||
static inline PyObject* Py_XNewRef(PyObject *o) { Py_XINCREF(o); return o; }
|
||||
#endif
|
||||
|
||||
typedef unsigned long long id_type;
|
||||
typedef uint32_t char_type;
|
||||
|
||||
@@ -305,6 +305,9 @@ WINDOW_MINIMIZED: int
|
||||
# }}}
|
||||
|
||||
|
||||
ReadOnlyBuffer = Union[bytes, bytearray, memoryview]
|
||||
|
||||
|
||||
def encode_key_for_tty(
|
||||
key: int = 0,
|
||||
shifted_key: int = 0,
|
||||
@@ -1700,6 +1703,16 @@ class MousePosition(TypedDict):
|
||||
def get_mouse_data_for_window(os_window_id: int, tab_id: int, window_id: int) -> Optional[MousePosition]: ...
|
||||
|
||||
|
||||
class StreamingBase64Decoder:
|
||||
def __init__(self, initial_capacity: int = 8 *1024) -> None: ... # set the initial output buffer capacity
|
||||
def add(self, data: ReadOnlyBuffer) -> int: ... # add the base64 data
|
||||
def flush(self) -> None: ... # indicate end of base64 data, left over bytes are processed as if they were followed by padding
|
||||
def take_output(self) -> bytes: ... # take the output so far. The decoder no longer references this output
|
||||
def copy_output(self) -> bytes: ... # copy the output so far
|
||||
def __len__(self) -> int: ... # return the length of the current output
|
||||
def leftover_bytes(self) -> memoryview: ... # return the currently leftover bytes that will be consumed by flush()
|
||||
|
||||
|
||||
class DiskCache:
|
||||
small_hole_threshold: int
|
||||
defrag_factor: int
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, FrozenSet, Iterator, List, NamedTuple, O
|
||||
from weakref import ReferenceType, ref
|
||||
|
||||
from .constants import cache_dir, is_macos, logo_png_file
|
||||
from .fast_data_types import ESC_OSC, base64_decode, current_focused_os_window_id, get_boss
|
||||
from .fast_data_types import ESC_OSC, StreamingBase64Decoder, current_focused_os_window_id, get_boss
|
||||
from .types import run_once
|
||||
from .typing import WindowType
|
||||
from .utils import get_custom_window_icon, log_error, sanitize_control_codes
|
||||
@@ -146,7 +146,7 @@ class DataStore:
|
||||
class EncodedDataStore:
|
||||
|
||||
def __init__(self, data_store: DataStore) -> None:
|
||||
self.current_leftover_bytes = memoryview(b'')
|
||||
self.decoder = StreamingBase64Decoder(initial_capacity=4096)
|
||||
self.data_store = data_store
|
||||
|
||||
@property
|
||||
@@ -162,41 +162,14 @@ class EncodedDataStore:
|
||||
def add_base64_data(self, data: Union[str, bytes]) -> None:
|
||||
if isinstance(data, str):
|
||||
data = data.encode('ascii')
|
||||
|
||||
def write_saving_leftover_bytes(data: bytes) -> None:
|
||||
if len(data) == 0:
|
||||
return
|
||||
extra = len(data) % 4
|
||||
if extra > 0:
|
||||
mv = memoryview(data)
|
||||
self.current_leftover_bytes = memoryview(bytes(mv[-extra:]))
|
||||
mv = mv[:-extra]
|
||||
if len(mv) > 0:
|
||||
self._write_base64_data(mv)
|
||||
else:
|
||||
self._write_base64_data(data)
|
||||
|
||||
if len(self.current_leftover_bytes) > 0:
|
||||
extra = 4 - len(self.current_leftover_bytes)
|
||||
if len(data) >= extra:
|
||||
self._write_base64_data(memoryview(bytes(self.current_leftover_bytes) + data[:extra]))
|
||||
self.current_leftover_bytes = memoryview(b'')
|
||||
data = memoryview(data)[extra:]
|
||||
write_saving_leftover_bytes(data)
|
||||
else:
|
||||
self.current_leftover_bytes = memoryview(bytes(self.current_leftover_bytes) + data)
|
||||
else:
|
||||
write_saving_leftover_bytes(data)
|
||||
|
||||
def _write_base64_data(self, b: bytes) -> None:
|
||||
self.data_store(base64_decode(b))
|
||||
self.decoder.add(data)
|
||||
if len(self.decoder) >= self.data_store.max_size:
|
||||
self.data_store(self.decoder.take_output())
|
||||
|
||||
def flush_encoded_data(self) -> None:
|
||||
b = self.current_leftover_bytes
|
||||
self.current_leftover_bytes = memoryview(b'')
|
||||
padding = 4 - len(b)
|
||||
if padding in (1, 2):
|
||||
self._write_base64_data(memoryview(bytes(b) + b'=' * padding))
|
||||
self.decoder.flush()
|
||||
if len(self.decoder):
|
||||
self.data_store(self.decoder.take_output())
|
||||
|
||||
def finalise(self) -> bytes:
|
||||
self.flush_encoded_data()
|
||||
|
||||
@@ -214,14 +214,19 @@ def do_test(self: 'TestNotifications', tdir: str) -> None:
|
||||
self.ae(ch.responses, [f'99;i=0:p=?;{qr}'])
|
||||
|
||||
# Test MIME streaming
|
||||
text = 'some reasonably long text to test MIME streaming with'
|
||||
encoded = standard_b64encode(text.encode()).decode()
|
||||
for ch in encoded:
|
||||
h(f'i=s:e=1:d=0;{ch}')
|
||||
h(f'i=s:e=1:d=0:p=body;{encoded[:13]}')
|
||||
h(f'i=s:e=1:d=0:p=body;{encoded[13:]}')
|
||||
h('i=s')
|
||||
self.ae(di.notifications, [n(text, text)])
|
||||
for padding in (True, False):
|
||||
for extra in ('a', 'ab', 'abc', 'abcd'):
|
||||
text = 'some reasonably long text to test MIME streaming with: '
|
||||
encoded = standard_b64encode(text.encode()).decode()
|
||||
if not padding:
|
||||
encoded = encoded.rstrip('=')
|
||||
for t in encoded:
|
||||
h(f'i=s:e=1:d=0;{t}')
|
||||
h(f'i=s:e=1:d=0:p=body;{encoded[:13]}')
|
||||
h(f'i=s:e=1:d=0:p=body;{encoded[13:]}')
|
||||
h('i=s')
|
||||
self.ae(di.notifications, [n(text, text)])
|
||||
reset()
|
||||
|
||||
# Test Disk Cache
|
||||
dc = IconDataCache(base_cache_dir=tdir, max_cache_size=4)
|
||||
|
||||
Reference in New Issue
Block a user