mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Others] clean code (#6839)
Co-authored-by: “liuruian” <liuruian@baidu.com>
This commit is contained in:
@@ -41,7 +41,7 @@
|
|||||||
* Prefill stage: Write KV cache with DS MLA FP8 format
|
* Prefill stage: Write KV cache with DS MLA FP8 format
|
||||||
*/
|
*/
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
|
std::vector<paddle::Tensor> DSMLAWriteCacheFP8(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
const paddle::Tensor& kv_nope,
|
const paddle::Tensor& kv_nope,
|
||||||
const paddle::Tensor& kv_pe,
|
const paddle::Tensor& kv_pe,
|
||||||
@@ -56,9 +56,6 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
|
|||||||
auto kv_lora_rank = 512; // DS MLA uses 512
|
auto kv_lora_rank = 512; // DS MLA uses 512
|
||||||
auto pe_dim = 64; // DS MLA uses 64
|
auto pe_dim = 64; // DS MLA uses 64
|
||||||
auto block_size = meta_data.block_size;
|
auto block_size = meta_data.block_size;
|
||||||
|
|
||||||
// Entry size for DS MLA FP8: 512 (fp8) + 16 (scales) + 128 (rope bf16) = 656
|
|
||||||
// bytes
|
|
||||||
const int entry_size = 656;
|
const int entry_size = 656;
|
||||||
|
|
||||||
// Launch kernel with 96 threads (64 for NoPE, 32 for RoPE)
|
// Launch kernel with 96 threads (64 for NoPE, 32 for RoPE)
|
||||||
@@ -86,52 +83,6 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Decode stage: Write KV cache with DS MLA FP8 format
|
|
||||||
*/
|
|
||||||
template <paddle::DataType T>
|
|
||||||
std::vector<paddle::Tensor> DecodeDSMLAWriteCacheFP8(
|
|
||||||
const AppendAttnMetaData& meta_data,
|
|
||||||
const paddle::Tensor& kv_nope,
|
|
||||||
const paddle::Tensor& kv_pe,
|
|
||||||
const paddle::Tensor& slot_mapping,
|
|
||||||
cudaStream_t& stream,
|
|
||||||
paddle::Tensor* kv_cache) {
|
|
||||||
typedef PDTraits<T> traits_;
|
|
||||||
typedef typename traits_::DataType DataType_;
|
|
||||||
typedef typename traits_::data_t data_t;
|
|
||||||
|
|
||||||
auto num_tokens = slot_mapping.dims()[0];
|
|
||||||
auto kv_lora_rank = 512;
|
|
||||||
auto pe_dim = 64;
|
|
||||||
auto block_size = meta_data.block_size;
|
|
||||||
const int entry_size = 656;
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(96);
|
|
||||||
|
|
||||||
const auto& kv_cache_dims = kv_cache->dims();
|
|
||||||
int block_stride = kv_cache->strides()[0];
|
|
||||||
int entry_stride = entry_size;
|
|
||||||
int kv_c_stride = kv_nope.strides()[0];
|
|
||||||
int k_pe_stride = kv_pe.strides()[0];
|
|
||||||
|
|
||||||
ds_mla::concat_and_cache_ds_mla_kernel<DataType_><<<grid, block, 0, stream>>>(
|
|
||||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
|
||||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
|
||||||
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
|
|
||||||
slot_mapping.data<int64_t>(),
|
|
||||||
block_stride,
|
|
||||||
entry_stride,
|
|
||||||
kv_c_stride,
|
|
||||||
k_pe_stride,
|
|
||||||
kv_lora_rank,
|
|
||||||
pe_dim,
|
|
||||||
block_size);
|
|
||||||
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
//==============================================================================
|
//==============================================================================
|
||||||
// Standard MLA WriteCache Implementation
|
// Standard MLA WriteCache Implementation
|
||||||
//==============================================================================
|
//==============================================================================
|
||||||
@@ -297,8 +248,7 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
|||||||
const paddle::Tensor& kv_cache,
|
const paddle::Tensor& kv_cache,
|
||||||
const paddle::Tensor& slot_mapping,
|
const paddle::Tensor& slot_mapping,
|
||||||
const paddle::optional<paddle::Tensor>& scale,
|
const paddle::optional<paddle::Tensor>& scale,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str) {
|
||||||
const bool is_prefill) {
|
|
||||||
cudaStream_t stream = kv_pe.stream();
|
cudaStream_t stream = kv_pe.stream();
|
||||||
AppendAttnMetaData meta_data;
|
AppendAttnMetaData meta_data;
|
||||||
|
|
||||||
@@ -320,42 +270,22 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
|||||||
// FP8 DS MLA format
|
// FP8 DS MLA format
|
||||||
switch (kv_pe.dtype()) {
|
switch (kv_pe.dtype()) {
|
||||||
case paddle::DataType::BFLOAT16: {
|
case paddle::DataType::BFLOAT16: {
|
||||||
if (is_prefill) {
|
return DSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
|
||||||
return PrefillDSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
|
meta_data,
|
||||||
meta_data,
|
kv_nope,
|
||||||
kv_nope,
|
kv_pe,
|
||||||
kv_pe,
|
slot_mapping,
|
||||||
slot_mapping,
|
stream,
|
||||||
stream,
|
const_cast<paddle::Tensor*>(&kv_cache));
|
||||||
const_cast<paddle::Tensor*>(&kv_cache));
|
|
||||||
} else {
|
|
||||||
return DecodeDSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
|
|
||||||
meta_data,
|
|
||||||
kv_nope,
|
|
||||||
kv_pe,
|
|
||||||
slot_mapping,
|
|
||||||
stream,
|
|
||||||
const_cast<paddle::Tensor*>(&kv_cache));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case paddle::DataType::FLOAT16: {
|
case paddle::DataType::FLOAT16: {
|
||||||
if (is_prefill) {
|
return DSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
|
||||||
return PrefillDSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
|
meta_data,
|
||||||
meta_data,
|
kv_nope,
|
||||||
kv_nope,
|
kv_pe,
|
||||||
kv_pe,
|
slot_mapping,
|
||||||
slot_mapping,
|
stream,
|
||||||
stream,
|
const_cast<paddle::Tensor*>(&kv_cache));
|
||||||
const_cast<paddle::Tensor*>(&kv_cache));
|
|
||||||
} else {
|
|
||||||
return DecodeDSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
|
|
||||||
meta_data,
|
|
||||||
kv_nope,
|
|
||||||
kv_pe,
|
|
||||||
slot_mapping,
|
|
||||||
stream,
|
|
||||||
const_cast<paddle::Tensor*>(&kv_cache));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
PD_THROW("Unsupported dtype for DS MLA FP8 cache");
|
PD_THROW("Unsupported dtype for DS MLA FP8 cache");
|
||||||
@@ -464,7 +394,7 @@ PD_BUILD_STATIC_OP(ds_mla_write_cache)
|
|||||||
paddle::Optional("scale")})
|
paddle::Optional("scale")})
|
||||||
.Outputs({"kv_cache_out"})
|
.Outputs({"kv_cache_out"})
|
||||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||||
.Attrs({"cache_quant_type_str: std::string", "is_prefill: bool"})
|
.Attrs({"cache_quant_type_str: std::string"})
|
||||||
.SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel));
|
.SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel));
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(indexer_k_quant_and_cache)
|
PD_BUILD_STATIC_OP(indexer_k_quant_and_cache)
|
||||||
|
|||||||
@@ -1216,8 +1216,7 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
|||||||
const paddle::Tensor& kv_cache,
|
const paddle::Tensor& kv_cache,
|
||||||
const paddle::Tensor& slot_mapping,
|
const paddle::Tensor& slot_mapping,
|
||||||
const paddle::optional<paddle::Tensor>& scale,
|
const paddle::optional<paddle::Tensor>& scale,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str);
|
||||||
const bool is_prefill);
|
|
||||||
|
|
||||||
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
|
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
|
||||||
const paddle::Tensor& k,
|
const paddle::Tensor& k,
|
||||||
|
|||||||
@@ -344,47 +344,28 @@ class DSAAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
from fastdeploy.model_executor.ops.gpu import dsk_attn_write_cache
|
from fastdeploy.model_executor.ops.gpu import dsk_attn_write_cache
|
||||||
|
|
||||||
|
k_range = paddle.tensor(200.0)
|
||||||
|
scale = paddle.abs(compressed_kv).max() / k_range
|
||||||
|
|
||||||
|
slot_mapping = compute_slot_mapping(
|
||||||
|
forward_meta.block_tables,
|
||||||
|
forward_meta.position_ids,
|
||||||
|
forward_meta.batch_id_per_token,
|
||||||
|
self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
dsk_attn_write_cache(
|
||||||
|
compressed_kv,
|
||||||
|
k_pe,
|
||||||
|
latent_cache,
|
||||||
|
slot_mapping,
|
||||||
|
scale.cast(paddle.float32),
|
||||||
|
"fp8_ds_mla",
|
||||||
|
)
|
||||||
|
|
||||||
|
fmha_out_prefill = None
|
||||||
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
||||||
|
|
||||||
# def calc_kv_scales(self, q: paddle.Tensor, kv_c_normed: paddle.Tensor, k_pe: paddle.Tensor) -> None:
|
|
||||||
# """Optional scale calculation for MLA inputs.
|
|
||||||
|
|
||||||
# Mirrors Attention.calc_kv_scales. Not all MLA backends require this
|
|
||||||
# """
|
|
||||||
# # Use safe defaults if ranges are not present
|
|
||||||
# q_range = paddle.tensor(200.0)
|
|
||||||
# k_range = paddle.tensor(200.0)
|
|
||||||
# v_range = paddle.tensor(100.0)
|
|
||||||
|
|
||||||
# self._q_scale.copy_(paddle.abs(q).max() / q_range)
|
|
||||||
|
|
||||||
# kv_abs_max = paddle.abs(kv_c_normed).max()
|
|
||||||
# self._k_scale.copy_(kv_abs_max / k_range)
|
|
||||||
# self._v_scale.copy_(kv_abs_max / v_range)
|
|
||||||
# self._q_scale_float = self._q_scale.item()
|
|
||||||
# self._k_scale_float = self._k_scale.item()
|
|
||||||
# self._v_scale_float = self._v_scale.item()
|
|
||||||
# self.calculate_kv_scales = False
|
|
||||||
|
|
||||||
metadata.slot_mapping = compute_slot_mapping(
|
|
||||||
forward_meta.block_tables,
|
|
||||||
forward_meta.position_ids,
|
|
||||||
forward_meta.batch_id_per_token,
|
|
||||||
self.block_size,
|
|
||||||
)
|
|
||||||
k_range = paddle.tensor(200.0)
|
|
||||||
scale = paddle.abs(compressed_kv).max() / k_range
|
|
||||||
|
|
||||||
dsk_attn_write_cache(
|
|
||||||
compressed_kv,
|
|
||||||
k_pe,
|
|
||||||
latent_cache,
|
|
||||||
metadata.slot_mapping,
|
|
||||||
scale.cast(paddle.float32),
|
|
||||||
"fp8_ds_mla",
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
|
|
||||||
fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd(
|
fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd(
|
||||||
q, # q_input.contiguous(),
|
q, # q_input.contiguous(),
|
||||||
k, # kv.unsqueeze(1),
|
k, # kv.unsqueeze(1),
|
||||||
@@ -392,31 +373,10 @@ class DSAAttentionBackend(AttentionBackend):
|
|||||||
sm_scale=self.attn_softmax_scale,
|
sm_scale=self.attn_softmax_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fmha_out_prefill
|
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
# if k is None:
|
# if k is None:
|
||||||
if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time
|
if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time
|
||||||
|
|
||||||
metadata.slot_mapping = compute_slot_mapping(
|
|
||||||
forward_meta.block_tables,
|
|
||||||
forward_meta.position_ids,
|
|
||||||
forward_meta.batch_id_per_token,
|
|
||||||
self.block_size,
|
|
||||||
)
|
|
||||||
k_range = paddle.tensor(200.0)
|
|
||||||
scale = paddle.abs(compressed_kv).max() / k_range
|
|
||||||
|
|
||||||
dsk_attn_write_cache(
|
|
||||||
compressed_kv,
|
|
||||||
k_pe,
|
|
||||||
latent_cache,
|
|
||||||
metadata.slot_mapping,
|
|
||||||
scale.cast(paddle.float32),
|
|
||||||
"fp8_ds_mla",
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
|
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
|
||||||
|
|
||||||
fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache(
|
fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache(
|
||||||
@@ -438,4 +398,26 @@ class DSAAttentionBackend(AttentionBackend):
|
|||||||
None, # extra_topk_length: Optional[torch.Tensor] = None
|
None, # extra_topk_length: Optional[torch.Tensor] = None
|
||||||
)
|
)
|
||||||
|
|
||||||
return fmha_out_decode
|
if fmha_out_prefill is not None:
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
|
merge_prefill_decode_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_prefill_decode_output(
|
||||||
|
fmha_out_prefill,
|
||||||
|
fmha_out_decode,
|
||||||
|
forward_meta.seq_lens_encoder,
|
||||||
|
forward_meta.seq_lens_decoder,
|
||||||
|
forward_meta.seq_lens_this_time,
|
||||||
|
forward_meta.cu_seqlens_q,
|
||||||
|
self.num_heads * 4,
|
||||||
|
128,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fmha_out_prefill
|
||||||
|
else:
|
||||||
|
return fmha_out_decode
|
||||||
|
|
||||||
|
return fmha_out_prefill
|
||||||
|
|||||||
@@ -145,7 +145,6 @@ class TestBasicPrefill(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True, # is_prefill
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# dsk_attn_write_cache 是 in-place 操作,直接修改 kv_cache
|
# dsk_attn_write_cache 是 in-place 操作,直接修改 kv_cache
|
||||||
@@ -168,7 +167,6 @@ class TestBasicDecode(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
False, # is_prefill
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# in-place 操作验证
|
# in-place 操作验证
|
||||||
@@ -193,7 +191,6 @@ class TestSingleToken(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
@@ -213,7 +210,6 @@ class TestLargeBatch(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
@@ -235,7 +231,6 @@ class TestUnalignedTokens(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
@@ -258,7 +253,6 @@ class TestQuantTypeFp8DsMla(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla", # 主要测试的量化类型
|
"fp8_ds_mla", # 主要测试的量化类型
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
@@ -306,7 +300,6 @@ class TestWithoutScale(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
None,
|
None,
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
@@ -326,7 +319,6 @@ class TestWithoutKvSignalData(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
@@ -349,7 +341,6 @@ class TestBfloat16Input(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
@@ -370,7 +361,6 @@ class TestFloat16Input(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -396,7 +386,6 @@ class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
@@ -413,7 +402,6 @@ class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest):
|
|||||||
tensors["slot_mapping"],
|
tensors["slot_mapping"],
|
||||||
tensors["scale"],
|
tensors["scale"],
|
||||||
"fp8_ds_mla",
|
"fp8_ds_mla",
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
|
|||||||
Reference in New Issue
Block a user