Move specialisations into compiler code

Allows for dynamic specialisations
This commit is contained in:
Kovid Goyal
2026-06-30 08:11:18 +05:30
parent bbec9d5bbd
commit a12883abf3
3 changed files with 94 additions and 40 deletions

View File

@@ -8,8 +8,6 @@ import utils;
extern static const bool is_alpha_mask = false;
extern static const bool texture_is_not_premultiplied = false;
// specialize: alpha_mask: is_alpha_mask=true
// specialize: premult: texture_is_not_premultiplied=true
struct VSOutput

View File

@@ -10,9 +10,26 @@ from contextlib import suppress
from enum import Enum
from functools import lru_cache
from pathlib import Path
from types import MappingProxyType
from typing import Callable, Iterable, Iterator, NamedTuple
from kitty.constants import slangc
from kitty.fast_data_types import (
BLINK,
COLOR_IS_INDEX,
COLOR_IS_RGB,
COLOR_IS_SPECIAL,
COLOR_NOT_SET,
DECORATION,
DECORATION_MASK,
DIM,
MARK,
MARK_MASK,
REVERSE,
STRIKETHROUGH,
get_options,
)
from kitty.options.types import defaults
class Stage(Enum):
@@ -27,22 +44,70 @@ class EntryPoint(NamedTuple):
class Specialization(NamedTuple):
name: str
variables: dict[str, str]
variables: MappingProxyType[str, str]
class SlangFile(NamedTuple):
path: str
text: str
imports: frozenset[str]
entry_points: frozenset[EntryPoint]
module: str
specializable_variables: dict[str, str]
specializations: tuple[Specialization, ...]
path: str = ''
text: str = ''
imports: frozenset[str] = frozenset()
entry_points: frozenset[EntryPoint] = frozenset()
module: str = ''
specializable_variables: MappingProxyType[str, str] = MappingProxyType({})
@property
def should_compile_to_ir(self) -> bool:
return bool(self.module or self.entry_points)
@property
def specializations(self) -> Iterator[Specialization]:
def s(name: str = '', **kwargs: str) -> Specialization:
return Specialization(name, MappingProxyType(kwargs))
match os.path.basename(self.path):
case 'graphics.slang':
yield s()
yield s('alpha_mask', is_alpha_mask='true')
yield s('premult', texture_is_not_premultiplied='true')
case 'cell.slang':
opts = get_options() or defaults
text_fg_override_threshold: float = opts.text_fg_override_threshold[0]
match opts.text_fg_override_threshold[1]:
case '%':
text_fg_override_threshold = max(0, min(text_fg_override_threshold, 100.0)) * 0.01
algo = '1'
case 'ratio':
text_fg_override_threshold = max(0, min(text_fg_override_threshold, 21.0))
algo = '2'
base = {k:str(v) for k, v in dict(
REVERSE_SHIFT=REVERSE,
STRIKE_SHIFT=STRIKETHROUGH,
DIM_SHIFT=DIM,
BLINK_SHIFT=BLINK,
DECORATION_SHIFT=DECORATION,
MARK_SHIFT=MARK,
MARK_MASK=MARK_MASK,
DECORATION_MASK=DECORATION_MASK,
COLOR_NOT_SET=COLOR_NOT_SET,
COLOR_IS_SPECIAL=COLOR_IS_SPECIAL,
COLOR_IS_INDEX=COLOR_IS_INDEX,
COLOR_IS_RGB=COLOR_IS_RGB,
ONLY_FOREGROUND='false',
ONLY_BACKGROUND='false',
DO_FG_OVERRIDE='true' if text_fg_override_threshold else 'false',
FG_OVERRIDE_ALGO=algo,
FG_OVERRIDE_THRESHOLD=text_fg_override_threshold,
TEXT_NEW_GAMMA='false' if opts.text_composition_strategy == 'legacy' else 'true',
).items()}
yield s('', **base)
base['ONLY_FOREGROUND'] = 'true'
yield s('fg', **base)
base['ONLY_FOREGROUND'] = 'false'
base['ONLY_BACKGROUND'] = 'true'
yield s('bg', **base)
case _:
yield s()
def parse_slang_text(text: str, path: str = '') -> SlangFile:
text = re.sub(r'/\*[\s\S]*?\*/', '', text)
@@ -50,18 +115,10 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile:
module = ''
found_entry_point = ''
specializable_variables = {}
specializations = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
if line.startswith('// specialize: '):
var, sep, spec = line.partition(':')[2].strip().partition(':')
variables = {}
for x in spec.split():
name, sep, val = x.partition('=')
variables[name] = val
specializations.append(Specialization(var, variables))
if line.startswith('//'):
continue
words = line.split()
@@ -92,8 +149,7 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile:
text = words[0].partition('(')[2].partition(')')[0].strip()
found_entry_point = text[1:-1]
return SlangFile(
path, text, frozenset(imports), frozenset(entry_points), module, specializable_variables,
tuple(specializations))
path, text, frozenset(imports), frozenset(entry_points), module, MappingProxyType(specializable_variables))
@lru_cache(4096)
@@ -201,7 +257,7 @@ def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest
def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_files: list[str]) -> Iterator[Command]:
base_cmd = ['-target', 'spirv', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name']
for base_dest, slang_module, scmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
for x in (Specialization('', {}),) + sfile.specializations:
for x in sfile.specializations:
cmd = list(scmd)
dest = f'{base_dest}.{x.name}.spv' if x.name else f'{base_dest}.spv'
if x.name:
@@ -221,7 +277,7 @@ def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, d
module_mtime = os.path.getmtime(slang_module)
extra_cmd = ['-line-directive-mode', 'none', '-target', 'glsl', '-profile', 'glsl_330']
for ep in sfile.entry_points:
for sp in (Specialization('', {}),) + sfile.specializations:
for sp in sfile.specializations:
dest = f'{base_dest}.{ep.stage.name}.glsl'
c = list(cmd)
if sp.name:

View File

@@ -29,26 +29,26 @@ void drawTriangle(float4 pos : POSITION) {
float4 psMain() : SV_Target {
return float4(1, 0, 0, 1);
}
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.vertex, 'drawTriangle'), EntryPoint(Stage.fragment, 'psMain')}), '', {}, ()))
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.vertex, 'drawTriangle'), EntryPoint(Stage.fragment, 'psMain')})))
# Empty source
check('', SlangFile('', '', frozenset(), frozenset(), '', {}, ()))
check('', SlangFile())
# Only line comments and block comments, no code
check('// just a comment\n/* block comment */', SlangFile('', '', frozenset(), frozenset(), '', {}, ()))
check('// just a comment\n/* block comment */', SlangFile('', '', frozenset(), frozenset()))
# Module and import declarations
check('''
module mymodule;
import utils;
import helpers;
''', SlangFile('', '', frozenset({'utils', 'helpers'}), frozenset(), 'mymodule', {}, ()))
''', SlangFile('', '', frozenset({'utils', 'helpers'}), frozenset(), 'mymodule'))
# pixel stage maps to Stage.fragment
check('''
[shader("pixel")]
float4 pixelMain() : SV_Target { return float4(0); }
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'pixelMain')}), '', {}, ()))
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'pixelMain')})))
# Block comment stripping removes multi-line comments before parsing
check('''
@@ -56,7 +56,7 @@ float4 pixelMain() : SV_Target { return float4(0); }
spanning multiple lines */
[shader("vertex")]
void vertMain() {}
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.vertex, 'vertMain')}), '', {}, ()))
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.vertex, 'vertMain')})))
# Block comment containing a shader attribute must not create a false entry point
check('''
@@ -64,7 +64,7 @@ void vertMain() {}
void shouldNotBeDetected() {} */
[shader("fragment")]
void fragMain() {}
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'fragMain')}), '', {}, ()))
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'fragMain')})))
# Multiple [attr] lines between [shader(...)] and the function declaration are skipped
check('''
@@ -72,7 +72,7 @@ void fragMain() {}
[numthreads(4, 4, 1)]
[SomeOtherAttribute]
float4 fragMain() : SV_Target { return float4(0); }
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'fragMain')}), '', {}, ()))
''', SlangFile('', '', frozenset(), frozenset({EntryPoint(Stage.fragment, 'fragMain')})))
# Multiple entry points: vertex, pixel, and fragment stages
check('''
@@ -88,7 +88,7 @@ float4 fsMain() : SV_Target { return float4(0); }
EntryPoint(Stage.vertex, 'vsMain'),
EntryPoint(Stage.fragment, 'psMain'),
EntryPoint(Stage.fragment, 'fsMain'),
}), '', {}, ()))
})))
# module, imports and entry points together
check('''
@@ -97,14 +97,14 @@ import common;
[shader("vertex")]
void vsMain() {}
''', SlangFile('', '', frozenset({'common'}), frozenset({EntryPoint(Stage.vertex, 'vsMain')}), 'myshader', {}, ()))
''', SlangFile('', '', frozenset({'common'}), frozenset({EntryPoint(Stage.vertex, 'vsMain')}), 'myshader'))
def test_slang_ordering(self):
# Test topological_sort with a manually constructed linear chain: a <- b <- c
graph: dict[str, SlangFile] = {
'a': SlangFile('', '', frozenset(), frozenset(), 'a', {}, ()),
'b': SlangFile('', '', frozenset({'a'}), frozenset(), 'b', {}, ()),
'c': SlangFile('', '', frozenset({'b'}), frozenset(), 'c', {}, ()),
'a': SlangFile('', '', frozenset(), frozenset(), 'a'),
'b': SlangFile('', '', frozenset({'a'}), frozenset(), 'b'),
'c': SlangFile('', '', frozenset({'b'}), frozenset(), 'c'),
}
order = topological_sort(graph)
self.assertLess(order.index('a'), order.index('b'))
@@ -112,10 +112,10 @@ void vsMain() {}
# Diamond dependency: base <- left, base <- right, left + right <- top
diamond: dict[str, SlangFile] = {
'base': SlangFile('', '', frozenset(), frozenset(), 'base', {}, ()),
'left': SlangFile('', '', frozenset({'base'}), frozenset(), 'left', {}, ()),
'right': SlangFile('', '', frozenset({'base'}), frozenset(), 'right', {}, ()),
'top': SlangFile('', '', frozenset({'left', 'right'}), frozenset(), 'top', {}, ()),
'base': SlangFile('', '', frozenset(), frozenset(), 'base'),
'left': SlangFile('', '', frozenset({'base'}), frozenset(), 'left'),
'right': SlangFile('', '', frozenset({'base'}), frozenset(), 'right'),
'top': SlangFile('', '', frozenset({'left', 'right'}), frozenset(), 'top'),
}
order2 = topological_sort(diamond)
self.assertLess(order2.index('base'), order2.index('left'))
@@ -125,7 +125,7 @@ void vsMain() {}
# Node with an import not present in the graph is silently skipped
partial: dict[str, SlangFile] = {
'x': SlangFile('', '', frozenset({'missing'}), frozenset(), 'x', {}, ()),
'x': SlangFile('', '', frozenset({'missing'}), frozenset(), 'x'),
}
self.assertEqual(topological_sort(partial), ['x'])