mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
7a2e33098f
* [XPU] support speculate_pre_process * merge develop * fix codestype * fix mtp, support cu_seqlens_q_output * fix mtp, support cu_seqlens_q_output * fix test --------- Co-authored-by: lizan1999 <lizan03@baidu.com>
329 lines
13 KiB
Python
329 lines
13 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import paddle
|
|
|
|
from fastdeploy.model_executor.ops.xpu import speculate_pre_process
|
|
|
|
|
|
def speculate_pre_process_ref(
|
|
input_ids,
|
|
seq_lens,
|
|
draft_tokens,
|
|
seq_lens_encoder,
|
|
max_seq_len,
|
|
max_draft_tokens_per_batch,
|
|
real_bsz,
|
|
token_num,
|
|
):
|
|
"""
|
|
Python reference implementation for SpeculatePreProcessKernel.
|
|
|
|
Returns:
|
|
ids_remove_padding: int64[token_num]
|
|
batch_id_per_token: int32[token_num]
|
|
cu_seqlens_q: int32[real_bsz + 1]
|
|
cu_seqlens_k: int32[real_bsz + 1]
|
|
seq_lens_output: int32[real_bsz]
|
|
cu_seq_lens_q_output: int32[real_bsz + 1]
|
|
batch_id_per_token_output: int32[real_bsz * max_draft_tokens_per_batch]
|
|
real_output_token_num: int32[1]
|
|
"""
|
|
# --- Part 1: ids_remove_padding, batch_id_per_token, cu_seqlens_q/k ---
|
|
ids_remove_padding = np.zeros(token_num, dtype=np.int64)
|
|
batch_id_per_token = np.zeros(token_num, dtype=np.int32)
|
|
cu_seqlens_q = np.zeros(real_bsz + 1, dtype=np.int32)
|
|
cu_seqlens_k = np.zeros(real_bsz + 1, dtype=np.int32)
|
|
|
|
cum = 0
|
|
for bi in range(real_bsz):
|
|
cum += seq_lens[bi]
|
|
cu_seqlens_q[bi + 1] = cum
|
|
cu_seqlens_k[bi + 1] = cum
|
|
|
|
start = cum - seq_lens[bi]
|
|
for i in range(seq_lens[bi]):
|
|
tgt = start + i
|
|
if max_draft_tokens_per_batch > 0 and seq_lens_encoder[bi] <= 0:
|
|
src = bi * max_draft_tokens_per_batch + i
|
|
ids_remove_padding[tgt] = draft_tokens[src]
|
|
else:
|
|
src = bi * max_seq_len + i
|
|
ids_remove_padding[tgt] = input_ids[src]
|
|
batch_id_per_token[tgt] = bi
|
|
|
|
# --- Part 2: seq_lens_output ---
|
|
seq_lens_output = np.zeros(real_bsz, dtype=np.int32)
|
|
for bid in range(real_bsz):
|
|
if seq_lens[bid] == 0:
|
|
seq_lens_output[bid] = 0
|
|
elif seq_lens[bid] == 1:
|
|
seq_lens_output[bid] = 1
|
|
elif seq_lens_encoder[bid] != 0:
|
|
seq_lens_output[bid] = 1
|
|
else:
|
|
seq_lens_output[bid] = seq_lens[bid]
|
|
|
|
# --- Part 3: cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num ---
|
|
cu_seq_lens_q_output = np.zeros(real_bsz + 1, dtype=np.int32)
|
|
batch_id_per_token_output = np.zeros(real_bsz * max_draft_tokens_per_batch, dtype=np.int32)
|
|
|
|
cum_output = 0
|
|
for bi in range(real_bsz):
|
|
cum_output += seq_lens_output[bi]
|
|
cu_seq_lens_q_output[bi + 1] = cum_output
|
|
|
|
start_out = cum_output - seq_lens_output[bi]
|
|
for i in range(seq_lens_output[bi]):
|
|
batch_id_per_token_output[start_out + i] = bi
|
|
|
|
real_output_token_num = np.array([cum_output], dtype=np.int32)
|
|
|
|
return (
|
|
ids_remove_padding,
|
|
batch_id_per_token,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
seq_lens_output,
|
|
cu_seq_lens_q_output,
|
|
batch_id_per_token_output,
|
|
real_output_token_num,
|
|
)
|
|
|
|
|
|
def build_inputs(
|
|
real_bsz,
|
|
max_seq_len,
|
|
max_draft_tokens,
|
|
seq_lens_list,
|
|
seq_lens_encoder_list,
|
|
draft_tokens_data=None,
|
|
input_ids_data=None,
|
|
seed=42,
|
|
):
|
|
"""
|
|
Helper to build test inputs from explicit seq_lens and seq_lens_encoder lists.
|
|
draft_tokens_data and input_ids_data are optional; if None, random data is used.
|
|
"""
|
|
rng = np.random.default_rng(seed)
|
|
seq_lens = np.array(seq_lens_list, dtype=np.int32)
|
|
seq_lens_encoder = np.array(seq_lens_encoder_list, dtype=np.int32)
|
|
seq_lens_decoder = np.zeros(real_bsz, dtype=np.int32) # not used in kernel logic
|
|
|
|
token_num = int(np.sum(seq_lens))
|
|
|
|
if input_ids_data is not None:
|
|
input_ids = np.array(input_ids_data, dtype=np.int64).reshape(real_bsz, max_seq_len)
|
|
else:
|
|
input_ids = rng.integers(1, 1000, size=(real_bsz, max_seq_len), dtype=np.int64)
|
|
|
|
if draft_tokens_data is not None:
|
|
draft_tokens = np.array(draft_tokens_data, dtype=np.int64).reshape(real_bsz, max_draft_tokens)
|
|
else:
|
|
draft_tokens = rng.integers(1, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"seq_lens": seq_lens,
|
|
"draft_tokens": draft_tokens,
|
|
"seq_lens_encoder": seq_lens_encoder,
|
|
"seq_lens_decoder": seq_lens_decoder,
|
|
"max_seq_len": max_seq_len,
|
|
"max_draft_tokens": max_draft_tokens,
|
|
"token_num": token_num,
|
|
"real_bsz": real_bsz,
|
|
}
|
|
|
|
|
|
def run_and_compare(tc, inputs):
|
|
"""
|
|
Call GPU op and Python reference, compare all outputs.
|
|
tc: unittest.TestCase instance (for assertion messages).
|
|
"""
|
|
real_bsz = inputs["real_bsz"]
|
|
max_seq_len = inputs["max_seq_len"]
|
|
max_draft_tokens = inputs["max_draft_tokens"]
|
|
token_num = inputs["token_num"]
|
|
|
|
t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
|
|
t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32")
|
|
t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64")
|
|
t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32")
|
|
t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32")
|
|
|
|
gpu_outs = speculate_pre_process(
|
|
token_num, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder
|
|
)
|
|
|
|
ref_outs = speculate_pre_process_ref(
|
|
input_ids=inputs["input_ids"].reshape(-1),
|
|
seq_lens=inputs["seq_lens"],
|
|
draft_tokens=inputs["draft_tokens"].reshape(-1),
|
|
seq_lens_encoder=inputs["seq_lens_encoder"],
|
|
max_seq_len=max_seq_len,
|
|
max_draft_tokens_per_batch=max_draft_tokens,
|
|
real_bsz=real_bsz,
|
|
token_num=token_num,
|
|
)
|
|
|
|
output_names = [
|
|
"ids_remove_padding",
|
|
"batch_id_per_token",
|
|
"cu_seqlens_q",
|
|
"cu_seqlens_k",
|
|
"cu_seq_lens_q_output",
|
|
"batch_id_per_token_output",
|
|
"real_output_token_num",
|
|
]
|
|
# GPU op returns 7 tensors; ref returns 8 (with seq_lens_output at index 4).
|
|
# GPU output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k,
|
|
# cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num
|
|
# Ref output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k,
|
|
# seq_lens_output, cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num
|
|
ref_indices = [0, 1, 2, 3, 5, 6, 7] # skip seq_lens_output (index 4) for direct comparison
|
|
for name, gpu_idx, ref_idx in zip(output_names, range(7), ref_indices):
|
|
gpu_val = gpu_outs[gpu_idx].numpy()
|
|
ref_val = ref_outs[ref_idx]
|
|
# Trim batch_id_per_token_output to the valid portion (real_output_token_num)
|
|
# The kernel only writes valid positions; beyond that the content is undefined.
|
|
if name == "batch_id_per_token_output":
|
|
valid_len = int(ref_outs[7][0]) # real_output_token_num
|
|
gpu_val = gpu_val[:valid_len]
|
|
ref_val = ref_val[:valid_len]
|
|
np.testing.assert_allclose(
|
|
gpu_val,
|
|
ref_val,
|
|
err_msg=f"Mismatch in output '{name}'",
|
|
)
|
|
|
|
|
|
class TestSpeculatePreProcess(unittest.TestCase):
|
|
"""Unit tests for speculate_pre_process custom operator."""
|
|
|
|
# ----------------------------------------------------------------
|
|
# Test 1: mixed batch covering all 4 seq_lens_output branches
|
|
# bid=0: seq_lens=0 => output=0 (skip)
|
|
# bid=1: seq_lens=1, encoder=0 => output=1, read draft_tokens
|
|
# bid=2: seq_lens=5, encoder=3 => output=1, read input_ids (prefill)
|
|
# bid=3: seq_lens=4, encoder=0 => output=4, read draft_tokens (decode)
|
|
# bid=4: seq_lens=1, encoder=2 => output=1, read input_ids (prefill single)
|
|
# bid=5: seq_lens=8, encoder=0 => output=8, read draft_tokens (decode saturated)
|
|
# ----------------------------------------------------------------
|
|
def test_mixed_batch_all_branches(self):
|
|
inputs = build_inputs(
|
|
real_bsz=6,
|
|
max_seq_len=16,
|
|
max_draft_tokens=8,
|
|
seq_lens_list=[0, 1, 5, 4, 1, 8],
|
|
seq_lens_encoder_list=[0, 0, 3, 0, 2, 0],
|
|
)
|
|
run_and_compare(self, inputs)
|
|
|
|
# ----------------------------------------------------------------
|
|
# Test 2: token_num=0 early return — verify no crash, 7 outputs
|
|
# ----------------------------------------------------------------
|
|
def test_all_zero_seq_lens(self):
|
|
real_bsz = 3
|
|
t_input_ids = paddle.zeros([real_bsz, 8], dtype="int64")
|
|
t_seq_lens = paddle.zeros([real_bsz], dtype="int32")
|
|
t_draft_tokens = paddle.zeros([real_bsz, 4], dtype="int64")
|
|
t_seq_lens_encoder = paddle.zeros([real_bsz], dtype="int32")
|
|
t_seq_lens_decoder = paddle.zeros([real_bsz], dtype="int32")
|
|
|
|
gpu_outs = speculate_pre_process(
|
|
0, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder
|
|
)
|
|
self.assertEqual(len(gpu_outs), 7)
|
|
self.assertIsNotNone(gpu_outs[-3])
|
|
self.assertIsNotNone(gpu_outs[-2])
|
|
self.assertIsNotNone(gpu_outs[-1])
|
|
# test copy
|
|
fake_cu_seqlens_q_output = paddle.empty([real_bsz + 1], dtype="int32")
|
|
fake_batch_id_per_token_output = paddle.empty([real_bsz], dtype="int32")
|
|
fake_cu_seqlens_q_output.copy_(gpu_outs[-3])
|
|
fake_batch_id_per_token_output.copy_(gpu_outs[-2])
|
|
# test slice
|
|
fake_batch_id_per_token_output[: gpu_outs[-1].item()]
|
|
|
|
# ----------------------------------------------------------------
|
|
# Test 3: exact token values — manually verify ids_remove_padding
|
|
# bid=0: encoder=0 (decode) => draft_tokens[0][0:3] = [10,11,12]
|
|
# bid=1: encoder=5 (prefill) => input_ids[1][0:2] = [200,201]
|
|
# ----------------------------------------------------------------
|
|
def test_exact_token_values(self):
|
|
inputs = build_inputs(
|
|
real_bsz=2,
|
|
max_seq_len=4,
|
|
max_draft_tokens=4,
|
|
seq_lens_list=[3, 2],
|
|
seq_lens_encoder_list=[0, 5],
|
|
draft_tokens_data=[[10, 11, 12, 13], [20, 21, 22, 23]],
|
|
input_ids_data=[[100, 101, 102, 103], [200, 201, 202, 203]],
|
|
)
|
|
|
|
t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
|
|
t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32")
|
|
t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64")
|
|
t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32")
|
|
t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32")
|
|
|
|
gpu_outs = speculate_pre_process(
|
|
int(np.sum(inputs["seq_lens"])),
|
|
t_input_ids,
|
|
t_seq_lens,
|
|
t_draft_tokens,
|
|
t_seq_lens_encoder,
|
|
t_seq_lens_decoder,
|
|
)
|
|
|
|
np.testing.assert_allclose(gpu_outs[0].numpy(), [10, 11, 12, 200, 201])
|
|
np.testing.assert_allclose(gpu_outs[1].numpy(), [0, 0, 0, 1, 1])
|
|
np.testing.assert_allclose(gpu_outs[2].numpy(), [0, 3, 5])
|
|
np.testing.assert_allclose(gpu_outs[6].numpy(), [4]) # real_output_token_num = 3+1
|
|
|
|
# ----------------------------------------------------------------
|
|
# Test 4: random stress test (2 configs covering small & medium batch)
|
|
# ----------------------------------------------------------------
|
|
def test_random_configs(self):
|
|
configs = [
|
|
{"real_bsz": 7, "max_seq_len": 32, "max_draft_tokens": 8, "seed": 200},
|
|
{"real_bsz": 32, "max_seq_len": 128, "max_draft_tokens": 16, "seed": 400},
|
|
]
|
|
for cfg in configs:
|
|
with self.subTest(**cfg):
|
|
rng = np.random.default_rng(cfg["seed"])
|
|
real_bsz = cfg["real_bsz"]
|
|
max_draft = cfg["max_draft_tokens"]
|
|
seq_lens_list = rng.integers(0, max_draft + 1, size=real_bsz).tolist()
|
|
seq_lens_encoder_list = rng.integers(0, 3, size=real_bsz).tolist()
|
|
|
|
inputs = build_inputs(
|
|
real_bsz=real_bsz,
|
|
max_seq_len=cfg["max_seq_len"],
|
|
max_draft_tokens=max_draft,
|
|
seq_lens_list=seq_lens_list,
|
|
seq_lens_encoder_list=seq_lens_encoder_list,
|
|
seed=cfg["seed"],
|
|
)
|
|
if inputs["token_num"] == 0:
|
|
continue
|
|
run_and_compare(self, inputs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|