{{ message }}
[ExecuTorch][WebGPU] SDPA: skip QK contraction for fully-masked causal tiles#20492
Merged
meta-codesync[bot] merged 2 commits intoJun 25, 2026
Merged
Conversation
This was referenced Jun 24, 2026
psiddh
approved these changes
Jun 24, 2026
SS-JIA
approved these changes
Jun 25, 2026
SS-JIA
left a comment
Contributor
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
1a9fe0a
into
gh/JulianCloudNTH/62/base
184 of 195 checks passed
JulianCloudNTH
added a commit
that referenced
this pull request
Jun 25, 2026
…l tiles Pull Request resolved: #20492 **Skip the QK contraction for fully-masked causal tiles** — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output). **Problem**: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full `d4` dot product before masking the result to `NEG_INF`. **Solution**: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel: - **Before**: every `(s0, c0)` tile runs the full `d4` dot-product loop, then `store_qk` masks above-diagonal elements to `NEG_INF`. - **After**: a fully-masked tile (`c0 > s0 + TM-1 + input_pos`) breaks the `d4` loop immediately (`acc` stays 0); `store_qk` masks every element to `NEG_INF` exactly as before. **Implementation**: - Add `skip_tile = c0 > s0 + (TM - 1) + params.input_pos`, folded into the `d4` loop break condition. - Store loop unchanged — runs unconditionally, so no scratch entry is left stale. - Mirrors Vulkan `sdpa_compute_attn_weights_tiled.glsl` (`tile_in_mask_region`). **Constraints**: - No KV-cache, host, dispatch, or uniform change (all tiles still launch; the skip is in-shader). - Prefill-only: decode `S=1` never triggers it (`c0 <= input_pos < input_pos + TM - 1`). - `NEG_INF` stays the WGSL-safe `-1.0e30` (WGSL forbids a literal `-inf`); does not copy Vulkan's `-1.0/0.0`. Co-authored with Claude Code. ghstack-source-id: 396792509 @exported-using-ghexport Differential Revision: [D109517773](https://our.internmc.facebook.com/intern/diff/D109517773/)
JulianCloudNTH
added a commit
that referenced
this pull request
Jun 25, 2026
…l tiles Pull Request resolved: #20492 **Skip the QK contraction for fully-masked causal tiles** — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output). **Problem**: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full `d4` dot product before masking the result to `NEG_INF`. **Solution**: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel: - **Before**: every `(s0, c0)` tile runs the full `d4` dot-product loop, then `store_qk` masks above-diagonal elements to `NEG_INF`. - **After**: a fully-masked tile (`c0 > s0 + TM-1 + input_pos`) breaks the `d4` loop immediately (`acc` stays 0); `store_qk` masks every element to `NEG_INF` exactly as before. **Implementation**: - Add `skip_tile = c0 > s0 + (TM - 1) + params.input_pos`, folded into the `d4` loop break condition. - Store loop unchanged — runs unconditionally, so no scratch entry is left stale. - Mirrors Vulkan `sdpa_compute_attn_weights_tiled.glsl` (`tile_in_mask_region`). **Constraints**: - No KV-cache, host, dispatch, or uniform change (all tiles still launch; the skip is in-shader). - Prefill-only: decode `S=1` never triggers it (`c0 <= input_pos < input_pos + TM - 1`). - `NEG_INF` stays the WGSL-safe `-1.0e30` (WGSL forbids a literal `-inf`); does not copy Vulkan's `-1.0/0.0`. Co-authored with Claude Code. ghstack-source-id: 396792509 @exported-using-ghexport Differential Revision: [D109517773](https://our.internmc.facebook.com/intern/diff/D109517773/)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Stack from ghstack (oldest at bottom):
Skip the QK contraction for fully-masked causal tiles — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output).
Problem: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full
d4dot product before masking the result toNEG_INF.Solution: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel:
(s0, c0)tile runs the fulld4dot-product loop, thenstore_qkmasks above-diagonal elements toNEG_INF.c0 > s0 + TM-1 + input_pos) breaks thed4loop immediately (accstays 0);store_qkmasks every element toNEG_INFexactly as before.Implementation:
skip_tile = c0 > s0 + (TM - 1) + params.input_pos, folded into thed4loop break condition.sdpa_compute_attn_weights_tiled.glsl(tile_in_mask_region).Constraints:
S=1never triggers it (c0 <= input_pos < input_pos + TM - 1).NEG_INFstays the WGSL-safe-1.0e30(WGSL forbids a literal-inf); does not copy Vulkan's-1.0/0.0.Co-authored with Claude Code.
@exported-using-ghexport
Differential Revision: D109517773
Differential Revision: D109517773