Function to specialize cell shader at runtime based on options

This commit is contained in:
Kovid Goyal
2026-07-02 17:13:16 +05:30
parent a98f7448da
commit 3982dc04d6

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2026, Kovid Goyal <kovid at kovidgoyal.net>
import fcntl
import json
import os
import re
@@ -9,7 +10,7 @@ import shutil
import sys
import time
from collections import OrderedDict
from contextlib import suppress
from contextlib import contextmanager, suppress
from enum import StrEnum
from functools import lru_cache
from itertools import chain
@@ -32,9 +33,10 @@ from kitty.fast_data_types import (
MARK_MASK,
REVERSE,
STRIKETHROUGH,
get_boss,
get_options,
)
from kitty.options.types import defaults
from kitty.options.types import Options, defaults
@lru_cache(maxsize=64)
@@ -106,6 +108,8 @@ class SlangFile(NamedTuple):
specializable_variables: MappingProxyType[str, str] = MappingProxyType({})
disable_warnings: frozenset[str] = frozenset()
opts: Options | None = None
def asdict(self, skip_source: bool = False) -> dict[str, Any]:
' Return a dict useable for serialization to JSON '
ans = self._asdict()
@@ -114,7 +118,9 @@ class SlangFile(NamedTuple):
ans['specializable_variables'] = dict(ans['specializable_variables'])
ans['disable_warnings'] = tuple(ans['disable_warnings'])
if skip_source:
ans['path'] = ans['text'] = ''
ans['text'] = ''
ans['path'] = os.path.basename(ans['path'])
del ans['opts']
return ans
@classmethod
@@ -147,6 +153,9 @@ class SlangFile(NamedTuple):
ans['COLOR_IS_RGB'] = str(COLOR_IS_RGB)
return MappingProxyType(ans)
def get_options(self) -> Options:
return self.opts or defaults
@property
def specializations(self) -> Iterator[Specialization]:
def s(name: str = '', **kwargs: str) -> Specialization:
@@ -158,9 +167,7 @@ class SlangFile(NamedTuple):
yield s('alpha_mask', is_alpha_mask='true')
yield s('premult', texture_is_not_premultiplied='true')
case 'cell.slang':
opts = defaults
with suppress(RuntimeError):
opts = get_options() or defaults
opts = self.get_options()
text_fg_override_threshold: float = opts.text_fg_override_threshold[0]
match opts.text_fg_override_threshold[1]:
case '%':
@@ -350,7 +357,7 @@ def serialize_source_metadata(sources: dict[str, SlangFile], dest_dir: str) -> N
f.write(json.dumps(sfile.asdict(skip_source=True), indent=2, sort_keys=True))
def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_files: list[str]) -> Iterator[Command]:
def commands_to_compile_to_spirv(sources: dict[str, SlangFile], dest_dir: str, built_files: list[str]) -> Iterator[Command]:
# glsl 450 is vulkan 1.1 and spirv 1.3 released 2008
base_cmd = ['-target', 'spirv', '-profile', 'glsl_450', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name']
for base_dest, slang_module, scmd, sfile in iter_entry_point_shaders(sources, dest_dir):
@@ -369,7 +376,7 @@ def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str,
# GLSL {{{
def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]:
def commands_to_compile_to_glsl(sources: dict[str, SlangFile], dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]:
glsl_version = max(150, GLSL_VERSION) # slangc fails with glsl_140 https://github.com/shader-slang/slang/issues/11898
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, dest_dir):
module_mtime = os.path.getmtime(slang_module)
@@ -514,7 +521,7 @@ def copy_files_preserving_structure(source_dir: str, dest_dir: str, extension: s
shutil.copy2(file_path, target_path)
def create_specialisations(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[Command]:
def create_specialisations(sources: dict[str, SlangFile], dest_dir: str) -> Iterator[Command]:
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, dest_dir):
if sfile.entry_points and sfile.specializations:
for sp in sfile.specializations:
@@ -553,13 +560,13 @@ def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: Paralle
# Copy IR to dest_dir
copy_files_preserving_structure(build_dir, dest_dir, '.slang-module')
# Create the specializations
parallel_run(create_specialisations(source_tree, build_dir, dest_dir))
parallel_run(create_specialisations(source_tree, dest_dir))
# Now Vulkan shaders
built_spirv_files: list[str] = []
spirv_commands = commands_to_compile_to_spirv(source_tree, build_dir, dest_dir, built_spirv_files)
spirv_commands = commands_to_compile_to_spirv(source_tree, dest_dir, built_spirv_files)
# Now glsl files
built_glsl_files: list[str] = []
glsl_commands = commands_to_compile_to_glsl(source_tree, build_dir, dest_dir, built_glsl_files)
glsl_commands = commands_to_compile_to_glsl(source_tree, dest_dir, built_glsl_files)
# Now run all commands
parallel_run(chain(spirv_commands, glsl_commands))
fixup_opengl_files(*built_glsl_files)
@@ -568,6 +575,106 @@ def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: Paralle
parallel_run((True, f'Validating |{os.path.basename(x)}| ...', validation_command_for_file(x)) for x in built_glsl_files)
@contextmanager
def lock_directory(target_dir: str) -> Iterator[None]:
'''
Context manager to exclusively lock a directory using a hidden lock file.
Works across all Unix-like operating systems.
'''
os.makedirs(target_dir, exist_ok=True)
lock_file_path = os.path.join(target_dir, '.shaders.lock')
lock_fd = os.open(lock_file_path, os.O_CREAT | os.O_WRONLY)
try:
fcntl.flock(lock_fd, fcntl.LOCK_EX)
yield
finally:
fcntl.flock(lock_fd, fcntl.LOCK_UN)
os.close(lock_fd)
def run_commands(cmds: Iterable[Command], cwd: str | None = None) -> None:
import subprocess
workers = []
for c in cmds:
if c.needs_build:
try:
p = subprocess.Popen(c.cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, cwd=cwd)
except FileNotFoundError:
raise Exception(f'Could not find slangc compiler ({slangc}) in PATH: {os.environ.get("PATH")}')
workers.append((c, p))
errors = []
for (c, p) in workers:
if p.wait() != 0:
assert p.stderr is not None
stderr = p.stderr.read().decode('utf-8', 'replace')
errors.append((c, stderr))
if errors:
raise Exception(f'Compiling shader failed. Command that was run: {errors[0][0]}\n{errors[0][1]}')
def specialize_shaders_to(sources: dict[str, SlangFile], dest_dir: str) -> None:
for name in sources:
shutil.copy2(os.path.join(shaders_dir, f'{name}.slang-module'), dest_dir)
specialisation_cmds = create_specialisations(sources, dest_dir)
run_commands(specialisation_cmds, dest_dir)
_ = []
spirv = commands_to_compile_to_spirv(sources, dest_dir, _)
glsl_built_files: list[str] = []
glsl = commands_to_compile_to_glsl(sources, dest_dir, glsl_built_files)
run_commands(chain(spirv, glsl), dest_dir)
fixup_opengl_files(*glsl_built_files)
@lru_cache(maxsize=2)
def per_process_cache_dir() -> str:
' A dir that has the lifetime of this process '
import atexit
from tempfile import mkdtemp
ans = mkdtemp()
boss = get_boss()
try:
boss.atexit.rmtree(ans)
except Exception: # happens if no boss exists
atexit.register(shutil.rmtree, ans)
return ans
specialize_cache: dict[str, tuple[tuple[Specialization, ...], dict[str, bytes]]] = {}
def specialize_cell_shader(
create_cache_dir: Callable[[], str] = per_process_cache_dir,
opts: Options | None = None
) -> dict[str, bytes]:
' Specialize the cell shader based on the specified options '
with open(os.path.join(shaders_dir, 'cell.json')) as f:
builtin_sfile = SlangFile.fromdict(json.load(f))
d = builtin_sfile._asdict()
if opts is None:
with suppress(RuntimeError):
opts = get_options()
d['opts'] = opts
sfile = SlangFile(**d)
builtin, current = tuple(builtin_sfile.specializations), tuple(sfile.specializations)
if builtin == current: # options not changed from defaults
return {}
dest_dir = create_cache_dir()
cache_key = f'cell-{dest_dir}'
if (cx := specialize_cache.get(cache_key)) and cx[0] == current:
return cx[1]
with lock_directory(dest_dir):
ensure_cache_dir(dest_dir)
specialize_shaders_to({'cell': sfile}, dest_dir)
ans = {}
for x in os.listdir(dest_dir):
if x.rpartition('.')[2] in ('spv', 'glsl', 'msl'):
with open(os.path.join(dest_dir, x), 'rb') as fb:
ans[x] = fb.read()
specialize_cache[cache_key] = (current, ans)
return ans
def main() -> None:
if not shutil.which(slangc[0]):
raise SystemExit(f'The shader slang compiler ({slangc[0]}) not in PATH: {os.environ.get("PATH")}')