Basic file transfer works for sending

This commit is contained in:
Kovid Goyal
2021-09-12 10:02:23 +05:30
parent efaeb0f0b2
commit 31ead3f088
3 changed files with 81 additions and 47 deletions

View File

@@ -23,9 +23,10 @@ from kitty.file_transmission import (
from kitty.types import run_once from kitty.types import run_once
from ..tui.handler import Handler from ..tui.handler import Handler
from ..tui.loop import Loop from ..tui.loop import Loop, debug
_cwd = _home = '' _cwd = _home = ''
debug
def abspath(path: str) -> str: def abspath(path: str) -> str:
@@ -132,6 +133,7 @@ class FileState(NameReprEnum):
waiting_for_data = auto() waiting_for_data = auto()
transmitting = auto() transmitting = auto()
finished = auto() finished = auto()
acknowledged = auto()
class File: class File:
@@ -155,29 +157,36 @@ class File:
self.symbolic_link_target = '' self.symbolic_link_target = ''
self.stat_result = stat_result self.stat_result = stat_result
self.file_type = file_type self.file_type = file_type
self.compression = Compression.zlib if self.file_size > 2048 else Compression.none self.compression = Compression.zlib if self.file_type is FileType.regular and self.file_size > 4096 else Compression.none
self.compressor: Union[ZlibCompressor, IdentityCompressor] = ZlibCompressor() if self.compression is Compression.zlib else IdentityCompressor()
self.remote_final_path = '' self.remote_final_path = ''
self.remote_initial_size = -1 self.remote_initial_size = -1
self.err_msg = '' self.err_msg = ''
self.actual_file: Optional[IO[bytes]] = None self.actual_file: Optional[IO[bytes]] = None
self.transmitted_bytes = 0 self.transmitted_bytes = 0
def next_chunk(self, sz: int = 4096) -> bytes: def next_chunk(self, sz: int = 1024 * 1024) -> Tuple[bytes, int]:
if self.file_type is FileType.symlink: if self.file_type is FileType.symlink:
self.state = FileState.finished self.state = FileState.finished
return self.symbolic_link_target.encode('utf-8') ans = self.symbolic_link_target.encode('utf-8')
return ans, len(ans)
if self.file_type is FileType.link: if self.file_type is FileType.link:
self.state = FileState.finished self.state = FileState.finished
return self.hard_link_target.encode('utf-8') ans = self.hard_link_target.encode('utf-8')
return ans, len(ans)
if self.actual_file is None: if self.actual_file is None:
self.actual_file = open(self.expanded_local_path, 'rb') self.actual_file = open(self.expanded_local_path, 'rb')
chunk = self.actual_file.read(sz) chunk = self.actual_file.read(sz)
uncompressed_sz = len(chunk)
is_last = not chunk or self.actual_file.tell() >= self.file_size is_last = not chunk or self.actual_file.tell() >= self.file_size
cchunk = self.compressor.compress(chunk)
if is_last and not isinstance(self.compressor, IdentityCompressor):
cchunk += self.compressor.flush()
if is_last: if is_last:
self.state = FileState.finished self.state = FileState.finished
self.actual_file.close() self.actual_file.close()
self.actual_file = None self.actual_file = None
return chunk return cchunk, uncompressed_sz
def metadata_command(self) -> FileTransmissionCommand: def metadata_command(self) -> FileTransmissionCommand:
return FileTransmissionCommand( return FileTransmissionCommand(
@@ -186,22 +195,6 @@ class File:
file_id=self.file_id, file_id=self.file_id,
) )
def data_commands(self) -> Iterator[FileTransmissionCommand]:
if self.file_type is FileType.symlink:
yield FileTransmissionCommand(action=Action.end_data, data=self.symbolic_link_target.encode('utf-8'), file_id=self.file_id)
elif self.file_type is FileType.link:
yield FileTransmissionCommand(action=Action.end_data, data=f'fid:{self.hard_link_target}'.encode('utf-8'), file_id=self.file_id)
elif self.file_type is FileType.regular:
compressor: Union[IdentityCompressor, ZlibCompressor] = ZlibCompressor() if self.compression is Compression.zlib else IdentityCompressor()
with open(self.local_path, 'rb') as f:
keep_going = True
while keep_going:
data = f.read(4096)
keep_going = bool(data)
data = compressor.compress(data) if data else compressor.flush()
if data or not keep_going:
yield FileTransmissionCommand(action=Action.data if keep_going else Action.end_data, data=data, file_id=self.file_id)
def process(cli_opts: TransferCLIOptions, paths: Iterable[str], remote_base: str) -> Iterator[File]: def process(cli_opts: TransferCLIOptions, paths: Iterable[str], remote_base: str) -> Iterator[File]:
counter = count(1) counter = count(1)
@@ -293,7 +286,6 @@ class SendState(NameReprEnum):
permission_granted = auto() permission_granted = auto()
permission_denied = auto() permission_denied = auto()
canceled = auto() canceled = auto()
finished = auto()
class SendManager: class SendManager:
@@ -304,14 +296,19 @@ class SendManager:
self.fid_map = {f.file_id: f for f in self.files} self.fid_map = {f.file_id: f for f in self.files}
self.request_id = request_id self.request_id = request_id
self.state = SendState.waiting_for_permission self.state = SendState.waiting_for_permission
self.all_done = False self.all_acknowledged = False
self.all_started = False self.all_started = False
self.active_idx: Optional[int] = None self.active_idx: Optional[int] = None
self.current_chunk_uncompressed_sz: Optional[int] = None
self.prefix = f'\x1b]{FILE_TRANSFER_CODE};id={self.request_id};'
self.suffix = '\x1b\\'
@property @property
def active_file(self) -> Optional[File]: def active_file(self) -> Optional[File]:
if self.active_idx is not None: if self.active_idx is not None:
return self.files[self.active_idx] ans = self.files[self.active_idx]
if ans.state is FileState.transmitting:
return ans
def activate_next_ready_file(self) -> Optional[File]: def activate_next_ready_file(self) -> Optional[File]:
for i, f in enumerate(self.files): for i, f in enumerate(self.files):
@@ -319,34 +316,46 @@ class SendManager:
self.active_idx = i self.active_idx = i
self.update_collective_statuses() self.update_collective_statuses()
return f return f
self.active_idx = None
self.update_collective_statuses() self.update_collective_statuses()
def update_collective_statuses(self) -> None: def update_collective_statuses(self) -> None:
found_not_started = found_not_done = False found_not_started = found_not_done = False
for f in self.files: for f in self.files:
if f.state is not FileState.finished: if f.state is not FileState.acknowledged:
found_not_done = True found_not_done = True
if f.state is FileState.waiting_for_start: if f.state is FileState.waiting_for_start:
found_not_started = True found_not_started = True
if found_not_started and found_not_done: if found_not_started and found_not_done:
break break
self.all_done = not found_not_done self.all_acknowledged = not found_not_done
self.all_started = not found_not_started self.all_started = not found_not_started
def start_transfer(self) -> str: def start_transfer(self) -> str:
return FileTransmissionCommand(action=Action.send, password=self.password).serialize() return FileTransmissionCommand(action=Action.send, password=self.password).serialize()
def next_chunk(self) -> str: def next_chunks(self) -> Iterator[str]:
if self.active_file is None: if self.active_file is None:
self.activate_next_ready_file() self.activate_next_ready_file()
af = self.active_file af = self.active_file
if af is None: if af is None:
return '' return
chunk = af.next_chunk() chunk = b''
self.current_chunk_uncompressed_sz = 0
while af.state is not FileState.finished and not chunk:
chunk, usz = af.next_chunk()
self.current_chunk_uncompressed_sz += usz
is_last = af.state is FileState.finished is_last = af.state is FileState.finished
if is_last: if is_last:
self.activate_next_ready_file() self.activate_next_ready_file()
return FileTransmissionCommand(action=Action.end_data if is_last else Action.data, file_id=af.file_id, data=chunk).serialize() mv = memoryview(chunk)
pos = 0
limit = len(chunk)
while pos < limit:
cc = mv[pos:pos + 4096]
pos += 4096
final = is_last and pos >= limit
yield FileTransmissionCommand(action=Action.end_data if final else Action.data, file_id=af.file_id, data=cc).serialize()
def send_file_metadata(self) -> Iterator[str]: def send_file_metadata(self) -> Iterator[str]:
for f in self.files: for f in self.files:
@@ -363,7 +372,7 @@ class SendManager:
else: else:
if ftc.name and not file.remote_final_path: if ftc.name and not file.remote_final_path:
file.remote_final_path = ftc.name file.remote_final_path = ftc.name
file.state = FileState.finished file.state = FileState.acknowledged
if ftc.status != 'OK': if ftc.status != 'OK':
file.err_msg = ftc.status file.err_msg = ftc.status
if file is self.active_file: if file is self.active_file:
@@ -387,19 +396,22 @@ class Send(Handler):
self.cli_opts = cli_opts self.cli_opts = cli_opts
self.transmit_started = False self.transmit_started = False
self.file_metadata_sent = False self.file_metadata_sent = False
self.quit_after_write_code: Optional[int] = None
def send_payload(self, payload: str) -> None: def send_payload(self, payload: str) -> None:
self.write(f'\x1b]{FILE_TRANSFER_CODE};id={self.manager.request_id};') self.write(self.manager.prefix)
self.write(payload) self.write(payload)
self.write(b'\x1b\\') self.write(self.manager.suffix)
def on_file_transfer_response(self, ftc: FileTransmissionCommand) -> None: def on_file_transfer_response(self, ftc: FileTransmissionCommand) -> None:
if self.quit_after_write_code is not None:
return
if ftc.id != self.manager.request_id: if ftc.id != self.manager.request_id:
return return
if ftc.status == 'CANCELED': if ftc.status == 'CANCELED':
self.quit_loop(1) self.quit_loop(1)
return return
if self.manager.state in (SendState.finished, SendState.canceled): if self.manager.state is SendState.canceled:
return return
before = self.manager.state before = self.manager.state
self.manager.on_file_transfer_response(ftc) self.manager.on_file_transfer_response(ftc)
@@ -428,17 +440,20 @@ class Send(Handler):
self.transmit_next_chunk() self.transmit_next_chunk()
def transmit_next_chunk(self) -> None: def transmit_next_chunk(self) -> None:
chunk = self.manager.next_chunk() for chunk in self.manager.next_chunks():
if chunk:
self.send_payload(chunk) self.send_payload(chunk)
else: else:
if self.manager.active_file is None: if self.manager.all_acknowledged:
self.transfer_finished() self.transfer_finished()
def transfer_finished(self) -> None: def transfer_finished(self) -> None:
self.quit_loop(0) self.send_payload(FileTransmissionCommand(action=Action.finish).serialize())
self.quit_after_write_code = 0
def on_writing_finished(self) -> None: def on_writing_finished(self) -> None:
if self.quit_after_write_code is not None:
self.quit_loop(self.quit_after_write_code)
return
if self.manager.state is SendState.permission_granted: if self.manager.state is SendState.permission_granted:
self.loop_tick() self.loop_tick()
@@ -464,11 +479,15 @@ class Send(Handler):
self.file_metadata_sent = True self.file_metadata_sent = True
def on_term(self) -> None: def on_term(self) -> None:
if self.quit_after_write_code is not None:
return
self.cmd.styled('Terminate requested, cancelling transfer, transferred files are in undefined state', fg='red') self.cmd.styled('Terminate requested, cancelling transfer, transferred files are in undefined state', fg='red')
self.print() self.print()
self.abort_transfer(delay=2) self.abort_transfer(delay=2)
def on_interrupt(self) -> None: def on_interrupt(self) -> None:
if self.quit_after_write_code is not None:
return
if self.manager.state is SendState.canceled: if self.manager.state is SendState.canceled:
self.print('Waiting for canceled acknowledgement from terminal, will abort in a few seconds if no response received') self.print('Waiting for canceled acknowledgement from terminal, will abort in a few seconds if no response received')
return return

View File

@@ -82,7 +82,7 @@ class TransmissionError(Exception):
self.name = name self.name = name
self.size = size self.size = size
def as_escape_code(self, request_id: str = '') -> str: def as_escape_code(self, request_id: str) -> str:
name = self.code if isinstance(self.code, str) else self.code.name name = self.code if isinstance(self.code, str) else self.code.name
if self.human_msg: if self.human_msg:
name += ':' + self.human_msg name += ':' + self.human_msg
@@ -105,9 +105,19 @@ class FileTransmissionCommand:
mtime: int = -1 mtime: int = -1
permissions: int = -1 permissions: int = -1
size: int = -1 size: int = -1
data: bytes = b''
name: str = field(default='', metadata={'base64': True}) name: str = field(default='', metadata={'base64': True})
status: str = field(default='', metadata={'base64': True}) status: str = field(default='', metadata={'base64': True})
data: bytes = field(default=b'', repr=False)
def __repr__(self) -> str:
ans = []
for k in fields(self):
if not k.repr:
continue
val = getattr(self, k.name)
if val != k.default:
ans.append(f'{k.name}={val!r}')
return 'FTC(' + ', '.join(ans) + ')'
def asdict(self, keep_defaults: bool = False) -> Dict[str, Union[str, int, bytes]]: def asdict(self, keep_defaults: bool = False) -> Dict[str, Union[str, int, bytes]]:
ans = {} ans = {}
@@ -222,6 +232,7 @@ class DestFile:
self.decompressor: Union[ZlibDecompressor, IdentityDecompressor] = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor() self.decompressor: Union[ZlibDecompressor, IdentityDecompressor] = ZlibDecompressor() if ftc.compression is Compression.zlib else IdentityDecompressor()
self.closed = self.ftype is FileType.directory self.closed = self.ftype is FileType.directory
self.actual_file: Optional[IO[bytes]] = None self.actual_file: Optional[IO[bytes]] = None
self.failed = False
def __repr__(self) -> str: def __repr__(self) -> str:
return f'DestFile(name={self.name}, file_id={self.file_id}, actual_file={self.actual_file})' return f'DestFile(name={self.name}, file_id={self.file_id}, actual_file={self.actual_file})'
@@ -351,9 +362,12 @@ class ActiveReceive:
df = self.files.get(ftc.file_id) df = self.files.get(ftc.file_id)
if df is None: if df is None:
raise TransmissionError(file_id=ftc.file_id, msg='Cannot write to a file without first starting it') raise TransmissionError(file_id=ftc.file_id, msg='Cannot write to a file without first starting it')
if df.failed:
return df
try: try:
df.write_data(self.files, ftc.data, ftc.action is Action.end_data) df.write_data(self.files, ftc.data, ftc.action is Action.end_data)
except Exception: except Exception:
df.failed = True
df.close() df.close()
raise raise
return df return df
@@ -473,6 +487,8 @@ class FileTransmission:
elif cmd.action in (Action.data, Action.end_data): elif cmd.action in (Action.data, Action.end_data):
try: try:
df = ar.add_data(cmd) df = ar.add_data(cmd)
if df.failed:
return
if df.closed and ar.send_acknowledgements: if df.closed and ar.send_acknowledgements:
self.send_status_response(code=ErrorCode.OK, request_id=ar.id, file_id=df.file_id, name=df.name) self.send_status_response(code=ErrorCode.OK, request_id=ar.id, file_id=df.file_id, name=df.name)
except TransmissionError as err: except TransmissionError as err:
@@ -508,7 +524,7 @@ class FileTransmission:
def send_transmission_error(self, request_id: str, err: TransmissionError) -> bool: def send_transmission_error(self, request_id: str, err: TransmissionError) -> bool:
if err.transmit: if err.transmit:
return self.write_osc_to_child(request_id, err.as_escape_code()) return self.write_osc_to_child(request_id, err.as_escape_code(request_id))
return True return True
def write_osc_to_child(self, request_id: str, payload: str, appendleft: bool = False) -> bool: def write_osc_to_child(self, request_id: str, payload: str, appendleft: bool = False) -> bool:

View File

@@ -127,17 +127,16 @@ class TestFileTransmission(BaseTest):
self.ae(ft.test_responses, [response(status='OK')]) self.ae(ft.test_responses, [response(status='OK')])
ft.handle_serialized_command(serialized_cmd(action='file', name=dest, compression='zlib')) ft.handle_serialized_command(serialized_cmd(action='file', name=dest, compression='zlib'))
self.assertPathEqual(ft.active_file().name, dest) self.assertPathEqual(ft.active_file().name, dest)
odata = b'abcd' * 1024 odata = b'abcd' * 1024 + b'xyz'
data = zlib.compress(odata) c = zlib.compressobj()
ft.handle_serialized_command(serialized_cmd(action='data', data=data[:len(data)//2])) ft.handle_serialized_command(serialized_cmd(action='data', data=c.compress(odata)))
self.assertTrue(os.path.exists(dest)) self.assertTrue(os.path.exists(dest))
ft.handle_serialized_command(serialized_cmd(action='end_data', data=data[len(data)//2:])) ft.handle_serialized_command(serialized_cmd(action='end_data', data=c.flush()))
self.ae(ft.test_responses, [response(status='OK'), response(status='STARTED', name=dest), response(status='OK', name=dest)]) self.ae(ft.test_responses, [response(status='OK'), response(status='STARTED', name=dest), response(status='OK', name=dest)])
ft.handle_serialized_command(serialized_cmd(action='finish')) ft.handle_serialized_command(serialized_cmd(action='finish'))
with open(dest, 'rb') as f: with open(dest, 'rb') as f:
self.ae(f.read(), odata) self.ae(f.read(), odata)
del odata del odata
del data
# overwriting # overwriting
self.clean_tdir() self.clean_tdir()