[ExecuTorch][WebGPU] Add et_vk.embedding_q4gsw (4-bit groupwise-symmetric quantized embedding) by JulianCloudNTH · Pull Request #20263 · pytorch/executorch · GitHub
Skip to content

[ExecuTorch][WebGPU] Add et_vk.embedding_q4gsw (4-bit groupwise-symmetric quantized embedding)#20263

Merged
meta-codesync[bot] merged 5 commits into
gh/JulianCloudNTH/25/basefrom
gh/JulianCloudNTH/25/head
Jun 22, 2026
Merged

[ExecuTorch][WebGPU] Add et_vk.embedding_q4gsw (4-bit groupwise-symmetric quantized embedding)#20263
meta-codesync[bot] merged 5 commits into
gh/JulianCloudNTH/25/basefrom
gh/JulianCloudNTH/25/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Adds the WebGPU backend handler for et_vk.embedding_q4gsw.default (a 4-bit groupwise-symmetric quantized embedding gather) plus the host-side integer-input infra it requires.

The op is a single compute dispatch composed of one stage: one thread per 32-element block of each gathered row dequantizes the packed 4-bit table (q = (nibble - 8) * scale; even dim = high nibble, odd dim = low) into the fp32 output, mirroring the Vulkan embedding_q4gsw reference (flat buffer-backed weight; is_linear_weight=true is unsupported and throws). The workgroup size is a wg_size pipeline-override constant clamped to the device limit via WebGPUUtils::clamp_workgroup_size, the 1D dispatch count goes through WebGPUUtils::compute_1d_workgroup_count (validated before any GPU-object allocation), and the embedded WGSL string header is generated by gen_wgsl_headers.py.

Embedding indices arrive as int64 at the program boundary but the serialized graph stores them as int32, so the shared input path is extended with a host-side InputData view ({data, nbytes, host_is_int64}) and copy_inputs gains three branches: a byte-for-byte fast path when host and GPU sizes match, an int64->int32 narrowing copy when the buffer is int32 and the host input is twice as wide (mirrors the Vulkan kLong->kInt staging cast), and a fail-loud throw otherwise. WebGPUTensor gains elem_size/is_int to drive the narrowing decision, and update_symints_from_inputs takes the same InputData vector so execute() builds a single input list consumed by both.
@exported-using-ghexport

Differential Revision: D108428753

Differential Revision: D108428753

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 13, 2026

Copy link
Copy Markdown

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review

@claude

claude Bot commented Jun 15, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 2m 1s —— View job


Code Review: et_vk.embedding_q4gsw (WebGPU)

  • Read PR diff and changed files
  • Review op implementation (EmbeddingQ4gsw.cpp + WGSL)
  • Review host-side input infra (WebGPUGraph, copy_inputs, InputData)
  • Review tests
  • Post review feedback

Overall this is a clean, well-validated addition. The op does extensive up-front shape/byte guards before allocating any GPU object, the WGSL nibble unpacking matches the documented pack convention, and the int64→int32 narrowing path is well reasoned. One correctness bug and a few smaller notes below.


🔴 copy_inputs can throw across the backend boundary

WebGPUBackend::execute deliberately wraps the symint refresh in a try/catch with the comment "Fail loud as a runtime Error so a throw never crosses the backend boundary" — but copy_inputs is called outside that block:

graph->copy_inputs(inputs);          // WebGPUBackend.cpp:109  — NOT guarded

try {
  graph->update_symints_from_inputs(inputs);   // :112-118 — guarded
  graph->propagate_resize();
} catch (const std::exception& e) { ... return Error::Internal; }

This PR newly makes copy_inputs throw std::runtime_error (the "fail-loud" branch in WebGPUGraph.cpp). Before this PR copy_inputs was non-throwing, so the unguarded call was fine — now an unsupported/mismatched input will let a C++ exception escape execute(), which returns Error and is not expected to throw. That contradicts the design comment two lines below and can crash the runtime instead of returning a clean error.

Suggest moving the copy_inputs call inside the existing try/catch (and the input-gathering loop is fine where it is):

try {
  graph->copy_inputs(inputs);
  graph->update_symints_from_inputs(inputs);
  graph->propagate_resize();
} catch (const std::exception& e) {
  ET_LOG(Error, "WebGPU input copy/symint refresh failed: %s", e.what());
  return Error::Internal;
}

Fix this →


🟡 Out-of-range indices read out of bounds (silent wrong output)

