From 71fd4d3c57945bcf39d82036fb54847859420933 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 28 Jun 2026 13:06:22 +0530 Subject: [PATCH] More work on slang compiler --- kitty/shaders/blit.slang | 14 +++---- kitty/shaders/slang.py | 83 +++++++++++++++++++++++++++++++++++----- 2 files changed, 79 insertions(+), 18 deletions(-) diff --git a/kitty/shaders/blit.slang b/kitty/shaders/blit.slang index 56cf96674..9be456d97 100644 --- a/kitty/shaders/blit.slang +++ b/kitty/shaders/blit.slang @@ -16,18 +16,16 @@ public struct BlitOutput { #define bottom 3 // Static constant array mapping vertex IDs -static const int2 vertex_pos_map[4] = -{ - int2(right, top), - int2(right, bottom), - int2(left, bottom), - int2(left, top) +static const int2 vertex_pos_map[4] = { + {right, top}, + {right, bottom}, + {left, bottom}, + {left, top} }; public BlitOutput get_coords_for_blit(uint vertex_id, float4 src_rect, float4 dest_rect) { - BlitOutput output; int2 pos = vertex_pos_map[vertex_id]; - // Indexing into the float4 vectors using pos.x and pos.y + BlitOutput output; output.texcoord = float2(src_rect[pos.x], src_rect[pos.y]); output.position = float2(dest_rect[pos.x], dest_rect[pos.y]); return output; diff --git a/kitty/shaders/slang.py b/kitty/shaders/slang.py index 67533dac4..26e003381 100644 --- a/kitty/shaders/slang.py +++ b/kitty/shaders/slang.py @@ -163,7 +163,7 @@ def commands_to_compile_dir_to_ir(sources: dict[str, SlangFile], src_dir: str, o ]) -def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]: +def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[tuple[str, str, list[str], SlangFile]]: cmdbase = list(slangc) for name, sfile in sources.items(): if not sfile.entry_points: @@ -171,11 +171,29 @@ def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, d parts = name.split('.') base_dest = os.path.join(dest_dir, *parts) slang_module = f'{base_dest}.slang-module' - output_mtime = future() cmd = cmdbase + ['-I', dest_dir, slang_module] + yield base_dest, slang_module, cmd, sfile + + +def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_files: list[str]) -> Iterator[Command]: + for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir): + dest = f'{base_dest}.spv' + cmd += ['-target', 'spirv', '-o', dest] + output_mtime = safe_mtime(dest) + module_mtime = os.path.getmtime(slang_module) + needs_build = output_mtime < module_mtime + if needs_build: + built_files.append(dest) + yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to SPIR-V ...', cmd) + + +def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]: + for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir): + output_mtime = future() dest_files = [] + cmd.extend(('-line-directive-mode', 'none')) for ep in sfile.entry_points: - dest = f'{base_dest}-{ep.stage.name}.glsl' + dest = f'{base_dest}.{ep.stage.name}.glsl' cmd += ['-entry', ep.name, '-stage', ep.stage.name, '-target', 'glsl', '-profile', 'glsl_330', '-o', dest] dest_files.append(dest) output_mtime = min(output_mtime, safe_mtime(dest)) @@ -183,18 +201,58 @@ def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, d needs_build = output_mtime < module_mtime if needs_build: built_glsl_files.extend(dest_files) - yield Command(needs_build, f'Linking |{name}.slang-module| to GLSL ...', cmd) + yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to GLSL ...', cmd) def fixup_opengl_code(glsl_code: str) -> str: lines = [] + in_uniform_block = False + in_uniform_block_contents = False + uniform_blocks = {} + current_uniform_names: list[str] = [] + uniform_names = {} for line in glsl_code.splitlines(): - if line.startswith('#version '): - line = '#version 330 core' - elif line.startswith('#extension ') or line in ('layout(row_major) buffer;', 'layout(push_constant)'): - line = '// ' + line + if in_uniform_block: + if in_uniform_block_contents: + if line.startswith('}'): + in_uniform_block = in_uniform_block_contents = False + uniform_blocks[line.lstrip('}').rstrip(';').strip()] = current_uniform_names + line = '// ' + line + current_uniform_names = [] + else: + line = line.strip() + name = line.split()[-1].rstrip(';') + current_uniform_names.append(name) + uniform_names[name] = name.rpartition('_')[0] + line = 'uniform ' + line + elif line.startswith('{'): # }} + line = '// ' + line + in_uniform_block_contents = True + current_uniform_names = [] + else: + if line.startswith('#version '): + line = '#version 330 core' + elif line.startswith('#extension ') or line in ('layout(row_major) buffer;', 'layout(push_constant)'): + line = '// ' + line + elif line.startswith('layout(location ='): + line = '// ' + line + else: + words = line.split() + if 'uniform' in words and line.startswith('layout('): # ) + in_uniform_block = True + in_uniform_block_contents = False + line = '// ' + line + elif line.startswith('const ') and '] = {' in line: # }] + line = line.replace('{', f'{words[1]}[](', 1) # }]) + line = line.removesuffix('};') + ');' lines.append(line) - return '\n'.join(lines) + ans = '\n'.join(lines) + for block_name, names in uniform_blocks.items(): + for u in names: + ans = ans.replace(f'{block_name}.{u}', u) + ans = ans.replace('gl_VertexIndex', 'gl_VertexID') + ans = ans.replace('gl_BaseVertex', '0') + return ans def fixup_opengl_files(*paths: str) -> None: @@ -237,7 +295,12 @@ def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: Paralle parallel_run(commands_to_compile_dir_to_ir(source_tree, src_dir, build_dir)) # Copy IR to dest_dir copy_files_preserving_structure(build_dir, dest_dir, '.slang-module') - # Now glsl shaders + # Now Vulkan shaders + built_spirv_files: list[str] = [] + spirv_commands = commands_to_compile_to_spirv(source_tree, build_dir, dest_dir, built_spirv_files) + parallel_run(spirv_commands) + + # Now glsl files built_glsl_files: list[str] = [] glsl_commands = commands_to_compile_to_glsl(source_tree, build_dir, dest_dir, built_glsl_files)