Files
FastDeploy/tests/operators/test_draft_model_preprocess.py
T

278 lines
11 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import draft_model_preprocess
def draft_model_preprocess_ref(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
pre_ids,
accept_tokens,
accept_num,
target_model_seq_lens_encoder,
target_model_seq_lens_decoder,
target_model_step_idx,
target_model_stop_flags,
max_dec_len,
target_model_draft_tokens,
num_model_step,
is_splitwise_prefill,
):
"""Reference implementation for draft_model_preprocess_kernel.
MTP state is "shadow state": initialized from target model each round.
is_splitwise_prefill: on P-D prefill node, only prefill requests run MTP;
decode requests are marked stopped.
"""
real_bsz = seq_lens_this_time.shape[0]
target_model_draft_tokens_len = target_model_draft_tokens.shape[1]
not_stop_flag_sum = 0
for tid in range(real_bsz):
not_stop_flag = 0
accept_tokens_now = accept_tokens[tid]
draft_tokens_now = draft_tokens[tid]
accept_num_now = int(accept_num[tid])
input_ids_now = input_ids[tid]
target_model_draft_tokens_now = target_model_draft_tokens[tid]
pre_ids_now = pre_ids[tid]
target_step = int(target_model_step_idx[tid])
seq_len_encoder = int(seq_lens_encoder[tid])
# Clear target_model_draft_tokens (keep first token)
target_model_draft_tokens_now[1:target_model_draft_tokens_len] = -1
# ----------------------------------------------------------------
# Decision: Should MTP run?
# ----------------------------------------------------------------
should_skip = False
# Target model stopped
if bool(target_model_stop_flags[tid]):
should_skip = True
# Near end of max_dec_len in no splitwise_prefill mode
if not should_skip and not is_splitwise_prefill and target_step + num_model_step >= int(max_dec_len[tid]):
should_skip = True
# ----------------------------------------------------------------
# Execute
# ----------------------------------------------------------------
if should_skip:
stop_flags[tid] = True
seq_lens_this_time[tid] = 0
seq_lens_decoder[tid] = 0
seq_lens_encoder[tid] = 0
step_idx[tid] = 0
not_stop_flag = 0
else:
not_stop_flag = 1
stop_flags[tid] = False
if seq_len_encoder > 0:
# prefill | chunk_prefill | prompt_cache | recover after preempted
target_model_first_token = int(accept_tokens_now[0])
pre_ids_now[0] = target_model_first_token
input_ids_now[seq_len_encoder - 1] = target_model_first_token
seq_lens_this_time[tid] = seq_len_encoder
# Shadow state: prefill just finished
step_idx[tid] = target_step - 1
else:
# Decode: shadow state from target model
need_compute_token = accept_num_now
seq_lens_decoder[tid] = int(target_model_seq_lens_decoder[tid]) - need_compute_token
step_idx[tid] = target_step - need_compute_token
# Prepare draft input tokens from accepted tokens
for i in range(accept_num_now):
draft_tokens_now[i] = int(accept_tokens_now[i])
pre_id_pos = target_step - (accept_num_now - i)
pre_ids_now[pre_id_pos] = int(accept_tokens_now[i])
seq_lens_this_time[tid] = accept_num_now
not_stop_flag_sum += not_stop_flag
not_need_stop[0] = not_stop_flag_sum > 0
class TestDraftModelPreprocess(unittest.TestCase):
def _run_case(self, is_splitwise_prefill: bool):
paddle.seed(2022)
bsz = 10
draft_tokens_len = 4
input_ids_len = 100
max_draft_token = 10
draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64")
input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64")
stop_flags = paddle.zeros([bsz], dtype="bool")
seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32")
# Mix prefill (>0) and decode (0) requests
seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
not_need_stop = paddle.zeros([1], dtype="bool")
pre_ids = input_ids.clone()
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
# accept_num should not exceed draft_tokens_len to avoid out-of-bounds
accept_num = paddle.randint(1, draft_tokens_len + 1, [bsz], dtype="int32")
target_model_seq_lens_encoder = seq_lens_encoder.clone()
target_model_seq_lens_decoder = paddle.randint(max_draft_token + 1, 100, [bsz], dtype="int32")
target_model_step_idx = paddle.randint(max_draft_token + 1, 100, [bsz], dtype="int64")
target_model_stop_flags = paddle.zeros([bsz], dtype="bool")
max_dec_len = paddle.full([bsz], 200, dtype="int64") # int64 to match CUDA kernel
target_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64")
num_model_step = max_draft_token
inputs = (
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
pre_ids,
accept_tokens,
accept_num,
target_model_seq_lens_encoder,
target_model_seq_lens_decoder,
target_model_step_idx,
target_model_stop_flags,
max_dec_len,
target_model_draft_tokens,
num_model_step,
is_splitwise_prefill,
)
# inplace modify, need to clone inputs
inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs]
draft_model_preprocess_ref(*inputs)
draft_model_preprocess(*inputs_clone)
return inputs, inputs_clone
def test_decode_mode(self):
"""Normal decode mode: is_splitwise_prefill=False"""
results1, results2 = self._run_case(is_splitwise_prefill=False)
for i in range(9):
np.testing.assert_equal(
results1[i].numpy(),
results2[i].numpy(),
err_msg=f"Mismatch at output index {i} (decode mode)",
)
def test_splitwise_prefill_mode(self):
"""Splitwise prefill node: is_splitwise_prefill=True"""
results1, results2 = self._run_case(is_splitwise_prefill=True)
for i in range(9):
np.testing.assert_equal(
results1[i].numpy(),
results2[i].numpy(),
err_msg=f"Mismatch at output index {i} (splitwise prefill mode)",
)
def test_max_bsz(self):
"""bsz == kBlockSize (1024) should succeed."""
results1, results2 = self._run_case_bsz(bsz=1024, is_splitwise_prefill=False)
for i in range(9):
np.testing.assert_equal(
results1[i].numpy(),
results2[i].numpy(),
err_msg=f"Mismatch at output index {i} (bsz=1024)",
)
def test_bsz_exceeds_block_size(self):
"""bsz > kBlockSize (1024) should raise."""
with self.assertRaises(Exception):
self._run_case_bsz(bsz=1025, is_splitwise_prefill=False)
def _run_case_bsz(self, bsz: int, is_splitwise_prefill: bool):
"""Like _run_case but with a configurable bsz."""
paddle.seed(2022)
draft_tokens_len = 4
input_ids_len = 100
max_draft_token = 10
draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64")
input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64")
stop_flags = paddle.zeros([bsz], dtype="bool")
seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32")
seq_lens_encoder = paddle.zeros([bsz], dtype="int32") # all decode for simplicity
seq_lens_decoder = paddle.randint(max_draft_token + 1, 100, [bsz], dtype="int32")
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
not_need_stop = paddle.zeros([1], dtype="bool")
pre_ids = input_ids.clone()
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
# accept_num should not exceed draft_tokens_len to avoid out-of-bounds
accept_num = paddle.randint(1, draft_tokens_len + 1, [bsz], dtype="int32")
target_model_seq_lens_encoder = paddle.zeros([bsz], dtype="int32")
target_model_seq_lens_decoder = paddle.randint(max_draft_token + 1, 100, [bsz], dtype="int32")
target_model_step_idx = paddle.randint(max_draft_token + 1, 100, [bsz], dtype="int64")
target_model_stop_flags = paddle.zeros([bsz], dtype="bool")
max_dec_len = paddle.full([bsz], 200, dtype="int64") # int64 to match CUDA kernel
target_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64")
num_model_step = max_draft_token
inputs = (
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
pre_ids,
accept_tokens,
accept_num,
target_model_seq_lens_encoder,
target_model_seq_lens_decoder,
target_model_step_idx,
target_model_stop_flags,
max_dec_len,
target_model_draft_tokens,
num_model_step,
is_splitwise_prefill,
)
inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs]
draft_model_preprocess_ref(*inputs)
draft_model_preprocess(*inputs_clone)
return inputs, inputs_clone
if __name__ == "__main__":
unittest.main()