mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
support dsv3 use flashmla (#6593)
This commit is contained in:
@@ -362,7 +362,10 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)[0]
|
||||
|
||||
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
||||
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
|
||||
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0
|
||||
|
||||
if need_do_prefill: # max_enc_len_this_time
|
||||
key_value = self.kv_b_proj(compressed_kv)
|
||||
key_value.reshape_(
|
||||
[
|
||||
@@ -393,10 +396,9 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
|
||||
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
|
||||
|
||||
fmha_out = fmha_out_prefill
|
||||
|
||||
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
|
||||
if need_do_decode: # max_dec_len_this_time
|
||||
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
|
||||
|
||||
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
|
||||
@@ -427,10 +429,10 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||
)
|
||||
|
||||
if fmha_out is None:
|
||||
fmha_out = fmha_out_decode
|
||||
if need_do_prefill:
|
||||
fmha_out += fmha_out_decode
|
||||
else:
|
||||
fmha_out = fmha_out + fmha_out_decode
|
||||
fmha_out = fmha_out_decode
|
||||
|
||||
output = self.o_proj(fmha_out)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user