From 334adf9c1a974b35292a5313050743606815881f Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 5 Jan 2025 12:51:59 +0530 Subject: [PATCH] Ensure temp files and other resources are cleaned up even if kitty crashes or is SIGKILLed --- kittens/ssh/utils.py | 6 +++-- kitty/boss.py | 43 +++++++++++++++++++++++++++++----- kitty/notifications.py | 7 +++--- kitty/utils.py | 12 ---------- kitty_tests/atexit.py | 10 ++++++++ kitty_tests/notifications.py | 2 +- tools/cmd/atexit/main.go | 16 +++++++++---- tools/utils/shm/shm_fs.go | 19 +++++++++++++-- tools/utils/shm/shm_syscall.go | 10 +++++++- 9 files changed, 93 insertions(+), 32 deletions(-) diff --git a/kittens/ssh/utils.py b/kittens/ssh/utils.py index e97eae731..023c03473 100644 --- a/kittens/ssh/utils.py +++ b/kittens/ssh/utils.py @@ -85,15 +85,17 @@ def set_cwd_in_cmdline(cwd: str, argv: List[str]) -> None: def create_shared_memory(data: Any, prefix: str) -> str: - import atexit import json + import atexit from kitty.shm import SharedMemory + from kitty.fast_data_types import get_boss db = json.dumps(data).encode('utf-8') with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, prefix=prefix) as shm: shm.write_data_with_size(db) shm.flush() - atexit.register(shm.unlink) + atexit.register(shm.close) # keeps shm alive till exit + get_boss().atexit.shm_unlink(shm.name) return shm.name diff --git a/kitty/boss.py b/kitty/boss.py index 5754a49a3..3a0bcc99d 100644 --- a/kitty/boss.py +++ b/kitty/boss.py @@ -2,12 +2,12 @@ # License: GPL v3 Copyright: 2016, Kovid Goyal # Imports {{{ -import atexit import base64 import json import os import re import socket +import subprocess import sys from collections.abc import Container, Generator, Iterable, Iterator, Sequence from contextlib import contextmanager, suppress @@ -16,6 +16,7 @@ from gettext import gettext as _ from gettext import ngettext from time import sleep from typing import ( + IO, TYPE_CHECKING, Any, Callable, @@ -137,7 +138,6 @@ from .utils import ( parse_os_window_state, parse_uri_list, platform_window_id, - remove_socket_file, safe_print, sanitize_url_for_dispay_to_user, startup_notification_handler, @@ -147,6 +147,7 @@ from .utils import ( from .window import CommandOutput, CwdRequest, Window if TYPE_CHECKING: + from .rc.base import ResponseType # }}} @@ -165,16 +166,46 @@ class OSWindowDict(TypedDict): background_opacity: float -def listen_on(spec: str) -> tuple[int, str]: +class Atexit: + + def __init__(self) -> None: + self.worker: Optional[subprocess.Popen[bytes]] = None + + def _write_line(self, line: str) -> None: + if '\n' in line: + raise ValueError('Newlines not allowed in atexit arguments: {path!r}') + w = self.worker + if w is None: + w = self.worker = subprocess.Popen([kitten_exe(), '__atexit__'], stdin=subprocess.PIPE, stdout=subprocess.DEVNULL, close_fds=True) + assert w.stdin is not None + os.set_inheritable(w.stdin.fileno(), False) + assert w.stdin is not None + w.stdin.write((line + '\n').encode()) + w.stdin.flush() + + def unlink(self, path: str) -> None: + self._write_line(f'unlink {path}') + + def shm_unlink(self, path: str) -> None: + self._write_line(f'shm_unlink {path}') + + def rmtree(self, path: str) -> None: + self._write_line(f'rmtree {path}') + + +def listen_on(spec: str, robust_atexit: Atexit) -> tuple[int, str]: import socket family, address, socket_path = parse_address_spec(spec) s = socket.socket(family) - atexit.register(remove_socket_file, s, socket_path) s.bind(address) + if family == socket.AF_UNIX and socket_path: + robust_atexit.unlink(socket_path) s.listen() if isinstance(address, tuple): # tcp socket h, resolved_port = s.getsockname()[:2] spec = spec.rpartition(':')[0] + f':{resolved_port}' + import atexit + atexit.register(s.close) # prevents s from being garbage collected return s.fileno(), spec @@ -320,6 +351,7 @@ class Boss: global_shortcuts: dict[str, SingleKey], talk_fd: int = -1, ): + self.atexit = Atexit() set_layout_options(opts) self.clipboard = Clipboard() self.window_for_dispatch: Optional[Window] = None @@ -353,7 +385,7 @@ class Boss: listen_fd = -1 if args.listen_on and self.allow_remote_control in ('y', 'socket', 'socket-only', 'password'): try: - listen_fd, self.listening_on = listen_on(args.listen_on) + listen_fd, self.listening_on = listen_on(args.listen_on, self.atexit) except Exception: self.misc_config_errors.append(f'Invalid listen_on={args.listen_on}, ignoring') log_error(self.misc_config_errors[-1]) @@ -2393,7 +2425,6 @@ class Boss: notify_on_death: Optional[Callable[[int, Optional[Exception]], None]] = None, # guaranteed to be called only after event loop tick stdout: Optional[int] = None, stderr: Optional[int] = None, ) -> None: - import subprocess env = env or None if env: env_ = default_env().copy() diff --git a/kitty/notifications.py b/kitty/notifications.py index 0ceac0f0a..208a61f4e 100644 --- a/kitty/notifications.py +++ b/kitty/notifications.py @@ -47,6 +47,9 @@ class IconDataCache: if not self.cache_dir: self.cache_dir = os.path.join(self.base_cache_dir or cache_dir(), 'notifications-icons', str(os.getpid())) os.makedirs(self.cache_dir, exist_ok=True, mode=0o700) + b = get_boss() + if hasattr(b, 'atexit'): + b.atexit.rmtree(self.cache_dir) return self.cache_dir def __del__(self) -> None: @@ -834,7 +837,6 @@ class NotificationManager: log: Log = Log(), debug: bool = False, base_cache_dir: str = '', - cleanup_at_exit: bool = True, ): global debug_desktop_integration debug_desktop_integration = debug @@ -856,9 +858,6 @@ class NotificationManager: except Exception as e: self.log(f'Failed to load {script_path} with error: {e}') self.reset() - if cleanup_at_exit: - import atexit - atexit.register(self.cleanup) def reset(self) -> None: self.icon_data_cache.clear() diff --git a/kitty/utils.py b/kitty/utils.py index 727c376a0..8f9354cbd 100644 --- a/kitty/utils.py +++ b/kitty/utils.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # License: GPL v3 Copyright: 2016, Kovid Goyal -import atexit import fcntl import math import os @@ -350,17 +349,6 @@ class startup_notification_handler: end_startup_notification(self.ctx) -def remove_socket_file(s: 'Socket', path: Optional[str] = None, is_dir: Optional[Callable[[str], None]] = None) -> None: - with suppress(OSError): - s.close() - if path: - with suppress(OSError): - if is_dir: - is_dir(path) - else: - os.unlink(path) - - def unix_socket_directories() -> Iterator[str]: import tempfile home = os.path.expanduser('~') diff --git a/kitty_tests/atexit.py b/kitty_tests/atexit.py index 2ff09124f..3caa4fbb4 100644 --- a/kitty_tests/atexit.py +++ b/kitty_tests/atexit.py @@ -10,6 +10,7 @@ import subprocess import tempfile from kitty.constants import kitten_exe, kitty_exe +from kitty.shm import SharedMemory from . import BaseTest @@ -51,8 +52,16 @@ raise SystemExit(p.wait()) open(os.path.join(sdir, 'f'), 'w').close() select.select(readers, [], [], 10) self.ae(read(), str(i+2)) + shm = SharedMemory(size=64) + shm.write(b'1' * 64) + shm.flush() + p.stdin.write(f'shm_unlink {shm.name}\n'.encode()) + p.stdin.flush() + self.ae(read(), str(i+3)) self.assertTrue(os.listdir(self.tdir)) + shm2 = SharedMemory(shm.name) + self.ae(shm2.read()[:64], b'1' * 64) # Ensure child is ignoring signals os.kill(atexit_pid, signal.SIGINT) @@ -74,6 +83,7 @@ raise SystemExit(p.wait()) os.waitpid(atexit_pid, 0) except ChildProcessError: pass + self.assertRaises(FileNotFoundError, lambda: SharedMemory(shm.name)) r('close') r('terminate') diff --git a/kitty_tests/notifications.py b/kitty_tests/notifications.py index 600a7432b..0031ae3bc 100644 --- a/kitty_tests/notifications.py +++ b/kitty_tests/notifications.py @@ -91,7 +91,7 @@ class NotificationManager(NotificationManager): def do_test(self: 'TestNotifications', tdir: str) -> None: di = DesktopIntegration(None) ch = Channel() - nm = NotificationManager(di, ch, lambda *a, **kw: None, base_cache_dir=tdir, cleanup_at_exit=False) + nm = NotificationManager(di, ch, lambda *a, **kw: None, base_cache_dir=tdir) di.notification_manager = nm def reset(): diff --git a/tools/cmd/atexit/main.go b/tools/cmd/atexit/main.go index fc7b696f9..401d4cefc 100644 --- a/tools/cmd/atexit/main.go +++ b/tools/cmd/atexit/main.go @@ -2,12 +2,15 @@ package atexit import ( "bufio" + "errors" "fmt" + "io/fs" "os" "os/signal" "strings" "kitty/tools/cli" + "kitty/tools/utils/shm" ) var _ = fmt.Print @@ -40,13 +43,18 @@ func main() (rc int, err error) { if action, rest, found := strings.Cut(line, " "); found { switch action { case "unlink": - if err := os.Remove(rest); err != nil { - fmt.Fprintln(os.Stderr, "Failed to remove:", rest, "with error:", err) + if err := os.Remove(rest); err != nil && !errors.Is(err, fs.ErrNotExist) { + fmt.Fprintln(os.Stderr, "Failed to unlink:", rest, "with error:", err) + rc = 1 + } + case "shm_unlink": + if err := shm.ShmUnlink(rest); err != nil && !errors.Is(err, fs.ErrNotExist) { + fmt.Fprintln(os.Stderr, "Failed to shm_unlink:", rest, "with error:", err) rc = 1 } case "rmtree": - if err := os.RemoveAll(rest); err != nil { - fmt.Fprintln(os.Stderr, "Failed to remove:", rest, "with error:", err) + if err := os.RemoveAll(rest); err != nil && !errors.Is(err, fs.ErrNotExist) { + fmt.Fprintln(os.Stderr, "Failed to rmtree:", rest, "with error:", err) rc = 1 } } diff --git a/tools/utils/shm/shm_fs.go b/tools/utils/shm/shm_fs.go index 729f817f2..6499d3fbb 100644 --- a/tools/utils/shm/shm_fs.go +++ b/tools/utils/shm/shm_fs.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "kitty/tools/utils" @@ -28,6 +29,16 @@ type file_based_mmap struct { special_name string } +func ShmUnlink(name string) error { + if runtime.GOOS == "openbsd" { + return os.Remove(openbsd_shm_path(name)) + } + if strings.HasPrefix(name, "/") { + name = name[1:] + } + return os.Remove(filepath.Join(SHM_DIR, name)) +} + func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool, special_name string) (MMap, error) { if truncate { err := truncate_or_unlink(f, size, os.Remove) @@ -106,11 +117,15 @@ func (self *file_based_mmap) Unlink() (err error) { func (self *file_based_mmap) IsFileSystemBacked() bool { return true } +func openbsd_shm_path(name string) string { + hash := sha256.Sum256(utils.UnsafeStringToBytes(name)) + return filepath.Join(SHM_DIR, utils.UnsafeBytesToString(hash[:])+".shm") +} + func file_path_from_name(name string) string { // See https://github.com/openbsd/src/blob/master/lib/libc/gen/shm_open.c if runtime.GOOS == "openbsd" { - hash := sha256.Sum256(utils.UnsafeStringToBytes(name)) - return filepath.Join(SHM_DIR, utils.UnsafeBytesToString(hash[:])+".shm") + return openbsd_shm_path(name) } return filepath.Join(SHM_DIR, name) } diff --git a/tools/utils/shm/shm_syscall.go b/tools/utils/shm/shm_syscall.go index 95408adb7..38bd3ffe5 100644 --- a/tools/utils/shm/shm_syscall.go +++ b/tools/utils/shm/shm_syscall.go @@ -37,7 +37,11 @@ func shm_unlink(name string) (err error) { _, _, errno := unix.Syscall(unix.SYS_SHM_UNLINK, uintptr(unsafe.Pointer(bname)), 0, 0) if errno != unix.EINTR { if errno != 0 { - err = fmt.Errorf("shm_unlink() failed with error: %w", errno) + if errno == unix.ENOENT { + err = fs.ErrNotExist + } else { + err = fmt.Errorf("shm_unlink() failed with error: %w", errno) + } } break } @@ -45,6 +49,10 @@ func shm_unlink(name string) (err error) { return } +func ShmUnlink(name string) error { + return shm_unlink(name) +} + func shm_open(name string, flags, perm int) (ans *os.File, err error) { bname := BytePtrFromString(name) var fd uintptr