In the WGSL, token = u32(t_indices[indices_idx]) is used directly to index t_weight and t_scales with no clamp against vocab size. A token ≥ vocab (or a negative int32 that wraps to a huge u32) indexes past the weight/scales buffers. WebGPU bounds-checking will return 0 rather than crash, so the failure is silent. This mirrors the Vulkan reference's trust in valid indices, so it may be acceptable — but worth a one-line comment noting indices are assumed in-range, since unlike a CPU gather there's no bounds error to surface the mistake.

🟡 total_blocks narrowed to uint32 without a guard

total_blocks is computed as uint64_t but stored into the uint32_t EmbeddingParams::total_blocks and passed to compute_1d_workgroup_count via a static_cast<uint32_t>. There's no guard that it fits in 32 bits. All current test shapes are tiny (256 blocks), so this is theoretical, but a > UINT32_MAX check alongside the other guards would keep the fail-loud invariant consistent for large gathers.

🟢 Minor: per-execute allocation in the narrowing path

copy_inputs allocates a fresh std::vector<int32_t> narrowed(numel) on every execute for the int64→int32 case. Fine for correctness; if the embedding-gather is on a hot path you may later want a reusable scratch buffer. Not blocking.


Things that look good

  • EmbeddingParams static_assert(sizeof == 32) + explicit _pad matches std140 uniform alignment — nice.
  • Validation happens before any wgpuDeviceCreate* call, so failures don't leak GPU objects.
  • The Python test computes the golden two independent ways (et_vk reference vs. a hand-rolled dequant+F.embedding) in test_golden_matches_eager, so a buggy golden can't mask a buggy kernel — good defensive testing.
  • embedding_q4gsw_wgsl.h carries the wgsl-sha256 and DO NOT EDIT banner; assuming gen_wgsl_headers.py regenerated it, the embedded copy matches the .wgsl source (I diffed them — identical).
  • The native C++ test exercises the int64→int32 narrowing path end-to-end by feeding an int64 tensor at the boundary.

Question

  • The native test (test_webgpu_native.cpp) only runs the small shape (embed=64). The export helper supports llama1b (embed=2048, group=64) and the Python test_golden_matches_eager covers it, but the native GPU kernel is never run at that shape. Was that intentional (keep CI light), or worth adding so the per-group scale indexing is validated on-device too?

[ghstack-poisoned]
@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review

@claude

claude Bot commented Jun 15, 2026

Copy link
Copy Markdown

[ghstack-poisoned]
[ghstack-poisoned]
@meta-codesync meta-codesync Bot merged commit f3d16c3 into gh/JulianCloudNTH/25/base Jun 22, 2026
175 of 183 checks passed
@meta-codesync meta-codesync Bot deleted the gh/JulianCloudNTH/25/head branch June 22, 2026 06:46
@meta-codesync meta-codesync Bot temporarily deployed to cherry-pick-bot June 22, 2026 06:46 Inactive
JulianCloudNTH added a commit that referenced this pull request Jun 22, 2026
#20292)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.15.0)
(oldest at bottom):
* __->__ #20292
* #20265
* #20290
* #20264
* #20289
* #20263



Test suite for the `et_vk.prepack` constant-materialization op, split
into its
own diff (op below, tests above) per the per-op test-split convention.

The prepack op is how a serialized constant becomes a GPU tensor: the
constant
arrives as a CPU-side reference (sizes + a pointer into the .pte bytes),
and the
prepack node is the sole materialization — one CPU->GPU transfer
straight into
the consumer's buffer. The model `M(x) = x + w` (w a constant) routes
`w`
through a prepack node, so the delegate must run the materialization for
the
output to equal `x + w` rather than `x + 0`.
@exported-using-ghexport

Differential Revision:
[D108678631](https://our.internmc.facebook.com/intern/diff/D108678631/)

Differential Revision:
[D108678631](https://our.internmc.facebook.com/intern/diff/D108678631)
JulianCloudNTH added a commit that referenced this pull request Jun 22, 2026
…tric quantized embedding) (#20414)

This PR was created by the merge bot to help merge the original PR into
the main branch.
ghstack PR number: #20263 by
@JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments,
and reviews
ghstack PR base:
https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/25/base
ghstack PR head:
https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/25/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/main
Merge bot PR head:
https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/25/orig

@diff-train-skip-merge

---------

Co-authored-by: Julian Ng-Thow-Hing <juliannth@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants