mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[Optimization] [OP] [Models] dsk del prefill mask (#7313)
* dsk del prefill mask * dsk support 1M+ seq_len rope * update rope tests
This commit is contained in:
@@ -116,9 +116,10 @@ class TestFusedRotaryPositionEncoding(unittest.TestCase):
|
||||
self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True)
|
||||
|
||||
def test_large_num_tokens(self):
|
||||
self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False)
|
||||
|
||||
def test_exceed_max_tokens(self):
|
||||
"""
|
||||
测试算子支持大量 tokens(超过 65535)
|
||||
算子使用 2D grid,理论上可支持 65535*65535 个 tokens
|
||||
"""
|
||||
num_tokens, num_heads, head_size = 65537, 1, 4
|
||||
num_kv_heads, rot_dim = 1, 4
|
||||
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
|
||||
@@ -126,8 +127,10 @@ class TestFusedRotaryPositionEncoding(unittest.TestCase):
|
||||
position_ids_np = np.arange(num_tokens, dtype="int32")
|
||||
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False)
|
||||
# 不应该抛出异常,算子应该能处理大量 tokens
|
||||
query_out, key_out = self._run_op(
|
||||
query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user