mirror of
https://github.com/kovidgoyal/kitty
synced 2026-06-08 14:18:26 +02:00
Allow interleaving file transmissions
This commit is contained in:
@@ -6,8 +6,9 @@ import copy
|
||||
import errno
|
||||
import os
|
||||
import tempfile
|
||||
from base64 import standard_b64decode
|
||||
from base64 import standard_b64decode, standard_b64encode
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from kitty.fast_data_types import OSC, get_boss
|
||||
@@ -66,6 +67,24 @@ class FileTransmissionCommand:
|
||||
dest: str = ''
|
||||
data: bytes = b''
|
||||
|
||||
def serialize(self) -> str:
|
||||
ans = [f'action={self.action.name}']
|
||||
if self.container_fmt is not Container.none:
|
||||
ans.append(f'container_fmt={self.container_fmt.name}')
|
||||
if self.compression is not Compression.none:
|
||||
ans.append(f'compression={self.compression.name}')
|
||||
for x in ('id', 'secret', 'mime', 'quiet'):
|
||||
val = getattr(self, x)
|
||||
if val:
|
||||
ans.append(f'{x}={val}')
|
||||
if self.dest:
|
||||
val = standard_b64encode(self.dest.encode('utf-8')).decode('ascii')
|
||||
ans.append(f'dest={val}')
|
||||
if self.data:
|
||||
val = standard_b64encode(self.data).decode('ascii')
|
||||
ans.append(f'data={val}')
|
||||
return ';'.join(ans)
|
||||
|
||||
|
||||
def parse_command(data: str) -> FileTransmissionCommand:
|
||||
ans = FileTransmissionCommand()
|
||||
@@ -173,15 +192,23 @@ class ZipExtractor:
|
||||
self.zf.extract(zinfo, targetpath)
|
||||
|
||||
|
||||
class ActiveCommand:
|
||||
ftc: FileTransmissionCommand
|
||||
file: Optional[IO[bytes]] = None
|
||||
dest: str = ''
|
||||
decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor()
|
||||
|
||||
def __init__(self, ftc: FileTransmissionCommand) -> None:
|
||||
self.ftc = ftc
|
||||
|
||||
|
||||
class FileTransmission:
|
||||
|
||||
active_cmd: Optional[FileTransmissionCommand] = None
|
||||
active_file: Optional[IO[bytes]] = None
|
||||
active_dest: str = ''
|
||||
active_decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor()
|
||||
active_cmds: Dict[str, ActiveCommand]
|
||||
|
||||
def __init__(self, window_id: int):
|
||||
self.window_id = window_id
|
||||
self.active_cmds = {}
|
||||
|
||||
def handle_serialized_command(self, data: str) -> None:
|
||||
try:
|
||||
@@ -189,19 +216,22 @@ class FileTransmission:
|
||||
except Exception as e:
|
||||
log_error(f'Failed to parse file transmission command with error: {e}')
|
||||
return
|
||||
if self.active_cmd is not None:
|
||||
if cmd.action not in (Action.data, Action.end_data):
|
||||
if cmd.id in self.active_cmds and cmd.action not in (Action.data, Action.end_data):
|
||||
log_error('File transmission command received while another is in flight, aborting')
|
||||
self.abort_in_flight()
|
||||
del self.active_cmds[cmd.id]
|
||||
|
||||
if cmd.action is Action.send:
|
||||
self.active_cmds[cmd.id] = ActiveCommand(cmd)
|
||||
self.start_send(cmd)
|
||||
elif cmd.action in (Action.data, Action.end_data):
|
||||
if cmd.id not in self.active_cmds:
|
||||
log_error('File transmission data command received with unknown id')
|
||||
return
|
||||
self.add_data(cmd)
|
||||
if cmd.action is Action.end_data and self.active_cmd is not None:
|
||||
self.commit()
|
||||
if cmd.action is Action.end_data and cmd.id in self.active_cmds:
|
||||
self.commit(cmd.id)
|
||||
|
||||
def send_response(self, **fields: str) -> None:
|
||||
ac = self.active_cmd
|
||||
def send_response(self, ac: Optional[FileTransmissionCommand], **fields: str) -> None:
|
||||
if ac is None:
|
||||
return
|
||||
if 'id' not in fields and ac.id:
|
||||
@@ -215,88 +245,74 @@ class FileTransmission:
|
||||
window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
|
||||
|
||||
def start_send(self, cmd: FileTransmissionCommand) -> None:
|
||||
self.active_cmd = cmd
|
||||
boss = get_boss()
|
||||
window = boss.window_id_map.get(self.window_id)
|
||||
if window is not None:
|
||||
boss._run_kitten(
|
||||
'transfer_ask', ['put', 'multiple' if cmd.container_fmt else 'single', cmd.dest],
|
||||
window=window, custom_callback=self.handle_send_confirmation
|
||||
window=window, custom_callback=partial(self.handle_send_confirmation, cmd.id),
|
||||
)
|
||||
|
||||
def handle_send_confirmation(self, data: 'Response', *a: Any) -> None:
|
||||
cmd = self.active_cmd
|
||||
def handle_send_confirmation(self, cmd_id: str, data: 'Response', *a: Any) -> None:
|
||||
cmd = self.active_cmds.get(cmd_id)
|
||||
if cmd is None:
|
||||
return
|
||||
if data['allowed']:
|
||||
self.active_dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest'])))
|
||||
self.active_decompressor = ZlibDecompressor() if cmd.compression is Compression.zlib else IdentityDecompressor()
|
||||
if cmd.quiet:
|
||||
cmd.dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest'])))
|
||||
cmd.decompressor = ZlibDecompressor() if cmd.ftc.compression is Compression.zlib else IdentityDecompressor()
|
||||
if cmd.ftc.quiet:
|
||||
return
|
||||
else:
|
||||
self.active_cmd = None
|
||||
self.active_dest = ''
|
||||
if cmd.quiet > 1:
|
||||
del self.active_cmds[cmd_id]
|
||||
if cmd.ftc.quiet > 1:
|
||||
return
|
||||
self.send_response(status='OK' if data['allowed'] else 'EPERM:User refused the transfer')
|
||||
self.send_response(cmd.ftc, status='OK' if data['allowed'] else 'EPERM:User refused the transfer')
|
||||
|
||||
def send_fail_on_os_error(self, err: OSError, msg: str) -> None:
|
||||
ac = self.active_cmd
|
||||
def send_fail_on_os_error(self, ac: Optional[FileTransmissionCommand], err: OSError, msg: str) -> None:
|
||||
if ac is None or ac.quiet < 2:
|
||||
return
|
||||
errname = errno.errorcode.get(err.errno, 'EFAIL')
|
||||
self.send_response(status=f'{errname}:{msg}')
|
||||
self.send_response(ac, status=f'{errname}:{msg}')
|
||||
|
||||
def add_data(self, cmd: FileTransmissionCommand) -> None:
|
||||
ac = self.active_cmd
|
||||
if ac is None or not self.active_dest:
|
||||
return
|
||||
if self.active_file is None:
|
||||
ac = self.active_cmds.get(cmd.id)
|
||||
|
||||
def abort_in_flight() -> None:
|
||||
self.active_cmds.pop(cmd.id, None)
|
||||
|
||||
if ac is None or not ac.dest:
|
||||
return abort_in_flight()
|
||||
|
||||
if ac.file is None:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.active_dest), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(ac.dest), exist_ok=True)
|
||||
except OSError as e:
|
||||
self.send_fail_on_os_error(e, 'Creating destination directory failed')
|
||||
return self.abort_in_flight()
|
||||
if ac.container_fmt is Container.none:
|
||||
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination directory failed')
|
||||
return abort_in_flight()
|
||||
if ac.ftc.container_fmt is Container.none:
|
||||
try:
|
||||
self.active_file = open(self.active_dest, 'wb')
|
||||
ac.file = open(ac.dest, 'wb')
|
||||
except OSError as e:
|
||||
self.send_fail_on_os_error(e, 'Creating destination file failed')
|
||||
return self.abort_in_flight()
|
||||
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination file failed')
|
||||
return abort_in_flight()
|
||||
else:
|
||||
try:
|
||||
self.active_file = tempfile.TemporaryFile(dir=os.path.dirname(self.active_dest))
|
||||
ac.file = tempfile.TemporaryFile(dir=os.path.dirname(ac.dest))
|
||||
except OSError as e:
|
||||
self.send_fail_on_os_error(e, 'Creating destination temp file failed')
|
||||
return self.abort_in_flight()
|
||||
data = self.active_decompressor(cmd.data, cmd.action is Action.end_data)
|
||||
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination temp file failed')
|
||||
return abort_in_flight()
|
||||
data = ac.decompressor(cmd.data, cmd.action is Action.end_data)
|
||||
try:
|
||||
self.active_file.write(data)
|
||||
ac.file.write(data)
|
||||
except OSError as e:
|
||||
self.send_fail_on_os_error(e, 'Writing to destination file failed')
|
||||
return self.abort_in_flight()
|
||||
self.send_fail_on_os_error(ac.ftc, e, 'Writing to destination file failed')
|
||||
return abort_in_flight()
|
||||
|
||||
def commit(self) -> None:
|
||||
cmd = self.active_cmd
|
||||
if cmd is None:
|
||||
return
|
||||
try:
|
||||
if cmd.container_fmt and self.active_file is not None:
|
||||
self.active_file.seek(0, os.SEEK_SET)
|
||||
Container.extractor_for_container_fmt(self.active_file, cmd.container_fmt)(self.active_dest)
|
||||
finally:
|
||||
self.active_cmd = None
|
||||
self.active_dest = ''
|
||||
if self.active_file is not None:
|
||||
self.active_file.close()
|
||||
self.active_file = None
|
||||
|
||||
def abort_in_flight(self) -> None:
|
||||
self.active_cmd = None
|
||||
self.active_dest = ''
|
||||
if self.active_file is not None:
|
||||
self.active_file.close()
|
||||
self.active_file = None
|
||||
def commit(self, cmd_id: str) -> None:
|
||||
cmd = self.active_cmds.pop(cmd_id, None)
|
||||
if cmd is not None and cmd.ftc.container_fmt and cmd.file is not None:
|
||||
cmd.file.seek(0, os.SEEK_SET)
|
||||
Container.extractor_for_container_fmt(cmd.file, cmd.ftc.container_fmt)(cmd.dest)
|
||||
|
||||
|
||||
class TestFileTransmission(FileTransmission):
|
||||
@@ -310,5 +326,5 @@ class TestFileTransmission(FileTransmission):
|
||||
self.test_responses.append(fields)
|
||||
|
||||
def start_send(self, cmd: FileTransmissionCommand) -> None:
|
||||
self.active_cmd = cmd
|
||||
self.handle_send_confirmation({'dest': self.test_dest, 'allowed': bool(self.test_dest)})
|
||||
dest = cmd.dest or self.test_dest
|
||||
self.handle_send_confirmation(cmd.id, {'dest': dest, 'allowed': bool(dest)})
|
||||
|
||||
43
kitty_tests/file_transmission.py
Normal file
43
kitty_tests/file_transmission.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python
|
||||
# vim:fileencoding=utf-8
|
||||
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from kitty.file_transmission import (
|
||||
Action, Compression, Container, FileTransmissionCommand,
|
||||
TestFileTransmission as FileTransmission
|
||||
)
|
||||
|
||||
from . import BaseTest
|
||||
|
||||
|
||||
def serialized_cmd(**fields) -> str:
|
||||
for k, A in (('action', Action), ('container_fmt', Container), ('compression', Compression)):
|
||||
if k in fields:
|
||||
fields[k] = A[fields[k]]
|
||||
if isinstance(fields.get('data'), str):
|
||||
fields['data'] = fields['data'].encode('utf-8')
|
||||
ans = FileTransmissionCommand()
|
||||
for k in fields:
|
||||
setattr(ans, k, fields[k])
|
||||
return ans.serialize()
|
||||
|
||||
|
||||
class TestFileTransmission(BaseTest):
|
||||
|
||||
def setUp(self):
|
||||
self.tdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tdir)
|
||||
|
||||
def test_file_put(self):
|
||||
ft = FileTransmission()
|
||||
ft.handle_serialized_command(serialized_cmd(action='send', id='1', dest=os.path.join(self.tdir, '1.bin')))
|
||||
self.assertIn('1', ft.active_cmds)
|
||||
self.ae(os.path.basename(ft.active_cmds['1'].dest), '1.bin')
|
||||
self.assertIsNone(ft.active_cmds['1'].file)
|
||||
Reference in New Issue
Block a user