Files
FastDeploy/tests/operators/test_dsmla_writecache.py
T
周周周 820eb60ec6 [Others] clean code (#6839)
Co-authored-by: “liuruian” <liuruian@baidu.com>
2026-03-14 11:09:28 +08:00

424 lines
13 KiB
Python

#!/usr/bin/env python3
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DSMLAWriteCacheKernel 算子单元测试
测试 cpp_extensions.cc 中定义的 dsk_attn_write_cache 算子:
- 参数: kv_nope, kv_pe, kv_cache, slot_mapping, seq_lens, seq_lens_decoder,
batch_id_per_token, cu_seqlens_q, block_tables, kv_signal_data(optional),
scale(optional), cache_quant_type_str, max_seq_len, is_prefill
"""
import time
import unittest
import paddle
# DS MLA 常量定义
KV_LORA_RANK = 512 # NoPE 部分维度
PE_DIM = 64 # RoPE 部分维度
BLOCK_SIZE = 16 # 每个 block 的 token 数
# FP8 entry size: 512(fp8 nope) + 16(scales) + 128(rope bf16) = 656 bytes
FP8_ENTRY_SIZE = 656
def create_test_tensors(
batch_size: int = 2,
num_tokens: int = 16,
max_num_blocks: int = 100,
is_prefill: bool = True,
dtype: str = "bfloat16",
):
"""创建测试用张量
Args:
batch_size: 批次大小
num_tokens: token 总数
max_num_blocks: 最大 block 数
is_prefill: 是否为 prefill 阶段
dtype: 输入数据类型
Returns:
dict: 包含所有测试张量的字典
"""
# 输入张量
kv_nope = paddle.randn([num_tokens, KV_LORA_RANK], dtype=dtype)
kv_pe = paddle.randn([num_tokens, PE_DIM], dtype=dtype)
# KV cache: [num_blocks, num_kv_heads=1, block_size, entry_size]
kv_cache = paddle.zeros([max_num_blocks, 1, BLOCK_SIZE, FP8_ENTRY_SIZE], dtype="uint8")
# slot_mapping: 每个 token 对应的 cache slot 位置
slot_mapping = paddle.arange(num_tokens, dtype="int64")
# seq_lens: 每个请求的序列长度
tokens_per_req = num_tokens // batch_size
seq_lens = paddle.full([batch_size], tokens_per_req, dtype="int32")
# seq_lens_decoder: decode 阶段的序列长度
if is_prefill:
seq_lens_decoder = paddle.zeros([batch_size], dtype="int32")
else:
seq_lens_decoder = paddle.full([batch_size], 1, dtype="int32")
# batch_id_per_token: 每个 token 所属的 batch
batch_id_per_token = paddle.concat([paddle.full([tokens_per_req], i, dtype="int32") for i in range(batch_size)])
# cu_seqlens_q: 累积序列长度 [0, seq1_len, seq1_len + seq2_len, ...]
cu_seqlens = [0]
for i in range(batch_size):
cu_seqlens.append(cu_seqlens[-1] + tokens_per_req)
cu_seqlens_q = paddle.to_tensor(cu_seqlens, dtype="int32")
# block_tables: [batch_size, max_blocks_per_seq]
max_blocks_per_seq = 10
block_tables = paddle.randint(0, max_num_blocks, [batch_size, max_blocks_per_seq], dtype="int32")
# scale: 量化缩放因子 (optional)
scale = paddle.ones([num_tokens, 1], dtype="float32")
return {
"kv_nope": kv_nope,
"kv_pe": kv_pe,
"kv_cache": kv_cache,
"slot_mapping": slot_mapping,
"seq_lens": seq_lens,
"seq_lens_decoder": seq_lens_decoder,
"batch_id_per_token": batch_id_per_token,
"cu_seqlens_q": cu_seqlens_q,
"block_tables": block_tables,
"scale": scale,
"max_seq_len": 4096,
"is_prefill": is_prefill,
}
class BaseDSMLAWriteCacheTest(unittest.TestCase):
"""基础测试类,包含共用的初始化和辅助方法"""
@classmethod
def setUpClass(cls):
"""测试类初始化"""
paddle.set_device("gpu")
try:
from fastdeploy.model_executor.ops.gpu import dsk_attn_write_cache
cls.dsk_attn_write_cache = dsk_attn_write_cache
cls.skip_tests = False
except ImportError as e:
cls.skip_tests = True
cls.skip_reason = f"无法导入 dsk_attn_write_cache: {e}"
def setUp(self):
"""每个测试前检查"""
if self.skip_tests:
self.skipTest(self.skip_reason)
# ==================== 基础功能测试 ====================
class TestBasicPrefill(BaseDSMLAWriteCacheTest):
"""测试基本 prefill 模式"""
def test_basic_prefill(self):
"""基本 prefill 模式测试"""
tensors = create_test_tensors(batch_size=2, num_tokens=16, is_prefill=True)
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
# dsk_attn_write_cache 是 in-place 操作,直接修改 kv_cache
# 返回值是空列表,验证 kv_cache 已被修改
self.assertIsNotNone(result)
self.assertEqual(tensors["kv_cache"].dtype, paddle.uint8)
class TestBasicDecode(BaseDSMLAWriteCacheTest):
"""测试基本 decode 模式"""
def test_basic_decode(self):
"""基本 decode 模式测试"""
tensors = create_test_tensors(batch_size=2, num_tokens=2, is_prefill=False)
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
# in-place 操作验证
self.assertIsNotNone(result)
self.assertEqual(tensors["kv_cache"].dtype, paddle.uint8)
# ==================== 边界条件测试 ====================
class TestSingleToken(BaseDSMLAWriteCacheTest):
"""测试单 token 场景"""
def test_single_token(self):
"""单 token 场景测试"""
tensors = create_test_tensors(batch_size=1, num_tokens=1)
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
self.assertIsNotNone(result)
class TestLargeBatch(BaseDSMLAWriteCacheTest):
"""测试大批次场景"""
def test_large_batch(self):
"""大批次场景测试"""
tensors = create_test_tensors(batch_size=32, num_tokens=512)
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
self.assertIsNotNone(result)
self.assertEqual(tensors["kv_cache"].dtype, paddle.uint8)
class TestUnalignedTokens(BaseDSMLAWriteCacheTest):
"""测试非对齐 token 数(非 block_size 整数倍)"""
def test_unaligned_tokens(self):
"""非对齐 token 数测试"""
# 17 tokens 不是 16 (block_size) 的整数倍
tensors = create_test_tensors(batch_size=1, num_tokens=17)
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
self.assertIsNotNone(result)
# ==================== 量化类型测试 ====================
class TestQuantTypeFp8DsMla(BaseDSMLAWriteCacheTest):
"""测试 fp8_ds_mla 量化类型"""
def test_quant_type_fp8_ds_mla(self):
"""fp8_ds_mla 量化类型测试"""
tensors = create_test_tensors(batch_size=2, num_tokens=16)
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla", # 主要测试的量化类型
)
self.assertIsNotNone(result)
class TestQuantTypeNone(BaseDSMLAWriteCacheTest):
"""测试无量化模式"""
def test_quant_type_none(self):
"""无量化模式测试"""
tensors = create_test_tensors(batch_size=2, num_tokens=16)
# 无量化时 cache 格式不同: [num_blocks, 1, block_size, kv_lora_rank + pe_dim]
tensors["kv_cache"] = paddle.zeros([100, 1, BLOCK_SIZE, KV_LORA_RANK + PE_DIM], dtype="bfloat16")
try:
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
None, # scale 在无量化时可为 None
"none",
True,
)
self.assertIsNotNone(result)
except Exception as e:
# 如果 'none' 类型不支持,跳过
self.skipTest(f"'none' quant type 可能未实现: {e}")
# ==================== 可选参数测试 ====================
class TestWithoutScale(BaseDSMLAWriteCacheTest):
"""测试不传 scale 参数"""
def test_without_scale(self):
"""不传 scale 参数测试"""
tensors = create_test_tensors(batch_size=2, num_tokens=16)
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
None,
"fp8_ds_mla",
)
self.assertIsNotNone(result)
class TestWithoutKvSignalData(BaseDSMLAWriteCacheTest):
"""测试不传 kv_signal_data 参数"""
def test_without_kv_signal_data(self):
"""不传 kv_signal_data 参数测试"""
tensors = create_test_tensors(batch_size=2, num_tokens=16)
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
self.assertIsNotNone(result)
# ==================== 数据类型测试 ====================
class TestBfloat16Input(BaseDSMLAWriteCacheTest):
"""测试 bfloat16 输入"""
def test_bfloat16_input(self):
"""bfloat16 输入测试"""
tensors = create_test_tensors(dtype="bfloat16")
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
self.assertIsNotNone(result)
class TestFloat16Input(BaseDSMLAWriteCacheTest):
"""测试 float16 输入"""
def test_float16_input(self):
"""float16 输入测试"""
tensors = create_test_tensors(dtype="float16")
try:
result = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
self.assertIsNotNone(result)
except Exception as e:
self.skipTest(f"float16 输入可能不支持: {e}")
# ==================== 性能测试 ====================
class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest):
"""DSMLAWriteCacheKernel 性能测试"""
def test_warmup_and_benchmark(self):
"""Warmup 并简单 benchmark"""
tensors = create_test_tensors(batch_size=16, num_tokens=256)
# Warmup
for _ in range(5):
_ = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
paddle.device.synchronize()
# Benchmark
num_iters = 100
start = time.perf_counter()
for _ in range(num_iters):
_ = self.dsk_attn_write_cache(
tensors["kv_nope"],
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["scale"],
"fp8_ds_mla",
)
paddle.device.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_iters * 1000
print(f"\n[Benchmark] 256 tokens, avg time: {avg_time_ms:.3f} ms")
# 性能阈值检查 (可根据实际情况调整)
self.assertLess(avg_time_ms, 10.0, "性能应在 10ms 内")
if __name__ == "__main__":
print("=" * 70)
print("DSMLAWriteCacheKernel 单元测试")
print("=" * 70)
# 运行测试
unittest.main(verbosity=2)