More work on file transmission

This commit is contained in:
Kovid Goyal
2021-09-01 09:22:04 +05:30
parent 59b84ae1a4
commit 40786427b0
2 changed files with 282 additions and 106 deletions

View File

@@ -3,14 +3,18 @@
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
import errno
import os
import stat
import tempfile
from base64 import standard_b64decode, standard_b64encode
from collections import deque
from contextlib import suppress
from dataclasses import Field, dataclass, field, fields
from enum import Enum, auto
from functools import partial
from gettext import gettext as _
from time import monotonic
from typing import Any, Deque, Dict, List, Optional, Tuple, Union
from typing import IO, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
from kitty.fast_data_types import OSC, add_timer, get_boss
@@ -19,7 +23,13 @@ from .utils import log_error, sanitize_control_codes
EXPIRE_TIME = 10 # minutes
class Action(Enum):
class NameReprEnum(Enum):
def __repr__(self) -> str:
return f'<{self.__class__.__name__}.{self.name}>'
class Action(NameReprEnum):
send = auto()
file = auto()
data = auto()
@@ -28,27 +38,28 @@ class Action(Enum):
invalid = auto()
cancel = auto()
status = auto()
finish = auto()
class Compression(Enum):
class Compression(NameReprEnum):
zlib = auto()
none = auto()
class FileType(Enum):
class FileType(NameReprEnum):
regular = auto()
directory = auto()
symlink = auto()
link = auto()
class TransmissionType(Enum):
class TransmissionType(NameReprEnum):
simple = auto()
resume = auto()
rsync = auto()
ErrorCode = Enum('ErrorCode', 'OK EINVAL EPERM')
ErrorCode = Enum('ErrorCode', 'OK EINVAL EPERM EISDIR')
class TransmissionError(Exception):
@@ -57,19 +68,22 @@ class TransmissionError(Exception):
self, code: Union[ErrorCode, str] = ErrorCode.EINVAL,
msg: str = 'Generic error',
transmit: bool = True,
file_id: str = ''
file_id: str = '',
name: str = ''
) -> None:
Exception.__init__(self, msg)
self.transmit = transmit
self.file_id = file_id
self.human_msg = msg
self.code = code
self.name = name
def as_escape_code(self, request_id: str = '') -> str:
code = self.code if isinstance(self.code, str) else self.code.name
name = self.code if isinstance(self.code, str) else self.code.name
if self.human_msg:
name += ':' + self.human_msg
return FileTransmissionCommand(
action=Action.status, id=request_id, file_id=self.file_id,
name=f'{code}:{self.human_msg}'
action=Action.status, id=request_id, file_id=self.file_id, status=name, name=self.name
).serialize()
@@ -83,12 +97,23 @@ class FileTransmissionCommand:
id: str = ''
file_id: str = ''
secret: str = ''
mime: str = ''
quiet: int = 0
mtime: int = -1
permissions: int = -1
data: bytes = b''
name: str = field(default='', metadata={'base64': True})
status: str = field(default='', metadata={'base64': True})
def asdict(self, keep_defaults: bool = False) -> Dict[str, Union[str, int, bytes]]:
ans = {}
for k in fields(self):
val = getattr(self, k.name)
if not keep_defaults and val == k.default:
continue
if issubclass(k.type, Enum):
val = val.name
ans[k.name] = val
return ans
def serialize(self) -> str:
ans = []
@@ -103,7 +128,7 @@ class FileTransmissionCommand:
ans.append(f'{k.name}={ev}')
elif k.type is str:
if k.metadata.get('base64'):
sval = standard_b64encode(self.name.encode('utf-8')).decode('ascii')
sval = standard_b64encode(val.encode('utf-8')).decode('ascii')
else:
sval = val
ans.append(f'{k.name}={sanitize_control_codes(sval)}')
@@ -171,17 +196,95 @@ class DestFile:
def __init__(self, ftc: FileTransmissionCommand) -> None:
self.name = ftc.name
if not os.path.isabs(self.name):
self.name = os.path.join(tempfile.gettempdir(), self.name)
self.mtime = ftc.mtime
self.file_id = ftc.file_id
self.permissions = ftc.permissions
if self.permissions != FileTransmissionCommand.permissions:
self.permissions = stat.S_IMODE(self.permissions)
self.ftype = ftc.ftype
self.ttype = ftc.ttype
self.link_target = b''
self.needs_data_sent = self.ttype is not TransmissionType.simple
self.decompressor = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor()
self.decompressor: Union[ZlibDecompressor, IdentityDecompressor] = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor()
self.closed = self.ftype is FileType.directory
self.actual_file: Optional[IO[bytes]] = None
def __repr__(self) -> str:
return f'DestFile(name={self.name}, file_id={self.file_id}, actual_file={self.actual_file})'
def close(self) -> None:
if not self.closed:
self.closed = True
if self.actual_file is not None:
self.actual_file.close()
self.actual_file = None
def make_parent_dirs(self) -> str:
d = os.path.dirname(self.name)
if d:
os.makedirs(d, exist_ok=True)
return d
def apply_metadata(self, is_symlink: bool = False) -> None:
if self.permissions != FileTransmissionCommand.permissions:
if is_symlink:
with suppress(NotImplementedError):
os.chmod(self.name, self.permissions, follow_symlinks=False)
else:
os.chmod(self.name, self.permissions)
if self.mtime != FileTransmissionCommand.mtime:
if is_symlink:
with suppress(NotImplementedError):
os.utime(self.name, ns=(self.mtime, self.mtime), follow_symlinks=False)
else:
os.utime(self.name, ns=(self.mtime, self.mtime))
def write_data(self, all_files: Dict[str, 'DestFile'], data: bytes, is_last: bool) -> None:
if self.ftype is FileType.directory:
raise TransmissionError(code=ErrorCode.EISDIR, file_id=self.file_id, msg='Cannot write data to a directory entry')
if self.closed:
raise TransmissionError(file_id=self.file_id, msg='Cannot write to a closed file')
if self.ftype in (FileType.symlink, FileType.link):
self.link_target += data
if is_last:
lt = self.link_target.decode('utf-8', 'replace')
base = self.make_parent_dirs()
if lt.startswith('fid:'):
lt = all_files[lt[4:]].name
if self.ftype is FileType.symlink:
try:
cp = os.path.commonpath((self.name, lt))
except ValueError:
pass
else:
if cp:
lt = os.path.relpath(lt, cp)
elif lt.startswith('fid_abs:'):
lt = all_files[lt[8:]].name
elif lt.startswith('path:'):
lt = lt[5:]
if not os.path.isabs(lt) and self.ftype is FileType.link:
lt = os.path.join(base, lt)
lt = lt.replace('/', os.sep)
else:
raise TransmissionError(msg='Unknown link target type', file_id=self.file_id)
if self.ftype is FileType.symlink:
os.symlink(lt, self.name)
else:
os.link(lt, self.name)
self.close()
self.apply_metadata(is_symlink=True)
elif self.ftype is FileType.regular:
if self.actual_file is None:
self.make_parent_dirs()
self.actual_file = open(self.name, 'wb')
data = self.decompressor(data, is_last=is_last)
self.actual_file.write(data)
if is_last:
self.close()
self.apply_metadata()
class ActiveReceive:
@@ -209,13 +312,33 @@ class ActiveReceive:
self.close()
def start_file(self, ftc: FileTransmissionCommand) -> DestFile:
self.last_activity_at = monotonic()
if ftc.file_id in self.files:
raise TransmissionError(
msg=f'The file_id {ftc.file_id} already exists',
file_id=ftc.file_id,
)
self.files[ftc.file_id] = result = DestFile(ftc)
return result
self.files[ftc.file_id] = df = DestFile(ftc)
return df
def add_data(self, ftc: FileTransmissionCommand) -> DestFile:
self.last_activity_at = monotonic()
df = self.files.get(ftc.file_id)
if df is None:
raise TransmissionError(file_id=ftc.file_id, msg='Cannot write to a file without first starting it')
try:
df.write_data(self.files, ftc.data, ftc.action is Action.end_data)
except Exception:
df.close()
raise
return df
def commit(self, send_os_error: Callable[[OSError, str, 'ActiveReceive', str], None]) -> None:
directories = sorted((df for df in self.files.values() if df.ftype is FileType.directory), key=lambda x: len(x.name), reverse=True)
for df in directories:
with suppress(OSError):
# we ignore failures to apply directory metadata as we have already sent an OK for the dir
df.apply_metadata()
class FileTransmission:
@@ -285,7 +408,7 @@ class FileTransmission:
if cmd.action is not Action.send:
log_error(f'File transmission command received for unknown or rejected id: {cmd.id}, ignoring')
return
ar = ActiveReceive(cmd.id, cmd.quiet)
ar = self.active_receives[cmd.id] = ActiveReceive(cmd.id, cmd.quiet)
self.start_receive(ar.id)
return
@@ -293,7 +416,7 @@ class FileTransmission:
self.drop_receive(ar.id)
elif cmd.action is Action.file:
try:
ar.start_file(cmd)
df = ar.start_file(cmd)
except TransmissionError as err:
if ar.send_errors:
self.send_transmission_error(ar.id, err)
@@ -302,14 +425,47 @@ class FileTransmission:
if ar.send_errors:
te = TransmissionError(file_id=cmd.file_id, msg=str(err))
self.send_transmission_error(ar.id, te)
else:
if df.ftype is FileType.directory:
try:
os.makedirs(df.name, exist_ok=True)
except OSError as err:
self.send_fail_on_os_error(err, 'Failed to create directory', ar, df.file_id)
else:
self.send_status_response(ErrorCode.OK, ar.id, df.file_id, name=df.name)
elif cmd.action in (Action.data, Action.end_data):
pass
try:
df = ar.add_data(cmd)
if df.closed and ar.send_acknowledgements:
self.send_status_response(code=ErrorCode.OK, request_id=ar.id, file_id=df.file_id, name=df.name)
except TransmissionError as err:
if ar.send_errors:
self.send_transmission_error(ar.id, err)
except Exception as err:
log_error(f'Transmission protocol failed to write data to file with error: {err}')
if ar.send_errors:
te = TransmissionError(file_id=cmd.file_id, msg=str(err))
self.send_transmission_error(ar.id, te)
elif cmd.action is Action.finish:
try:
ar.commit(self.send_fail_on_os_error)
except TransmissionError as err:
if ar.send_errors:
self.send_transmission_error(ar.id, err)
except Exception as err:
log_error(f'Transmission protocol failed to commit receive with error: {err}')
if ar.send_errors:
te = TransmissionError(msg=str(err))
self.send_transmission_error(ar.id, te)
finally:
self.drop_receive(ar.id)
def send_status_response(
self, code: Union[ErrorCode, str] = ErrorCode.EINVAL,
request_id: str = '', file_id: str = '', msg: str = ''
request_id: str = '', file_id: str = '', msg: str = '',
name: str = '',
) -> bool:
err = TransmissionError(code=code, msg=msg, file_id=file_id)
err = TransmissionError(code=code, msg=msg, file_id=file_id, name=name)
data = err.as_escape_code(request_id)
return self.write_osc_to_child(request_id, data)
@@ -351,8 +507,10 @@ class FileTransmission:
else:
self.drop_receive(ar.id)
if ar.accepted:
if ar.send_acknowledgements:
self.send_status_response(code=ErrorCode.OK, request_id=ar.id)
else:
if ar.send_errors:
self.send_status_response(code=ErrorCode.EPERM, request_id=ar.id, msg='User refused the transfer')
def send_fail_on_os_error(self, err: OSError, msg: str, ar: ActiveReceive, file_id: str = '') -> None:
@@ -361,17 +519,20 @@ class FileTransmission:
errname = errno.errorcode.get(err.errno, 'EFAIL')
self.send_status_response(code=errname, msg=msg, request_id=ar.id, file_id=file_id)
def active_file(self, rid: str = '', file_id: str = '') -> DestFile:
return self.active_receives[rid].files[file_id]
class TestFileTransmission(FileTransmission):
def __init__(self, allow: bool = True) -> None:
super().__init__(0)
self.test_responses: List[FileTransmissionCommand] = []
self.test_responses: List[dict] = []
self.allow = allow
def write_osc_to_child(self, request_id: str, payload: str, appendleft: bool = False) -> bool:
self.test_responses.append(FileTransmissionCommand.deserialize(payload))
self.test_responses.append(FileTransmissionCommand.deserialize(payload).asdict())
return True
def start_receive(self, aid: str) -> None:
self.handle_send_confirmation(aid, {'response': 'y' if self.allow else 'm'})
self.handle_send_confirmation(aid, {'response': 'y' if self.allow else 'n'})

View File

@@ -5,11 +5,9 @@
import os
import shutil
import tarfile
import stat
import tempfile
import zipfile
import zlib
from io import BytesIO
from kitty.file_transmission import (
Action, Compression, FileTransmissionCommand, FileType,
@@ -19,6 +17,19 @@ from kitty.file_transmission import (
from . import BaseTest
def response(id='', msg='', file_id='', name='', action='status', status=''):
ans = {'action': 'status'}
if id:
ans['id'] = id
if file_id:
ans['file_id'] = file_id
if name:
ans['name'] = name
if status:
ans['status'] = status
return ans
def names_in(path):
for dirpath, dirnames, filenames in os.walk(path):
for d in dirnames + filenames:
@@ -39,125 +50,129 @@ class TestFileTransmission(BaseTest):
def setUp(self):
self.tdir = os.path.realpath(tempfile.mkdtemp())
self.responses = []
def tearDown(self):
shutil.rmtree(self.tdir)
self.responses = []
def clean_tdir(self):
shutil.rmtree(self.tdir)
self.tdir = os.path.realpath(tempfile.mkdtemp())
def assertResponses(self, ft, **kw):
self.responses.append(response(**kw))
self.ae(ft.test_responses, self.responses)
def assertPathEqual(self, a, b):
a = os.path.abspath(os.path.realpath(a))
b = os.path.abspath(os.path.realpath(b))
self.ae(a, b)
def test_file_put(self):
return # disabled pending rewrite
# send refusal
for quiet in (0, 1, 2):
ft = FileTransmission()
ft = FileTransmission(allow=False)
ft.handle_serialized_command(serialized_cmd(action='send', id='x', quiet=quiet))
self.ae(ft.test_responses, [] if quiet == 2 else [{'status': 'EPERM:User refused the transfer', 'id': 'x'}])
self.assertFalse(ft.active_cmds)
self.ae(ft.test_responses, [] if quiet == 2 else [response(id='x', status='EPERM:User refused the transfer')])
self.assertFalse(ft.active_receives)
# simple single file send
for quiet in (0, 1, 2):
ft = FileTransmission()
dest = os.path.join(self.tdir, '1.bin')
ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, quiet=quiet))
self.assertIn('', ft.active_cmds)
self.ae(os.path.basename(ft.active_cmds[''].dest), '1.bin')
self.assertIsNone(ft.active_cmds[''].file)
self.ae(ft.test_responses, [] if quiet else [{'status': 'OK'}])
ft.handle_serialized_command(serialized_cmd(action='send', quiet=quiet))
self.assertIn('', ft.active_receives)
ft.handle_serialized_command(serialized_cmd(action='file', name=dest, quiet=quiet))
self.assertPathEqual(ft.active_file().name, dest)
self.assertIsNone(ft.active_file().actual_file)
self.ae(ft.test_responses, [] if quiet else [response(status='OK')])
ft.handle_serialized_command(serialized_cmd(action='data', data='abcd'))
self.assertPathEqual(ft.active_cmds[''].file.name, dest)
self.assertPathEqual(ft.active_file().actual_file.name, dest)
ft.handle_serialized_command(serialized_cmd(action='end_data', data='123'))
self.assertFalse(ft.active_cmds)
self.ae(ft.test_responses, [] if quiet else [{'status': 'OK'}, {'status': 'COMPLETED'}])
self.ae(ft.test_responses, [] if quiet else [response(status='OK'), response(status='OK', name=dest)])
self.assertTrue(ft.active_receives)
ft.handle_serialized_command(serialized_cmd(action='finish'))
self.assertFalse(ft.active_receives)
with open(dest) as f:
self.ae(f.read(), 'abcd123')
# cancel a send
ft = FileTransmission()
dest = os.path.join(self.tdir, '2.bin')
ft.handle_serialized_command(serialized_cmd(action='send', dest=dest))
self.ae(ft.test_responses, [{'status': 'OK'}])
ft.handle_serialized_command(serialized_cmd(action='send'))
self.ae(ft.test_responses, [response(status='OK')])
ft.handle_serialized_command(serialized_cmd(action='file', name=dest))
self.assertPathEqual(ft.active_file().name, dest)
ft.handle_serialized_command(serialized_cmd(action='data', data='abcd'))
self.assertTrue(os.path.exists(dest))
ft.handle_serialized_command(serialized_cmd(action='cancel'))
self.ae(ft.test_responses, [{'status': 'OK'}])
self.assertFalse(os.path.exists(dest))
self.assertFalse(ft.active_cmds)
self.ae(ft.test_responses, [response(status='OK')])
self.assertFalse(ft.active_receives)
# compress with zlib
ft = FileTransmission()
dest = os.path.join(self.tdir, '3.bin')
ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, compression='zlib'))
self.ae(ft.test_responses, [{'status': 'OK'}])
odata = 'abcd' * 1024
data = zlib.compress(odata.encode('ascii'))
ft.handle_serialized_command(serialized_cmd(action='send'))
self.ae(ft.test_responses, [response(status='OK')])
ft.handle_serialized_command(serialized_cmd(action='file', name=dest, compression='zlib'))
self.assertPathEqual(ft.active_file().name, dest)
odata = b'abcd' * 1024
data = zlib.compress(odata)
ft.handle_serialized_command(serialized_cmd(action='data', data=data[:len(data)//2]))
self.assertTrue(os.path.exists(dest))
ft.handle_serialized_command(serialized_cmd(action='end_data', data=data[len(data)//2:]))
with open(dest) as f:
self.ae(ft.test_responses, [response(status='OK'), response(status='OK', name=dest)])
ft.handle_serialized_command(serialized_cmd(action='finish'))
with open(dest, 'rb') as f:
self.ae(f.read(), odata)
self.ae(ft.test_responses, [{'status': 'OK'}, {'status': 'COMPLETED'}])
del odata
del data
# zip send
# multi file send
self.clean_tdir()
buf = BytesIO()
with zipfile.ZipFile(buf, 'w') as zf:
zf.writestr('one.txt', '1' * 1111)
zf.writestr('two/one', '2' * 2222)
zf.writestr('onex/../../three', '3333')
zf.writestr('/onex', '3333')
self.responses = []
ft = FileTransmission()
dest = os.path.join(self.tdir, 'zf')
ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, container_fmt='zip'))
self.ae(ft.test_responses, [{'status': 'OK'}])
ft.handle_serialized_command(serialized_cmd(action='end_data', data=buf.getvalue()))
self.ae(ft.test_responses, [{'status': 'OK'}, {'status': 'COMPLETED'}])
with open(os.path.join(dest, 'one.txt')) as f:
self.ae(f.read(), '1' * 1111)
with open(os.path.join(dest, 'two', 'one')) as f:
self.ae(f.read(), '2' * 2222)
self.ae({'zf', 'zf/two', 'zf/one.txt', 'zf/two/one'}, set(names_in(self.tdir)))
dest = os.path.join(self.tdir, '2.bin')
ft.handle_serialized_command(serialized_cmd(action='send'))
self.assertResponses(ft, status='OK')
fid = 0
# tar send
for mode in ('', 'gz', 'bz2', 'xz'):
buf = BytesIO()
with tarfile.open(fileobj=buf, mode=f'w:{mode}') as tf:
def a(name, data, mode=0o717, lt=None):
ti = tarfile.TarInfo(name)
ti.mtime = 13
ti.size = len(data)
ti.mode = mode
if lt:
ti.linkname = data
ti.type = lt
tf.addfile(ti)
else:
tf.addfile(ti, BytesIO(data.encode('utf-8')))
a('a.txt', 'abcd')
a('/b.txt', 'abcd')
a('../c.txt', 'abcd')
a('sym', 'a.txt', lt=tarfile.SYMTYPE)
a('asym', '/abstarget', lt=tarfile.SYMTYPE)
a('link', 'a.txt', lt=tarfile.LNKTYPE)
self.clean_tdir()
ft = FileTransmission()
dest = os.path.join(self.tdir, 'tf')
ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, container_fmt='t' + (mode or 'ar')))
self.ae(ft.test_responses, [{'status': 'OK'}])
ft.handle_serialized_command(serialized_cmd(action='end_data', data=buf.getvalue()))
self.ae(ft.test_responses, [{'status': 'OK'}, {'status': 'COMPLETED'}])
with open(os.path.join(dest, 'a.txt')) as f:
self.ae(f.read(), 'abcd')
st = os.stat(f.name)
self.ae(st.st_mode & 0b111111111, 0o717)
self.ae(st.st_mtime, 13)
self.assertPathEqual(os.path.join(dest, 'sym'), f.name)
self.assertPathEqual(os.path.join(dest, 'asym'), '/abstarget')
self.assertTrue(os.path.samefile(f.name, os.path.join(dest, 'link')))
self.ae({'tf', 'tf/a.txt', 'tf/sym', 'tf/asym', 'tf/link'}, set(names_in(self.tdir)))
self.ae(len(os.listdir(self.tdir)), 1)
def send(name, data, **kw):
nonlocal fid
fid += 1
kw['action'] = 'file'
kw['file_id'] = str(fid)
kw['name'] = name
ft.handle_serialized_command(serialized_cmd(**kw))
if data:
ft.handle_serialized_command(serialized_cmd(action='end_data', file_id=str(fid), data=data))
self.assertResponses(ft, status='OK', name=name, file_id=str(fid))
send(dest, b'xyz', permissions=0o777, mtime=13)
st = os.stat(dest)
self.ae(st.st_nlink, 1)
self.ae(stat.S_IMODE(st.st_mode), 0o777)
self.ae(st.st_mtime_ns, 13)
send(dest + 's1', 'path:' + os.path.basename(dest), permissions=0o777, mtime=17, ftype='symlink')
st = os.stat(dest + 's1', follow_symlinks=False)
self.ae(stat.S_IMODE(st.st_mode), 0o777)
self.ae(st.st_mtime_ns, 17)
self.ae(os.readlink(dest + 's1'), os.path.basename(dest))
send(dest + 's2', 'fid:1', ftype='symlink')
self.ae(os.readlink(dest + 's2'), os.path.basename(dest))
send(dest + 's3', 'fid_abs:1', ftype='symlink')
self.assertPathEqual(os.readlink(dest + 's3'), dest)
send(dest + 'l1', 'path:' + os.path.basename(dest), ftype='link')
self.ae(os.stat(dest).st_nlink, 2)
send(dest + 'l2', 'fid:1', ftype='link')
self.ae(os.stat(dest).st_nlink, 3)
send(dest + 'd1/1', 'in_dir')
send(dest + 'd1', '', ftype='directory', mtime=29)
send(dest + 'd2', '', ftype='directory', mtime=29)
with open(dest + 'd1/1') as f:
self.ae(f.read(), 'in_dir')
self.assertTrue(os.path.isdir(dest + 'd1'))
self.assertTrue(os.path.isdir(dest + 'd2'))
ft.handle_serialized_command(serialized_cmd(action='finish'))
self.ae(os.stat(dest + 'd1').st_mtime_ns, 29)
self.ae(os.stat(dest + 'd2').st_mtime_ns, 29)
self.assertFalse(ft.active_receives)