[Kernels] Use causal_conv1d kernel from Tri Dao for 10-90% speedup by gabrieldemarmiesse · Pull Request #6625 · modular/modular · GitHub
Skip to content

[Kernels] Use causal_conv1d kernel from Tri Dao for 10-90% speedup#6625

Draft
gabrieldemarmiesse wants to merge 2 commits into
modular:mainfrom
gabrieldemarmiesse:causal-conv1d-fwd-kernel
Draft

[Kernels] Use causal_conv1d kernel from Tri Dao for 10-90% speedup#6625
gabrieldemarmiesse wants to merge 2 commits into
modular:mainfrom
gabrieldemarmiesse:causal-conv1d-fwd-kernel

Conversation

@gabrieldemarmiesse

Copy link
Copy Markdown
Contributor

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • Performance improvement (includes benchmark results below)
  • Documentation update
  • New feature or public API (requires prior proposal or issue approval)
  • Refactor / internal cleanup (no user-visible change)
  • Build, CI, or tooling change

Motivation

What changed

Testing

Checklist

  • The linked issue above has been reviewed by a maintainer and is
    agreed-upon, or this is a trivial fix that does not need prior
    approval
  • PR is small and focused — I've split larger changes into a sequence of
    smaller PRs where possible (see
    pull request sizes)
  • I ran ./bazelw run format to format my changes
  • I added or updated tests to cover my changes
  • If AI tools assisted with this contribution, I have included an
    Assisted-by: trailer in my commit message or this PR description (see
    AI Tool Use Policy)

BEGIN_PUBLIC
[Kernels][GPU] Add causal_conv1d forward GPU benchmark

Adds a kernel-time benchmark for the channel-first causal_conv1d forward
GPU kernel (state_space). It mirrors the validated test launch config
(kNThreads=128, kNElts=4), times the kernel via the Bench/Bencher
iter_custom harness, and reports achieved memory bandwidth (the op is
memory-bound) as 2 * batch * dim * seqlen * sizeof(dtype).

dtype and conv width are compile-time defines (default bfloat16, width=4
to match the common Mamba config); batch, dim, seqlen and the SiLU
activation flag are runtime args.

Since causal_conv1d lives in //max:state_space, which the globbed GPU
benchmark deps don't include, the target is declared explicitly (and
excluded from the glob) following the existing bench_conv2d/bench_conv3d
pattern.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Gabriel <gabriel@kyutai.org>
@gabrieldemarmiesse gabrieldemarmiesse marked this pull request as ready for review May 29, 2026 12:36
@gabrieldemarmiesse gabrieldemarmiesse requested a review from a team as a code owner May 29, 2026 12:36
@gabrieldemarmiesse gabrieldemarmiesse changed the title [Kernels][GPU] Add causal_conv1d forward GPU benchmark [Kernels] Use causal_conv1d kernel from Tri Dao for 10-90% speedup May 29, 2026
@gabrieldemarmiesse gabrieldemarmiesse marked this pull request as draft May 29, 2026 12:38
BEGIN_PUBLIC
[Kernels][GPU] Vectorize channel-first causal_conv1d forward kernel

Replaces the channel-first forward GPU kernel
(causal_conv1d_channel_first_fwd_gpu) with the Tri Dao-style vectorized
algorithm: grid = (dim, batch) so one block per (B, C) walks the whole
seqlen (loading weight/bias once), 16-byte LDG (kNElts = 16/sizeof(dtype),
i.e. 8 for bf16/fp16 vs the old fixed 4), and a shared-memory ring-buffer
for the (W-1) halo shared across chunks.

The old kernel loaded x element-by-element and never emitted 128-bit
vector loads; on an H100 (bf16) it sustained ~1050 GB/s (~52% of HBM
peak). The new kernel sustains ~1700-1875 GB/s (~85-90% of peak) — about
1.6-1.9x faster across representative shapes.

Correctness is preserved for arbitrary strides: a block-uniform runtime
check selects the 16-byte vector path only when the layout is
contiguous-in-L and aligned, and falls back to scalar strided
loads/stores otherwise (the non-contiguous layout the Mamba pipeline
passes after permute+split). Bias is optional with the existing safety
check, SiLU stays a runtime flag, widths 1-4 and mixed dtypes are
supported. The kernel signature is unchanged; only the launch config
(grid shape, kNElts) changes at the call sites (op wrapper, test,
benchmark).

Tests in test_causal_conv1d.mojo are extended to cover bf16 (the
production dtype), multi-chunk seqlens, and non-contiguous-L inputs, all
validated against the unchanged CPU reference.

Channel-last, CPU, update, and varlen kernels are untouched.
END_PUBLIC

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Gabriel <gabriel@kyutai.org>
@gabrieldemarmiesse gabrieldemarmiesse force-pushed the causal-conv1d-fwd-kernel branch from abe5534 to e508721 Compare May 29, 2026 13:18
@ulmentflam

Copy link
Copy Markdown
Contributor

@gabrieldemarmiesse

Copy link
Copy Markdown
Contributor Author

Thanks @ulmentflam for the context, related to that I have a question, there are a few functions that were added and never called anywhere:

  - causal_conv1d_channel_last_fwd_cpu
  - causal_conv1d_channel_last_fwd_cpu_no_bias
  - causal_conv1d_channel_last_fwd_cpu_with_seq_idx
  - causal_conv1d_channel_last_fwd_cpu_no_bias_with_seq_idx
  - causal_conv1d_channel_last_fwd_gpu
  - causal_conv1d_channel_last_fwd_gpu_no_bias
  - causal_conv1d_channel_last_fwd_gpu_with_seq_idx
  - causal_conv1d_channel_last_fwd_gpu_no_bias_with_seq_idx

  - causal_conv1d_channel_first_fwd_cpu_no_bias
  - causal_conv1d_channel_first_fwd_gpu_no_bias
  - causal_conv1d_channel_first_fwd_gpu_with_seq_idx
  - causal_conv1d_channel_first_fwd_gpu_no_bias_with_seq_idx

  - causal_conv1d_update_cpu_no_bias (called only in test/state_space + test/gpu/state_space)
  - causal_conv1d_update_gpu_no_bias (called only in test/gpu/state_space)

do you have more information about what they're used for? I don't see any call site. Maybe it was something in draft that was never used?

@ulmentflam

Copy link
Copy Markdown
Contributor

Yes, so they are not needed for Mamba1 and it seems they landed without their tests. They are actually for full parity with the forward pass kernels with causal-conv1d. I have an internal repo I was using that has a full kernel implementation of the causal-conv1d library that these functions were taken from for Mamba support. Let me grab what I can surface from those tests. It is correct that the update kernels are all that Mamba needs, which is the channel_first_fwd suite.

@gabrieldemarmiesse

gabrieldemarmiesse commented May 29, 2026

Copy link
Copy Markdown
Contributor Author

Thanks, I'll let you do the cleanup and wiring, since I don't know what functions I should touch. many of those functions have overlapping capabilities, and I guess maintainers don't want to maintain functions that are used only in unit tests. Ping me when they're wired to the rest of the codebase, and I can start from there! There is also massive code duplication, so if we could address that too, that would be great!

@ulmentflam

Copy link
Copy Markdown
Contributor

Will do. Most of the duplication was a forcing function of what needed to be included at compile time and issues with passing pointers which I believe have been resolved.

@ulmentflam

Copy link
Copy Markdown
Contributor

I just did a cleanup pass on selective scan kernel with some new (or new to me) syntax features and will take one here as well.

@ulmentflam

Copy link
Copy Markdown
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants