mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-30 04:10:05 +08:00
[Cherry-Pick] [BugFix] fix mtp split kv attetion (#5921)
* [BugFix] fix mtp split kv attetion * clean code * clean code
This commit is contained in:
@@ -2414,7 +2414,8 @@ template <typename T,
|
||||
uint32_t bdy,
|
||||
uint32_t HEAD_DIM,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true>
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool DECODE_ONLY = true>
|
||||
__global__ void merge_multi_chunks_v2_kernel(
|
||||
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
|
||||
// head_dim]
|
||||
@@ -2458,15 +2459,16 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
if (ENABLE_PREFILL) {
|
||||
seq_len_kv += seq_len_q;
|
||||
if (seq_len_kv == 0) continue;
|
||||
|
||||
const int seq_len_enc = seq_lens_encoder[bid];
|
||||
if (seq_len_enc <= 0) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
}
|
||||
if constexpr (DECODE_ONLY) {
|
||||
const int seq_len_enc = seq_lens_encoder[bid];
|
||||
if (seq_len_enc > 0) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
continue;
|
||||
|
||||
Reference in New Issue
Block a user