webgpu: merge batchA into M dimension when batchB==1#28197
webgpu: merge batchA into M dimension when batchB==1#28197xhcao wants to merge 1 commit intomicrosoft:mainfrom
Conversation
When M is small and batchA is large, there are some invalid elements in each tile, merge batchA into M dimesion would reduce the workgroup count.
There was a problem hiding this comment.
Pull request overview
This PR updates the WebGPU MatMul implementation to flatten A’s batch dimensions into the effective M dimension when B has no batching (batchB==1), aiming to reduce workgroup overhead for cases with small M and large batchA. It also adds WebGPU-specific regression tests for additional 3D batched MatMul shapes.
Changes:
- WebGPU MatMul: reshape A/B and treat output as
{1, batchA*M, N}whenbatchA != 1 && batchB == 1(applies to both the generic and Intel subgroup paths). - Add WebGPU-only MatMul test cases covering 3D inputs with
batchA=3, M=2andN in {3,4}.
Reviewed changes
Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // when B is a matrix (batch is 1), merge batchA into M dimension to improve performance for small M cases. | ||
| if (batchA != 1 && batchB == 1) { | ||
| // dimensions of A: [1,`batchA`, M, K] | ||
| int64_t batchAndM = a_shape.SizeToDimension(a_shape.NumDimensions() - 1); | ||
| TensorShapeVector dims_a = {1, batchAndM, helper.K()}; | ||
| // dimensions of B: [1,K,N] | ||
| TensorShapeVector dims_b = {1, helper.K(), helper.N()}; | ||
|
|
||
| a_shape = TensorShape(dims_a); | ||
| b_shape = TensorShape(dims_b); | ||
| output_shape = {1, batchA, helper.N()}; | ||
| output_shape = {1, batchAndM, helper.N()}; | ||
| } |
There was a problem hiding this comment.
This reshape path is now enabled for all cases where batchA != 1 && batchB == 1, but the comment/PR motivation suggests it’s intended specifically for small-M scenarios. As written, it will also trigger for large-M workloads (e.g., other callers that reuse ComputeMatMul) and broadly changes dispatch geometry. Consider gating this with an explicit heuristic (e.g., helper.M() below a threshold and/or batchA above a threshold), or update the comment to clarify it’s intentionally unconditional.
| TensorShapeVector dims_a = {1, batchA, helper.K()}; | ||
| // when B is a matrix (batch is 1), merge batchA into M dimension to improve performance for small M cases. | ||
| if (batchA != 1 && batchB == 1) { | ||
| // dimensions of A: [1,`batchA`, M, K] |
There was a problem hiding this comment.
The shape comment is inaccurate: dims_a is set to {1, batchAndM, K} (flattening all leading dims including M), not [1, batchA, M, K]. Please update the comment to reflect the actual reshape so it’s clear which dimensions are being merged.

When M is small and batchA is large, there are some invalid elements in each tile, merge batchA into M dimesion would reduce the workgroup count.
Description
Motivation and Context