Ensure temp files and other resources are cleaned up even if kitty crashes or is SIGKILLed

This commit is contained in:
Kovid Goyal
2025-01-05 12:51:59 +05:30
parent 48d5c90bb8
commit 334adf9c1a
9 changed files with 93 additions and 32 deletions

View File

@@ -85,15 +85,17 @@ def set_cwd_in_cmdline(cwd: str, argv: List[str]) -> None:
def create_shared_memory(data: Any, prefix: str) -> str: def create_shared_memory(data: Any, prefix: str) -> str:
import atexit
import json import json
import atexit
from kitty.shm import SharedMemory from kitty.shm import SharedMemory
from kitty.fast_data_types import get_boss
db = json.dumps(data).encode('utf-8') db = json.dumps(data).encode('utf-8')
with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, prefix=prefix) as shm: with SharedMemory(size=len(db) + SharedMemory.num_bytes_for_size, prefix=prefix) as shm:
shm.write_data_with_size(db) shm.write_data_with_size(db)
shm.flush() shm.flush()
atexit.register(shm.unlink) atexit.register(shm.close) # keeps shm alive till exit
get_boss().atexit.shm_unlink(shm.name)
return shm.name return shm.name

View File

@@ -2,12 +2,12 @@
# License: GPL v3 Copyright: 2016, Kovid Goyal <kovid at kovidgoyal.net> # License: GPL v3 Copyright: 2016, Kovid Goyal <kovid at kovidgoyal.net>
# Imports {{{ # Imports {{{
import atexit
import base64 import base64
import json import json
import os import os
import re import re
import socket import socket
import subprocess
import sys import sys
from collections.abc import Container, Generator, Iterable, Iterator, Sequence from collections.abc import Container, Generator, Iterable, Iterator, Sequence
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
@@ -16,6 +16,7 @@ from gettext import gettext as _
from gettext import ngettext from gettext import ngettext
from time import sleep from time import sleep
from typing import ( from typing import (
IO,
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
@@ -137,7 +138,6 @@ from .utils import (
parse_os_window_state, parse_os_window_state,
parse_uri_list, parse_uri_list,
platform_window_id, platform_window_id,
remove_socket_file,
safe_print, safe_print,
sanitize_url_for_dispay_to_user, sanitize_url_for_dispay_to_user,
startup_notification_handler, startup_notification_handler,
@@ -147,6 +147,7 @@ from .utils import (
from .window import CommandOutput, CwdRequest, Window from .window import CommandOutput, CwdRequest, Window
if TYPE_CHECKING: if TYPE_CHECKING:
from .rc.base import ResponseType from .rc.base import ResponseType
# }}} # }}}
@@ -165,16 +166,46 @@ class OSWindowDict(TypedDict):
background_opacity: float background_opacity: float
def listen_on(spec: str) -> tuple[int, str]: class Atexit:
def __init__(self) -> None:
self.worker: Optional[subprocess.Popen[bytes]] = None
def _write_line(self, line: str) -> None:
if '\n' in line:
raise ValueError('Newlines not allowed in atexit arguments: {path!r}')
w = self.worker
if w is None:
w = self.worker = subprocess.Popen([kitten_exe(), '__atexit__'], stdin=subprocess.PIPE, stdout=subprocess.DEVNULL, close_fds=True)
assert w.stdin is not None
os.set_inheritable(w.stdin.fileno(), False)
assert w.stdin is not None
w.stdin.write((line + '\n').encode())
w.stdin.flush()
def unlink(self, path: str) -> None:
self._write_line(f'unlink {path}')
def shm_unlink(self, path: str) -> None:
self._write_line(f'shm_unlink {path}')
def rmtree(self, path: str) -> None:
self._write_line(f'rmtree {path}')
def listen_on(spec: str, robust_atexit: Atexit) -> tuple[int, str]:
import socket import socket
family, address, socket_path = parse_address_spec(spec) family, address, socket_path = parse_address_spec(spec)
s = socket.socket(family) s = socket.socket(family)
atexit.register(remove_socket_file, s, socket_path)
s.bind(address) s.bind(address)
if family == socket.AF_UNIX and socket_path:
robust_atexit.unlink(socket_path)
s.listen() s.listen()
if isinstance(address, tuple): # tcp socket if isinstance(address, tuple): # tcp socket
h, resolved_port = s.getsockname()[:2] h, resolved_port = s.getsockname()[:2]
spec = spec.rpartition(':')[0] + f':{resolved_port}' spec = spec.rpartition(':')[0] + f':{resolved_port}'
import atexit
atexit.register(s.close) # prevents s from being garbage collected
return s.fileno(), spec return s.fileno(), spec
@@ -320,6 +351,7 @@ class Boss:
global_shortcuts: dict[str, SingleKey], global_shortcuts: dict[str, SingleKey],
talk_fd: int = -1, talk_fd: int = -1,
): ):
self.atexit = Atexit()
set_layout_options(opts) set_layout_options(opts)
self.clipboard = Clipboard() self.clipboard = Clipboard()
self.window_for_dispatch: Optional[Window] = None self.window_for_dispatch: Optional[Window] = None
@@ -353,7 +385,7 @@ class Boss:
listen_fd = -1 listen_fd = -1
if args.listen_on and self.allow_remote_control in ('y', 'socket', 'socket-only', 'password'): if args.listen_on and self.allow_remote_control in ('y', 'socket', 'socket-only', 'password'):
try: try:
listen_fd, self.listening_on = listen_on(args.listen_on) listen_fd, self.listening_on = listen_on(args.listen_on, self.atexit)
except Exception: except Exception:
self.misc_config_errors.append(f'Invalid listen_on={args.listen_on}, ignoring') self.misc_config_errors.append(f'Invalid listen_on={args.listen_on}, ignoring')
log_error(self.misc_config_errors[-1]) log_error(self.misc_config_errors[-1])
@@ -2393,7 +2425,6 @@ class Boss:
notify_on_death: Optional[Callable[[int, Optional[Exception]], None]] = None, # guaranteed to be called only after event loop tick notify_on_death: Optional[Callable[[int, Optional[Exception]], None]] = None, # guaranteed to be called only after event loop tick
stdout: Optional[int] = None, stderr: Optional[int] = None, stdout: Optional[int] = None, stderr: Optional[int] = None,
) -> None: ) -> None:
import subprocess
env = env or None env = env or None
if env: if env:
env_ = default_env().copy() env_ = default_env().copy()

View File

@@ -47,6 +47,9 @@ class IconDataCache:
if not self.cache_dir: if not self.cache_dir:
self.cache_dir = os.path.join(self.base_cache_dir or cache_dir(), 'notifications-icons', str(os.getpid())) self.cache_dir = os.path.join(self.base_cache_dir or cache_dir(), 'notifications-icons', str(os.getpid()))
os.makedirs(self.cache_dir, exist_ok=True, mode=0o700) os.makedirs(self.cache_dir, exist_ok=True, mode=0o700)
b = get_boss()
if hasattr(b, 'atexit'):
b.atexit.rmtree(self.cache_dir)
return self.cache_dir return self.cache_dir
def __del__(self) -> None: def __del__(self) -> None:
@@ -834,7 +837,6 @@ class NotificationManager:
log: Log = Log(), log: Log = Log(),
debug: bool = False, debug: bool = False,
base_cache_dir: str = '', base_cache_dir: str = '',
cleanup_at_exit: bool = True,
): ):
global debug_desktop_integration global debug_desktop_integration
debug_desktop_integration = debug debug_desktop_integration = debug
@@ -856,9 +858,6 @@ class NotificationManager:
except Exception as e: except Exception as e:
self.log(f'Failed to load {script_path} with error: {e}') self.log(f'Failed to load {script_path} with error: {e}')
self.reset() self.reset()
if cleanup_at_exit:
import atexit
atexit.register(self.cleanup)
def reset(self) -> None: def reset(self) -> None:
self.icon_data_cache.clear() self.icon_data_cache.clear()

