mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization][Feature]Supports multiple batches of DSK-DSA. (#6930)
* support DSA_MUTI_BATCH * update test topk * update dsk-dsa
This commit is contained in:
@@ -467,6 +467,98 @@ def compute_slot_mapping(
|
||||
return slot_mapping.cast(paddle.int64)
|
||||
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit()
|
||||
def extract_kernel(
|
||||
q,
|
||||
weight,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
output,
|
||||
out_weight,
|
||||
cache_seqlens,
|
||||
HIDDEN_DIM: tl.constexpr,
|
||||
WEIGHT_DIM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
|
||||
batch_id = tl.program_id(axis=0)
|
||||
cache_kv_len = tl.load(seq_lens_decoder + batch_id)
|
||||
|
||||
# 这个batch不是decoder,所以不需要动弹
|
||||
if cache_kv_len <= 0:
|
||||
return
|
||||
|
||||
cu_len_this_batch = tl.load(cu_seqlens_q + batch_id)
|
||||
|
||||
read_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
read_weight_offsets = tl.arange(0, WEIGHT_DIM)
|
||||
|
||||
q += cu_len_this_batch * HIDDEN_DIM
|
||||
weight += cu_len_this_batch * WEIGHT_DIM
|
||||
|
||||
row_data = tl.load(q + read_offsets, mask=read_offsets < HIDDEN_DIM)
|
||||
weight_row_data = tl.load(weight + read_weight_offsets, mask=read_weight_offsets < WEIGHT_DIM)
|
||||
|
||||
output += batch_id * HIDDEN_DIM
|
||||
out_weight += batch_id * WEIGHT_DIM
|
||||
|
||||
tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM)
|
||||
tl.store(out_weight + read_weight_offsets, weight_row_data, mask=read_weight_offsets < WEIGHT_DIM)
|
||||
|
||||
tl.store(cache_seqlens + batch_id, cache_kv_len + 1)
|
||||
|
||||
|
||||
def extract_decoder_token_from_q(
|
||||
q: paddle.Tensor,
|
||||
weight: paddle.Tensor,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
):
|
||||
assert len(q.shape) == 2
|
||||
assert len(weight.shape) == 2
|
||||
assert len(cu_seqlens_q.shape) == 1
|
||||
assert len(seq_lens_encoder.shape) == 1
|
||||
assert len(seq_lens_decoder.shape) == 1
|
||||
|
||||
max_bsz = seq_lens_decoder.shape[0]
|
||||
|
||||
hidden_dim = q.shape[-1]
|
||||
weight_dim = weight.shape[-1]
|
||||
|
||||
# if q.shape[0] <= max_bsz:
|
||||
# max_bsz = q.shape[0]
|
||||
out = paddle.zeros([max_bsz, hidden_dim], dtype=q.dtype)
|
||||
out_weight = paddle.zeros([max_bsz, weight_dim], dtype=weight.dtype)
|
||||
|
||||
cache_seqlens = paddle.zeros_like(seq_lens_decoder)
|
||||
|
||||
BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
|
||||
|
||||
grid = (max_bsz,)
|
||||
|
||||
extract_kernel[grid](
|
||||
q,
|
||||
weight,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
out,
|
||||
out_weight,
|
||||
cache_seqlens,
|
||||
hidden_dim,
|
||||
weight_dim,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out, out_weight, cache_seqlens
|
||||
|
||||
|
||||
class Indexer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -587,7 +679,17 @@ class Indexer(nn.Layer):
|
||||
# ks,ke = forward_meta.attn_mask_offsets[::2].contiguous(),forward_meta.attn_mask_offsets[1::2].contiguous()
|
||||
num_tokens = q_fp8.shape[0]
|
||||
ks = paddle.zeros(num_tokens, dtype=paddle.int32)
|
||||
ke = paddle.arange(num_tokens, dtype=paddle.int32) + 1 # + (seq_len_kv - seq_len)
|
||||
ks_topk = paddle.zeros(num_tokens, dtype=paddle.int32)
|
||||
ke = paddle.zeros(num_tokens, dtype=paddle.int32)
|
||||
|
||||
bsz = forward_meta.seq_lens_this_time.shape[0]
|
||||
for i in range(bsz):
|
||||
if forward_meta.seq_lens_encoder[i] > 0:
|
||||
token_start_k = forward_meta.cu_seqlens_k[i]
|
||||
token_end_k = forward_meta.cu_seqlens_k[i + 1]
|
||||
ks[token_start_k:token_end_k] = forward_meta.cu_seqlens_k[i]
|
||||
ke[token_start_k:token_end_k] = paddle.arange(token_start_k, token_end_k, dtype=paddle.int32) + 1
|
||||
|
||||
max_seqlen_k = (ke - ks).max().item()
|
||||
|
||||
logits = deep_gemm.fp8_mqa_logits(
|
||||
@@ -604,26 +706,34 @@ class Indexer(nn.Layer):
|
||||
radix_topk_ragged_transform(
|
||||
logits.contiguous(),
|
||||
indexer_top_k,
|
||||
ks, # self.offsets,
|
||||
ke, # mask.contiguous(),#self.lengths,
|
||||
ks_topk, # self.offsets,
|
||||
ke - ks + 1, # mask.contiguous(),#self.lengths,
|
||||
None, # forward_meta.seq_lens_decoder,
|
||||
None, # forward_meta.batch_id_per_token,
|
||||
None,
|
||||
None, # self.buffer
|
||||
0,
|
||||
self.index_topk,
|
||||
1,
|
||||
)
|
||||
|
||||
if forward_meta.max_len_tensor_cpu[2]:
|
||||
|
||||
seq_len_kv = forward_meta.seq_lens_decoder + forward_meta.seq_lens_this_time
|
||||
decoder_q, decoder_weight, cache_seqlens = extract_decoder_token_from_q(
|
||||
q_fp8.reshape(-1, self.index_n_heads * self.index_head_dim),
|
||||
weights,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
)
|
||||
|
||||
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(seq_len_kv, 64, deep_gemm.get_num_sms())
|
||||
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(cache_seqlens, 64, deep_gemm.get_num_sms())
|
||||
|
||||
logits = deep_gemm.fp8_paged_mqa_logits(
|
||||
q_fp8.unsqueeze(1),
|
||||
decoder_q.reshape(-1, 1, self.index_n_heads, self.index_head_dim),
|
||||
self.indexer_cache.unsqueeze(2),
|
||||
weights,
|
||||
seq_len_kv,
|
||||
decoder_weight,
|
||||
cache_seqlens,
|
||||
forward_meta.block_tables,
|
||||
schedule_metadata,
|
||||
self.max_model_len,
|
||||
@@ -635,11 +745,13 @@ class Indexer(nn.Layer):
|
||||
indexer_top_k,
|
||||
self.offsets, # unused
|
||||
self.lengths, # unused
|
||||
seq_len_kv,
|
||||
cache_seqlens,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.block_tables,
|
||||
None, # self.buffer
|
||||
forward_meta.block_tables.shape[1],
|
||||
self.index_topk,
|
||||
1,
|
||||
1, # q_head
|
||||
)
|
||||
|
||||
return indexer_top_k
|
||||
@@ -673,6 +785,9 @@ class DeepseekV32DSAAttention(nn.Layer):
|
||||
|
||||
self.attn_softmax_scale = self.qk_head_dim**-0.5
|
||||
self.rope_theta = fd_config.model_config.rope_theta
|
||||
if fd_config.model_config.model_type == "glm_moe_dsa":
|
||||
self.rope_theta = fd_config.model_config.rope_parameters["rope_theta"]
|
||||
|
||||
self.rms_norm_eps = fd_config.model_config.rms_norm_eps
|
||||
|
||||
assert self.q_lora_rank is not None, "self.q_lora_rank is None, Please Check your config."
|
||||
|
||||
Reference in New Issue
Block a user