[Cherry-Pick] [BugFix] fix mtp split kv attetion (#5921)

* [BugFix] fix mtp split kv attetion

* clean code

* clean code
This commit is contained in:
lizhenyun01
2026-01-07 19:50:02 +08:00
committed by GitHub
parent 7cdffced2d
commit 0b630fc3c1
5 changed files with 170 additions and 273 deletions
@@ -2414,7 +2414,8 @@ template <typename T,
uint32_t bdy,
uint32_t HEAD_DIM,
typename OutT = T,
bool ENABLE_PREFILL = true>
bool ENABLE_PREFILL = true,
bool DECODE_ONLY = true>
__global__ void merge_multi_chunks_v2_kernel(
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
@@ -2458,15 +2459,16 @@ __global__ void merge_multi_chunks_v2_kernel(
if (ENABLE_PREFILL) {
seq_len_kv += seq_len_q;
if (seq_len_kv == 0) continue;
const int seq_len_enc = seq_lens_encoder[bid];
if (seq_len_enc <= 0) {
continue;
}
} else {
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;
}
if constexpr (DECODE_ONLY) {
const int seq_len_enc = seq_lens_encoder[bid];
if (seq_len_enc > 0) {
continue;
}
}
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;