mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
ba01d7a823
* dsk del prefill mask * dsk support 1M+ seq_len rope * update rope tests
138 lines
5.6 KiB
Python
138 lines
5.6 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import unittest
|
||
|
||
import numpy as np
|
||
import paddle
|
||
|
||
from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding
|
||
|
||
|
||
class TestFusedRotaryPositionEncoding(unittest.TestCase):
|
||
def setUp(self):
|
||
paddle.set_device("gpu")
|
||
np.random.seed(42)
|
||
|
||
def _make_cos_sin_cache(self, max_position: int, rot_dim: int) -> np.ndarray:
|
||
"""Generate cos/sin cache."""
|
||
assert rot_dim % 2 == 0, "rot_dim must be even"
|
||
half_dim = rot_dim // 2
|
||
inv_freq = 1.0 / (10000 ** (np.arange(0, half_dim).astype("float32") / half_dim))
|
||
positions = np.arange(max_position, dtype="float32")
|
||
freqs = np.outer(positions, inv_freq) # [max_position, half_dim]
|
||
cos_np = np.cos(freqs)
|
||
sin_np = np.sin(freqs)
|
||
return np.concatenate([cos_np, sin_np], axis=1).astype("float32")
|
||
|
||
def _ref_rotary(self, query, key, position_ids, cos_sin_cache, head_size, is_neox):
|
||
"""Numpy reference implementation."""
|
||
num_tokens, num_heads, _ = query.shape
|
||
num_kv_heads = key.shape[1]
|
||
rot_dim = cos_sin_cache.shape[1]
|
||
embed_dim = rot_dim // 2
|
||
|
||
query_ref = query.copy()
|
||
key_ref = key.copy()
|
||
|
||
for t in range(num_tokens):
|
||
pos = position_ids[t]
|
||
cos_ptr = cos_sin_cache[pos, :embed_dim]
|
||
sin_ptr = cos_sin_cache[pos, embed_dim:]
|
||
|
||
for h in range(num_heads):
|
||
arr = query_ref[t, h]
|
||
for i in range(embed_dim):
|
||
if is_neox:
|
||
x_idx, y_idx = i, embed_dim + i
|
||
cos, sin = cos_ptr[i], sin_ptr[i]
|
||
else:
|
||
x_idx, y_idx = 2 * i, 2 * i + 1
|
||
cos, sin = cos_ptr[i], sin_ptr[i]
|
||
x, y = arr[x_idx], arr[y_idx]
|
||
arr[x_idx] = x * cos - y * sin
|
||
arr[y_idx] = y * cos + x * sin
|
||
|
||
for h in range(num_kv_heads):
|
||
arr = key_ref[t, h]
|
||
for i in range(embed_dim):
|
||
if is_neox:
|
||
x_idx, y_idx = i, embed_dim + i
|
||
cos, sin = cos_ptr[i], sin_ptr[i]
|
||
else:
|
||
x_idx, y_idx = 2 * i, 2 * i + 1
|
||
cos, sin = cos_ptr[i], sin_ptr[i]
|
||
x, y = arr[x_idx], arr[y_idx]
|
||
arr[x_idx] = x * cos - y * sin
|
||
arr[y_idx] = y * cos + x * sin
|
||
|
||
return query_ref, key_ref
|
||
|
||
def _run_op(
|
||
self,
|
||
query_np: np.ndarray,
|
||
key_np: np.ndarray,
|
||
position_ids_np: np.ndarray,
|
||
cos_sin_cache_np: np.ndarray,
|
||
head_size: int,
|
||
is_neox: bool,
|
||
):
|
||
"""Run fused_rotary_position_encoding operator."""
|
||
query = paddle.to_tensor(query_np, dtype="float32")
|
||
key = paddle.to_tensor(key_np, dtype="float32")
|
||
position_ids = paddle.to_tensor(position_ids_np, dtype="int32")
|
||
cos_sin_cache = paddle.to_tensor(cos_sin_cache_np, dtype="float32")
|
||
|
||
fused_rotary_position_encoding(query, key, position_ids, cos_sin_cache, head_size, is_neox)
|
||
return query.numpy(), key.numpy()
|
||
|
||
def _check_correctness(self, num_tokens, num_heads, num_kv_heads, head_size, rot_dim, is_neox):
|
||
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
|
||
key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
|
||
position_ids_np = np.arange(num_tokens, dtype="int32")
|
||
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
|
||
|
||
query_out, key_out = self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox)
|
||
query_ref, key_ref = self._ref_rotary(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox)
|
||
|
||
np.testing.assert_allclose(query_out, query_ref, rtol=1e-5, atol=1e-6)
|
||
np.testing.assert_allclose(key_out, key_ref, rtol=1e-5, atol=1e-6)
|
||
|
||
def test_basic_case(self):
|
||
self._check_correctness(num_tokens=4, num_heads=2, num_kv_heads=2, head_size=6, rot_dim=4, is_neox=False)
|
||
|
||
def test_neox_mode(self):
|
||
self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True)
|
||
|
||
def test_large_num_tokens(self):
|
||
"""
|
||
测试算子支持大量 tokens(超过 65535)
|
||
算子使用 2D grid,理论上可支持 65535*65535 个 tokens
|
||
"""
|
||
num_tokens, num_heads, head_size = 65537, 1, 4
|
||
num_kv_heads, rot_dim = 1, 4
|
||
query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32")
|
||
key_np = np.random.rand(num_tokens, num_kv_heads, head_size).astype("float32")
|
||
position_ids_np = np.arange(num_tokens, dtype="int32")
|
||
cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim)
|
||
|
||
# 不应该抛出异常,算子应该能处理大量 tokens
|
||
query_out, key_out = self._run_op(
|
||
query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|