diff --git a/kitty/shaders/graphics.slang b/kitty/shaders/graphics.slang index 28703532f..911259ebd 100644 --- a/kitty/shaders/graphics.slang +++ b/kitty/shaders/graphics.slang @@ -8,8 +8,6 @@ import utils; extern static const bool is_alpha_mask = false; extern static const bool texture_is_not_premultiplied = false; -// specialize: alpha_mask: is_alpha_mask=true -// specialize: premult: texture_is_not_premultiplied=true struct VSOutput diff --git a/kitty/shaders/slang.py b/kitty/shaders/slang.py index 8e64eeebd..116a2569c 100644 --- a/kitty/shaders/slang.py +++ b/kitty/shaders/slang.py @@ -10,9 +10,26 @@ from contextlib import suppress from enum import Enum from functools import lru_cache from pathlib import Path +from types import MappingProxyType from typing import Callable, Iterable, Iterator, NamedTuple from kitty.constants import slangc +from kitty.fast_data_types import ( + BLINK, + COLOR_IS_INDEX, + COLOR_IS_RGB, + COLOR_IS_SPECIAL, + COLOR_NOT_SET, + DECORATION, + DECORATION_MASK, + DIM, + MARK, + MARK_MASK, + REVERSE, + STRIKETHROUGH, + get_options, +) +from kitty.options.types import defaults class Stage(Enum): @@ -27,22 +44,70 @@ class EntryPoint(NamedTuple): class Specialization(NamedTuple): name: str - variables: dict[str, str] + variables: MappingProxyType[str, str] class SlangFile(NamedTuple): - path: str - text: str - imports: frozenset[str] - entry_points: frozenset[EntryPoint] - module: str - specializable_variables: dict[str, str] - specializations: tuple[Specialization, ...] + path: str = '' + text: str = '' + imports: frozenset[str] = frozenset() + entry_points: frozenset[EntryPoint] = frozenset() + module: str = '' + specializable_variables: MappingProxyType[str, str] = MappingProxyType({}) @property def should_compile_to_ir(self) -> bool: return bool(self.module or self.entry_points) + @property + def specializations(self) -> Iterator[Specialization]: + def s(name: str = '', **kwargs: str) -> Specialization: + return Specialization(name, MappingProxyType(kwargs)) + + match os.path.basename(self.path): + case 'graphics.slang': + yield s() + yield s('alpha_mask', is_alpha_mask='true') + yield s('premult', texture_is_not_premultiplied='true') + case 'cell.slang': + opts = get_options() or defaults + text_fg_override_threshold: float = opts.text_fg_override_threshold[0] + match opts.text_fg_override_threshold[1]: + case '%': + text_fg_override_threshold = max(0, min(text_fg_override_threshold, 100.0)) * 0.01 + algo = '1' + case 'ratio': + text_fg_override_threshold = max(0, min(text_fg_override_threshold, 21.0)) + algo = '2' + base = {k:str(v) for k, v in dict( + REVERSE_SHIFT=REVERSE, + STRIKE_SHIFT=STRIKETHROUGH, + DIM_SHIFT=DIM, + BLINK_SHIFT=BLINK, + DECORATION_SHIFT=DECORATION, + MARK_SHIFT=MARK, + MARK_MASK=MARK_MASK, + DECORATION_MASK=DECORATION_MASK, + COLOR_NOT_SET=COLOR_NOT_SET, + COLOR_IS_SPECIAL=COLOR_IS_SPECIAL, + COLOR_IS_INDEX=COLOR_IS_INDEX, + COLOR_IS_RGB=COLOR_IS_RGB, + ONLY_FOREGROUND='false', + ONLY_BACKGROUND='false', + DO_FG_OVERRIDE='true' if text_fg_override_threshold else 'false', + FG_OVERRIDE_ALGO=algo, + FG_OVERRIDE_THRESHOLD=text_fg_override_threshold, + TEXT_NEW_GAMMA='false' if opts.text_composition_strategy == 'legacy' else 'true', + ).items()} + yield s('', **base) + base['ONLY_FOREGROUND'] = 'true' + yield s('fg', **base) + base['ONLY_FOREGROUND'] = 'false' + base['ONLY_BACKGROUND'] = 'true' + yield s('bg', **base) + case _: + yield s() + def parse_slang_text(text: str, path: str = '') -> SlangFile: text = re.sub(r'/\*[\s\S]*?\*/', '', text) @@ -50,18 +115,10 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile: module = '' found_entry_point = '' specializable_variables = {} - specializations = [] for line in text.splitlines(): line = line.strip() if not line: continue - if line.startswith('// specialize: '): - var, sep, spec = line.partition(':')[2].strip().partition(':') - variables = {} - for x in spec.split(): - name, sep, val = x.partition('=') - variables[name] = val - specializations.append(Specialization(var, variables)) if line.startswith('//'): continue words = line.split() @@ -92,8 +149,7 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile: text = words[0].partition('(')[2].partition(')')[0].strip() found_entry_point = text[1:-1] return SlangFile( - path, text, frozenset(imports), frozenset(entry_points), module, specializable_variables, - tuple(specializations)) + path, text, frozenset(imports), frozenset(entry_points), module, MappingProxyType(specializable_variables)) @lru_cache(4096) @@ -201,7 +257,7 @@ def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_files: list[str]) -> Iterator[Command]: base_cmd = ['-target', 'spirv', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name'] for base_dest, slang_module, scmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir): - for x in (Specialization('', {}),) + sfile.specializations: + for x in sfile.specializations: cmd = list(scmd) dest = f'{base_dest}.{x.name}.spv' if x.name else f'{base_dest}.spv' if x.name: @@ -221,7 +277,7 @@ def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, d module_mtime = os.path.getmtime(slang_module) extra_cmd = ['-line-directive-mode', 'none', '-target', 'glsl', '-profile', 'glsl_330'] for ep in sfile.entry_points: - for sp in (Specialization('', {}),) + sfile.specializations: + for sp in sfile.specializations: dest = f'{base_dest}.{ep.stage.name}.glsl' c = list(cmd) if sp.name: diff --git a/kitty_tests/slang.py b/kitty_tests/slang.py index 121954e0e..201e00064 100644 --- a/kitty_tests/slang.py +++ b/kitty_tests/slang.py @@ -29,26 +29,26 @@ void drawTriangle(float4 pos : POSITION) { float4 psMain() : SV_Target { return float4(1, 0, 0, 1); } - ''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.vertex, 'drawTriangle'), EntryPoint(Stage.fragment, 'psMain')}), '', {}, ())) + ''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.vertex, 'drawTriangle'), EntryPoint(Stage.fragment, 'psMain')}))) # Empty source - check('', SlangFile('', '', frozenset(), frozenset(), '', {}, ())) + check('', SlangFile()) # Only line comments and block comments, no code - check('// just a comment\n/* block comment */', SlangFile('', '', frozenset(), frozenset(), '', {}, ())) + check('// just a comment\n/* block comment */', SlangFile('', '', frozenset(), frozenset())) # Module and import declarations check(''' module mymodule; import utils; import helpers; -''', SlangFile('', '', frozenset({'utils', 'helpers'}), frozenset(), 'mymodule', {}, ())) +''', SlangFile('', '', frozenset({'utils', 'helpers'}), frozenset(), 'mymodule')) # pixel stage maps to Stage.fragment check(''' [shader("pixel")] float4 pixelMain() : SV_Target { return float4(0); } -''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'pixelMain')}), '', {}, ())) +''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'pixelMain')}))) # Block comment stripping removes multi-line comments before parsing check(''' @@ -56,7 +56,7 @@ float4 pixelMain() : SV_Target { return float4(0); } spanning multiple lines */ [shader("vertex")] void vertMain() {} -''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.vertex, 'vertMain')}), '', {}, ())) +''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.vertex, 'vertMain')}))) # Block comment containing a shader attribute must not create a false entry point check(''' @@ -64,7 +64,7 @@ void vertMain() {} void shouldNotBeDetected() {} */ [shader("fragment")] void fragMain() {} -''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'fragMain')}), '', {}, ())) +''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'fragMain')}))) # Multiple [attr] lines between [shader(...)] and the function declaration are skipped check(''' @@ -72,7 +72,7 @@ void fragMain() {} [numthreads(4, 4, 1)] [SomeOtherAttribute] float4 fragMain() : SV_Target { return float4(0); } -''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'fragMain')}), '', {}, ())) +''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'fragMain')}))) # Multiple entry points: vertex, pixel, and fragment stages check(''' @@ -88,7 +88,7 @@ float4 fsMain() : SV_Target { return float4(0); } EntryPoint(Stage.vertex, 'vsMain'), EntryPoint(Stage.fragment, 'psMain'), EntryPoint(Stage.fragment, 'fsMain'), - }), '', {}, ())) + }))) # module, imports and entry points together check(''' @@ -97,14 +97,14 @@ import common; [shader("vertex")] void vsMain() {} -''', SlangFile('', '', frozenset({'common'}), frozenset({EntryPoint(Stage.vertex, 'vsMain')}), 'myshader', {}, ())) +''', SlangFile('', '', frozenset({'common'}), frozenset({EntryPoint(Stage.vertex, 'vsMain')}), 'myshader')) def test_slang_ordering(self): # Test topological_sort with a manually constructed linear chain: a <- b <- c graph: dict[str, SlangFile] = { - 'a': SlangFile('', '', frozenset(), frozenset(), 'a', {}, ()), - 'b': SlangFile('', '', frozenset({'a'}), frozenset(), 'b', {}, ()), - 'c': SlangFile('', '', frozenset({'b'}), frozenset(), 'c', {}, ()), + 'a': SlangFile('', '', frozenset(), frozenset(), 'a'), + 'b': SlangFile('', '', frozenset({'a'}), frozenset(), 'b'), + 'c': SlangFile('', '', frozenset({'b'}), frozenset(), 'c'), } order = topological_sort(graph) self.assertLess(order.index('a'), order.index('b')) @@ -112,10 +112,10 @@ void vsMain() {} # Diamond dependency: base <- left, base <- right, left + right <- top diamond: dict[str, SlangFile] = { - 'base': SlangFile('', '', frozenset(), frozenset(), 'base', {}, ()), - 'left': SlangFile('', '', frozenset({'base'}), frozenset(), 'left', {}, ()), - 'right': SlangFile('', '', frozenset({'base'}), frozenset(), 'right', {}, ()), - 'top': SlangFile('', '', frozenset({'left', 'right'}), frozenset(), 'top', {}, ()), + 'base': SlangFile('', '', frozenset(), frozenset(), 'base'), + 'left': SlangFile('', '', frozenset({'base'}), frozenset(), 'left'), + 'right': SlangFile('', '', frozenset({'base'}), frozenset(), 'right'), + 'top': SlangFile('', '', frozenset({'left', 'right'}), frozenset(), 'top'), } order2 = topological_sort(diamond) self.assertLess(order2.index('base'), order2.index('left')) @@ -125,7 +125,7 @@ void vsMain() {} # Node with an import not present in the graph is silently skipped partial: dict[str, SlangFile] = { - 'x': SlangFile('', '', frozenset({'missing'}), frozenset(), 'x', {}, ()), + 'x': SlangFile('', '', frozenset({'missing'}), frozenset(), 'x'), } self.assertEqual(topological_sort(partial), ['x'])