From 40786427b02f2736b94c69df004c3303e95b296b Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Wed, 1 Sep 2021 09:22:04 +0530 Subject: [PATCH] More work on file transmission --- kitty/file_transmission.py | 211 +++++++++++++++++++++++++++---- kitty_tests/file_transmission.py | 177 ++++++++++++++------------ 2 files changed, 282 insertions(+), 106 deletions(-) diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index 203ce5dd2..7e2dc9946 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -3,14 +3,18 @@ # License: GPLv3 Copyright: 2021, Kovid Goyal import errno +import os +import stat +import tempfile from base64 import standard_b64decode, standard_b64encode from collections import deque +from contextlib import suppress from dataclasses import Field, dataclass, field, fields from enum import Enum, auto from functools import partial from gettext import gettext as _ from time import monotonic -from typing import Any, Deque, Dict, List, Optional, Tuple, Union +from typing import IO, Any, Callable, Deque, Dict, List, Optional, Tuple, Union from kitty.fast_data_types import OSC, add_timer, get_boss @@ -19,7 +23,13 @@ from .utils import log_error, sanitize_control_codes EXPIRE_TIME = 10 # minutes -class Action(Enum): +class NameReprEnum(Enum): + + def __repr__(self) -> str: + return f'<{self.__class__.__name__}.{self.name}>' + + +class Action(NameReprEnum): send = auto() file = auto() data = auto() @@ -28,27 +38,28 @@ class Action(Enum): invalid = auto() cancel = auto() status = auto() + finish = auto() -class Compression(Enum): +class Compression(NameReprEnum): zlib = auto() none = auto() -class FileType(Enum): +class FileType(NameReprEnum): regular = auto() directory = auto() symlink = auto() link = auto() -class TransmissionType(Enum): +class TransmissionType(NameReprEnum): simple = auto() resume = auto() rsync = auto() -ErrorCode = Enum('ErrorCode', 'OK EINVAL EPERM') +ErrorCode = Enum('ErrorCode', 'OK EINVAL EPERM EISDIR') class TransmissionError(Exception): @@ -57,19 +68,22 @@ class TransmissionError(Exception): self, code: Union[ErrorCode, str] = ErrorCode.EINVAL, msg: str = 'Generic error', transmit: bool = True, - file_id: str = '' + file_id: str = '', + name: str = '' ) -> None: Exception.__init__(self, msg) self.transmit = transmit self.file_id = file_id self.human_msg = msg self.code = code + self.name = name def as_escape_code(self, request_id: str = '') -> str: - code = self.code if isinstance(self.code, str) else self.code.name + name = self.code if isinstance(self.code, str) else self.code.name + if self.human_msg: + name += ':' + self.human_msg return FileTransmissionCommand( - action=Action.status, id=request_id, file_id=self.file_id, - name=f'{code}:{self.human_msg}' + action=Action.status, id=request_id, file_id=self.file_id, status=name, name=self.name ).serialize() @@ -83,12 +97,23 @@ class FileTransmissionCommand: id: str = '' file_id: str = '' secret: str = '' - mime: str = '' quiet: int = 0 mtime: int = -1 permissions: int = -1 data: bytes = b'' name: str = field(default='', metadata={'base64': True}) + status: str = field(default='', metadata={'base64': True}) + + def asdict(self, keep_defaults: bool = False) -> Dict[str, Union[str, int, bytes]]: + ans = {} + for k in fields(self): + val = getattr(self, k.name) + if not keep_defaults and val == k.default: + continue + if issubclass(k.type, Enum): + val = val.name + ans[k.name] = val + return ans def serialize(self) -> str: ans = [] @@ -103,7 +128,7 @@ class FileTransmissionCommand: ans.append(f'{k.name}={ev}') elif k.type is str: if k.metadata.get('base64'): - sval = standard_b64encode(self.name.encode('utf-8')).decode('ascii') + sval = standard_b64encode(val.encode('utf-8')).decode('ascii') else: sval = val ans.append(f'{k.name}={sanitize_control_codes(sval)}') @@ -171,17 +196,95 @@ class DestFile: def __init__(self, ftc: FileTransmissionCommand) -> None: self.name = ftc.name + if not os.path.isabs(self.name): + self.name = os.path.join(tempfile.gettempdir(), self.name) self.mtime = ftc.mtime + self.file_id = ftc.file_id self.permissions = ftc.permissions + if self.permissions != FileTransmissionCommand.permissions: + self.permissions = stat.S_IMODE(self.permissions) self.ftype = ftc.ftype self.ttype = ftc.ttype + self.link_target = b'' self.needs_data_sent = self.ttype is not TransmissionType.simple - self.decompressor = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor() + self.decompressor: Union[ZlibDecompressor, IdentityDecompressor] = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor() self.closed = self.ftype is FileType.directory + self.actual_file: Optional[IO[bytes]] = None + + def __repr__(self) -> str: + return f'DestFile(name={self.name}, file_id={self.file_id}, actual_file={self.actual_file})' def close(self) -> None: if not self.closed: self.closed = True + if self.actual_file is not None: + self.actual_file.close() + self.actual_file = None + + def make_parent_dirs(self) -> str: + d = os.path.dirname(self.name) + if d: + os.makedirs(d, exist_ok=True) + return d + + def apply_metadata(self, is_symlink: bool = False) -> None: + if self.permissions != FileTransmissionCommand.permissions: + if is_symlink: + with suppress(NotImplementedError): + os.chmod(self.name, self.permissions, follow_symlinks=False) + else: + os.chmod(self.name, self.permissions) + if self.mtime != FileTransmissionCommand.mtime: + if is_symlink: + with suppress(NotImplementedError): + os.utime(self.name, ns=(self.mtime, self.mtime), follow_symlinks=False) + else: + os.utime(self.name, ns=(self.mtime, self.mtime)) + + def write_data(self, all_files: Dict[str, 'DestFile'], data: bytes, is_last: bool) -> None: + if self.ftype is FileType.directory: + raise TransmissionError(code=ErrorCode.EISDIR, file_id=self.file_id, msg='Cannot write data to a directory entry') + if self.closed: + raise TransmissionError(file_id=self.file_id, msg='Cannot write to a closed file') + if self.ftype in (FileType.symlink, FileType.link): + self.link_target += data + if is_last: + lt = self.link_target.decode('utf-8', 'replace') + base = self.make_parent_dirs() + if lt.startswith('fid:'): + lt = all_files[lt[4:]].name + if self.ftype is FileType.symlink: + try: + cp = os.path.commonpath((self.name, lt)) + except ValueError: + pass + else: + if cp: + lt = os.path.relpath(lt, cp) + elif lt.startswith('fid_abs:'): + lt = all_files[lt[8:]].name + elif lt.startswith('path:'): + lt = lt[5:] + if not os.path.isabs(lt) and self.ftype is FileType.link: + lt = os.path.join(base, lt) + lt = lt.replace('/', os.sep) + else: + raise TransmissionError(msg='Unknown link target type', file_id=self.file_id) + if self.ftype is FileType.symlink: + os.symlink(lt, self.name) + else: + os.link(lt, self.name) + self.close() + self.apply_metadata(is_symlink=True) + elif self.ftype is FileType.regular: + if self.actual_file is None: + self.make_parent_dirs() + self.actual_file = open(self.name, 'wb') + data = self.decompressor(data, is_last=is_last) + self.actual_file.write(data) + if is_last: + self.close() + self.apply_metadata() class ActiveReceive: @@ -209,13 +312,33 @@ class ActiveReceive: self.close() def start_file(self, ftc: FileTransmissionCommand) -> DestFile: + self.last_activity_at = monotonic() if ftc.file_id in self.files: raise TransmissionError( msg=f'The file_id {ftc.file_id} already exists', file_id=ftc.file_id, ) - self.files[ftc.file_id] = result = DestFile(ftc) - return result + self.files[ftc.file_id] = df = DestFile(ftc) + return df + + def add_data(self, ftc: FileTransmissionCommand) -> DestFile: + self.last_activity_at = monotonic() + df = self.files.get(ftc.file_id) + if df is None: + raise TransmissionError(file_id=ftc.file_id, msg='Cannot write to a file without first starting it') + try: + df.write_data(self.files, ftc.data, ftc.action is Action.end_data) + except Exception: + df.close() + raise + return df + + def commit(self, send_os_error: Callable[[OSError, str, 'ActiveReceive', str], None]) -> None: + directories = sorted((df for df in self.files.values() if df.ftype is FileType.directory), key=lambda x: len(x.name), reverse=True) + for df in directories: + with suppress(OSError): + # we ignore failures to apply directory metadata as we have already sent an OK for the dir + df.apply_metadata() class FileTransmission: @@ -285,7 +408,7 @@ class FileTransmission: if cmd.action is not Action.send: log_error(f'File transmission command received for unknown or rejected id: {cmd.id}, ignoring') return - ar = ActiveReceive(cmd.id, cmd.quiet) + ar = self.active_receives[cmd.id] = ActiveReceive(cmd.id, cmd.quiet) self.start_receive(ar.id) return @@ -293,7 +416,7 @@ class FileTransmission: self.drop_receive(ar.id) elif cmd.action is Action.file: try: - ar.start_file(cmd) + df = ar.start_file(cmd) except TransmissionError as err: if ar.send_errors: self.send_transmission_error(ar.id, err) @@ -302,14 +425,47 @@ class FileTransmission: if ar.send_errors: te = TransmissionError(file_id=cmd.file_id, msg=str(err)) self.send_transmission_error(ar.id, te) + else: + if df.ftype is FileType.directory: + try: + os.makedirs(df.name, exist_ok=True) + except OSError as err: + self.send_fail_on_os_error(err, 'Failed to create directory', ar, df.file_id) + else: + self.send_status_response(ErrorCode.OK, ar.id, df.file_id, name=df.name) elif cmd.action in (Action.data, Action.end_data): - pass + try: + df = ar.add_data(cmd) + if df.closed and ar.send_acknowledgements: + self.send_status_response(code=ErrorCode.OK, request_id=ar.id, file_id=df.file_id, name=df.name) + except TransmissionError as err: + if ar.send_errors: + self.send_transmission_error(ar.id, err) + except Exception as err: + log_error(f'Transmission protocol failed to write data to file with error: {err}') + if ar.send_errors: + te = TransmissionError(file_id=cmd.file_id, msg=str(err)) + self.send_transmission_error(ar.id, te) + elif cmd.action is Action.finish: + try: + ar.commit(self.send_fail_on_os_error) + except TransmissionError as err: + if ar.send_errors: + self.send_transmission_error(ar.id, err) + except Exception as err: + log_error(f'Transmission protocol failed to commit receive with error: {err}') + if ar.send_errors: + te = TransmissionError(msg=str(err)) + self.send_transmission_error(ar.id, te) + finally: + self.drop_receive(ar.id) def send_status_response( self, code: Union[ErrorCode, str] = ErrorCode.EINVAL, - request_id: str = '', file_id: str = '', msg: str = '' + request_id: str = '', file_id: str = '', msg: str = '', + name: str = '', ) -> bool: - err = TransmissionError(code=code, msg=msg, file_id=file_id) + err = TransmissionError(code=code, msg=msg, file_id=file_id, name=name) data = err.as_escape_code(request_id) return self.write_osc_to_child(request_id, data) @@ -351,9 +507,11 @@ class FileTransmission: else: self.drop_receive(ar.id) if ar.accepted: - self.send_status_response(code=ErrorCode.OK, request_id=ar.id) + if ar.send_acknowledgements: + self.send_status_response(code=ErrorCode.OK, request_id=ar.id) else: - self.send_status_response(code=ErrorCode.EPERM, request_id=ar.id, msg='User refused the transfer') + if ar.send_errors: + self.send_status_response(code=ErrorCode.EPERM, request_id=ar.id, msg='User refused the transfer') def send_fail_on_os_error(self, err: OSError, msg: str, ar: ActiveReceive, file_id: str = '') -> None: if not ar.send_errors: @@ -361,17 +519,20 @@ class FileTransmission: errname = errno.errorcode.get(err.errno, 'EFAIL') self.send_status_response(code=errname, msg=msg, request_id=ar.id, file_id=file_id) + def active_file(self, rid: str = '', file_id: str = '') -> DestFile: + return self.active_receives[rid].files[file_id] + class TestFileTransmission(FileTransmission): def __init__(self, allow: bool = True) -> None: super().__init__(0) - self.test_responses: List[FileTransmissionCommand] = [] + self.test_responses: List[dict] = [] self.allow = allow def write_osc_to_child(self, request_id: str, payload: str, appendleft: bool = False) -> bool: - self.test_responses.append(FileTransmissionCommand.deserialize(payload)) + self.test_responses.append(FileTransmissionCommand.deserialize(payload).asdict()) return True def start_receive(self, aid: str) -> None: - self.handle_send_confirmation(aid, {'response': 'y' if self.allow else 'm'}) + self.handle_send_confirmation(aid, {'response': 'y' if self.allow else 'n'}) diff --git a/kitty_tests/file_transmission.py b/kitty_tests/file_transmission.py index 24923a833..4b18ddc8b 100644 --- a/kitty_tests/file_transmission.py +++ b/kitty_tests/file_transmission.py @@ -5,11 +5,9 @@ import os import shutil -import tarfile +import stat import tempfile -import zipfile import zlib -from io import BytesIO from kitty.file_transmission import ( Action, Compression, FileTransmissionCommand, FileType, @@ -19,6 +17,19 @@ from kitty.file_transmission import ( from . import BaseTest +def response(id='', msg='', file_id='', name='', action='status', status=''): + ans = {'action': 'status'} + if id: + ans['id'] = id + if file_id: + ans['file_id'] = file_id + if name: + ans['name'] = name + if status: + ans['status'] = status + return ans + + def names_in(path): for dirpath, dirnames, filenames in os.walk(path): for d in dirnames + filenames: @@ -39,125 +50,129 @@ class TestFileTransmission(BaseTest): def setUp(self): self.tdir = os.path.realpath(tempfile.mkdtemp()) + self.responses = [] def tearDown(self): shutil.rmtree(self.tdir) + self.responses = [] def clean_tdir(self): shutil.rmtree(self.tdir) self.tdir = os.path.realpath(tempfile.mkdtemp()) + def assertResponses(self, ft, **kw): + self.responses.append(response(**kw)) + self.ae(ft.test_responses, self.responses) + def assertPathEqual(self, a, b): a = os.path.abspath(os.path.realpath(a)) b = os.path.abspath(os.path.realpath(b)) self.ae(a, b) def test_file_put(self): - return # disabled pending rewrite # send refusal for quiet in (0, 1, 2): - ft = FileTransmission() + ft = FileTransmission(allow=False) ft.handle_serialized_command(serialized_cmd(action='send', id='x', quiet=quiet)) - self.ae(ft.test_responses, [] if quiet == 2 else [{'status': 'EPERM:User refused the transfer', 'id': 'x'}]) - self.assertFalse(ft.active_cmds) + self.ae(ft.test_responses, [] if quiet == 2 else [response(id='x', status='EPERM:User refused the transfer')]) + self.assertFalse(ft.active_receives) # simple single file send for quiet in (0, 1, 2): ft = FileTransmission() dest = os.path.join(self.tdir, '1.bin') - ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, quiet=quiet)) - self.assertIn('', ft.active_cmds) - self.ae(os.path.basename(ft.active_cmds[''].dest), '1.bin') - self.assertIsNone(ft.active_cmds[''].file) - self.ae(ft.test_responses, [] if quiet else [{'status': 'OK'}]) + ft.handle_serialized_command(serialized_cmd(action='send', quiet=quiet)) + self.assertIn('', ft.active_receives) + ft.handle_serialized_command(serialized_cmd(action='file', name=dest, quiet=quiet)) + self.assertPathEqual(ft.active_file().name, dest) + self.assertIsNone(ft.active_file().actual_file) + self.ae(ft.test_responses, [] if quiet else [response(status='OK')]) ft.handle_serialized_command(serialized_cmd(action='data', data='abcd')) - self.assertPathEqual(ft.active_cmds[''].file.name, dest) + self.assertPathEqual(ft.active_file().actual_file.name, dest) ft.handle_serialized_command(serialized_cmd(action='end_data', data='123')) - self.assertFalse(ft.active_cmds) - self.ae(ft.test_responses, [] if quiet else [{'status': 'OK'}, {'status': 'COMPLETED'}]) + self.ae(ft.test_responses, [] if quiet else [response(status='OK'), response(status='OK', name=dest)]) + self.assertTrue(ft.active_receives) + ft.handle_serialized_command(serialized_cmd(action='finish')) + self.assertFalse(ft.active_receives) with open(dest) as f: self.ae(f.read(), 'abcd123') # cancel a send ft = FileTransmission() dest = os.path.join(self.tdir, '2.bin') - ft.handle_serialized_command(serialized_cmd(action='send', dest=dest)) - self.ae(ft.test_responses, [{'status': 'OK'}]) + ft.handle_serialized_command(serialized_cmd(action='send')) + self.ae(ft.test_responses, [response(status='OK')]) + ft.handle_serialized_command(serialized_cmd(action='file', name=dest)) + self.assertPathEqual(ft.active_file().name, dest) ft.handle_serialized_command(serialized_cmd(action='data', data='abcd')) self.assertTrue(os.path.exists(dest)) ft.handle_serialized_command(serialized_cmd(action='cancel')) - self.ae(ft.test_responses, [{'status': 'OK'}]) - self.assertFalse(os.path.exists(dest)) - self.assertFalse(ft.active_cmds) + self.ae(ft.test_responses, [response(status='OK')]) + self.assertFalse(ft.active_receives) # compress with zlib ft = FileTransmission() dest = os.path.join(self.tdir, '3.bin') - ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, compression='zlib')) - self.ae(ft.test_responses, [{'status': 'OK'}]) - odata = 'abcd' * 1024 - data = zlib.compress(odata.encode('ascii')) + ft.handle_serialized_command(serialized_cmd(action='send')) + self.ae(ft.test_responses, [response(status='OK')]) + ft.handle_serialized_command(serialized_cmd(action='file', name=dest, compression='zlib')) + self.assertPathEqual(ft.active_file().name, dest) + odata = b'abcd' * 1024 + data = zlib.compress(odata) ft.handle_serialized_command(serialized_cmd(action='data', data=data[:len(data)//2])) self.assertTrue(os.path.exists(dest)) ft.handle_serialized_command(serialized_cmd(action='end_data', data=data[len(data)//2:])) - with open(dest) as f: + self.ae(ft.test_responses, [response(status='OK'), response(status='OK', name=dest)]) + ft.handle_serialized_command(serialized_cmd(action='finish')) + with open(dest, 'rb') as f: self.ae(f.read(), odata) - self.ae(ft.test_responses, [{'status': 'OK'}, {'status': 'COMPLETED'}]) del odata del data - # zip send + # multi file send self.clean_tdir() - buf = BytesIO() - with zipfile.ZipFile(buf, 'w') as zf: - zf.writestr('one.txt', '1' * 1111) - zf.writestr('two/one', '2' * 2222) - zf.writestr('onex/../../three', '3333') - zf.writestr('/onex', '3333') + self.responses = [] ft = FileTransmission() - dest = os.path.join(self.tdir, 'zf') - ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, container_fmt='zip')) - self.ae(ft.test_responses, [{'status': 'OK'}]) - ft.handle_serialized_command(serialized_cmd(action='end_data', data=buf.getvalue())) - self.ae(ft.test_responses, [{'status': 'OK'}, {'status': 'COMPLETED'}]) - with open(os.path.join(dest, 'one.txt')) as f: - self.ae(f.read(), '1' * 1111) - with open(os.path.join(dest, 'two', 'one')) as f: - self.ae(f.read(), '2' * 2222) - self.ae({'zf', 'zf/two', 'zf/one.txt', 'zf/two/one'}, set(names_in(self.tdir))) + dest = os.path.join(self.tdir, '2.bin') + ft.handle_serialized_command(serialized_cmd(action='send')) + self.assertResponses(ft, status='OK') + fid = 0 - # tar send - for mode in ('', 'gz', 'bz2', 'xz'): - buf = BytesIO() - with tarfile.open(fileobj=buf, mode=f'w:{mode}') as tf: - def a(name, data, mode=0o717, lt=None): - ti = tarfile.TarInfo(name) - ti.mtime = 13 - ti.size = len(data) - ti.mode = mode - if lt: - ti.linkname = data - ti.type = lt - tf.addfile(ti) - else: - tf.addfile(ti, BytesIO(data.encode('utf-8'))) - a('a.txt', 'abcd') - a('/b.txt', 'abcd') - a('../c.txt', 'abcd') - a('sym', 'a.txt', lt=tarfile.SYMTYPE) - a('asym', '/abstarget', lt=tarfile.SYMTYPE) - a('link', 'a.txt', lt=tarfile.LNKTYPE) - self.clean_tdir() - ft = FileTransmission() - dest = os.path.join(self.tdir, 'tf') - ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, container_fmt='t' + (mode or 'ar'))) - self.ae(ft.test_responses, [{'status': 'OK'}]) - ft.handle_serialized_command(serialized_cmd(action='end_data', data=buf.getvalue())) - self.ae(ft.test_responses, [{'status': 'OK'}, {'status': 'COMPLETED'}]) - with open(os.path.join(dest, 'a.txt')) as f: - self.ae(f.read(), 'abcd') - st = os.stat(f.name) - self.ae(st.st_mode & 0b111111111, 0o717) - self.ae(st.st_mtime, 13) - self.assertPathEqual(os.path.join(dest, 'sym'), f.name) - self.assertPathEqual(os.path.join(dest, 'asym'), '/abstarget') - self.assertTrue(os.path.samefile(f.name, os.path.join(dest, 'link'))) - self.ae({'tf', 'tf/a.txt', 'tf/sym', 'tf/asym', 'tf/link'}, set(names_in(self.tdir))) - self.ae(len(os.listdir(self.tdir)), 1) + def send(name, data, **kw): + nonlocal fid + fid += 1 + kw['action'] = 'file' + kw['file_id'] = str(fid) + kw['name'] = name + ft.handle_serialized_command(serialized_cmd(**kw)) + if data: + ft.handle_serialized_command(serialized_cmd(action='end_data', file_id=str(fid), data=data)) + self.assertResponses(ft, status='OK', name=name, file_id=str(fid)) + + send(dest, b'xyz', permissions=0o777, mtime=13) + st = os.stat(dest) + self.ae(st.st_nlink, 1) + self.ae(stat.S_IMODE(st.st_mode), 0o777) + self.ae(st.st_mtime_ns, 13) + send(dest + 's1', 'path:' + os.path.basename(dest), permissions=0o777, mtime=17, ftype='symlink') + st = os.stat(dest + 's1', follow_symlinks=False) + self.ae(stat.S_IMODE(st.st_mode), 0o777) + self.ae(st.st_mtime_ns, 17) + self.ae(os.readlink(dest + 's1'), os.path.basename(dest)) + send(dest + 's2', 'fid:1', ftype='symlink') + self.ae(os.readlink(dest + 's2'), os.path.basename(dest)) + send(dest + 's3', 'fid_abs:1', ftype='symlink') + self.assertPathEqual(os.readlink(dest + 's3'), dest) + send(dest + 'l1', 'path:' + os.path.basename(dest), ftype='link') + self.ae(os.stat(dest).st_nlink, 2) + send(dest + 'l2', 'fid:1', ftype='link') + self.ae(os.stat(dest).st_nlink, 3) + send(dest + 'd1/1', 'in_dir') + send(dest + 'd1', '', ftype='directory', mtime=29) + send(dest + 'd2', '', ftype='directory', mtime=29) + with open(dest + 'd1/1') as f: + self.ae(f.read(), 'in_dir') + self.assertTrue(os.path.isdir(dest + 'd1')) + self.assertTrue(os.path.isdir(dest + 'd2')) + + ft.handle_serialized_command(serialized_cmd(action='finish')) + self.ae(os.stat(dest + 'd1').st_mtime_ns, 29) + self.ae(os.stat(dest + 'd2').st_mtime_ns, 29) + self.assertFalse(ft.active_receives)