View File

@@ -1,7 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# License: GPL v3 Copyright: 2016, Kovid Goyal <kovid at kovidgoyal.net> # License: GPL v3 Copyright: 2016, Kovid Goyal <kovid at kovidgoyal.net>
import atexit
import fcntl import fcntl
import math import math
import os import os
@@ -350,17 +349,6 @@ class startup_notification_handler:
end_startup_notification(self.ctx) end_startup_notification(self.ctx)
def remove_socket_file(s: 'Socket', path: Optional[str] = None, is_dir: Optional[Callable[[str], None]] = None) -> None:
with suppress(OSError):
s.close()
if path:
with suppress(OSError):
if is_dir:
is_dir(path)
else:
os.unlink(path)
def unix_socket_directories() -> Iterator[str]: def unix_socket_directories() -> Iterator[str]:
import tempfile import tempfile
home = os.path.expanduser('~') home = os.path.expanduser('~')

View File

@@ -10,6 +10,7 @@ import subprocess
import tempfile import tempfile
from kitty.constants import kitten_exe, kitty_exe from kitty.constants import kitten_exe, kitty_exe
from kitty.shm import SharedMemory
from . import BaseTest from . import BaseTest
@@ -51,8 +52,16 @@ raise SystemExit(p.wait())
open(os.path.join(sdir, 'f'), 'w').close() open(os.path.join(sdir, 'f'), 'w').close()
select.select(readers, [], [], 10) select.select(readers, [], [], 10)
self.ae(read(), str(i+2)) self.ae(read(), str(i+2))
shm = SharedMemory(size=64)
shm.write(b'1' * 64)
shm.flush()
p.stdin.write(f'shm_unlink {shm.name}\n'.encode())
p.stdin.flush()
self.ae(read(), str(i+3))
self.assertTrue(os.listdir(self.tdir)) self.assertTrue(os.listdir(self.tdir))
shm2 = SharedMemory(shm.name)
self.ae(shm2.read()[:64], b'1' * 64)
# Ensure child is ignoring signals # Ensure child is ignoring signals
os.kill(atexit_pid, signal.SIGINT) os.kill(atexit_pid, signal.SIGINT)
@@ -74,6 +83,7 @@ raise SystemExit(p.wait())
os.waitpid(atexit_pid, 0) os.waitpid(atexit_pid, 0)
except ChildProcessError: except ChildProcessError:
pass pass
self.assertRaises(FileNotFoundError, lambda: SharedMemory(shm.name))
r('close') r('close')
r('terminate') r('terminate')

View File

@@ -91,7 +91,7 @@ class NotificationManager(NotificationManager):
def do_test(self: 'TestNotifications', tdir: str) -> None: def do_test(self: 'TestNotifications', tdir: str) -> None:
di = DesktopIntegration(None) di = DesktopIntegration(None)
ch = Channel() ch = Channel()
nm = NotificationManager(di, ch, lambda *a, **kw: None, base_cache_dir=tdir, cleanup_at_exit=False) nm = NotificationManager(di, ch, lambda *a, **kw: None, base_cache_dir=tdir)
di.notification_manager = nm di.notification_manager = nm
def reset(): def reset():

View File

@@ -2,12 +2,15 @@ package atexit
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io/fs"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
"kitty/tools/cli" "kitty/tools/cli"
"kitty/tools/utils/shm"
) )
var _ = fmt.Print var _ = fmt.Print
@@ -40,13 +43,18 @@ func main() (rc int, err error) {
if action, rest, found := strings.Cut(line, " "); found { if action, rest, found := strings.Cut(line, " "); found {
switch action { switch action {
case "unlink": case "unlink":
if err := os.Remove(rest); err != nil { if err := os.Remove(rest); err != nil && !errors.Is(err, fs.ErrNotExist) {
fmt.Fprintln(os.Stderr, "Failed to remove:", rest, "with error:", err) fmt.Fprintln(os.Stderr, "Failed to unlink:", rest, "with error:", err)
rc = 1
}
case "shm_unlink":
if err := shm.ShmUnlink(rest); err != nil && !errors.Is(err, fs.ErrNotExist) {
fmt.Fprintln(os.Stderr, "Failed to shm_unlink:", rest, "with error:", err)
rc = 1 rc = 1
} }
case "rmtree": case "rmtree":
if err := os.RemoveAll(rest); err != nil { if err := os.RemoveAll(rest); err != nil && !errors.Is(err, fs.ErrNotExist) {
fmt.Fprintln(os.Stderr, "Failed to remove:", rest, "with error:", err) fmt.Fprintln(os.Stderr, "Failed to rmtree:", rest, "with error:", err)
rc = 1 rc = 1
} }
} }

View File

@@ -12,6 +12,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"kitty/tools/utils" "kitty/tools/utils"
@@ -28,6 +29,16 @@ type file_based_mmap struct {
special_name string special_name string
} }
func ShmUnlink(name string) error {
if runtime.GOOS == "openbsd" {
return os.Remove(openbsd_shm_path(name))
}
if strings.HasPrefix(name, "/") {
name = name[1:]
}
return os.Remove(filepath.Join(SHM_DIR, name))
}
func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool, special_name string) (MMap, error) { func file_mmap(f *os.File, size uint64, access AccessFlags, truncate bool, special_name string) (MMap, error) {
if truncate { if truncate {
err := truncate_or_unlink(f, size, os.Remove) err := truncate_or_unlink(f, size, os.Remove)
@@ -106,11 +117,15 @@ func (self *file_based_mmap) Unlink() (err error) {
func (self *file_based_mmap) IsFileSystemBacked() bool { return true } func (self *file_based_mmap) IsFileSystemBacked() bool { return true }
func openbsd_shm_path(name string) string {
hash := sha256.Sum256(utils.UnsafeStringToBytes(name))
return filepath.Join(SHM_DIR, utils.UnsafeBytesToString(hash[:])+".shm")
}
func file_path_from_name(name string) string { func file_path_from_name(name string) string {
// See https://github.com/openbsd/src/blob/master/lib/libc/gen/shm_open.c // See https://github.com/openbsd/src/blob/master/lib/libc/gen/shm_open.c
if runtime.GOOS == "openbsd" { if runtime.GOOS == "openbsd" {
hash := sha256.Sum256(utils.UnsafeStringToBytes(name)) return openbsd_shm_path(name)
return filepath.Join(SHM_DIR, utils.UnsafeBytesToString(hash[:])+".shm")
} }
return filepath.Join(SHM_DIR, name) return filepath.Join(SHM_DIR, name)
} }

View File

@@ -37,7 +37,11 @@ func shm_unlink(name string) (err error) {
_, _, errno := unix.Syscall(unix.SYS_SHM_UNLINK, uintptr(unsafe.Pointer(bname)), 0, 0) _, _, errno := unix.Syscall(unix.SYS_SHM_UNLINK, uintptr(unsafe.Pointer(bname)), 0, 0)
if errno != unix.EINTR { if errno != unix.EINTR {
if errno != 0 { if errno != 0 {
err = fmt.Errorf("shm_unlink() failed with error: %w", errno) if errno == unix.ENOENT {
err = fs.ErrNotExist
} else {
err = fmt.Errorf("shm_unlink() failed with error: %w", errno)
}
} }
break break
} }
@@ -45,6 +49,10 @@ func shm_unlink(name string) (err error) {
return return
} }
func ShmUnlink(name string) error {
return shm_unlink(name)
}
func shm_open(name string, flags, perm int) (ans *os.File, err error) { func shm_open(name string, flags, perm int) (ans *os.File, err error) {
bname := BytePtrFromString(name) bname := BytePtrFromString(name)
var fd uintptr var fd uintptr