More work on transmission

This commit is contained in:
Kovid Goyal
2021-08-29 21:09:20 +05:30
parent 1d9425ecdc
commit 4ddd2bf980

View File

@@ -2,18 +2,19 @@
# vim:fileencoding=utf-8
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
import copy
import errno
import os
import tempfile
from base64 import standard_b64decode, standard_b64encode
from time import monotonic
from collections import deque
from dataclasses import Field, dataclass, field, fields
from enum import Enum, auto
from functools import partial
from typing import IO, Any, Dict, List, Optional, Union
from gettext import gettext as _
from time import monotonic
from typing import IO, Any, Deque, Dict, List, Optional, Tuple
from kitty.fast_data_types import OSC, get_boss
from kitty.fast_data_types import OSC, add_timer, get_boss
from .utils import log_error, sanitize_control_codes
@@ -28,6 +29,7 @@ class Action(Enum):
receive = auto()
invalid = auto()
cancel = auto()
status = auto()
class Compression(Enum):
@@ -48,84 +50,105 @@ class TransmisstionType(Enum):
rsync = auto()
class ErrorCode(Enum):
EINVAL = auto()
OK = auto()
class TransmissionError(Exception):
def __init__(
self, code: ErrorCode = ErrorCode.EINVAL,
msg: str = 'Generic error',
transmit: bool = True,
file_id: str = ''
) -> None:
Exception.__init__(self, msg)
self.transmit = transmit
self.file_id = file_id
self.human_msg = msg
self.code = code
def as_escape_code(self, request_id: str = '') -> str:
return FileTransmissionCommand(
action=Action.status, id=request_id, file_id=self.file_id,
name=f'{self.code.name}:{self.human_msg}'
).serialize()
@dataclass
class FileTransmissionCommand:
action = Action.invalid
compression = Compression.none
ftype = FileType.regular
ttype = TransmisstionType.simple
action: Action = Action.invalid
compression: Compression = Compression.none
ftype: FileType = FileType.regular
ttype: TransmisstionType = TransmisstionType.simple
id: str = ''
file_id: str = ''
secret: str = ''
mime: str = ''
quiet: int = 0
name: str = ''
mtime: int = -1
permissions: int = -1
data: bytes = b''
name: str = field(default='', metadata={'base64': True})
def serialize(self) -> str:
ans = [f'action={self.action.name}']
if self.compression is not Compression.none:
ans.append(f'compression={self.compression.name}')
if self.ftype is not FileType.regular:
ans.append(f'ftype={self.ftype.name}')
if self.ttype is not TransmisstionType.simple:
ans.append(f'ttype={self.ttype.name}')
for x in ('id', 'file_id', 'secret', 'mime', 'quiet'):
val = getattr(self, x)
if val:
ans.append(f'{x}={val}')
for k in ('mtime', 'permissions'):
val = getattr(self, k)
if val >= 0:
ans.append(f'{k}={val}')
if self.name:
val = standard_b64encode(self.name.encode('utf-8')).decode('ascii')
ans.append(f'name={val}')
if self.data:
val = standard_b64encode(self.data).decode('ascii')
ans.append(f'data={val}')
ans = []
for k in fields(self):
val = getattr(self, k.name)
if val == k.default:
continue
if issubclass(k.type, Enum):
ans.append(f'{k.name}={val.name}')
elif k.type is bytes:
ev = standard_b64encode(val).decode('ascii')
ans.append(f'{k.name}={ev}')
elif k.type is str:
if k.metadata.get('base64'):
sval = standard_b64encode(self.name.encode('utf-8')).decode('ascii')
else:
sval = val
ans.append(f'{k.name}={sanitize_control_codes(sval)}')
elif k.type is int:
ans.append(f'{k.name}={val}')
else:
raise KeyError(f'Field of unknown type: {k.name}')
def escape_semicolons(x: str) -> str:
return x.replace(';', ';;')
return ';'.join(map(escape_semicolons, ans))
@classmethod
def deserialize(cls, data: str) -> 'FileTransmissionCommand':
ans = FileTransmissionCommand()
parts = (x.replace('\0', ';').partition('=')[::2] for x in data.replace(';;', '\0').split(';'))
if not hasattr(cls, 'fmap'):
setattr(cls, 'fmap', {k.name: k for k in fields(cls)})
fmap: Dict[str, Field] = getattr(cls, 'fmap')
def parse_command(data: str) -> FileTransmissionCommand:
ans = FileTransmissionCommand()
parts = data.replace(';;', '\0').split(';')
for k, v in parts:
field = fmap.get(k)
if field is None:
continue
if issubclass(field.type, Enum):
setattr(ans, field.name, field.type[v])
elif field.type is bytes:
setattr(ans, field.name, standard_b64decode(v))
elif field.type is int:
setattr(ans, field.name, int(v))
elif field.type is str:
if field.metadata.get('base64'):
sval = standard_b64decode(v).decode('utf-8')
else:
sval = v
setattr(ans, field.name, sanitize_control_codes(sval))
for i, x in enumerate(parts):
k, v = x.replace('\0', ';').partition('=')[::2]
if k == 'action':
ans.action = Action[v]
elif k == 'compression':
ans.compression = Compression[v]
elif k == 'ftype':
ans.ftype = FileType[v]
elif k == 'ttype':
ans.ttype = TransmisstionType[v]
elif k in ('secret', 'mime', 'id', 'file_id'):
setattr(ans, k, sanitize_control_codes(v))
elif k in ('quiet',):
setattr(ans, k, int(v))
elif k in ('mtime', 'permissions'):
mt = int(v)
if mt >= 0:
setattr(ans, k, mt)
elif k in ('name', 'data'):
val = standard_b64decode(v)
if k == 'name':
ans.name = sanitize_control_codes(val.decode('utf-8'))
else:
ans.data = val
if ans.action is Action.invalid:
raise ValueError('No valid action specified in file transmission command')
if ans.action is Action.invalid:
raise ValueError('No valid action specified in file transmission command')
return ans
return ans
class IdentityDecompressor:
@@ -157,9 +180,11 @@ class DestFile:
self.ttype = ftc.ttype
self.needs_data_sent = self.ttype is not TransmisstionType.simple
self.decompressor = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor()
self.closed = self.ftype is FileType.directory
def close(self) -> None:
pass
if not self.closed:
self.closed = True
class ActiveReceive:
@@ -167,10 +192,12 @@ class ActiveReceive:
files: Dict[str, DestFile]
accepted: bool = False
def __init__(self, id: str) -> None:
def __init__(self, id: str, quiet: int) -> None:
self.id = id
self.files = {}
self.last_activity_at = monotonic()
self.send_acknowledgements = quiet < 1
self.send_errors = quiet < 2
@property
def is_expired(self) -> bool:
@@ -186,7 +213,10 @@ class ActiveReceive:
def start_file(self, ftc: FileTransmissionCommand) -> DestFile:
if ftc.file_id in self.files:
raise KeyError(f'The file_id {ftc.file_id} already exists')
raise TransmissionError(
msg=f'The file_id {ftc.file_id} already exists',
file_id=ftc.file_id,
)
self.files[ftc.file_id] = result = DestFile(ftc)
return result
@@ -198,6 +228,24 @@ class FileTransmission:
def __init__(self, window_id: int):
self.window_id = window_id
self.active_receives = {}
self.pending_receive_responses: Deque[Tuple[str, str]] = deque()
self.pending_timer: Optional[int] = None
def start_pending_timer(self) -> None:
if self.pending_timer is None:
self.pending_timer = add_timer(self.try_pending, 0.2, False)
def try_pending(self, timer_id: Optional[int]) -> None:
self.pending_timer = None
while self.pending_receive_responses:
request_id, payload = self.pending_receive_responses.popleft()
ar = self.active_receives.get(request_id)
if ar is None:
continue
if not self.write_osc_to_child(request_id, payload, appendleft=True):
break
ar.last_activity_at = monotonic()
self.prune_expired()
def __del__(self) -> None:
for ar in self.active_receives.values():
@@ -217,7 +265,7 @@ class FileTransmission:
def handle_serialized_command(self, data: str) -> None:
self.prune_expired()
try:
cmd = parse_command(data)
cmd = FileTransmissionCommand.deserialize(data)
except Exception as e:
log_error(f'Failed to parse file transmission command with error: {e}')
return
@@ -226,11 +274,11 @@ class FileTransmission:
def handle_receive_cmd(self, cmd: FileTransmissionCommand) -> None:
if cmd.id in self.active_receives:
ar = self.active_receives[cmd.id]
if cmd.action is Action.send:
log_error('File transmission send received for already active id, aborting')
self.drop_receive(cmd.id)
return
ar = self.active_receives[cmd.id]
if not ar.accepted:
log_error(f'File transmission command received for rejected id: {cmd.id}, aborting')
self.drop_receive(cmd.id)
@@ -240,14 +288,23 @@ class FileTransmission:
if cmd.action is not Action.send:
log_error(f'File transmission command received for unknown or rejected id: {cmd.id}, ignoring')
return
ar = ActiveReceive(cmd.id)
ar = ActiveReceive(cmd.id, cmd.quiet)
self.start_receive(ar.id)
return
if cmd.action is Action.cancel:
self.drop_receive(ar.id)
elif cmd.action is Action.file:
ar.start_file(cmd)
try:
ar.start_file(cmd)
except TransmissionError as err:
if ar.send_errors:
self.send_transmission_error(ar.id, err)
except Exception as err:
log_error(f'Transmission protocol failed to start file with error: {err}')
if ar.send_errors:
te = TransmissionError(file_id=cmd.file_id, msg=str(err))
self.send_transmission_error(ar.id, te)
elif cmd.action in (Action.data, Action.end_data):
try:
self.add_data(ar, cmd)
@@ -260,16 +317,26 @@ class FileTransmission:
except Exception:
self.drop_receive(cmd.id)
def send_response(self, id: str = '', **fields: str) -> bool:
if 'id' not in fields and id:
fields['id'] = id
return self.write_response_to_child(fields)
def send_status_response(self, code: ErrorCode = ErrorCode.EINVAL, request_id: str = '', file_id: str = '', msg: str = '') -> bool:
err = TransmissionError(code=code, msg=msg, file_id=file_id)
data = err.as_escape_code(request_id)
return self.write_osc_to_child(request_id, data)
def write_response_to_child(self, fields: Dict[str, str]) -> bool:
def send_transmission_error(self, request_id: str, err: TransmissionError) -> bool:
return self.write_osc_to_child(request_id, err.as_escape_code())
def write_osc_to_child(self, request_id: str, payload: str, appendleft: bool = False) -> bool:
boss = get_boss()
window = boss.window_id_map.get(self.window_id)
if window is not None:
return window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
queued = window.screen.send_escape_code_to_child(OSC, payload)
if not queued:
if appendleft:
self.pending_receive_responses.appendleft((request_id, payload))
else:
self.pending_receive_responses.append((request_id, payload))
self.start_pending_timer()
return queued
return False
def start_receive(self, ar_id: str) -> None:
@@ -361,11 +428,11 @@ class TestFileTransmission(FileTransmission):
def __init__(self, allow: bool = True) -> None:
super().__init__(0)
self.test_responses: List[Dict[str, str]] = []
self.test_responses: List[FileTransmissionCommand] = []
self.allow = allow
def write_response_to_child(self, fields: Dict[str, str]) -> bool:
self.test_responses.append(fields)
def write_osc_to_child(self, data: str) -> bool:
self.test_responses.append(FileTransmissionCommand.deserialize(data))
return True
def start_receive(self, aid: str) -> None: