Allow interleaving file transmissions

This commit is contained in:
Kovid Goyal
2021-08-22 22:08:03 +05:30
parent ebcd053bf3
commit d548b21be2
2 changed files with 127 additions and 68 deletions

View File

@@ -6,8 +6,9 @@ import copy
import errno
import os
import tempfile
from base64 import standard_b64decode
from base64 import standard_b64decode, standard_b64encode
from enum import Enum, auto
from functools import partial
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Union
from kitty.fast_data_types import OSC, get_boss
@@ -66,6 +67,24 @@ class FileTransmissionCommand:
dest: str = ''
data: bytes = b''
def serialize(self) -> str:
ans = [f'action={self.action.name}']
if self.container_fmt is not Container.none:
ans.append(f'container_fmt={self.container_fmt.name}')
if self.compression is not Compression.none:
ans.append(f'compression={self.compression.name}')
for x in ('id', 'secret', 'mime', 'quiet'):
val = getattr(self, x)
if val:
ans.append(f'{x}={val}')
if self.dest:
val = standard_b64encode(self.dest.encode('utf-8')).decode('ascii')
ans.append(f'dest={val}')
if self.data:
val = standard_b64encode(self.data).decode('ascii')
ans.append(f'data={val}')
return ';'.join(ans)
def parse_command(data: str) -> FileTransmissionCommand:
ans = FileTransmissionCommand()
@@ -173,15 +192,23 @@ class ZipExtractor:
self.zf.extract(zinfo, targetpath)
class ActiveCommand:
ftc: FileTransmissionCommand
file: Optional[IO[bytes]] = None
dest: str = ''
decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor()
def __init__(self, ftc: FileTransmissionCommand) -> None:
self.ftc = ftc
class FileTransmission:
active_cmd: Optional[FileTransmissionCommand] = None
active_file: Optional[IO[bytes]] = None
active_dest: str = ''
active_decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor()
active_cmds: Dict[str, ActiveCommand]
def __init__(self, window_id: int):
self.window_id = window_id
self.active_cmds = {}
def handle_serialized_command(self, data: str) -> None:
try:
@@ -189,19 +216,22 @@ class FileTransmission:
except Exception as e:
log_error(f'Failed to parse file transmission command with error: {e}')
return
if self.active_cmd is not None:
if cmd.action not in (Action.data, Action.end_data):
log_error('File transmission command received while another is in flight, aborting')
self.abort_in_flight()
if cmd.id in self.active_cmds and cmd.action not in (Action.data, Action.end_data):
log_error('File transmission command received while another is in flight, aborting')
del self.active_cmds[cmd.id]
if cmd.action is Action.send:
self.active_cmds[cmd.id] = ActiveCommand(cmd)
self.start_send(cmd)
elif cmd.action in (Action.data, Action.end_data):
if cmd.id not in self.active_cmds:
log_error('File transmission data command received with unknown id')
return
self.add_data(cmd)
if cmd.action is Action.end_data and self.active_cmd is not None:
self.commit()
if cmd.action is Action.end_data and cmd.id in self.active_cmds:
self.commit(cmd.id)
def send_response(self, **fields: str) -> None:
ac = self.active_cmd
def send_response(self, ac: Optional[FileTransmissionCommand], **fields: str) -> None:
if ac is None:
return
if 'id' not in fields and ac.id:
@@ -215,88 +245,74 @@ class FileTransmission:
window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
def start_send(self, cmd: FileTransmissionCommand) -> None:
self.active_cmd = cmd
boss = get_boss()
window = boss.window_id_map.get(self.window_id)
if window is not None:
boss._run_kitten(
'transfer_ask', ['put', 'multiple' if cmd.container_fmt else 'single', cmd.dest],
window=window, custom_callback=self.handle_send_confirmation
window=window, custom_callback=partial(self.handle_send_confirmation, cmd.id),
)
def handle_send_confirmation(self, data: 'Response', *a: Any) -> None:
cmd = self.active_cmd
def handle_send_confirmation(self, cmd_id: str, data: 'Response', *a: Any) -> None:
cmd = self.active_cmds.get(cmd_id)
if cmd is None:
return
if data['allowed']:
self.active_dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest'])))
self.active_decompressor = ZlibDecompressor() if cmd.compression is Compression.zlib else IdentityDecompressor()
if cmd.quiet:
cmd.dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest'])))
cmd.decompressor = ZlibDecompressor() if cmd.ftc.compression is Compression.zlib else IdentityDecompressor()
if cmd.ftc.quiet:
return
else:
self.active_cmd = None
self.active_dest = ''
if cmd.quiet > 1:
del self.active_cmds[cmd_id]
if cmd.ftc.quiet > 1:
return
self.send_response(status='OK' if data['allowed'] else 'EPERM:User refused the transfer')
self.send_response(cmd.ftc, status='OK' if data['allowed'] else 'EPERM:User refused the transfer')
def send_fail_on_os_error(self, err: OSError, msg: str) -> None:
ac = self.active_cmd
def send_fail_on_os_error(self, ac: Optional[FileTransmissionCommand], err: OSError, msg: str) -> None:
if ac is None or ac.quiet < 2:
return
errname = errno.errorcode.get(err.errno, 'EFAIL')
self.send_response(status=f'{errname}:{msg}')
self.send_response(ac, status=f'{errname}:{msg}')
def add_data(self, cmd: FileTransmissionCommand) -> None:
ac = self.active_cmd
if ac is None or not self.active_dest:
return
if self.active_file is None:
ac = self.active_cmds.get(cmd.id)
def abort_in_flight() -> None:
self.active_cmds.pop(cmd.id, None)
if ac is None or not ac.dest:
return abort_in_flight()
if ac.file is None:
try:
os.makedirs(os.path.dirname(self.active_dest), exist_ok=True)
os.makedirs(os.path.dirname(ac.dest), exist_ok=True)
except OSError as e:
self.send_fail_on_os_error(e, 'Creating destination directory failed')
return self.abort_in_flight()
if ac.container_fmt is Container.none:
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination directory failed')
return abort_in_flight()
if ac.ftc.container_fmt is Container.none:
try:
self.active_file = open(self.active_dest, 'wb')
ac.file = open(ac.dest, 'wb')
except OSError as e:
self.send_fail_on_os_error(e, 'Creating destination file failed')
return self.abort_in_flight()
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination file failed')
return abort_in_flight()
else:
try:
self.active_file = tempfile.TemporaryFile(dir=os.path.dirname(self.active_dest))
ac.file = tempfile.TemporaryFile(dir=os.path.dirname(ac.dest))
except OSError as e:
self.send_fail_on_os_error(e, 'Creating destination temp file failed')
return self.abort_in_flight()
data = self.active_decompressor(cmd.data, cmd.action is Action.end_data)
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination temp file failed')
return abort_in_flight()
data = ac.decompressor(cmd.data, cmd.action is Action.end_data)
try:
self.active_file.write(data)
ac.file.write(data)
except OSError as e:
self.send_fail_on_os_error(e, 'Writing to destination file failed')
return self.abort_in_flight()
self.send_fail_on_os_error(ac.ftc, e, 'Writing to destination file failed')
return abort_in_flight()
def commit(self) -> None:
cmd = self.active_cmd
if cmd is None:
return
try:
if cmd.container_fmt and self.active_file is not None:
self.active_file.seek(0, os.SEEK_SET)
Container.extractor_for_container_fmt(self.active_file, cmd.container_fmt)(self.active_dest)
finally:
self.active_cmd = None
self.active_dest = ''
if self.active_file is not None:
self.active_file.close()
self.active_file = None
def abort_in_flight(self) -> None:
self.active_cmd = None
self.active_dest = ''
if self.active_file is not None:
self.active_file.close()
self.active_file = None
def commit(self, cmd_id: str) -> None:
cmd = self.active_cmds.pop(cmd_id, None)
if cmd is not None and cmd.ftc.container_fmt and cmd.file is not None:
cmd.file.seek(0, os.SEEK_SET)
Container.extractor_for_container_fmt(cmd.file, cmd.ftc.container_fmt)(cmd.dest)
class TestFileTransmission(FileTransmission):
@@ -310,5 +326,5 @@ class TestFileTransmission(FileTransmission):
self.test_responses.append(fields)
def start_send(self, cmd: FileTransmissionCommand) -> None:
self.active_cmd = cmd
self.handle_send_confirmation({'dest': self.test_dest, 'allowed': bool(self.test_dest)})
dest = cmd.dest or self.test_dest
self.handle_send_confirmation(cmd.id, {'dest': dest, 'allowed': bool(dest)})

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python
# vim:fileencoding=utf-8
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
import os
import shutil
import tempfile
from kitty.file_transmission import (
Action, Compression, Container, FileTransmissionCommand,
TestFileTransmission as FileTransmission
)
from . import BaseTest
def serialized_cmd(**fields) -> str:
for k, A in (('action', Action), ('container_fmt', Container), ('compression', Compression)):
if k in fields:
fields[k] = A[fields[k]]
if isinstance(fields.get('data'), str):
fields['data'] = fields['data'].encode('utf-8')
ans = FileTransmissionCommand()
for k in fields:
setattr(ans, k, fields[k])
return ans.serialize()
class TestFileTransmission(BaseTest):
def setUp(self):
self.tdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tdir)
def test_file_put(self):
ft = FileTransmission()
ft.handle_serialized_command(serialized_cmd(action='send', id='1', dest=os.path.join(self.tdir, '1.bin')))
self.assertIn('1', ft.active_cmds)
self.ae(os.path.basename(ft.active_cmds['1'].dest), '1.bin')
self.assertIsNone(ft.active_cmds['1'].file)