{{ message }}
[ROCm][Perf][Bugfix] DSv4 indexer: use platform FP8 dtype (fnuz) for Q-quant on gfx942#46730
Open
akii96 wants to merge 1 commit into
Open
[ROCm][Perf][Bugfix] DSv4 indexer: use platform FP8 dtype (fnuz) for Q-quant on gfx942#46730akii96 wants to merge 1 commit into
akii96 wants to merge 1 commit into
Conversation
Signed-off-by: Aakif Nawaz <aakif.nawaz@amd.com>
db4e561 to
0309e6c
Compare
Contributor
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Motivation
On gfx942 the DeepSeek-V4 Flash indexer quantizes Q and K to different FP8 types. K already uses the platform type (e4m3fnuz on gfx942, via current_platform.fp8_dtype()), but the fused RoPE+quant kernel in fused_indexer_q.py hardcodes Q to e4m3fn. The FP8 logits kernel then gets fnuz K with fn Q and falls back to a mixed-dtype path on every call. This change derives Q's type from current_platform.fp8_dtype() as well, so on gfx942 both are fnuz and the logits kernel runs its native fnuz/fnuz path.
This is gfx942-specific by design. is_fp8_fnuz() is true only for gfx94x, so on gfx950 the platform type is OCP e4m3fn and Q/K are already fn/fn. Nothing changes there, and the NVIDIA cutedsl and MXFP4 paths are untouched (the two new kernel constexprs are defaulted).
The fnuz quant max is set to 224.0 to match get_fp8_min_max() in quant_utils.py, the value the K cache already uses.
Results
End to end serving of DeepSeek-V4 Flash (TP4, gfx942 / MI300), mean TTFT, prefill-heavy (OSL=27, concurrency 4):
Bonus correctness check: the existing kernel test tests/kernels/test_fused_indexer_q_rope_quant.py matches the unfused reference bit for bit on 9 of 10 gfx942 shapes. The one miss is 3 of 8,380,416 values at float32 / 1023 tokens, from fused vs unfused RoPE rounding at FP8 boundaries, not a dtype or scale error.
Note
The same indexer dtype handling was included in the larger ROCm enablement PRs #41601 and #42033, both stalled on rebase since May. This is a minimal, standalone version of just that fix for the gfx942 path; main as of today still hardcodes Q to e4m3fn.
Repro: serve + bench commands (DeepSeek-V4 Flash, TP4, gfx942)
Serve (4x MI325):
HIP_VISIBLE_DEVICES=0,1,2,3 VLLM_ROCM_USE_AITER=1 \ vllm serve deepseek-ai/DeepSeek-V4-Flash \ --tensor-parallel-size 4 \ --gpu-memory-utilization 0.85 \ --kv-cache-dtype fp8_e4m3 \ --block-size 256 \ --max-model-len 132096 \ --max-num-batched-tokens 16384 \ --max-num-seqs 156 \ --async-scheduling \ --no-enable-prefix-caching \ --tokenizer-mode deepseek_v4 \ --reasoning-parser deepseek_v4 \ --tool-call-parser deepseek_v4 \ --enable-auto-tool-choice \ --disable-log-stats \ --host 0.0.0.0 --port 8000 vllm bench serve --backend vllm \ --model deepseek-ai/DeepSeek-V4-Flash \ --host localhost --port 8000 \ --dataset-name random --ignore-eos --trust-remote-code \ --seed 5678 \ --random-input-len 8192 \ --random-output-len 27 \ --max-concurrency 4 --num-prompts 12 --num-warmups 4 # 32K ISL: same command with --random-input-len 32768