diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index e1233763b..587492043 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -6,8 +6,9 @@ import copy import errno import os import tempfile -from base64 import standard_b64decode +from base64 import standard_b64decode, standard_b64encode from enum import Enum, auto +from functools import partial from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Union from kitty.fast_data_types import OSC, get_boss @@ -66,6 +67,24 @@ class FileTransmissionCommand: dest: str = '' data: bytes = b'' + def serialize(self) -> str: + ans = [f'action={self.action.name}'] + if self.container_fmt is not Container.none: + ans.append(f'container_fmt={self.container_fmt.name}') + if self.compression is not Compression.none: + ans.append(f'compression={self.compression.name}') + for x in ('id', 'secret', 'mime', 'quiet'): + val = getattr(self, x) + if val: + ans.append(f'{x}={val}') + if self.dest: + val = standard_b64encode(self.dest.encode('utf-8')).decode('ascii') + ans.append(f'dest={val}') + if self.data: + val = standard_b64encode(self.data).decode('ascii') + ans.append(f'data={val}') + return ';'.join(ans) + def parse_command(data: str) -> FileTransmissionCommand: ans = FileTransmissionCommand() @@ -173,15 +192,23 @@ class ZipExtractor: self.zf.extract(zinfo, targetpath) +class ActiveCommand: + ftc: FileTransmissionCommand + file: Optional[IO[bytes]] = None + dest: str = '' + decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor() + + def __init__(self, ftc: FileTransmissionCommand) -> None: + self.ftc = ftc + + class FileTransmission: - active_cmd: Optional[FileTransmissionCommand] = None - active_file: Optional[IO[bytes]] = None - active_dest: str = '' - active_decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor() + active_cmds: Dict[str, ActiveCommand] def __init__(self, window_id: int): self.window_id = window_id + self.active_cmds = {} def handle_serialized_command(self, data: str) -> None: try: @@ -189,19 +216,22 @@ class FileTransmission: except Exception as e: log_error(f'Failed to parse file transmission command with error: {e}') return - if self.active_cmd is not None: - if cmd.action not in (Action.data, Action.end_data): - log_error('File transmission command received while another is in flight, aborting') - self.abort_in_flight() + if cmd.id in self.active_cmds and cmd.action not in (Action.data, Action.end_data): + log_error('File transmission command received while another is in flight, aborting') + del self.active_cmds[cmd.id] + if cmd.action is Action.send: + self.active_cmds[cmd.id] = ActiveCommand(cmd) self.start_send(cmd) elif cmd.action in (Action.data, Action.end_data): + if cmd.id not in self.active_cmds: + log_error('File transmission data command received with unknown id') + return self.add_data(cmd) - if cmd.action is Action.end_data and self.active_cmd is not None: - self.commit() + if cmd.action is Action.end_data and cmd.id in self.active_cmds: + self.commit(cmd.id) - def send_response(self, **fields: str) -> None: - ac = self.active_cmd + def send_response(self, ac: Optional[FileTransmissionCommand], **fields: str) -> None: if ac is None: return if 'id' not in fields and ac.id: @@ -215,88 +245,74 @@ class FileTransmission: window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items())) def start_send(self, cmd: FileTransmissionCommand) -> None: - self.active_cmd = cmd boss = get_boss() window = boss.window_id_map.get(self.window_id) if window is not None: boss._run_kitten( 'transfer_ask', ['put', 'multiple' if cmd.container_fmt else 'single', cmd.dest], - window=window, custom_callback=self.handle_send_confirmation + window=window, custom_callback=partial(self.handle_send_confirmation, cmd.id), ) - def handle_send_confirmation(self, data: 'Response', *a: Any) -> None: - cmd = self.active_cmd + def handle_send_confirmation(self, cmd_id: str, data: 'Response', *a: Any) -> None: + cmd = self.active_cmds.get(cmd_id) if cmd is None: return if data['allowed']: - self.active_dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest']))) - self.active_decompressor = ZlibDecompressor() if cmd.compression is Compression.zlib else IdentityDecompressor() - if cmd.quiet: + cmd.dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest']))) + cmd.decompressor = ZlibDecompressor() if cmd.ftc.compression is Compression.zlib else IdentityDecompressor() + if cmd.ftc.quiet: return else: - self.active_cmd = None - self.active_dest = '' - if cmd.quiet > 1: + del self.active_cmds[cmd_id] + if cmd.ftc.quiet > 1: return - self.send_response(status='OK' if data['allowed'] else 'EPERM:User refused the transfer') + self.send_response(cmd.ftc, status='OK' if data['allowed'] else 'EPERM:User refused the transfer') - def send_fail_on_os_error(self, err: OSError, msg: str) -> None: - ac = self.active_cmd + def send_fail_on_os_error(self, ac: Optional[FileTransmissionCommand], err: OSError, msg: str) -> None: if ac is None or ac.quiet < 2: return errname = errno.errorcode.get(err.errno, 'EFAIL') - self.send_response(status=f'{errname}:{msg}') + self.send_response(ac, status=f'{errname}:{msg}') def add_data(self, cmd: FileTransmissionCommand) -> None: - ac = self.active_cmd - if ac is None or not self.active_dest: - return - if self.active_file is None: + ac = self.active_cmds.get(cmd.id) + + def abort_in_flight() -> None: + self.active_cmds.pop(cmd.id, None) + + if ac is None or not ac.dest: + return abort_in_flight() + + if ac.file is None: try: - os.makedirs(os.path.dirname(self.active_dest), exist_ok=True) + os.makedirs(os.path.dirname(ac.dest), exist_ok=True) except OSError as e: - self.send_fail_on_os_error(e, 'Creating destination directory failed') - return self.abort_in_flight() - if ac.container_fmt is Container.none: + self.send_fail_on_os_error(ac.ftc, e, 'Creating destination directory failed') + return abort_in_flight() + if ac.ftc.container_fmt is Container.none: try: - self.active_file = open(self.active_dest, 'wb') + ac.file = open(ac.dest, 'wb') except OSError as e: - self.send_fail_on_os_error(e, 'Creating destination file failed') - return self.abort_in_flight() + self.send_fail_on_os_error(ac.ftc, e, 'Creating destination file failed') + return abort_in_flight() else: try: - self.active_file = tempfile.TemporaryFile(dir=os.path.dirname(self.active_dest)) + ac.file = tempfile.TemporaryFile(dir=os.path.dirname(ac.dest)) except OSError as e: - self.send_fail_on_os_error(e, 'Creating destination temp file failed') - return self.abort_in_flight() - data = self.active_decompressor(cmd.data, cmd.action is Action.end_data) + self.send_fail_on_os_error(ac.ftc, e, 'Creating destination temp file failed') + return abort_in_flight() + data = ac.decompressor(cmd.data, cmd.action is Action.end_data) try: - self.active_file.write(data) + ac.file.write(data) except OSError as e: - self.send_fail_on_os_error(e, 'Writing to destination file failed') - return self.abort_in_flight() + self.send_fail_on_os_error(ac.ftc, e, 'Writing to destination file failed') + return abort_in_flight() - def commit(self) -> None: - cmd = self.active_cmd - if cmd is None: - return - try: - if cmd.container_fmt and self.active_file is not None: - self.active_file.seek(0, os.SEEK_SET) - Container.extractor_for_container_fmt(self.active_file, cmd.container_fmt)(self.active_dest) - finally: - self.active_cmd = None - self.active_dest = '' - if self.active_file is not None: - self.active_file.close() - self.active_file = None - - def abort_in_flight(self) -> None: - self.active_cmd = None - self.active_dest = '' - if self.active_file is not None: - self.active_file.close() - self.active_file = None + def commit(self, cmd_id: str) -> None: + cmd = self.active_cmds.pop(cmd_id, None) + if cmd is not None and cmd.ftc.container_fmt and cmd.file is not None: + cmd.file.seek(0, os.SEEK_SET) + Container.extractor_for_container_fmt(cmd.file, cmd.ftc.container_fmt)(cmd.dest) class TestFileTransmission(FileTransmission): @@ -310,5 +326,5 @@ class TestFileTransmission(FileTransmission): self.test_responses.append(fields) def start_send(self, cmd: FileTransmissionCommand) -> None: - self.active_cmd = cmd - self.handle_send_confirmation({'dest': self.test_dest, 'allowed': bool(self.test_dest)}) + dest = cmd.dest or self.test_dest + self.handle_send_confirmation(cmd.id, {'dest': dest, 'allowed': bool(dest)}) diff --git a/kitty_tests/file_transmission.py b/kitty_tests/file_transmission.py new file mode 100644 index 000000000..f09880f49 --- /dev/null +++ b/kitty_tests/file_transmission.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# vim:fileencoding=utf-8 +# License: GPLv3 Copyright: 2021, Kovid Goyal + + +import os +import shutil +import tempfile + +from kitty.file_transmission import ( + Action, Compression, Container, FileTransmissionCommand, + TestFileTransmission as FileTransmission +) + +from . import BaseTest + + +def serialized_cmd(**fields) -> str: + for k, A in (('action', Action), ('container_fmt', Container), ('compression', Compression)): + if k in fields: + fields[k] = A[fields[k]] + if isinstance(fields.get('data'), str): + fields['data'] = fields['data'].encode('utf-8') + ans = FileTransmissionCommand() + for k in fields: + setattr(ans, k, fields[k]) + return ans.serialize() + + +class TestFileTransmission(BaseTest): + + def setUp(self): + self.tdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tdir) + + def test_file_put(self): + ft = FileTransmission() + ft.handle_serialized_command(serialized_cmd(action='send', id='1', dest=os.path.join(self.tdir, '1.bin'))) + self.assertIn('1', ft.active_cmds) + self.ae(os.path.basename(ft.active_cmds['1'].dest), '1.bin') + self.assertIsNone(ft.active_cmds['1'].file)