[Kernels] Use causal_conv1d kernel from Tri Dao for 10-90% speedup#6625
[Kernels] Use causal_conv1d kernel from Tri Dao for 10-90% speedup#6625gabrieldemarmiesse wants to merge 2 commits into
Conversation
BEGIN_PUBLIC [Kernels][GPU] Add causal_conv1d forward GPU benchmark Adds a kernel-time benchmark for the channel-first causal_conv1d forward GPU kernel (state_space). It mirrors the validated test launch config (kNThreads=128, kNElts=4), times the kernel via the Bench/Bencher iter_custom harness, and reports achieved memory bandwidth (the op is memory-bound) as 2 * batch * dim * seqlen * sizeof(dtype). dtype and conv width are compile-time defines (default bfloat16, width=4 to match the common Mamba config); batch, dim, seqlen and the SiLU activation flag are runtime args. Since causal_conv1d lives in //max:state_space, which the globbed GPU benchmark deps don't include, the target is declared explicitly (and excluded from the glob) following the existing bench_conv2d/bench_conv3d pattern. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Gabriel <gabriel@kyutai.org>
BEGIN_PUBLIC [Kernels][GPU] Vectorize channel-first causal_conv1d forward kernel Replaces the channel-first forward GPU kernel (causal_conv1d_channel_first_fwd_gpu) with the Tri Dao-style vectorized algorithm: grid = (dim, batch) so one block per (B, C) walks the whole seqlen (loading weight/bias once), 16-byte LDG (kNElts = 16/sizeof(dtype), i.e. 8 for bf16/fp16 vs the old fixed 4), and a shared-memory ring-buffer for the (W-1) halo shared across chunks. The old kernel loaded x element-by-element and never emitted 128-bit vector loads; on an H100 (bf16) it sustained ~1050 GB/s (~52% of HBM peak). The new kernel sustains ~1700-1875 GB/s (~85-90% of peak) — about 1.6-1.9x faster across representative shapes. Correctness is preserved for arbitrary strides: a block-uniform runtime check selects the 16-byte vector path only when the layout is contiguous-in-L and aligned, and falls back to scalar strided loads/stores otherwise (the non-contiguous layout the Mamba pipeline passes after permute+split). Bias is optional with the existing safety check, SiLU stays a runtime flag, widths 1-4 and mixed dtypes are supported. The kernel signature is unchanged; only the launch config (grid shape, kNElts) changes at the call sites (op wrapper, test, benchmark). Tests in test_causal_conv1d.mojo are extended to cover bf16 (the production dtype), multi-chunk seqlens, and non-contiguous-L inputs, all validated against the unchanged CPU reference. Channel-last, CPU, update, and varlen kernels are untouched. END_PUBLIC Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Gabriel <gabriel@kyutai.org>
abe5534 to
e508721
Compare
|
Thanks @ulmentflam for the context, related to that I have a question, there are a few functions that were added and never called anywhere: do you have more information about what they're used for? I don't see any call site. Maybe it was something in draft that was never used? |
|
Yes, so they are not needed for Mamba1 and it seems they landed without their tests. They are actually for full parity with the forward pass kernels with causal-conv1d. I have an internal repo I was using that has a full kernel implementation of the causal-conv1d library that these functions were taken from for Mamba support. Let me grab what I can surface from those tests. It is correct that the update kernels are all that Mamba needs, which is the channel_first_fwd suite. |
|
Thanks, I'll let you do the cleanup and wiring, since I don't know what functions I should touch. many of those functions have overlapping capabilities, and I guess maintainers don't want to maintain functions that are used only in unit tests. Ping me when they're wired to the rest of the codebase, and I can start from there! There is also massive code duplication, so if we could address that too, that would be great! |
|
Will do. Most of the duplication was a forcing function of what needed to be included at compile time and issues with passing pointers which I believe have been resolved. |
|
I just did a cleanup pass on selective scan kernel with some new (or new to me) syntax features and will take one here as well. |

Type of change
Motivation
What changed
Testing
Checklist
agreed-upon, or this is a trivial fix that does not need prior
approval
smaller PRs where possible (see
pull request sizes)
./bazelw run formatto format my changesAssisted-by:trailer in my commit message or this PR description (seeAI Tool Use Policy)