mirror of
https://github.com/kovidgoyal/kitty
synced 2026-06-08 22:28:24 +02:00
More work on transmission
This commit is contained in:
@@ -2,18 +2,19 @@
|
||||
# vim:fileencoding=utf-8
|
||||
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
import copy
|
||||
import errno
|
||||
import os
|
||||
import tempfile
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from time import monotonic
|
||||
from collections import deque
|
||||
from dataclasses import Field, dataclass, field, fields
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import IO, Any, Dict, List, Optional, Union
|
||||
from gettext import gettext as _
|
||||
from time import monotonic
|
||||
from typing import IO, Any, Deque, Dict, List, Optional, Tuple
|
||||
|
||||
from kitty.fast_data_types import OSC, get_boss
|
||||
from kitty.fast_data_types import OSC, add_timer, get_boss
|
||||
|
||||
from .utils import log_error, sanitize_control_codes
|
||||
|
||||
@@ -28,6 +29,7 @@ class Action(Enum):
|
||||
receive = auto()
|
||||
invalid = auto()
|
||||
cancel = auto()
|
||||
status = auto()
|
||||
|
||||
|
||||
class Compression(Enum):
|
||||
@@ -48,84 +50,105 @@ class TransmisstionType(Enum):
|
||||
rsync = auto()
|
||||
|
||||
|
||||
class ErrorCode(Enum):
|
||||
EINVAL = auto()
|
||||
OK = auto()
|
||||
|
||||
|
||||
class TransmissionError(Exception):
|
||||
|
||||
def __init__(
|
||||
self, code: ErrorCode = ErrorCode.EINVAL,
|
||||
msg: str = 'Generic error',
|
||||
transmit: bool = True,
|
||||
file_id: str = ''
|
||||
) -> None:
|
||||
Exception.__init__(self, msg)
|
||||
self.transmit = transmit
|
||||
self.file_id = file_id
|
||||
self.human_msg = msg
|
||||
self.code = code
|
||||
|
||||
def as_escape_code(self, request_id: str = '') -> str:
|
||||
return FileTransmissionCommand(
|
||||
action=Action.status, id=request_id, file_id=self.file_id,
|
||||
name=f'{self.code.name}:{self.human_msg}'
|
||||
).serialize()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileTransmissionCommand:
|
||||
|
||||
action = Action.invalid
|
||||
compression = Compression.none
|
||||
ftype = FileType.regular
|
||||
ttype = TransmisstionType.simple
|
||||
action: Action = Action.invalid
|
||||
compression: Compression = Compression.none
|
||||
ftype: FileType = FileType.regular
|
||||
ttype: TransmisstionType = TransmisstionType.simple
|
||||
id: str = ''
|
||||
file_id: str = ''
|
||||
secret: str = ''
|
||||
mime: str = ''
|
||||
quiet: int = 0
|
||||
name: str = ''
|
||||
mtime: int = -1
|
||||
permissions: int = -1
|
||||
data: bytes = b''
|
||||
name: str = field(default='', metadata={'base64': True})
|
||||
|
||||
def serialize(self) -> str:
|
||||
ans = [f'action={self.action.name}']
|
||||
if self.compression is not Compression.none:
|
||||
ans.append(f'compression={self.compression.name}')
|
||||
if self.ftype is not FileType.regular:
|
||||
ans.append(f'ftype={self.ftype.name}')
|
||||
if self.ttype is not TransmisstionType.simple:
|
||||
ans.append(f'ttype={self.ttype.name}')
|
||||
for x in ('id', 'file_id', 'secret', 'mime', 'quiet'):
|
||||
val = getattr(self, x)
|
||||
if val:
|
||||
ans.append(f'{x}={val}')
|
||||
for k in ('mtime', 'permissions'):
|
||||
val = getattr(self, k)
|
||||
if val >= 0:
|
||||
ans.append(f'{k}={val}')
|
||||
if self.name:
|
||||
val = standard_b64encode(self.name.encode('utf-8')).decode('ascii')
|
||||
ans.append(f'name={val}')
|
||||
if self.data:
|
||||
val = standard_b64encode(self.data).decode('ascii')
|
||||
ans.append(f'data={val}')
|
||||
ans = []
|
||||
for k in fields(self):
|
||||
val = getattr(self, k.name)
|
||||
if val == k.default:
|
||||
continue
|
||||
if issubclass(k.type, Enum):
|
||||
ans.append(f'{k.name}={val.name}')
|
||||
elif k.type is bytes:
|
||||
ev = standard_b64encode(val).decode('ascii')
|
||||
ans.append(f'{k.name}={ev}')
|
||||
elif k.type is str:
|
||||
if k.metadata.get('base64'):
|
||||
sval = standard_b64encode(self.name.encode('utf-8')).decode('ascii')
|
||||
else:
|
||||
sval = val
|
||||
ans.append(f'{k.name}={sanitize_control_codes(sval)}')
|
||||
elif k.type is int:
|
||||
ans.append(f'{k.name}={val}')
|
||||
else:
|
||||
raise KeyError(f'Field of unknown type: {k.name}')
|
||||
|
||||
def escape_semicolons(x: str) -> str:
|
||||
return x.replace(';', ';;')
|
||||
|
||||
return ';'.join(map(escape_semicolons, ans))
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: str) -> 'FileTransmissionCommand':
|
||||
ans = FileTransmissionCommand()
|
||||
parts = (x.replace('\0', ';').partition('=')[::2] for x in data.replace(';;', '\0').split(';'))
|
||||
if not hasattr(cls, 'fmap'):
|
||||
setattr(cls, 'fmap', {k.name: k for k in fields(cls)})
|
||||
fmap: Dict[str, Field] = getattr(cls, 'fmap')
|
||||
|
||||
def parse_command(data: str) -> FileTransmissionCommand:
|
||||
ans = FileTransmissionCommand()
|
||||
parts = data.replace(';;', '\0').split(';')
|
||||
for k, v in parts:
|
||||
field = fmap.get(k)
|
||||
if field is None:
|
||||
continue
|
||||
if issubclass(field.type, Enum):
|
||||
setattr(ans, field.name, field.type[v])
|
||||
elif field.type is bytes:
|
||||
setattr(ans, field.name, standard_b64decode(v))
|
||||
elif field.type is int:
|
||||
setattr(ans, field.name, int(v))
|
||||
elif field.type is str:
|
||||
if field.metadata.get('base64'):
|
||||
sval = standard_b64decode(v).decode('utf-8')
|
||||
else:
|
||||
sval = v
|
||||
setattr(ans, field.name, sanitize_control_codes(sval))
|
||||
|
||||
for i, x in enumerate(parts):
|
||||
k, v = x.replace('\0', ';').partition('=')[::2]
|
||||
if k == 'action':
|
||||
ans.action = Action[v]
|
||||
elif k == 'compression':
|
||||
ans.compression = Compression[v]
|
||||
elif k == 'ftype':
|
||||
ans.ftype = FileType[v]
|
||||
elif k == 'ttype':
|
||||
ans.ttype = TransmisstionType[v]
|
||||
elif k in ('secret', 'mime', 'id', 'file_id'):
|
||||
setattr(ans, k, sanitize_control_codes(v))
|
||||
elif k in ('quiet',):
|
||||
setattr(ans, k, int(v))
|
||||
elif k in ('mtime', 'permissions'):
|
||||
mt = int(v)
|
||||
if mt >= 0:
|
||||
setattr(ans, k, mt)
|
||||
elif k in ('name', 'data'):
|
||||
val = standard_b64decode(v)
|
||||
if k == 'name':
|
||||
ans.name = sanitize_control_codes(val.decode('utf-8'))
|
||||
else:
|
||||
ans.data = val
|
||||
if ans.action is Action.invalid:
|
||||
raise ValueError('No valid action specified in file transmission command')
|
||||
|
||||
if ans.action is Action.invalid:
|
||||
raise ValueError('No valid action specified in file transmission command')
|
||||
|
||||
return ans
|
||||
return ans
|
||||
|
||||
|
||||
class IdentityDecompressor:
|
||||
@@ -157,9 +180,11 @@ class DestFile:
|
||||
self.ttype = ftc.ttype
|
||||
self.needs_data_sent = self.ttype is not TransmisstionType.simple
|
||||
self.decompressor = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor()
|
||||
self.closed = self.ftype is FileType.directory
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
|
||||
|
||||
class ActiveReceive:
|
||||
@@ -167,10 +192,12 @@ class ActiveReceive:
|
||||
files: Dict[str, DestFile]
|
||||
accepted: bool = False
|
||||
|
||||
def __init__(self, id: str) -> None:
|
||||
def __init__(self, id: str, quiet: int) -> None:
|
||||
self.id = id
|
||||
self.files = {}
|
||||
self.last_activity_at = monotonic()
|
||||
self.send_acknowledgements = quiet < 1
|
||||
self.send_errors = quiet < 2
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
@@ -186,7 +213,10 @@ class ActiveReceive:
|
||||
|
||||
def start_file(self, ftc: FileTransmissionCommand) -> DestFile:
|
||||
if ftc.file_id in self.files:
|
||||
raise KeyError(f'The file_id {ftc.file_id} already exists')
|
||||
raise TransmissionError(
|
||||
msg=f'The file_id {ftc.file_id} already exists',
|
||||
file_id=ftc.file_id,
|
||||
)
|
||||
self.files[ftc.file_id] = result = DestFile(ftc)
|
||||
return result
|
||||
|
||||
@@ -198,6 +228,24 @@ class FileTransmission:
|
||||
def __init__(self, window_id: int):
|
||||
self.window_id = window_id
|
||||
self.active_receives = {}
|
||||
self.pending_receive_responses: Deque[Tuple[str, str]] = deque()
|
||||
self.pending_timer: Optional[int] = None
|
||||
|
||||
def start_pending_timer(self) -> None:
|
||||
if self.pending_timer is None:
|
||||
self.pending_timer = add_timer(self.try_pending, 0.2, False)
|
||||
|
||||
def try_pending(self, timer_id: Optional[int]) -> None:
|
||||
self.pending_timer = None
|
||||
while self.pending_receive_responses:
|
||||
request_id, payload = self.pending_receive_responses.popleft()
|
||||
ar = self.active_receives.get(request_id)
|
||||
if ar is None:
|
||||
continue
|
||||
if not self.write_osc_to_child(request_id, payload, appendleft=True):
|
||||
break
|
||||
ar.last_activity_at = monotonic()
|
||||
self.prune_expired()
|
||||
|
||||
def __del__(self) -> None:
|
||||
for ar in self.active_receives.values():
|
||||
@@ -217,7 +265,7 @@ class FileTransmission:
|
||||
def handle_serialized_command(self, data: str) -> None:
|
||||
self.prune_expired()
|
||||
try:
|
||||
cmd = parse_command(data)
|
||||
cmd = FileTransmissionCommand.deserialize(data)
|
||||
except Exception as e:
|
||||
log_error(f'Failed to parse file transmission command with error: {e}')
|
||||
return
|
||||
@@ -226,11 +274,11 @@ class FileTransmission:
|
||||
|
||||
def handle_receive_cmd(self, cmd: FileTransmissionCommand) -> None:
|
||||
if cmd.id in self.active_receives:
|
||||
ar = self.active_receives[cmd.id]
|
||||
if cmd.action is Action.send:
|
||||
log_error('File transmission send received for already active id, aborting')
|
||||
self.drop_receive(cmd.id)
|
||||
return
|
||||
ar = self.active_receives[cmd.id]
|
||||
if not ar.accepted:
|
||||
log_error(f'File transmission command received for rejected id: {cmd.id}, aborting')
|
||||
self.drop_receive(cmd.id)
|
||||
@@ -240,14 +288,23 @@ class FileTransmission:
|
||||
if cmd.action is not Action.send:
|
||||
log_error(f'File transmission command received for unknown or rejected id: {cmd.id}, ignoring')
|
||||
return
|
||||
ar = ActiveReceive(cmd.id)
|
||||
ar = ActiveReceive(cmd.id, cmd.quiet)
|
||||
self.start_receive(ar.id)
|
||||
return
|
||||
|
||||
if cmd.action is Action.cancel:
|
||||
self.drop_receive(ar.id)
|
||||
elif cmd.action is Action.file:
|
||||
ar.start_file(cmd)
|
||||
try:
|
||||
ar.start_file(cmd)
|
||||
except TransmissionError as err:
|
||||
if ar.send_errors:
|
||||
self.send_transmission_error(ar.id, err)
|
||||
except Exception as err:
|
||||
log_error(f'Transmission protocol failed to start file with error: {err}')
|
||||
if ar.send_errors:
|
||||
te = TransmissionError(file_id=cmd.file_id, msg=str(err))
|
||||
self.send_transmission_error(ar.id, te)
|
||||
elif cmd.action in (Action.data, Action.end_data):
|
||||
try:
|
||||
self.add_data(ar, cmd)
|
||||
@@ -260,16 +317,26 @@ class FileTransmission:
|
||||
except Exception:
|
||||
self.drop_receive(cmd.id)
|
||||
|
||||
def send_response(self, id: str = '', **fields: str) -> bool:
|
||||
if 'id' not in fields and id:
|
||||
fields['id'] = id
|
||||
return self.write_response_to_child(fields)
|
||||
def send_status_response(self, code: ErrorCode = ErrorCode.EINVAL, request_id: str = '', file_id: str = '', msg: str = '') -> bool:
|
||||
err = TransmissionError(code=code, msg=msg, file_id=file_id)
|
||||
data = err.as_escape_code(request_id)
|
||||
return self.write_osc_to_child(request_id, data)
|
||||
|
||||
def write_response_to_child(self, fields: Dict[str, str]) -> bool:
|
||||
def send_transmission_error(self, request_id: str, err: TransmissionError) -> bool:
|
||||
return self.write_osc_to_child(request_id, err.as_escape_code())
|
||||
|
||||
def write_osc_to_child(self, request_id: str, payload: str, appendleft: bool = False) -> bool:
|
||||
boss = get_boss()
|
||||
window = boss.window_id_map.get(self.window_id)
|
||||
if window is not None:
|
||||
return window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
|
||||
queued = window.screen.send_escape_code_to_child(OSC, payload)
|
||||
if not queued:
|
||||
if appendleft:
|
||||
self.pending_receive_responses.appendleft((request_id, payload))
|
||||
else:
|
||||
self.pending_receive_responses.append((request_id, payload))
|
||||
self.start_pending_timer()
|
||||
return queued
|
||||
return False
|
||||
|
||||
def start_receive(self, ar_id: str) -> None:
|
||||
@@ -361,11 +428,11 @@ class TestFileTransmission(FileTransmission):
|
||||
|
||||
def __init__(self, allow: bool = True) -> None:
|
||||
super().__init__(0)
|
||||
self.test_responses: List[Dict[str, str]] = []
|
||||
self.test_responses: List[FileTransmissionCommand] = []
|
||||
self.allow = allow
|
||||
|
||||
def write_response_to_child(self, fields: Dict[str, str]) -> bool:
|
||||
self.test_responses.append(fields)
|
||||
def write_osc_to_child(self, data: str) -> bool:
|
||||
self.test_responses.append(FileTransmissionCommand.deserialize(data))
|
||||
return True
|
||||
|
||||
def start_receive(self, aid: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user