mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Others] clean code (#6839)
Co-authored-by: “liuruian” <liuruian@baidu.com>
This commit is contained in:
@@ -344,47 +344,28 @@ class DSAAttentionBackend(AttentionBackend):
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import dsk_attn_write_cache
|
||||
|
||||
k_range = paddle.tensor(200.0)
|
||||
scale = paddle.abs(compressed_kv).max() / k_range
|
||||
|
||||
slot_mapping = compute_slot_mapping(
|
||||
forward_meta.block_tables,
|
||||
forward_meta.position_ids,
|
||||
forward_meta.batch_id_per_token,
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
dsk_attn_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
slot_mapping,
|
||||
scale.cast(paddle.float32),
|
||||
"fp8_ds_mla",
|
||||
)
|
||||
|
||||
fmha_out_prefill = None
|
||||
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
||||
|
||||
# def calc_kv_scales(self, q: paddle.Tensor, kv_c_normed: paddle.Tensor, k_pe: paddle.Tensor) -> None:
|
||||
# """Optional scale calculation for MLA inputs.
|
||||
|
||||
# Mirrors Attention.calc_kv_scales. Not all MLA backends require this
|
||||
# """
|
||||
# # Use safe defaults if ranges are not present
|
||||
# q_range = paddle.tensor(200.0)
|
||||
# k_range = paddle.tensor(200.0)
|
||||
# v_range = paddle.tensor(100.0)
|
||||
|
||||
# self._q_scale.copy_(paddle.abs(q).max() / q_range)
|
||||
|
||||
# kv_abs_max = paddle.abs(kv_c_normed).max()
|
||||
# self._k_scale.copy_(kv_abs_max / k_range)
|
||||
# self._v_scale.copy_(kv_abs_max / v_range)
|
||||
# self._q_scale_float = self._q_scale.item()
|
||||
# self._k_scale_float = self._k_scale.item()
|
||||
# self._v_scale_float = self._v_scale.item()
|
||||
# self.calculate_kv_scales = False
|
||||
|
||||
metadata.slot_mapping = compute_slot_mapping(
|
||||
forward_meta.block_tables,
|
||||
forward_meta.position_ids,
|
||||
forward_meta.batch_id_per_token,
|
||||
self.block_size,
|
||||
)
|
||||
k_range = paddle.tensor(200.0)
|
||||
scale = paddle.abs(compressed_kv).max() / k_range
|
||||
|
||||
dsk_attn_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
metadata.slot_mapping,
|
||||
scale.cast(paddle.float32),
|
||||
"fp8_ds_mla",
|
||||
True,
|
||||
)
|
||||
|
||||
fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd(
|
||||
q, # q_input.contiguous(),
|
||||
k, # kv.unsqueeze(1),
|
||||
@@ -392,31 +373,10 @@ class DSAAttentionBackend(AttentionBackend):
|
||||
sm_scale=self.attn_softmax_scale,
|
||||
)
|
||||
|
||||
return fmha_out_prefill
|
||||
|
||||
# Decode
|
||||
# if k is None:
|
||||
if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time
|
||||
|
||||
metadata.slot_mapping = compute_slot_mapping(
|
||||
forward_meta.block_tables,
|
||||
forward_meta.position_ids,
|
||||
forward_meta.batch_id_per_token,
|
||||
self.block_size,
|
||||
)
|
||||
k_range = paddle.tensor(200.0)
|
||||
scale = paddle.abs(compressed_kv).max() / k_range
|
||||
|
||||
dsk_attn_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
latent_cache,
|
||||
metadata.slot_mapping,
|
||||
scale.cast(paddle.float32),
|
||||
"fp8_ds_mla",
|
||||
False,
|
||||
)
|
||||
|
||||
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
|
||||
|
||||
fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache(
|
||||
@@ -438,4 +398,26 @@ class DSAAttentionBackend(AttentionBackend):
|
||||
None, # extra_topk_length: Optional[torch.Tensor] = None
|
||||
)
|
||||
|
||||
return fmha_out_decode
|
||||
if fmha_out_prefill is not None:
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
merge_prefill_decode_output,
|
||||
)
|
||||
|
||||
merge_prefill_decode_output(
|
||||
fmha_out_prefill,
|
||||
fmha_out_decode,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
self.num_heads * 4,
|
||||
128,
|
||||
1,
|
||||
)
|
||||
|
||||
return fmha_out_prefill
|
||||
else:
|
||||
return fmha_out_decode
|
||||
|
||||
return fmha_out_prefill
|
||||
|
||||
Reference in New Issue
Block a user