[Optimization][Feature]Supports multiple batches of DSK-DSA. (#6930)

* support DSA_MUTI_BATCH

* update test topk

* update dsk-dsa
This commit is contained in:
AIbin
2026-03-20 15:59:22 +08:00
committed by GitHub
parent 1c38da2118
commit bf7e2424d0
5 changed files with 262 additions and 61 deletions
+125 -10
View File
@@ -467,6 +467,98 @@ def compute_slot_mapping(
return slot_mapping.cast(paddle.int64)
import triton
import triton.language as tl
@triton.jit()
def extract_kernel(
q,
weight,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
output,
out_weight,
cache_seqlens,
HIDDEN_DIM: tl.constexpr,
WEIGHT_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
batch_id = tl.program_id(axis=0)
cache_kv_len = tl.load(seq_lens_decoder + batch_id)
# 这个batch不是decoder,所以不需要动弹
if cache_kv_len <= 0:
return
cu_len_this_batch = tl.load(cu_seqlens_q + batch_id)
read_offsets = tl.arange(0, BLOCK_SIZE)
read_weight_offsets = tl.arange(0, WEIGHT_DIM)
q += cu_len_this_batch * HIDDEN_DIM
weight += cu_len_this_batch * WEIGHT_DIM
row_data = tl.load(q + read_offsets, mask=read_offsets < HIDDEN_DIM)
weight_row_data = tl.load(weight + read_weight_offsets, mask=read_weight_offsets < WEIGHT_DIM)
output += batch_id * HIDDEN_DIM
out_weight += batch_id * WEIGHT_DIM
tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM)
tl.store(out_weight + read_weight_offsets, weight_row_data, mask=read_weight_offsets < WEIGHT_DIM)
tl.store(cache_seqlens + batch_id, cache_kv_len + 1)
def extract_decoder_token_from_q(
q: paddle.Tensor,
weight: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
):
assert len(q.shape) == 2
assert len(weight.shape) == 2
assert len(cu_seqlens_q.shape) == 1
assert len(seq_lens_encoder.shape) == 1
assert len(seq_lens_decoder.shape) == 1
max_bsz = seq_lens_decoder.shape[0]
hidden_dim = q.shape[-1]
weight_dim = weight.shape[-1]
# if q.shape[0] <= max_bsz:
# max_bsz = q.shape[0]
out = paddle.zeros([max_bsz, hidden_dim], dtype=q.dtype)
out_weight = paddle.zeros([max_bsz, weight_dim], dtype=weight.dtype)
cache_seqlens = paddle.zeros_like(seq_lens_decoder)
BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
grid = (max_bsz,)
extract_kernel[grid](
q,
weight,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
out,
out_weight,
cache_seqlens,
hidden_dim,
weight_dim,
BLOCK_SIZE,
)
return out, out_weight, cache_seqlens
class Indexer(nn.Layer):
def __init__(
self,
@@ -587,7 +679,17 @@ class Indexer(nn.Layer):
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
num_tokens = q_fp8.shape[0]
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
ke = paddle.arange(num_tokens, dtype=paddle.int32) + 1 # + (seq_len_kv - seq_len)
ks_topk = paddle.zeros(num_tokens, dtype=paddle.int32)
ke = paddle.zeros(num_tokens, dtype=paddle.int32)
bsz = forward_meta.seq_lens_this_time.shape[0]
for i in range(bsz):
if forward_meta.seq_lens_encoder[i] > 0:
token_start_k = forward_meta.cu_seqlens_k[i]
token_end_k = forward_meta.cu_seqlens_k[i + 1]
ks[token_start_k:token_end_k] = forward_meta.cu_seqlens_k[i]
ke[token_start_k:token_end_k] = paddle.arange(token_start_k, token_end_k, dtype=paddle.int32) + 1
max_seqlen_k = (ke - ks).max().item()
logits = deep_gemm.fp8_mqa_logits(
@@ -604,26 +706,34 @@ class Indexer(nn.Layer):
radix_topk_ragged_transform(
logits.contiguous(),
indexer_top_k,
ks, # self.offsets,
ke, # mask.contiguous(),#self.lengths,
ks_topk, # self.offsets,
ke - ks + 1, # mask.contiguous(),#self.lengths,
None, # forward_meta.seq_lens_decoder,
None, # forward_meta.batch_id_per_token,
None,
None, # self.buffer
0,
self.index_topk,
1,
)
if forward_meta.max_len_tensor_cpu[2]:
seq_len_kv = forward_meta.seq_lens_decoder + forward_meta.seq_lens_this_time
decoder_q, decoder_weight, cache_seqlens = extract_decoder_token_from_q(
q_fp8.reshape(-1, self.index_n_heads * self.index_head_dim),
weights,
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
)
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(seq_len_kv, 64, deep_gemm.get_num_sms())
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(cache_seqlens, 64, deep_gemm.get_num_sms())
logits = deep_gemm.fp8_paged_mqa_logits(
q_fp8.unsqueeze(1),
decoder_q.reshape(-1, 1, self.index_n_heads, self.index_head_dim),
self.indexer_cache.unsqueeze(2),
weights,
seq_len_kv,
decoder_weight,
cache_seqlens,
forward_meta.block_tables,
schedule_metadata,
self.max_model_len,
@@ -635,11 +745,13 @@ class Indexer(nn.Layer):
indexer_top_k,
self.offsets, # unused
self.lengths, # unused
seq_len_kv,
cache_seqlens,
forward_meta.batch_id_per_token,
forward_meta.block_tables,
None, # self.buffer
forward_meta.block_tables.shape[1],
self.index_topk,
1,
1, # q_head
)
return indexer_top_k
@@ -673,6 +785,9 @@ class DeepseekV32DSAAttention(nn.Layer):
self.attn_softmax_scale = self.qk_head_dim**-0.5
self.rope_theta = fd_config.model_config.rope_theta
if fd_config.model_config.model_type == "glm_moe_dsa":
self.rope_theta = fd_config.model_config.rope_parameters["rope_theta"]
self.rms_norm_eps = fd_config.model_config.rms_norm_eps
assert self.q_lora_rank is not None, "self.q_lora_rank is None, Please Check your config."