[Optimization] [OP] [Models] dsk del prefill mask (#7313)

* dsk del prefill mask

* dsk support 1M+ seq_len rope

* update rope tests
This commit is contained in:
AIbin
2026-04-11 19:32:27 +08:00
committed by GitHub
parent 076ab07528
commit ba01d7a823
4 changed files with 83 additions and 49 deletions
@@ -116,9 +116,10 @@ class TestFusedRotaryPositionEncoding(unittest.TestCase):
self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True)
def test_large_num_tokens(self):
self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False)
def test_exceed_max_tokens(self):
"""
测试算子支持大量 tokens(超过 65535
算子使用 2D grid,理论上可支持 65535*65535 个 tokens
"""
num_tokens, num_heads, head_size = 65537, 1, 4
num_kv_heads, rot_dim = 1, 4
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
@@ -126,8 +127,10 @@ class TestFusedRotaryPositionEncoding(unittest.TestCase):
position_ids_np = np.arange(num_tokens, dtype="int32")
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
with self.assertRaises(Exception):
self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False)
# 不应该抛出异常,算子应该能处理大量 tokens
query_out, key_out = self._run_op(
query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
)
if __name__ == "__main__":