mirror of
https://github.com/kovidgoyal/kitty
synced 2026-06-08 22:28:24 +02:00
Allow interleaving file transmissions
This commit is contained in:
@@ -6,8 +6,9 @@ import copy
|
|||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from base64 import standard_b64decode
|
from base64 import standard_b64decode, standard_b64encode
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from functools import partial
|
||||||
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from kitty.fast_data_types import OSC, get_boss
|
from kitty.fast_data_types import OSC, get_boss
|
||||||
@@ -66,6 +67,24 @@ class FileTransmissionCommand:
|
|||||||
dest: str = ''
|
dest: str = ''
|
||||||
data: bytes = b''
|
data: bytes = b''
|
||||||
|
|
||||||
|
def serialize(self) -> str:
|
||||||
|
ans = [f'action={self.action.name}']
|
||||||
|
if self.container_fmt is not Container.none:
|
||||||
|
ans.append(f'container_fmt={self.container_fmt.name}')
|
||||||
|
if self.compression is not Compression.none:
|
||||||
|
ans.append(f'compression={self.compression.name}')
|
||||||
|
for x in ('id', 'secret', 'mime', 'quiet'):
|
||||||
|
val = getattr(self, x)
|
||||||
|
if val:
|
||||||
|
ans.append(f'{x}={val}')
|
||||||
|
if self.dest:
|
||||||
|
val = standard_b64encode(self.dest.encode('utf-8')).decode('ascii')
|
||||||
|
ans.append(f'dest={val}')
|
||||||
|
if self.data:
|
||||||
|
val = standard_b64encode(self.data).decode('ascii')
|
||||||
|
ans.append(f'data={val}')
|
||||||
|
return ';'.join(ans)
|
||||||
|
|
||||||
|
|
||||||
def parse_command(data: str) -> FileTransmissionCommand:
|
def parse_command(data: str) -> FileTransmissionCommand:
|
||||||
ans = FileTransmissionCommand()
|
ans = FileTransmissionCommand()
|
||||||
@@ -173,15 +192,23 @@ class ZipExtractor:
|
|||||||
self.zf.extract(zinfo, targetpath)
|
self.zf.extract(zinfo, targetpath)
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveCommand:
|
||||||
|
ftc: FileTransmissionCommand
|
||||||
|
file: Optional[IO[bytes]] = None
|
||||||
|
dest: str = ''
|
||||||
|
decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor()
|
||||||
|
|
||||||
|
def __init__(self, ftc: FileTransmissionCommand) -> None:
|
||||||
|
self.ftc = ftc
|
||||||
|
|
||||||
|
|
||||||
class FileTransmission:
|
class FileTransmission:
|
||||||
|
|
||||||
active_cmd: Optional[FileTransmissionCommand] = None
|
active_cmds: Dict[str, ActiveCommand]
|
||||||
active_file: Optional[IO[bytes]] = None
|
|
||||||
active_dest: str = ''
|
|
||||||
active_decompressor: Union[IdentityDecompressor, ZlibDecompressor] = IdentityDecompressor()
|
|
||||||
|
|
||||||
def __init__(self, window_id: int):
|
def __init__(self, window_id: int):
|
||||||
self.window_id = window_id
|
self.window_id = window_id
|
||||||
|
self.active_cmds = {}
|
||||||
|
|
||||||
def handle_serialized_command(self, data: str) -> None:
|
def handle_serialized_command(self, data: str) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -189,19 +216,22 @@ class FileTransmission:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_error(f'Failed to parse file transmission command with error: {e}')
|
log_error(f'Failed to parse file transmission command with error: {e}')
|
||||||
return
|
return
|
||||||
if self.active_cmd is not None:
|
if cmd.id in self.active_cmds and cmd.action not in (Action.data, Action.end_data):
|
||||||
if cmd.action not in (Action.data, Action.end_data):
|
|
||||||
log_error('File transmission command received while another is in flight, aborting')
|
log_error('File transmission command received while another is in flight, aborting')
|
||||||
self.abort_in_flight()
|
del self.active_cmds[cmd.id]
|
||||||
|
|
||||||
if cmd.action is Action.send:
|
if cmd.action is Action.send:
|
||||||
|
self.active_cmds[cmd.id] = ActiveCommand(cmd)
|
||||||
self.start_send(cmd)
|
self.start_send(cmd)
|
||||||
elif cmd.action in (Action.data, Action.end_data):
|
elif cmd.action in (Action.data, Action.end_data):
|
||||||
|
if cmd.id not in self.active_cmds:
|
||||||
|
log_error('File transmission data command received with unknown id')
|
||||||
|
return
|
||||||
self.add_data(cmd)
|
self.add_data(cmd)
|
||||||
if cmd.action is Action.end_data and self.active_cmd is not None:
|
if cmd.action is Action.end_data and cmd.id in self.active_cmds:
|
||||||
self.commit()
|
self.commit(cmd.id)
|
||||||
|
|
||||||
def send_response(self, **fields: str) -> None:
|
def send_response(self, ac: Optional[FileTransmissionCommand], **fields: str) -> None:
|
||||||
ac = self.active_cmd
|
|
||||||
if ac is None:
|
if ac is None:
|
||||||
return
|
return
|
||||||
if 'id' not in fields and ac.id:
|
if 'id' not in fields and ac.id:
|
||||||
@@ -215,88 +245,74 @@ class FileTransmission:
|
|||||||
window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
|
window.screen.send_escape_code_to_child(OSC, ';'.join(f'{k}={v}' for k, v in fields.items()))
|
||||||
|
|
||||||
def start_send(self, cmd: FileTransmissionCommand) -> None:
|
def start_send(self, cmd: FileTransmissionCommand) -> None:
|
||||||
self.active_cmd = cmd
|
|
||||||
boss = get_boss()
|
boss = get_boss()
|
||||||
window = boss.window_id_map.get(self.window_id)
|
window = boss.window_id_map.get(self.window_id)
|
||||||
if window is not None:
|
if window is not None:
|
||||||
boss._run_kitten(
|
boss._run_kitten(
|
||||||
'transfer_ask', ['put', 'multiple' if cmd.container_fmt else 'single', cmd.dest],
|
'transfer_ask', ['put', 'multiple' if cmd.container_fmt else 'single', cmd.dest],
|
||||||
window=window, custom_callback=self.handle_send_confirmation
|
window=window, custom_callback=partial(self.handle_send_confirmation, cmd.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_send_confirmation(self, data: 'Response', *a: Any) -> None:
|
def handle_send_confirmation(self, cmd_id: str, data: 'Response', *a: Any) -> None:
|
||||||
cmd = self.active_cmd
|
cmd = self.active_cmds.get(cmd_id)
|
||||||
if cmd is None:
|
if cmd is None:
|
||||||
return
|
return
|
||||||
if data['allowed']:
|
if data['allowed']:
|
||||||
self.active_dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest'])))
|
cmd.dest = os.path.abspath(os.path.realpath(os.path.abspath(data['dest'])))
|
||||||
self.active_decompressor = ZlibDecompressor() if cmd.compression is Compression.zlib else IdentityDecompressor()
|
cmd.decompressor = ZlibDecompressor() if cmd.ftc.compression is Compression.zlib else IdentityDecompressor()
|
||||||
if cmd.quiet:
|
if cmd.ftc.quiet:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.active_cmd = None
|
del self.active_cmds[cmd_id]
|
||||||
self.active_dest = ''
|
if cmd.ftc.quiet > 1:
|
||||||
if cmd.quiet > 1:
|
|
||||||
return
|
return
|
||||||
self.send_response(status='OK' if data['allowed'] else 'EPERM:User refused the transfer')
|
self.send_response(cmd.ftc, status='OK' if data['allowed'] else 'EPERM:User refused the transfer')
|
||||||
|
|
||||||
def send_fail_on_os_error(self, err: OSError, msg: str) -> None:
|
def send_fail_on_os_error(self, ac: Optional[FileTransmissionCommand], err: OSError, msg: str) -> None:
|
||||||
ac = self.active_cmd
|
|
||||||
if ac is None or ac.quiet < 2:
|
if ac is None or ac.quiet < 2:
|
||||||
return
|
return
|
||||||
errname = errno.errorcode.get(err.errno, 'EFAIL')
|
errname = errno.errorcode.get(err.errno, 'EFAIL')
|
||||||
self.send_response(status=f'{errname}:{msg}')
|
self.send_response(ac, status=f'{errname}:{msg}')
|
||||||
|
|
||||||
def add_data(self, cmd: FileTransmissionCommand) -> None:
|
def add_data(self, cmd: FileTransmissionCommand) -> None:
|
||||||
ac = self.active_cmd
|
ac = self.active_cmds.get(cmd.id)
|
||||||
if ac is None or not self.active_dest:
|
|
||||||
return
|
def abort_in_flight() -> None:
|
||||||
if self.active_file is None:
|
self.active_cmds.pop(cmd.id, None)
|
||||||
|
|
||||||
|
if ac is None or not ac.dest:
|
||||||
|
return abort_in_flight()
|
||||||
|
|
||||||
|
if ac.file is None:
|
||||||
try:
|
try:
|
||||||
os.makedirs(os.path.dirname(self.active_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(ac.dest), exist_ok=True)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
self.send_fail_on_os_error(e, 'Creating destination directory failed')
|
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination directory failed')
|
||||||
return self.abort_in_flight()
|
return abort_in_flight()
|
||||||
if ac.container_fmt is Container.none:
|
if ac.ftc.container_fmt is Container.none:
|
||||||
try:
|
try:
|
||||||
self.active_file = open(self.active_dest, 'wb')
|
ac.file = open(ac.dest, 'wb')
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
self.send_fail_on_os_error(e, 'Creating destination file failed')
|
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination file failed')
|
||||||
return self.abort_in_flight()
|
return abort_in_flight()
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.active_file = tempfile.TemporaryFile(dir=os.path.dirname(self.active_dest))
|
ac.file = tempfile.TemporaryFile(dir=os.path.dirname(ac.dest))
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
self.send_fail_on_os_error(e, 'Creating destination temp file failed')
|
self.send_fail_on_os_error(ac.ftc, e, 'Creating destination temp file failed')
|
||||||
return self.abort_in_flight()
|
return abort_in_flight()
|
||||||
data = self.active_decompressor(cmd.data, cmd.action is Action.end_data)
|
data = ac.decompressor(cmd.data, cmd.action is Action.end_data)
|
||||||
try:
|
try:
|
||||||
self.active_file.write(data)
|
ac.file.write(data)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
self.send_fail_on_os_error(e, 'Writing to destination file failed')
|
self.send_fail_on_os_error(ac.ftc, e, 'Writing to destination file failed')
|
||||||
return self.abort_in_flight()
|
return abort_in_flight()
|
||||||
|
|
||||||
def commit(self) -> None:
|
def commit(self, cmd_id: str) -> None:
|
||||||
cmd = self.active_cmd
|
cmd = self.active_cmds.pop(cmd_id, None)
|
||||||
if cmd is None:
|
if cmd is not None and cmd.ftc.container_fmt and cmd.file is not None:
|
||||||
return
|
cmd.file.seek(0, os.SEEK_SET)
|
||||||
try:
|
Container.extractor_for_container_fmt(cmd.file, cmd.ftc.container_fmt)(cmd.dest)
|
||||||
if cmd.container_fmt and self.active_file is not None:
|
|
||||||
self.active_file.seek(0, os.SEEK_SET)
|
|
||||||
Container.extractor_for_container_fmt(self.active_file, cmd.container_fmt)(self.active_dest)
|
|
||||||
finally:
|
|
||||||
self.active_cmd = None
|
|
||||||
self.active_dest = ''
|
|
||||||
if self.active_file is not None:
|
|
||||||
self.active_file.close()
|
|
||||||
self.active_file = None
|
|
||||||
|
|
||||||
def abort_in_flight(self) -> None:
|
|
||||||
self.active_cmd = None
|
|
||||||
self.active_dest = ''
|
|
||||||
if self.active_file is not None:
|
|
||||||
self.active_file.close()
|
|
||||||
self.active_file = None
|
|
||||||
|
|
||||||
|
|
||||||
class TestFileTransmission(FileTransmission):
|
class TestFileTransmission(FileTransmission):
|
||||||
@@ -310,5 +326,5 @@ class TestFileTransmission(FileTransmission):
|
|||||||
self.test_responses.append(fields)
|
self.test_responses.append(fields)
|
||||||
|
|
||||||
def start_send(self, cmd: FileTransmissionCommand) -> None:
|
def start_send(self, cmd: FileTransmissionCommand) -> None:
|
||||||
self.active_cmd = cmd
|
dest = cmd.dest or self.test_dest
|
||||||
self.handle_send_confirmation({'dest': self.test_dest, 'allowed': bool(self.test_dest)})
|
self.handle_send_confirmation(cmd.id, {'dest': dest, 'allowed': bool(dest)})
|
||||||
|
|||||||
43
kitty_tests/file_transmission.py
Normal file
43
kitty_tests/file_transmission.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# vim:fileencoding=utf-8
|
||||||
|
# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from kitty.file_transmission import (
|
||||||
|
Action, Compression, Container, FileTransmissionCommand,
|
||||||
|
TestFileTransmission as FileTransmission
|
||||||
|
)
|
||||||
|
|
||||||
|
from . import BaseTest
|
||||||
|
|
||||||
|
|
||||||
|
def serialized_cmd(**fields) -> str:
|
||||||
|
for k, A in (('action', Action), ('container_fmt', Container), ('compression', Compression)):
|
||||||
|
if k in fields:
|
||||||
|
fields[k] = A[fields[k]]
|
||||||
|
if isinstance(fields.get('data'), str):
|
||||||
|
fields['data'] = fields['data'].encode('utf-8')
|
||||||
|
ans = FileTransmissionCommand()
|
||||||
|
for k in fields:
|
||||||
|
setattr(ans, k, fields[k])
|
||||||
|
return ans.serialize()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileTransmission(BaseTest):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tdir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tdir)
|
||||||
|
|
||||||
|
def test_file_put(self):
|
||||||
|
ft = FileTransmission()
|
||||||
|
ft.handle_serialized_command(serialized_cmd(action='send', id='1', dest=os.path.join(self.tdir, '1.bin')))
|
||||||
|
self.assertIn('1', ft.active_cmds)
|
||||||
|
self.ae(os.path.basename(ft.active_cmds['1'].dest), '1.bin')
|
||||||
|
self.assertIsNone(ft.active_cmds['1'].file)
|
||||||
Reference in New Issue
Block a user