diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index f51a0e87b..4167c6b03 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -2,18 +2,19 @@ # vim:fileencoding=utf-8 # License: GPLv3 Copyright: 2021, Kovid Goyal -import copy import errno import os import tempfile from base64 import standard_b64decode, standard_b64encode -from time import monotonic +from collections import deque +from dataclasses import Field, dataclass, field, fields from enum import Enum, auto from functools import partial -from typing import IO, Any, Dict, List, Optional, Union from gettext import gettext as _ +from time import monotonic +from typing import IO, Any, Deque, Dict, List, Optional, Tuple -from kitty.fast_data_types import OSC, get_boss +from kitty.fast_data_types import OSC, add_timer, get_boss from .utils import log_error, sanitize_control_codes @@ -28,6 +29,7 @@ class Action(Enum): receive = auto() invalid = auto() cancel = auto() + status = auto() class Compression(Enum): @@ -48,84 +50,105 @@ class TransmisstionType(Enum): rsync = auto() +class ErrorCode(Enum): + EINVAL = auto() + OK = auto() + + +class TransmissionError(Exception): + + def __init__( + self, code: ErrorCode = ErrorCode.EINVAL, + msg: str = 'Generic error', + transmit: bool = True, + file_id: str = '' + ) -> None: + Exception.__init__(self, msg) + self.transmit = transmit + self.file_id = file_id + self.human_msg = msg + self.code = code + + def as_escape_code(self, request_id: str = '') -> str: + return FileTransmissionCommand( + action=Action.status, id=request_id, file_id=self.file_id, + name=f'{self.code.name}:{self.human_msg}' + ).serialize() + + +@dataclass class FileTransmissionCommand: - action = Action.invalid - compression = Compression.none - ftype = FileType.regular - ttype = TransmisstionType.simple + action: Action = Action.invalid + compression: Compression = Compression.none + ftype: FileType = FileType.regular + ttype: TransmisstionType = TransmisstionType.simple id: str = '' file_id: str = '' secret: str = '' mime: str = '' quiet: int = 0 - name: str = '' mtime: int = -1 permissions: int = -1 data: bytes = b'' + name: str = field(default='', metadata={'base64': True}) def serialize(self) -> str: - ans = [f'action={self.action.name}'] - if self.compression is not Compression.none: - ans.append(f'compression={self.compression.name}') - if self.ftype is not FileType.regular: - ans.append(f'ftype={self.ftype.name}') - if self.ttype is not TransmisstionType.simple: - ans.append(f'ttype={self.ttype.name}') - for x in ('id', 'file_id', 'secret', 'mime', 'quiet'): - val = getattr(self, x) - if val: - ans.append(f'{x}={val}') - for k in ('mtime', 'permissions'): - val = getattr(self, k) - if val >= 0: - ans.append(f'{k}={val}') - if self.name: - val = standard_b64encode(self.name.encode('utf-8')).decode('ascii') - ans.append(f'name={val}') - if self.data: - val = standard_b64encode(self.data).decode('ascii') - ans.append(f'data={val}') + ans = [] + for k in fields(self): + val = getattr(self, k.name) + if val == k.default: + continue + if issubclass(k.type, Enum): + ans.append(f'{k.name}={val.name}') + elif k.type is bytes: + ev = standard_b64encode(val).decode('ascii') + ans.append(f'{k.name}={ev}') + elif k.type is str: + if k.metadata.get('base64'): + sval = standard_b64encode(self.name.encode('utf-8')).decode('ascii') + else: + sval = val + ans.append(f'{k.name}={sanitize_control_codes(sval)}') + elif k.type is int: + ans.append(f'{k.name}={val}') + else: + raise KeyError(f'Field of unknown type: {k.name}') def escape_semicolons(x: str) -> str: return x.replace(';', ';;') return ';'.join(map(escape_semicolons, ans)) + @classmethod + def deserialize(cls, data: str) -> 'FileTransmissionCommand': + ans = FileTransmissionCommand() + parts = (x.replace('\0', ';').partition('=')[::2] for x in data.replace(';;', '\0').split(';')) + if not hasattr(cls, 'fmap'): + setattr(cls, 'fmap', {k.name: k for k in fields(cls)}) + fmap: Dict[str, Field] = getattr(cls, 'fmap') -def parse_command(data: str) -> FileTransmissionCommand: - ans = FileTransmissionCommand() - parts = data.replace(';;', '\0').split(';') + for k, v in parts: + field = fmap.get(k) + if field is None: + continue + if issubclass(field.type, Enum): + setattr(ans, field.name, field.type[v]) + elif field.type is bytes: + setattr(ans, field.name, standard_b64decode(v)) + elif field.type is int: + setattr(ans, field.name, int(v)) + elif field.type is str: + if field.metadata.get('base64'): + sval = standard_b64decode(v).decode('utf-8') + else: + sval = v + setattr(ans, field.name, sanitize_control_codes(sval)) - for i, x in enumerate(parts): - k, v = x.replace('\0', ';').partition('=')[::2] - if k == 'action': - ans.action = Action[v] - elif k == 'compression': - ans.compression = Compression[v] - elif k == 'ftype': - ans.ftype = FileType[v] - elif k == 'ttype': - ans.ttype = TransmisstionType[v] - elif k in ('secret', 'mime', 'id', 'file_id'): - setattr(ans, k, sanitize_control_codes(v)) - elif k in ('quiet',): - setattr(ans, k, int(v)) - elif k in ('mtime', 'permissions'): - mt = int(v) - if mt >= 0: - setattr(ans, k, mt) - elif k in ('name', 'data'): - val = standard_b64decode(v) - if k == 'name': - ans.name = sanitize_control_codes(val.decode('utf-8')) - else: - ans.data = val + if ans.action is Action.invalid: + raise ValueError('No valid action specified in file transmission command') - if ans.action is Action.invalid: - raise ValueError('No valid action specified in file transmission command') - - return ans + return ans class IdentityDecompressor: @@ -157,9 +180,11 @@ class DestFile: self.ttype = ftc.ttype self.needs_data_sent = self.ttype is not TransmisstionType.simple self.decompressor = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor() + self.closed = self.ftype is FileType.directory def close(self) -> None: - pass + if not self.closed: + self.closed = True class ActiveReceive: @@ -167,10 +192,12 @@ class ActiveReceive: files: Dict[str, DestFile] accepted: bool = False - def __init__(self, id: str) -> None: + def __init__(self, id: str, quiet: int) -> None: self.id = id self.files = {} self.last_activity_at = monotonic() + self.send_acknowledgements = quiet < 1 + self.send_errors = quiet < 2 @property def is_expired(self) -> bool: @@ -186,7 +213,10 @@ class ActiveReceive: def start_file(self, ftc: FileTransmissionCommand) -> DestFile: if ftc.file_id in self.files: - raise KeyError(f'The file_id {ftc.file_id} already exists') + raise TransmissionError( + msg=f'The file_id {ftc.file_id} already exists', + file_id=ftc.file_id, + ) self.files[ftc.file_id] = result = DestFile(ftc) return result @@ -198,6 +228,24 @@ class FileTransmission: def __init__(self, window_id: int): self.window_id = window_id self.active_receives = {} + self.pending_receive_responses: Deque[Tuple[str, str]] = deque() + self.pending_timer: Optional[int] = None + + def start_pending_timer(self) -> None: + if self.pending_timer is None: + self.pending_timer = add_timer(self.try_pending, 0.2, False) + + def try_pending(self, timer_id: Optional[int]) -> None: + self.pending_timer = None + while self.pending_receive_responses: + request_id, payload = self.pending_receive_responses.popleft() + ar = self.active_receives.get(request_id) + if ar is None: + continue + if not self.write_osc_to_child(request_id, payload, appendleft=True): + break + ar.last_activity_at = monotonic() + self.prune_expired() def __del__(self) -> None: for ar in self.active_receives.values(): @@ -217,7 +265,7 @@ class FileTransmission: def handle_serialized_command(self, data: str) -> None: self.prune_expired() try: - cmd = parse_command(data) + cmd = FileTransmissionCommand.deserialize(data) except Exception as e: log_error(f'Failed to parse file transmission command with error: {e}') return @@ -226,11 +274,11 @@ class FileTransmission: def handle_receive_cmd(self, cmd: FileTransmissionCommand) -> None: if cmd.id in self.active_receives: + ar = self.active_receives[cmd.id] if cmd.action is Action.send: log_error('File transmission send received for already active id, aborting') self.drop_receive(cmd.id) return - ar = self.active_receives[cmd.id] if not ar.accepted: log_error(f'File transmission command received for rejected id: {cmd.id}, aborting') self.drop_receive(cmd.id) @@ -240,14 +288,23 @@ class FileTransmission: if cmd.action is not Action.send: log_error(f'File transmission command received for unknown or rejected id: {cmd.id}, ignoring') return - ar = ActiveReceive(cmd.id) + ar = ActiveReceive(cmd.id, cmd.quiet) self.start_receive(ar.id) return if cmd.action is Action.cancel: self.drop_receive(ar.id) elif cmd.action is Action.file: - ar.start_file(cmd) + try: + ar.start_file(cmd) + except TransmissionError as err: + if ar.send_errors: + self.send_transmission_error(ar.id, err) + except Exception as err: + log_error(f'Transmission protocol failed to start file with error: {err}') + if ar.send_errors: + te = TransmissionError(file_id=cmd.file_id, msg=str(err)) + self.send_transmission_error(ar.id, te) elif cmd.action in (Action.data, Action.end_data): try: self.add_data(ar, cmd) @@ -260,16 +317,26 @@ class FileTransmission: except Exception: self.drop_receive(cmd.id) - def send_response(self, id: str = '', **fields: str) -> bool: - if 'id' not in fields and id: - fields['id'] = id - return self.write_response_to_child(fields) + def send_status_response(self, code: ErrorCode = ErrorCode.EINVAL, request_id: str = '', file_id: str = '', msg: str = '') -> bool: + err = TransmissionError(code=code, msg=msg, file_id=file_id) + data = err.as_escape_code(request_id) + return self.write_osc_to_child(request_id, data) - def write_response_to_child(self, fields: Dict[str, str]) -> bool: + def send_transmission_error(self, request_id: str, err: TransmissionError) -> bool: + return self.write_osc_to_child(request_id, err.as_escape_code()) + + def write_osc_to_child(self, request_id: str, payload: str, appendleft: bool = False) -> bool: boss = get_boss() window = boss.window_id_map.get(self.window_id) if window is not None: - return window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items())) + queued = window.screen.send_escape_code_to_child(OSC, payload) + if not queued: + if appendleft: + self.pending_receive_responses.appendleft((request_id, payload)) + else: + self.pending_receive_responses.append((request_id, payload)) + self.start_pending_timer() + return queued return False def start_receive(self, ar_id: str) -> None: @@ -361,11 +428,11 @@ class TestFileTransmission(FileTransmission): def __init__(self, allow: bool = True) -> None: super().__init__(0) - self.test_responses: List[Dict[str, str]] = [] + self.test_responses: List[FileTransmissionCommand] = [] self.allow = allow - def write_response_to_child(self, fields: Dict[str, str]) -> bool: - self.test_responses.append(fields) + def write_osc_to_child(self, data: str) -> bool: + self.test_responses.append(FileTransmissionCommand.deserialize(data)) return True def start_receive(self, aid: str) -> None: