mirror of
https://github.com/kovidgoyal/kitty
synced 2026-07-03 05:03:39 +02:00
Function to specialize cell shader at runtime based on options
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# License: GPLv3 Copyright: 2026, Kovid Goyal <kovid at kovidgoyal.net>
|
||||
|
||||
import fcntl
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -9,7 +10,7 @@ import shutil
|
||||
import sys
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from contextlib import contextmanager, suppress
|
||||
from enum import StrEnum
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
@@ -32,9 +33,10 @@ from kitty.fast_data_types import (
|
||||
MARK_MASK,
|
||||
REVERSE,
|
||||
STRIKETHROUGH,
|
||||
get_boss,
|
||||
get_options,
|
||||
)
|
||||
from kitty.options.types import defaults
|
||||
from kitty.options.types import Options, defaults
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
@@ -106,6 +108,8 @@ class SlangFile(NamedTuple):
|
||||
specializable_variables: MappingProxyType[str, str] = MappingProxyType({})
|
||||
disable_warnings: frozenset[str] = frozenset()
|
||||
|
||||
opts: Options | None = None
|
||||
|
||||
def asdict(self, skip_source: bool = False) -> dict[str, Any]:
|
||||
' Return a dict useable for serialization to JSON '
|
||||
ans = self._asdict()
|
||||
@@ -114,7 +118,9 @@ class SlangFile(NamedTuple):
|
||||
ans['specializable_variables'] = dict(ans['specializable_variables'])
|
||||
ans['disable_warnings'] = tuple(ans['disable_warnings'])
|
||||
if skip_source:
|
||||
ans['path'] = ans['text'] = ''
|
||||
ans['text'] = ''
|
||||
ans['path'] = os.path.basename(ans['path'])
|
||||
del ans['opts']
|
||||
return ans
|
||||
|
||||
@classmethod
|
||||
@@ -147,6 +153,9 @@ class SlangFile(NamedTuple):
|
||||
ans['COLOR_IS_RGB'] = str(COLOR_IS_RGB)
|
||||
return MappingProxyType(ans)
|
||||
|
||||
def get_options(self) -> Options:
|
||||
return self.opts or defaults
|
||||
|
||||
@property
|
||||
def specializations(self) -> Iterator[Specialization]:
|
||||
def s(name: str = '', **kwargs: str) -> Specialization:
|
||||
@@ -158,9 +167,7 @@ class SlangFile(NamedTuple):
|
||||
yield s('alpha_mask', is_alpha_mask='true')
|
||||
yield s('premult', texture_is_not_premultiplied='true')
|
||||
case 'cell.slang':
|
||||
opts = defaults
|
||||
with suppress(RuntimeError):
|
||||
opts = get_options() or defaults
|
||||
opts = self.get_options()
|
||||
text_fg_override_threshold: float = opts.text_fg_override_threshold[0]
|
||||
match opts.text_fg_override_threshold[1]:
|
||||
case '%':
|
||||
@@ -350,7 +357,7 @@ def serialize_source_metadata(sources: dict[str, SlangFile], dest_dir: str) -> N
|
||||
f.write(json.dumps(sfile.asdict(skip_source=True), indent=2, sort_keys=True))
|
||||
|
||||
|
||||
def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_files: list[str]) -> Iterator[Command]:
|
||||
def commands_to_compile_to_spirv(sources: dict[str, SlangFile], dest_dir: str, built_files: list[str]) -> Iterator[Command]:
|
||||
# glsl 450 is vulkan 1.1 and spirv 1.3 released 2008
|
||||
base_cmd = ['-target', 'spirv', '-profile', 'glsl_450', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name']
|
||||
for base_dest, slang_module, scmd, sfile in iter_entry_point_shaders(sources, dest_dir):
|
||||
@@ -369,7 +376,7 @@ def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str,
|
||||
|
||||
|
||||
# GLSL {{{
|
||||
def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]:
|
||||
def commands_to_compile_to_glsl(sources: dict[str, SlangFile], dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]:
|
||||
glsl_version = max(150, GLSL_VERSION) # slangc fails with glsl_140 https://github.com/shader-slang/slang/issues/11898
|
||||
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, dest_dir):
|
||||
module_mtime = os.path.getmtime(slang_module)
|
||||
@@ -514,7 +521,7 @@ def copy_files_preserving_structure(source_dir: str, dest_dir: str, extension: s
|
||||
shutil.copy2(file_path, target_path)
|
||||
|
||||
|
||||
def create_specialisations(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[Command]:
|
||||
def create_specialisations(sources: dict[str, SlangFile], dest_dir: str) -> Iterator[Command]:
|
||||
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, dest_dir):
|
||||
if sfile.entry_points and sfile.specializations:
|
||||
for sp in sfile.specializations:
|
||||
@@ -553,13 +560,13 @@ def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: Paralle
|
||||
# Copy IR to dest_dir
|
||||
copy_files_preserving_structure(build_dir, dest_dir, '.slang-module')
|
||||
# Create the specializations
|
||||
parallel_run(create_specialisations(source_tree, build_dir, dest_dir))
|
||||
parallel_run(create_specialisations(source_tree, dest_dir))
|
||||
# Now Vulkan shaders
|
||||
built_spirv_files: list[str] = []
|
||||
spirv_commands = commands_to_compile_to_spirv(source_tree, build_dir, dest_dir, built_spirv_files)
|
||||
spirv_commands = commands_to_compile_to_spirv(source_tree, dest_dir, built_spirv_files)
|
||||
# Now glsl files
|
||||
built_glsl_files: list[str] = []
|
||||
glsl_commands = commands_to_compile_to_glsl(source_tree, build_dir, dest_dir, built_glsl_files)
|
||||
glsl_commands = commands_to_compile_to_glsl(source_tree, dest_dir, built_glsl_files)
|
||||
# Now run all commands
|
||||
parallel_run(chain(spirv_commands, glsl_commands))
|
||||
fixup_opengl_files(*built_glsl_files)
|
||||
@@ -568,6 +575,106 @@ def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: Paralle
|
||||
parallel_run((True, f'Validating |{os.path.basename(x)}| ...', validation_command_for_file(x)) for x in built_glsl_files)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def lock_directory(target_dir: str) -> Iterator[None]:
|
||||
'''
|
||||
Context manager to exclusively lock a directory using a hidden lock file.
|
||||
Works across all Unix-like operating systems.
|
||||
'''
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
lock_file_path = os.path.join(target_dir, '.shaders.lock')
|
||||
lock_fd = os.open(lock_file_path, os.O_CREAT | os.O_WRONLY)
|
||||
try:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX)
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
||||
os.close(lock_fd)
|
||||
|
||||
|
||||
def run_commands(cmds: Iterable[Command], cwd: str | None = None) -> None:
|
||||
import subprocess
|
||||
workers = []
|
||||
for c in cmds:
|
||||
if c.needs_build:
|
||||
try:
|
||||
p = subprocess.Popen(c.cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, cwd=cwd)
|
||||
except FileNotFoundError:
|
||||
raise Exception(f'Could not find slangc compiler ({slangc}) in PATH: {os.environ.get("PATH")}')
|
||||
workers.append((c, p))
|
||||
errors = []
|
||||
for (c, p) in workers:
|
||||
if p.wait() != 0:
|
||||
assert p.stderr is not None
|
||||
stderr = p.stderr.read().decode('utf-8', 'replace')
|
||||
errors.append((c, stderr))
|
||||
if errors:
|
||||
raise Exception(f'Compiling shader failed. Command that was run: {errors[0][0]}\n{errors[0][1]}')
|
||||
|
||||
|
||||
def specialize_shaders_to(sources: dict[str, SlangFile], dest_dir: str) -> None:
|
||||
for name in sources:
|
||||
shutil.copy2(os.path.join(shaders_dir, f'{name}.slang-module'), dest_dir)
|
||||
specialisation_cmds = create_specialisations(sources, dest_dir)
|
||||
run_commands(specialisation_cmds, dest_dir)
|
||||
_ = []
|
||||
spirv = commands_to_compile_to_spirv(sources, dest_dir, _)
|
||||
glsl_built_files: list[str] = []
|
||||
glsl = commands_to_compile_to_glsl(sources, dest_dir, glsl_built_files)
|
||||
run_commands(chain(spirv, glsl), dest_dir)
|
||||
fixup_opengl_files(*glsl_built_files)
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def per_process_cache_dir() -> str:
|
||||
' A dir that has the lifetime of this process '
|
||||
import atexit
|
||||
from tempfile import mkdtemp
|
||||
ans = mkdtemp()
|
||||
boss = get_boss()
|
||||
try:
|
||||
boss.atexit.rmtree(ans)
|
||||
except Exception: # happens if no boss exists
|
||||
atexit.register(shutil.rmtree, ans)
|
||||
return ans
|
||||
|
||||
|
||||
specialize_cache: dict[str, tuple[tuple[Specialization, ...], dict[str, bytes]]] = {}
|
||||
|
||||
|
||||
def specialize_cell_shader(
|
||||
create_cache_dir: Callable[[], str] = per_process_cache_dir,
|
||||
opts: Options | None = None
|
||||
) -> dict[str, bytes]:
|
||||
' Specialize the cell shader based on the specified options '
|
||||
with open(os.path.join(shaders_dir, 'cell.json')) as f:
|
||||
builtin_sfile = SlangFile.fromdict(json.load(f))
|
||||
d = builtin_sfile._asdict()
|
||||
if opts is None:
|
||||
with suppress(RuntimeError):
|
||||
opts = get_options()
|
||||
d['opts'] = opts
|
||||
sfile = SlangFile(**d)
|
||||
builtin, current = tuple(builtin_sfile.specializations), tuple(sfile.specializations)
|
||||
if builtin == current: # options not changed from defaults
|
||||
return {}
|
||||
dest_dir = create_cache_dir()
|
||||
cache_key = f'cell-{dest_dir}'
|
||||
if (cx := specialize_cache.get(cache_key)) and cx[0] == current:
|
||||
return cx[1]
|
||||
|
||||
with lock_directory(dest_dir):
|
||||
ensure_cache_dir(dest_dir)
|
||||
specialize_shaders_to({'cell': sfile}, dest_dir)
|
||||
ans = {}
|
||||
for x in os.listdir(dest_dir):
|
||||
if x.rpartition('.')[2] in ('spv', 'glsl', 'msl'):
|
||||
with open(os.path.join(dest_dir, x), 'rb') as fb:
|
||||
ans[x] = fb.read()
|
||||
specialize_cache[cache_key] = (current, ans)
|
||||
return ans
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not shutil.which(slangc[0]):
|
||||
raise SystemExit(f'The shader slang compiler ({slangc[0]}) not in PATH: {os.environ.get("PATH")}')
|
||||
|
||||
Reference in New Issue
Block a user