mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
Merge origin/release/2.6 and resolve worker_process conflict
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel(
|
||||
const int max_batch_size) {
|
||||
constexpr int VecSize = 4;
|
||||
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
|
||||
int startend_row_vec[4];
|
||||
int startend_row_vec[2];
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
@@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel(
|
||||
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
|
||||
|
||||
startend_row_vec[0] = this_batch_q_end;
|
||||
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
|
||||
startend_row_vec[2] = 0;
|
||||
startend_row_vec[3] = this_batch_q_end;
|
||||
// startend_row_vec[1] = cu_seqlens_q[max_batch_size];
|
||||
// startend_row_vec[2] = 0;
|
||||
startend_row_vec[1] = this_batch_q_end;
|
||||
for (int this_batch_q_idx = this_batch_q_start;
|
||||
this_batch_q_idx < this_batch_q_end;
|
||||
++this_batch_q_idx) {
|
||||
@@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel(
|
||||
: this_batch_q_idx - this_batch_q_start + kv_len -
|
||||
(this_batch_q_len);
|
||||
if (cache_k_idx <= append_mask_k_end) {
|
||||
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
|
||||
startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx);
|
||||
// 可提前跳出循环
|
||||
break;
|
||||
}
|
||||
}
|
||||
reinterpret_cast<int4*>(startend_row_indices_ptr +
|
||||
cu_seqlens_k_idx * 4)[0] =
|
||||
reinterpret_cast<int4*>(startend_row_vec)[0];
|
||||
reinterpret_cast<int2*>(startend_row_indices_ptr +
|
||||
cu_seqlens_k_idx * 2)[0] =
|
||||
reinterpret_cast<int2*>(startend_row_vec)[0];
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> get_attn_mask_q(
|
||||
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
||||
const int kv_token_num) {
|
||||
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
|
||||
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
|
||||
{1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place());
|
||||
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
|
||||
constexpr int block_size = 512;
|
||||
int grid_size = div_up(kv_token_num, block_size);
|
||||
@@ -123,7 +123,7 @@ std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
|
||||
const std::vector<int64_t>& cu_seqlens_k_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
|
||||
const int kv_token_num) {
|
||||
return {{1, 1, kv_token_num, 4}};
|
||||
return {{1, 1, kv_token_num, 2}};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_attn_mask_q)
|
||||
|
||||
@@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
int64_t* next_tokens, // [bs, tokens_per_step]
|
||||
const int* max_think_lens, // [bs]
|
||||
int* max_reply_lens, // [bs]
|
||||
int64_t* step_idx, // [bs]
|
||||
const int64_t* step_idx, // [bs]
|
||||
const int64_t* eos_token_ids, // [eos_len]
|
||||
int* limit_status, // [bs]
|
||||
int* accept_num, // [bs]
|
||||
@@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
int new_accept_num = original_accept_num;
|
||||
|
||||
// 本 step 的 token offset 对应的绝对 step
|
||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
||||
const int64_t current_base_step = step_idx[bid] + 1;
|
||||
|
||||
for (int token_offset = 0; token_offset < original_accept_num;
|
||||
token_offset++) {
|
||||
@@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
// inject_token_ids[0])
|
||||
if (status == 0 &&
|
||||
(current_step - 1) ==
|
||||
max_think_len) { // current_step - 1 是因为 speculate_verify 里
|
||||
// step_idx + 1 了
|
||||
max_think_len) { // current_step - 1 : 已输出 current_step-1
|
||||
// 个thinking token
|
||||
status = (inject_len > 0) ? 1 : done_status;
|
||||
}
|
||||
} else if (max_think_len == 0) {
|
||||
@@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 step_idx / accept_num(被截断的 token 需要回退
|
||||
// step_idx)
|
||||
const int discarded_tokens = original_accept_num - new_accept_num;
|
||||
if (discarded_tokens > 0) {
|
||||
step_idx[bid] -= discarded_tokens;
|
||||
}
|
||||
|
||||
accept_num[bid] = new_accept_num;
|
||||
limit_status[bid] = status;
|
||||
max_reply_lens[bid] = max_reply_len;
|
||||
@@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength(
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
const_cast<int*>(max_reply_lens.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
step_idx.data<int64_t>(),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int*>(limit_status.data<int>()),
|
||||
const_cast<int*>(accept_num.data<int>()),
|
||||
|
||||
@@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
const int64_t step_idx_now = step_idx[bid];
|
||||
const int64_t min_token_limit = min_tokens[bid];
|
||||
|
||||
const bool can_stop = (step_idx_now >= min_token_limit);
|
||||
const bool can_stop = (step_idx_now + accept_num >= min_token_limit);
|
||||
if (!can_stop) return;
|
||||
if (!stop_flags[bid]) {
|
||||
int accept_idx = 0;
|
||||
/*
|
||||
accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based)
|
||||
accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾
|
||||
(pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。
|
||||
为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1]
|
||||
(当前轮最后一个 token),该 token 延迟到下一轮匹配。
|
||||
循环范围:accept_num > 0 时为 [-1, accept_num-2];
|
||||
accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。
|
||||
*/
|
||||
int accept_idx = -1;
|
||||
bool is_end = false;
|
||||
// 遍历起始位置
|
||||
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
||||
|
||||
// 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾
|
||||
// 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens
|
||||
// 中的匹配。两者共享同一套从后向前匹配逻辑。
|
||||
int loop_end = (accept_num > 0) ? accept_num - 2 : -1;
|
||||
for (; accept_idx <= loop_end && !is_end; accept_idx++) {
|
||||
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("num %d < stop_seq_len %d\n",
|
||||
step_idx_now - accept_num + accept_idx + 1,
|
||||
step_idx_now + accept_idx + 1,
|
||||
stop_seq_len);
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
// 遍历一个 stop_seqs
|
||||
// 从后向前匹配 stop_seq 的每个 token
|
||||
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
||||
int64_t cur_token_idx = -1;
|
||||
|
||||
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
|
||||
if (stop_seq_len - 1 - i < accept_idx) {
|
||||
int offset = stop_seq_len - 1 - i;
|
||||
int accept_tokens_idx = accept_idx - offset;
|
||||
|
||||
if (accept_tokens_idx >= 0) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
||||
"accept_token_idx: "
|
||||
"%d\n",
|
||||
"accept_token_idx: %d\n",
|
||||
bid,
|
||||
tid,
|
||||
accept_idx,
|
||||
accept_idx - (stop_seq_len - 1 - i) - 1);
|
||||
accept_tokens_idx);
|
||||
#endif
|
||||
cur_token_idx =
|
||||
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
|
||||
cur_token_idx = accept_tokens_now[accept_tokens_idx];
|
||||
} else {
|
||||
int pre_ids_idx = step_idx_now + accept_tokens_idx;
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf(
|
||||
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
|
||||
"accept_idx:%d. "
|
||||
"pre_id_idx: %ld\n",
|
||||
"accept_idx:%d. pre_id_idx: %d\n",
|
||||
bid,
|
||||
tid,
|
||||
step_idx_now,
|
||||
accept_idx,
|
||||
step_idx_now - accept_num + accept_idx -
|
||||
(stop_seq_len - 1 - i));
|
||||
pre_ids_idx);
|
||||
#endif
|
||||
int pre_ids_idx =
|
||||
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
|
||||
// EC3
|
||||
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
||||
// 导致异常结束
|
||||
if (pre_ids_idx <= 0) {
|
||||
break;
|
||||
}
|
||||
if (pre_ids_idx < 0) break;
|
||||
cur_token_idx = pre_ids_now[pre_ids_idx];
|
||||
}
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
@@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
}
|
||||
if (is_end) {
|
||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||
printf("bid:%d end with accept_idx %d", bid, accept_idx);
|
||||
printf("bid:%d end with accept_idx %d\n", bid, accept_idx);
|
||||
#endif
|
||||
|
||||
accept_nums[bid] = accept_idx;
|
||||
accept_tokens_now[accept_idx - 1] = end_ids[0];
|
||||
// stop_flags[bid] = true;
|
||||
// accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置
|
||||
accept_nums[bid] = accept_idx + 1;
|
||||
accept_tokens_now[accept_idx] = end_ids[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||
int64_t *token_ids_all_now =
|
||||
&token_ids_all[batch_id * max_model_len + prompt_len];
|
||||
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||
int64_t base = cur_step_idx - output_len + 1;
|
||||
int64_t base = cur_step_idx - output_len;
|
||||
for (int i = 0; i < output_len; i++) {
|
||||
token_ids_all_now[base + i] = output_ids[i];
|
||||
}
|
||||
|
||||
@@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Whether to enable the decode caches requests for preallocating resource
|
||||
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
||||
|
||||
# Batched token timeout in EP
|
||||
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||
|
||||
# Max pre-fetch requests number in PD
|
||||
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
||||
|
||||
|
||||
@@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# 是否启用 decode 缓存请求以预分配资源
|
||||
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
||||
|
||||
# EP 中批处理 token 的超时时间
|
||||
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||
|
||||
# PD 中最大预取请求数量
|
||||
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
||||
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# There are two methods for splitwise deployment:
|
||||
# v0: using splitwise_scheduler or dp_scheduler (deprecated)
|
||||
# v1: using local_scheduler + router
|
||||
|
||||
# prepare environment
|
||||
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
source ${SCRIPT_DIR}/utils.sh
|
||||
|
||||
P_PORT=52400
|
||||
D_PORT=52500
|
||||
REDIS_PORT="${REDIS_PORT:-6379}"
|
||||
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
ports=(
|
||||
$P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5))
|
||||
$D_PORT $((D_PORT + 1)) $((D_PORT + 2)) $((D_PORT + 3)) $((D_PORT + 4)) $((D_PORT + 5))
|
||||
$REDIS_PORT
|
||||
)
|
||||
check_ports "${ports[@]}" || {
|
||||
echo "❌ Some ports are in use. Please release them."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# start redis
|
||||
if ! redis-cli -p ${REDIS_PORT} ping &>/dev/null; then
|
||||
echo "Redis is not running. Starting redis-server..."
|
||||
redis-server --daemonize yes --port ${REDIS_PORT}
|
||||
sleep 1
|
||||
else
|
||||
echo "Redis is already running."
|
||||
fi
|
||||
sleep 1
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port ${P_PORT} \
|
||||
--metrics-port $((P_PORT + 1)) \
|
||||
--engine-worker-queue-port $((P_PORT + 2)) \
|
||||
--cache-queue-port $((P_PORT + 3)) \
|
||||
--max-model-len 32768 \
|
||||
--num-gpu-blocks-override 1000 \
|
||||
--splitwise-role "prefill" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports $((P_PORT + 4)) \
|
||||
--pd-comm-port $((P_PORT + 5)) \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port ${REDIS_PORT} \
|
||||
--scheduler-ttl 9000 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${P_PORT}
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port ${D_PORT} \
|
||||
--metrics-port $((D_PORT + 1)) \
|
||||
--engine-worker-queue-port $((D_PORT + 2)) \
|
||||
--cache-queue-port $((D_PORT + 3)) \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "decode" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports $((D_PORT + 4)) \
|
||||
--pd-comm-port $((D_PORT + 5)) \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port ${REDIS_PORT} \
|
||||
--scheduler-ttl 9000 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${D_PORT}
|
||||
|
||||
|
||||
# send request
|
||||
sleep 10 # make sure server is registered to router
|
||||
echo "send request..."
|
||||
curl -X POST "http://0.0.0.0:${D_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
],
|
||||
"max_tokens": 20,
|
||||
"stream": false
|
||||
}'
|
||||
@@ -2009,13 +2009,13 @@ class FDConfig:
|
||||
and self.router_config
|
||||
and self.router_config.router
|
||||
):
|
||||
# For RL scenario: version.yaml will be required for models in future releases.
|
||||
# For RL scenario, version.yaml is required for models
|
||||
# Temporarily enforce use router to be enabled.
|
||||
self.model_config.read_model_version()
|
||||
|
||||
self.read_from_config()
|
||||
self.postprocess()
|
||||
self.init_cache_info()
|
||||
self.init_pd_info()
|
||||
if test_mode:
|
||||
return
|
||||
self.check()
|
||||
@@ -2348,18 +2348,17 @@ class FDConfig:
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info("=============================================================")
|
||||
|
||||
def init_cache_info(self):
|
||||
def init_pd_info(self):
|
||||
"""
|
||||
initialize cache info
|
||||
initialize info for pd deployment
|
||||
"""
|
||||
# TODO: group the splitiwse params
|
||||
# There are two methods for splitwise deployment:
|
||||
# 1. v0 splitwise_scheduler or dp_scheduler
|
||||
# 2. v1 local_scheduler + router
|
||||
# 2. v1 local_scheduler + router (optional)
|
||||
self.splitwise_version = None
|
||||
if self.scheduler_config.name in ("splitwise", "dp"):
|
||||
self.splitwise_version = "v0"
|
||||
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
|
||||
elif self.scheduler_config.name == "local":
|
||||
self.splitwise_version = "v1"
|
||||
|
||||
# the information for registering this server to router or splitwise_scheduler
|
||||
|
||||
@@ -592,10 +592,15 @@ class EngineArgs:
|
||||
raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1")
|
||||
|
||||
if self.splitwise_role != "mixed":
|
||||
if self.scheduler_name == "local" and self.router is None:
|
||||
if self.scheduler_name == "splitwise":
|
||||
raise ValueError(
|
||||
f"When using {self.splitwise_role} role and the {self.scheduler_name} "
|
||||
f"scheduler, please provide --router argument."
|
||||
"Setting scheduler_name as splitwise is not supported in pd deployment, "
|
||||
"please use router as scheduler."
|
||||
)
|
||||
if self.scheduler_name == "local" and self.router is None:
|
||||
console_logger.warning(
|
||||
f"Running {self.splitwise_role} role with {self.scheduler_name} "
|
||||
f"scheduler without --router. Router registration and request routing will be disabled."
|
||||
)
|
||||
|
||||
if not (
|
||||
|
||||
@@ -367,15 +367,6 @@ class EngineService:
|
||||
create=True,
|
||||
)
|
||||
|
||||
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.engine_forward_signal = IPCSignal(
|
||||
name="engine_forward_signal",
|
||||
array=engine_forward_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=current_suffix,
|
||||
create=True,
|
||||
)
|
||||
|
||||
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
|
||||
worker_healthy_live_recorded_time_array = np.zeros(
|
||||
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
|
||||
@@ -1037,7 +1028,16 @@ class EngineService:
|
||||
with self._pause_cond:
|
||||
self._pause_cond.wait_for(lambda: not self.is_paused)
|
||||
try:
|
||||
if self.engine_worker_queue.exist_tasks():
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
if not is_fetching:
|
||||
is_fetching = True
|
||||
get_request_pool.submit(_fetch_request)
|
||||
|
||||
else:
|
||||
if len(self.resource_manager.waiting) == 0 and (not is_fetching):
|
||||
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
|
||||
try:
|
||||
is_fetching = True
|
||||
@@ -1048,18 +1048,6 @@ class EngineService:
|
||||
break
|
||||
else:
|
||||
raise
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
|
||||
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
|
||||
# more efficient batches.
|
||||
if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
# In mixed, todo: optimze cache swap, to decouple swap from scheduler
|
||||
if self.engine_worker_queue.exist_tasks():
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
|
||||
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()
|
||||
@@ -1120,13 +1108,6 @@ class EngineService:
|
||||
elif not task.has_been_preempted_before:
|
||||
task.metrics.inference_start_time = time.time()
|
||||
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
||||
else:
|
||||
# When there are no actual tasks to schedule, send an empty task batch to EP workers.
|
||||
# This helps EP workers barrier for syncing tasks not hang.
|
||||
if self.cfg.parallel_config.enable_expert_parallel:
|
||||
self.engine_worker_queue.put_tasks(
|
||||
([], self.resource_manager.real_bsz)
|
||||
) # Empty (as idle tasks for ep)
|
||||
|
||||
# 4. Response error tasks
|
||||
if error_tasks:
|
||||
|
||||
@@ -109,7 +109,7 @@ class ExpertService:
|
||||
if envs.FD_ENABLE_RETURN_TEXT:
|
||||
self.engine.create_data_processor()
|
||||
if self.cfg.scheduler_config.name == "dp":
|
||||
self.cfg.init_cache_info()
|
||||
self.cfg.init_pd_info()
|
||||
self.engine.scheduler.start(local_data_parallel_id)
|
||||
|
||||
if ipc_signal_suffix is not None:
|
||||
@@ -122,7 +122,7 @@ class ExpertService:
|
||||
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.cfg.init_cache_info()
|
||||
self.cfg.init_pd_info()
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
|
||||
|
||||
@@ -927,6 +927,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
if (
|
||||
self.config.cache_config.enable_prefix_caching
|
||||
and self.config.scheduler_config.splitwise_role != "decode"
|
||||
and self.config.scheduler_config.splitwise_role != "prefill"
|
||||
):
|
||||
self.cache_manager.update_cache_blocks(
|
||||
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||
@@ -1374,6 +1375,11 @@ class ResourceManagerV1(ResourceManager):
|
||||
self.stop_flags[request.idx] = False
|
||||
self.requests[request.request_id] = request
|
||||
self.req_dict[request.request_id] = allocated_position
|
||||
|
||||
self.cache_manager.update_cache_blocks(
|
||||
request, self.config.cache_config.block_size, request.need_prefill_tokens
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
self._free_blocks(request)
|
||||
|
||||
@@ -111,7 +111,7 @@ class ErnieX1ToolParser(ToolParser):
|
||||
)
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
except Exception:
|
||||
@@ -182,11 +182,13 @@ class ErnieX1ToolParser(ToolParser):
|
||||
logger.debug("attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||
if diff:
|
||||
if '"}' not in delta_text:
|
||||
if diff is not None:
|
||||
if "}" not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex("}")
|
||||
diff = delta_text[:end_loc]
|
||||
if not diff:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not " "been streamed yet: %s",
|
||||
diff,
|
||||
@@ -248,15 +250,15 @@ class ErnieX1ToolParser(ToolParser):
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
if cur_arguments is None and prev_arguments is None:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
elif not cur_arguments and prev_arguments:
|
||||
elif cur_arguments is None and prev_arguments is not None:
|
||||
logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
elif cur_arguments and not prev_arguments:
|
||||
elif cur_arguments is not None and prev_arguments is None:
|
||||
function_name = current_tool_call.get("name")
|
||||
match = re.search(
|
||||
r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
|
||||
@@ -265,6 +267,19 @@ class ErnieX1ToolParser(ToolParser):
|
||||
)
|
||||
if match:
|
||||
cur_arguments_json = match.group(1)
|
||||
# When tool_call_portion is complete JSON, the regex
|
||||
# (.*) over-captures the outer closing brace of the
|
||||
# tool call object. Strip it from both
|
||||
# cur_arguments_json and delta_text, consistent with
|
||||
# the both-have-arguments branch handling.
|
||||
try:
|
||||
json.loads(tool_call_portion)
|
||||
if cur_arguments_json.endswith("}"):
|
||||
cur_arguments_json = cur_arguments_json[:-1]
|
||||
if delta_text.rstrip().endswith("}"):
|
||||
delta_text = delta_text.rstrip()[:-1]
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
|
||||
@@ -287,7 +302,7 @@ class ErnieX1ToolParser(ToolParser):
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
elif cur_arguments is not None and prev_arguments is not None:
|
||||
try:
|
||||
json.loads(tool_call_portion)
|
||||
is_complete_json = True
|
||||
|
||||
+16
-11
@@ -145,6 +145,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||
# Whether to enable the decode caches requests for preallocating resource
|
||||
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
||||
# Batched token timeout in EP
|
||||
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||
# Max pre-fetch requests number in PD
|
||||
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
||||
# Enable or disable model caching.
|
||||
@@ -210,17 +212,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""),
|
||||
# Whether to enable low latency in mixed scenario
|
||||
"FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))),
|
||||
# Whether to use phi FP8 quantization,if 1,use paddle default.
|
||||
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
|
||||
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
|
||||
# intended for training alignment. Defaults to 0 (disabled).
|
||||
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
|
||||
# Whether to use phi MOE permute,if 1,use paddle op.
|
||||
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
|
||||
# Whether to use phi rms_norm,if 1,use paddle op.
|
||||
"FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))),
|
||||
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
|
||||
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
|
||||
# Reserve output blocks for decoding requests when schedule new prefill requests
|
||||
"FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int(
|
||||
os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16")
|
||||
@@ -262,8 +253,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
|
||||
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
|
||||
),
|
||||
# train-infer consistency, used in RL
|
||||
# Whether to align RoPE and moe gate precision with training
|
||||
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
|
||||
# Whether to use phi FP8 quantization,if 1,use paddle default.
|
||||
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
|
||||
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
|
||||
# intended for training alignment. Defaults to 0 (disabled).
|
||||
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
|
||||
# Whether to use phi MOE permute,if 1,use paddle op.
|
||||
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
|
||||
# Whether to use phi rms_norm,if 1,use paddle op.
|
||||
"FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))),
|
||||
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
|
||||
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
|
||||
# Whether to enable FP8 quantization with pow2scale.
|
||||
"FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ class GraphOptWrapper:
|
||||
def __call__(self, **kwargs):
|
||||
return self.graph_opt_backend(**kwargs)
|
||||
|
||||
def clear_grpah_opt_backend(self, fd_config):
|
||||
def clear_graph_opt_backend(self, fd_config):
|
||||
""" """
|
||||
# TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights
|
||||
assert (
|
||||
|
||||
@@ -95,7 +95,7 @@ def init_flash_attn_version():
|
||||
logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.")
|
||||
|
||||
if FLASH_ATTN_VERSION is None:
|
||||
if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()):
|
||||
if sm_version == 90 and 90 in paddle.version.cuda_archs():
|
||||
FLASH_ATTN_VERSION = 3
|
||||
logger.info("The current platform supports Flash Attention V3.")
|
||||
else:
|
||||
|
||||
@@ -188,7 +188,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
|
||||
else:
|
||||
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
ffn_out,
|
||||
using_pow2_scale=not disable_ue8m0_cast,
|
||||
using_pow2_scale=not disable_ue8m0_cast or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||
using_ue8m0_scale=not disable_ue8m0_cast,
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
||||
@@ -355,7 +355,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
else:
|
||||
x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
@@ -581,7 +581,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
else:
|
||||
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
ffn_out,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0
|
||||
or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
||||
@@ -773,7 +774,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
else:
|
||||
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
|
||||
@@ -1247,7 +1247,7 @@ def python_op_fused_moe_kernel_paddle(
|
||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
|
||||
else:
|
||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x, using_pow2_scale=False, output_scale_transpose=False
|
||||
x, using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=False
|
||||
)
|
||||
x_scale = x_scale[: x.shape[0]]
|
||||
|
||||
@@ -1305,7 +1305,9 @@ def python_op_fused_moe_kernel_paddle(
|
||||
)
|
||||
else:
|
||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False
|
||||
intermediate_cache2,
|
||||
using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||
output_scale_transpose=False,
|
||||
)
|
||||
x_scale = x_scale[: x_q.shape[0]]
|
||||
|
||||
|
||||
@@ -343,7 +343,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
else:
|
||||
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||
x,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||
output_scale_transpose=True,
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
|
||||
@@ -1306,9 +1306,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class DeepSeekV3PretrainedModel(PretrainedModel):
|
||||
|
||||
@@ -701,9 +701,9 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.ernie.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
|
||||
@@ -829,9 +829,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.ernie.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||
|
||||
@@ -563,9 +563,9 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Glm4MoePretrainedModel(PretrainedModel):
|
||||
|
||||
@@ -369,3 +369,7 @@ class Glm4MTPForCausalLM(ModelForCasualLM):
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
@@ -417,9 +417,9 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.qwen2.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Qwen2PretrainedModel(PretrainedModel):
|
||||
|
||||
@@ -341,9 +341,9 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Qwen3PretrainedModel(PretrainedModel):
|
||||
|
||||
@@ -382,9 +382,9 @@ class Qwen3VLForConditionalGeneration(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Qwen3VLPretrainedModel(PretrainedModel):
|
||||
|
||||
@@ -453,9 +453,9 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class Qwen3MoePretrainedModel(PretrainedModel):
|
||||
|
||||
@@ -68,6 +68,7 @@ class RolloutModelConfig:
|
||||
routing_replay_config: str = None,
|
||||
load_choices: str = "default_v1",
|
||||
lm_head_fp32: bool = False,
|
||||
moe_gate_fp32: bool = True,
|
||||
):
|
||||
# Required parameters
|
||||
self.model = model_name_or_path
|
||||
@@ -121,6 +122,7 @@ class RolloutModelConfig:
|
||||
self.routing_replay_config = routing_replay_config
|
||||
self.load_choices = load_choices
|
||||
self.lm_head_fp32 = lm_head_fp32
|
||||
self.moe_gate_fp32 = moe_gate_fp32
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Dict, List, Optional
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledResponse
|
||||
from fastdeploy.scheduler.local_scheduler import LocalScheduler
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
|
||||
class DPLocalScheduler(LocalScheduler):
|
||||
@@ -131,19 +131,52 @@ class DPLocalScheduler(LocalScheduler):
|
||||
Returns:
|
||||
List of Request objects ready for processing
|
||||
"""
|
||||
# DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that.
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
self.scheduler_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}"
|
||||
)
|
||||
return []
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
start_batch_time = time.time()
|
||||
requests: List[Request] = []
|
||||
|
||||
with self.requests_not_empty:
|
||||
while True:
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1],
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||
0.005,
|
||||
)
|
||||
if batch_ids:
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
start_batch_time = time.time()
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
if len(requests) >= batch:
|
||||
break
|
||||
if (
|
||||
(current_prefill_tokens > max_num_batched_tokens)
|
||||
or (len(requests) >= batch)
|
||||
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
|
||||
):
|
||||
break
|
||||
|
||||
if batch_ids:
|
||||
if len(batch_ids) > 0 and len(requests) == 0:
|
||||
self.scheduler_logger.debug(
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
|
||||
)
|
||||
|
||||
if len(requests) > 0:
|
||||
self.scheduler_logger.info(
|
||||
|
||||
@@ -53,9 +53,6 @@ class InternalAdapter:
|
||||
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
|
||||
|
||||
available_block_num = self.engine.resource_manager.available_block_num()
|
||||
unhandled_request_num = self.engine.scheduler.get_unhandled_request_num()
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
unhandled_request_num = max(unhandled_request_num, len(self.engine.resource_manager.waiting))
|
||||
server_info = {
|
||||
"splitwise_role": self.cfg.scheduler_config.splitwise_role,
|
||||
"block_size": int(self.cfg.cache_config.block_size),
|
||||
@@ -65,7 +62,7 @@ class InternalAdapter:
|
||||
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
|
||||
"max_batch_size": int(available_batch_size),
|
||||
"max_input_token_num": self.cfg.model_config.max_model_len,
|
||||
"unhandled_request_num": unhandled_request_num,
|
||||
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
|
||||
"available_batch": int(self.engine.resource_manager.available_batch()),
|
||||
}
|
||||
return server_info
|
||||
|
||||
@@ -27,7 +27,7 @@ import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig
|
||||
from fastdeploy.engine.pooling_params import PoolingParams
|
||||
from fastdeploy.engine.request import ImagePosition, Request, RequestType
|
||||
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||
@@ -2110,6 +2110,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._cached_sampler_output = sampler_output
|
||||
self._cached_post_process_event = post_process_event
|
||||
else:
|
||||
if (
|
||||
self.fd_config.speculative_config.method == SpecMethod.MTP
|
||||
and hasattr(self.proposer.model, "empty_input_forward")
|
||||
and self.parallel_config.use_ep
|
||||
):
|
||||
self._execute_empty_mtp_input(self.forward_meta)
|
||||
self._cached_model_output_data = None
|
||||
self._cached_sampler_output = None
|
||||
self._cached_post_process_event = None
|
||||
@@ -2403,6 +2409,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# 5.1. Async cpy
|
||||
post_process_event = paddle.device.cuda.create_event()
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
# If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished.
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1,
|
||||
PREEMPTED_TOKEN_ID,
|
||||
sampler_output.sampled_token_ids,
|
||||
),
|
||||
sampler_output.sampled_token_ids,
|
||||
)
|
||||
# if not self.speculative_decoding:
|
||||
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
|
||||
if self.speculative_decoding:
|
||||
@@ -2676,13 +2692,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""Dynamic model loader use to clear parameters use for RL"""
|
||||
# Clear CUDAGraph
|
||||
if self.use_cudagraph:
|
||||
self.model.clear_grpah_opt_backend()
|
||||
self.model.clear_graph_opt_backend()
|
||||
# Clear parameters and Send single
|
||||
self.dynamic_weight_manager.clear_parameters(
|
||||
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||
)
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.model.clear_grpah_opt_backend()
|
||||
self.proposer.model.clear_graph_opt_backend()
|
||||
self.proposer.clear_mtp_cache()
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
@@ -2736,7 +2752,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
logger.info("GPU model runner's weight is already sleeping, no need to sleep again!")
|
||||
return
|
||||
if self.use_cudagraph:
|
||||
self.model.clear_grpah_opt_backend()
|
||||
self.model.clear_graph_opt_backend()
|
||||
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||
self.dynamic_weight_manager.clear_deepep_buffer()
|
||||
self.dynamic_weight_manager.clear_model_weight()
|
||||
|
||||
@@ -2511,7 +2511,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
"""Dynamic model loader use to clear parameters use for RL"""
|
||||
# Clear CUDAGraph
|
||||
if self.use_cudagraph:
|
||||
self.model.clear_grpah_opt_backend()
|
||||
self.model.clear_graph_opt_backend()
|
||||
# Clear parameters and Send single
|
||||
self.dynamic_weight_manager.clear_parameters(
|
||||
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||
|
||||
@@ -457,6 +457,9 @@ class PaddleDisWorkerProc:
|
||||
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
|
||||
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
||||
while True:
|
||||
# run eplb
|
||||
self._run_eplb(tp_rank)
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.ranks > 1:
|
||||
@@ -534,7 +537,7 @@ class PaddleDisWorkerProc:
|
||||
|
||||
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
||||
logger.debug(f"Rank: {self.local_rank} Detected new requests.")
|
||||
self.engine_forward_signal.value[0] = 1
|
||||
|
||||
tasks, read_finish = self.task_queue.get_tasks()
|
||||
# Only one of all tp_size client will get read_finish == True.
|
||||
if read_finish:
|
||||
@@ -543,22 +546,8 @@ class PaddleDisWorkerProc:
|
||||
self.task_queue.read_finish_flag.set(0)
|
||||
else:
|
||||
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
||||
# In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival.
|
||||
# Only EP + DP prefill should barrier for data arrival.
|
||||
# In mixed mode and decoder in D, we should not barrier to influence decoding.
|
||||
if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill":
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
|
||||
req_dicts, control_reqs = [], []
|
||||
assert (
|
||||
len(tasks) > 0
|
||||
), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}"
|
||||
# In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue.
|
||||
# tasks[0] contains two part, ([req1, ...] ,real_bsz)
|
||||
# tasks[0][0] is [req1, ...]
|
||||
# if empty batch is delived, eval(tasks[0][0]) should be False ([]),
|
||||
# if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below.
|
||||
if tasks[0][0]:
|
||||
for req_dict, bsz in tasks:
|
||||
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
|
||||
control_reqs.append(req_dict[0])
|
||||
@@ -591,12 +580,6 @@ class PaddleDisWorkerProc:
|
||||
|
||||
# Process prefill inputs
|
||||
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
|
||||
else:
|
||||
if self.scheduler_config.splitwise_role == "prefill":
|
||||
if tp_size > 1:
|
||||
# Synchronize the signal for other workers
|
||||
self._tp_barrier_wait()
|
||||
continue
|
||||
|
||||
# Let the ep group run control method synchronically
|
||||
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep:
|
||||
@@ -611,7 +594,6 @@ class PaddleDisWorkerProc:
|
||||
and not self.worker.model_runner.not_need_stop()
|
||||
):
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
self.engine_forward_signal.value[0] = 0
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
@@ -634,9 +616,6 @@ class PaddleDisWorkerProc:
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
|
||||
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
|
||||
# run eplb
|
||||
self._run_eplb(tp_rank)
|
||||
self.engine_forward_signal.value[0] = 0
|
||||
|
||||
if (
|
||||
not self.parallel_config.use_ep
|
||||
|
||||
+1
-1
@@ -47,5 +47,5 @@ aistudio_sdk
|
||||
p2pstore
|
||||
py-cpuinfo
|
||||
flashinfer-python-paddle
|
||||
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
|
||||
flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl
|
||||
transformers>=4.55.1,<5.0.0
|
||||
|
||||
@@ -214,29 +214,28 @@ def test_metrics_with_clear_and_reset():
|
||||
"""
|
||||
Test the metrics monitoring endpoint.
|
||||
"""
|
||||
pass # not stable, uncomment after bug fix
|
||||
# metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||
metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||
|
||||
# async_concurrency(n=10)
|
||||
async_concurrency(n=10)
|
||||
|
||||
# time.sleep(0.3)
|
||||
time.sleep(0.3)
|
||||
|
||||
# ===== clear_load_weight =====
|
||||
# clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight"
|
||||
# print("Calling clear_load_weight...")
|
||||
# r = requests.get(clear_url, timeout=30)
|
||||
# assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}"
|
||||
clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight"
|
||||
print("Calling clear_load_weight...")
|
||||
r = requests.get(clear_url, timeout=30)
|
||||
assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}"
|
||||
|
||||
# metrics = get_metrics_dict(metrics_url)
|
||||
# running = metrics["fastdeploy:num_requests_running"]
|
||||
# waiting = metrics["fastdeploy:num_requests_waiting"]
|
||||
metrics = get_metrics_dict(metrics_url)
|
||||
running = metrics["fastdeploy:num_requests_running"]
|
||||
waiting = metrics["fastdeploy:num_requests_waiting"]
|
||||
|
||||
# print(
|
||||
# "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
|
||||
# running,
|
||||
# "waiting:",
|
||||
# waiting,
|
||||
# )
|
||||
print(
|
||||
"ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
|
||||
running,
|
||||
"waiting:",
|
||||
waiting,
|
||||
)
|
||||
# assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight"
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,455 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Test splitwise deployment WITHOUT Router:
|
||||
# use local_scheduler, manually construct disaggregate_info,
|
||||
# send requests to both Prefill and Decode concurrently.
|
||||
# ENABLE_V1_KVCACHE_SCHEDULER=1, use rdma to transfer cache.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from utils.serving_utils import (
|
||||
FD_API_PORT,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
FD_METRICS_PORT,
|
||||
check_service_health,
|
||||
clean,
|
||||
)
|
||||
|
||||
# Ports for PD disaggregation (no router port needed)
|
||||
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
|
||||
FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623))
|
||||
|
||||
# Prefill uses base ports, Decode uses base+1
|
||||
PORTS_TO_CLEAN = [
|
||||
FD_API_PORT,
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
FD_METRICS_PORT,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
FD_CONNECTOR_PORT,
|
||||
FD_RDMA_PORT,
|
||||
FD_API_PORT + 1,
|
||||
FD_ENGINE_QUEUE_PORT + 1,
|
||||
FD_METRICS_PORT + 1,
|
||||
FD_CACHE_QUEUE_PORT + 1,
|
||||
FD_CONNECTOR_PORT + 1,
|
||||
FD_RDMA_PORT + 1,
|
||||
]
|
||||
|
||||
|
||||
def _build_disaggregate_info() -> dict:
|
||||
"""Build disaggregate_info manually, replicating Router's handle_splitwise_request logic."""
|
||||
host_ip = os.getenv("FD_HOST_IP", "127.0.0.1")
|
||||
return {
|
||||
"prefill_ip": host_ip,
|
||||
"decode_ip": host_ip,
|
||||
"prefill_connector_port": FD_CONNECTOR_PORT,
|
||||
"decode_connector_port": FD_CONNECTOR_PORT + 1,
|
||||
"decode_device_ids": ["1"],
|
||||
"decode_rdma_ports": [FD_RDMA_PORT + 1],
|
||||
"transfer_protocol": "rdma",
|
||||
"decode_tp_size": 1,
|
||||
}
|
||||
|
||||
|
||||
def _send_pd_request(payload: dict, timeout: int = 120):
|
||||
"""
|
||||
Send request to both Prefill and Decode concurrently,
|
||||
replicate Router's fan-out forwarding behavior.
|
||||
Returns the Decode response (same as Router's return_result_url_index=-1).
|
||||
"""
|
||||
disaggregate_info = _build_disaggregate_info()
|
||||
|
||||
# Inject disaggregate_info and request_id (same as Router)
|
||||
payload = payload.copy()
|
||||
payload["disaggregate_info"] = disaggregate_info
|
||||
if "request_id" not in payload:
|
||||
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||
|
||||
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/chat/completions"
|
||||
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# For streaming, use requests with stream=True for decode response
|
||||
if payload.get("stream", False):
|
||||
# Send to both concurrently (same as Router's fan-out), stream from decode
|
||||
import concurrent.futures
|
||||
|
||||
def _post_stream(url):
|
||||
return requests.post(url, headers=headers, json=payload, timeout=timeout, stream=True)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
prefill_future = executor.submit(_post_stream, prefill_url)
|
||||
decode_future = executor.submit(_post_stream, decode_url)
|
||||
# Return decode streaming response immediately
|
||||
decode_resp = decode_future.result()
|
||||
# Consume prefill response in background (don't block)
|
||||
try:
|
||||
prefill_future.result(timeout=timeout)
|
||||
except Exception:
|
||||
pass
|
||||
return decode_resp
|
||||
else:
|
||||
# Non-streaming: send to both, return decode response
|
||||
import concurrent.futures
|
||||
|
||||
def _post(url):
|
||||
return requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
prefill_future = executor.submit(_post, prefill_url)
|
||||
decode_future = executor.submit(_post, decode_url)
|
||||
# Wait for both, return decode response
|
||||
decode_resp = decode_future.result()
|
||||
# Also check prefill didn't error (but don't block on it)
|
||||
try:
|
||||
prefill_future.result(timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
return decode_resp
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_run_server():
|
||||
"""
|
||||
Pytest fixture that runs once per test session:
|
||||
- Cleans ports before tests
|
||||
- Starts Prefill and Decode instances WITHOUT Router
|
||||
- Waits for both to be healthy
|
||||
- Tears down after all tests finish
|
||||
"""
|
||||
print("Pre-test port cleanup...")
|
||||
clean(PORTS_TO_CLEAN)
|
||||
|
||||
print("log dir clean")
|
||||
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
|
||||
shutil.rmtree("log_prefill")
|
||||
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
|
||||
shutil.rmtree("log_decode")
|
||||
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||
else:
|
||||
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
|
||||
print(f"model_path: {model_path}")
|
||||
|
||||
base_log_dir = os.getenv("FD_LOG_DIR", "log")
|
||||
|
||||
# Prefill instance
|
||||
print("start prefill...")
|
||||
env_prefill = os.environ.copy()
|
||||
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
env_prefill["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_prefill")
|
||||
|
||||
prefill_log_path = "prefill.log"
|
||||
prefill_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT),
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--splitwise-role",
|
||||
"prefill",
|
||||
"--cache-transfer-protocol",
|
||||
"rdma",
|
||||
"--rdma-comm-ports",
|
||||
str(FD_RDMA_PORT),
|
||||
"--pd-comm-port",
|
||||
str(FD_CONNECTOR_PORT),
|
||||
# No --router flag
|
||||
]
|
||||
|
||||
with open(prefill_log_path, "w") as logfile:
|
||||
process_prefill = subprocess.Popen(
|
||||
prefill_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True,
|
||||
env=env_prefill,
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# Decode instance
|
||||
print("start decode...")
|
||||
env_decode = os.environ.copy()
|
||||
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
env_decode["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_decode")
|
||||
|
||||
decode_log_path = "decode.log"
|
||||
decode_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT + 1),
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT + 1),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT + 1),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT + 1),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--splitwise-role",
|
||||
"decode",
|
||||
"--cache-transfer-protocol",
|
||||
"rdma",
|
||||
"--rdma-comm-ports",
|
||||
str(FD_RDMA_PORT + 1),
|
||||
"--pd-comm-port",
|
||||
str(FD_CONNECTOR_PORT + 1),
|
||||
# No --router flag
|
||||
]
|
||||
|
||||
with open(decode_log_path, "w") as logfile:
|
||||
process_decode = subprocess.Popen(
|
||||
decode_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True,
|
||||
env=env_decode,
|
||||
)
|
||||
|
||||
# Wait up to 300 seconds for both instances to be healthy
|
||||
for _ in range(60):
|
||||
prefill_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT}")
|
||||
decode_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT + 1}")
|
||||
if prefill_healthy and decode_healthy:
|
||||
print("Prefill and decode servers are both online")
|
||||
break
|
||||
time.sleep(5)
|
||||
else:
|
||||
print("[TIMEOUT] Servers failed to start in 5 minutes. Cleaning up...")
|
||||
try:
|
||||
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||
clean(PORTS_TO_CLEAN)
|
||||
except Exception as e:
|
||||
print(f"Failed to kill process group: {e}")
|
||||
raise RuntimeError("Prefill or decode server did not start")
|
||||
|
||||
yield # Run tests
|
||||
|
||||
print("\n===== Post-test server cleanup... =====")
|
||||
try:
|
||||
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||
clean(PORTS_TO_CLEAN)
|
||||
print(f"Prefill server (pid={process_prefill.pid}) terminated")
|
||||
print(f"Decode server (pid={process_decode.pid}) terminated")
|
||||
except Exception as e:
|
||||
print(f"Failed to terminate server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_url(request):
|
||||
"""
|
||||
Returns the Decode API endpoint URL (where final responses come from).
|
||||
"""
|
||||
return f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headers():
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def get_stream_chunks(response):
|
||||
"""Parse streaming response into chunk list."""
|
||||
chunks = []
|
||||
|
||||
if response.status_code == 200:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: ") :]
|
||||
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
chunks.append(chunk)
|
||||
except Exception as e:
|
||||
print(f"Parse failed: {e}, line: {line}")
|
||||
else:
|
||||
print(f"Request failed, status: {response.status_code}")
|
||||
print("Response:", response.text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def test_chat_usage_stream(api_url):
|
||||
"""Test streaming chat with usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = _send_pd_request(payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_chat_usage_non_stream(api_url):
|
||||
"""Test non-streaming chat with usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = _send_pd_request(payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["message"]["content"]
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_stream(api_url):
|
||||
"""Test streaming completion (non-chat) with usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
# Send to /v1/completions endpoints
|
||||
disaggregate_info = _build_disaggregate_info()
|
||||
payload = payload.copy()
|
||||
payload["disaggregate_info"] = disaggregate_info
|
||||
if "request_id" not in payload:
|
||||
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||
|
||||
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions"
|
||||
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120)
|
||||
decode_future = executor.submit(
|
||||
requests.post, decode_url, json=payload, headers=headers, timeout=120, stream=True
|
||||
)
|
||||
response = decode_future.result()
|
||||
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_non_stream(api_url):
|
||||
"""Test non-streaming completion (non-chat) with usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
# Send to /v1/completions endpoints
|
||||
disaggregate_info = _build_disaggregate_info()
|
||||
payload = payload.copy()
|
||||
payload["disaggregate_info"] = disaggregate_info
|
||||
if "request_id" not in payload:
|
||||
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||
|
||||
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions"
|
||||
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120)
|
||||
decode_future = executor.submit(requests.post, decode_url, json=payload, headers=headers, timeout=120)
|
||||
response = decode_future.result().json()
|
||||
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["text"]
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
@@ -1457,9 +1457,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
task.metrics.scheduler_recv_req_time = time.time()
|
||||
|
||||
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
||||
eng.engine_worker_queue = Mock(
|
||||
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
|
||||
)
|
||||
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
|
||||
eng._send_error_response = Mock()
|
||||
|
||||
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")]))
|
||||
@@ -1493,9 +1491,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
task.metrics.scheduler_recv_req_time = time.time()
|
||||
|
||||
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
||||
eng.engine_worker_queue = Mock(
|
||||
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
|
||||
)
|
||||
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
|
||||
|
||||
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], []))
|
||||
|
||||
@@ -1526,9 +1522,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
task.metrics.scheduler_recv_req_time = time.time()
|
||||
|
||||
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
||||
eng.engine_worker_queue = Mock(
|
||||
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
|
||||
)
|
||||
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
|
||||
eng._send_error_response = Mock()
|
||||
|
||||
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)]))
|
||||
|
||||
@@ -60,6 +60,50 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
|
||||
return ErnieX1ToolParser(tokenizer=DummyTokenizer())
|
||||
|
||||
def _simulate_streaming(self, parser, deltas):
|
||||
"""Simulate a multi-step streaming flow.
|
||||
|
||||
Args:
|
||||
parser: ErnieX1ToolParser instance
|
||||
deltas: list of delta text strings, each representing one streaming step
|
||||
|
||||
Returns:
|
||||
list of results from each extract_tool_calls_streaming call
|
||||
"""
|
||||
results = []
|
||||
previous_text = ""
|
||||
token_id = 0
|
||||
previous_token_ids = []
|
||||
|
||||
for delta in deltas:
|
||||
current_text = previous_text + delta
|
||||
# When delta contains <tool_call> plus more content, use 2 tokens
|
||||
# so that the parser extracts tool_call_portion (line 163-164)
|
||||
if "<tool_call>" in delta and delta != "<tool_call>":
|
||||
n_tokens = 2
|
||||
else:
|
||||
n_tokens = 1
|
||||
|
||||
delta_token_ids = list(range(token_id + 1, token_id + 1 + n_tokens))
|
||||
token_id += n_tokens
|
||||
current_token_ids = previous_token_ids + delta_token_ids
|
||||
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
self.dummy_request,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
previous_text = current_text
|
||||
previous_token_ids = list(current_token_ids)
|
||||
|
||||
return results
|
||||
|
||||
# ==================== __init__ tests (lines 60-81) ====================
|
||||
|
||||
def test_init_sets_tokens_and_ids(self):
|
||||
@@ -116,6 +160,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
self.assertTrue(result.tools_called)
|
||||
self.assertEqual(result.tool_calls[0].function.arguments, "{}")
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(self):
|
||||
"""Cover: tool call with explicit empty arguments {}"""
|
||||
output = '<tool_call>{"name": "fn", "arguments": {}}</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertTrue(result.tools_called)
|
||||
self.assertEqual(result.tool_calls[0].function.name, "fn")
|
||||
self.assertEqual(result.tool_calls[0].function.arguments, "{}")
|
||||
|
||||
def test_extract_tool_calls_nested_arguments(self):
|
||||
"""Cover regex with nested braces in arguments"""
|
||||
output = '<tool_call>{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}</tool_call>'
|
||||
@@ -182,38 +234,24 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
def test_streaming_end_token_in_delta(self):
|
||||
"""Cover lines 149-156: </tool_call> appears in delta"""
|
||||
parser = self._new_parser()
|
||||
# First, start a tool call
|
||||
parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn"',
|
||||
[],
|
||||
[1, 10],
|
||||
[1, 10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "', # start + name + args key
|
||||
"v", # args value
|
||||
'"}}</tool_call>', # close with end token in delta
|
||||
],
|
||||
)
|
||||
# Now stream arguments
|
||||
parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
||||
', "arguments": {"k": "v',
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
self.dummy_request,
|
||||
)
|
||||
# Close with end token in delta
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}</tool_call>',
|
||||
'"}}</tool_call>',
|
||||
[1, 10, 20],
|
||||
[1, 10, 20, 2],
|
||||
[2],
|
||||
self.dummy_request,
|
||||
)
|
||||
# Should handle end token
|
||||
self.assertTrue(result is None or isinstance(result, DeltaMessage))
|
||||
# Step 1: name sent
|
||||
self.assertIsNotNone(results[0])
|
||||
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
|
||||
# Step 2: first-args branch, regex extracts '{"k": "v' as arguments_delta
|
||||
self.assertIsNotNone(results[1])
|
||||
self.assertEqual(results[1].tool_calls[0].function.arguments, '{"k": "v')
|
||||
# Step 3: end token in delta triggers close handling
|
||||
# delta before </tool_call> is '"}}', close branch: rindex('}')=2, diff='"}'
|
||||
self.assertIsNotNone(results[2])
|
||||
self.assertEqual(results[2].tool_calls[0].function.arguments, '"}')
|
||||
|
||||
# --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) ---
|
||||
|
||||
@@ -255,37 +293,29 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
def test_streaming_continue_tool_call_no_name_yet(self):
|
||||
"""Cover lines 174-176, 220-222: partial JSON without name yet"""
|
||||
parser = self._new_parser()
|
||||
# Start tool call
|
||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
||||
# Continue with partial content, no name parseable yet
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
"<tool_call>",
|
||||
'<tool_call>{"na',
|
||||
'{"na',
|
||||
[1],
|
||||
[1, 10],
|
||||
[10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
"<tool_call>", # start tool call
|
||||
'{"na', # partial content, no name yet
|
||||
],
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
self.assertIsNone(results[0])
|
||||
self.assertIsNone(results[1])
|
||||
|
||||
def test_streaming_continue_tool_call_with_name(self):
|
||||
"""Cover lines 174-176, 223-235: name becomes available"""
|
||||
parser = self._new_parser()
|
||||
# Start tool call
|
||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
||||
# Name appears
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
"<tool_call>",
|
||||
'<tool_call>{"name": "get_weather"',
|
||||
'{"name": "get_weather"',
|
||||
[1],
|
||||
[1, 10],
|
||||
[10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
"<tool_call>", # start tool call
|
||||
'{"name": "get_weather"', # name appears
|
||||
],
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||
self.assertIsNone(results[0])
|
||||
self.assertIsNotNone(results[1])
|
||||
self.assertEqual(results[1].tool_calls[0].function.name, "get_weather")
|
||||
self.assertTrue(parser.current_tool_name_sent)
|
||||
|
||||
# --- Lines 236-237: name not sent and function_name is None ---
|
||||
@@ -293,18 +323,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
def test_streaming_no_function_name(self):
|
||||
"""Cover lines 236-237: parsed JSON has no 'name' field"""
|
||||
parser = self._new_parser()
|
||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
||||
# Send JSON without name field
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
"<tool_call>",
|
||||
'<tool_call>{"arguments": {"k": "v"}}',
|
||||
'{"arguments": {"k": "v"}}',
|
||||
[1],
|
||||
[1, 10],
|
||||
[10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
"<tool_call>", # start tool call
|
||||
'{"arguments": {"k": "v"}}', # JSON without name field
|
||||
],
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
self.assertIsNone(results[1])
|
||||
|
||||
# --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) ---
|
||||
|
||||
@@ -333,9 +359,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
parser.streamed_args_for_tool = [""]
|
||||
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name":"fn","arguments":{"k":"v"}}',
|
||||
'<tool_call>{"name":"fn","arguments":{"k":"v"',
|
||||
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call>',
|
||||
'"}}</tool_call>',
|
||||
"}}</tool_call>",
|
||||
[1, 10],
|
||||
[1, 10, 2],
|
||||
[2],
|
||||
@@ -343,9 +369,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIsNotNone(result.tool_calls)
|
||||
self.assertEqual(result.tool_calls[0].function.arguments, "}")
|
||||
|
||||
def test_streaming_close_with_diff_no_end_marker(self):
|
||||
"""Cover lines 184-185: close with arguments but no '"}' in delta_text"""
|
||||
def test_streaming_text_after_completed_tool_call(self):
|
||||
"""Cover lines 143-147: text content after a completed tool call.
|
||||
|
||||
When start==end counts, prev_end==cur_end, and end_token not in delta,
|
||||
the parser treats delta as regular text content.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
parser.current_tool_id = 0
|
||||
parser.current_tool_name_sent = True
|
||||
@@ -353,7 +384,7 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
||||
# Simulate end token in delta but without '"}' pattern
|
||||
# We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta
|
||||
# so that we enter the elif at 178
|
||||
# so that we enter the text-content branch at line 143-147
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call>',
|
||||
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call> text',
|
||||
@@ -363,8 +394,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
[30],
|
||||
self.dummy_request,
|
||||
)
|
||||
# balanced counts, prev_end==cur_end, end not in delta -> returns content (line 147)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
# balanced counts, prev_end==cur_end, end not in delta -> returns content (line 149)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.content, " text")
|
||||
|
||||
def test_streaming_close_no_arguments(self):
|
||||
"""Cover lines 182-183: close branch where prev arguments is None/empty"""
|
||||
@@ -382,8 +414,126 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
[2],
|
||||
self.dummy_request,
|
||||
)
|
||||
# diff is None (no arguments), so falls through to partial_json_parser
|
||||
self.assertTrue(result is None or isinstance(result, DeltaMessage))
|
||||
# diff is None (no arguments key in prev), falls through to partial_json_parser
|
||||
# parses complete JSON, cur_args=None, prev_args=None -> no-args -> delta=None
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_streaming_close_with_empty_dict_arguments(self):
|
||||
"""Regression: close branch must handle arguments={} (empty dict).
|
||||
|
||||
Before fix, `if diff:` was False for empty dict {}, so the close
|
||||
logic was skipped. After fix, `if diff is not None:` correctly
|
||||
enters the branch.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": ', # start + name + args key
|
||||
"{}", # empty dict value
|
||||
"}", # outer close brace
|
||||
"</tool_call>", # end token
|
||||
],
|
||||
)
|
||||
# Step 1: name sent
|
||||
# Step 2: first-args, cur_args={} is not None, prev_args=None
|
||||
# Without fix: not {} == True -> no-args branch -> returns None
|
||||
# With fix: enters first-args -> streams "{}" -> DeltaMessage
|
||||
self.assertIsNotNone(results[1])
|
||||
self.assertIsNotNone(results[1].tool_calls)
|
||||
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
|
||||
|
||||
def test_streaming_empty_arguments_with_outer_brace_in_same_token(self):
|
||||
"""Regression: when arguments={} and outer } arrive in the same token '{}}',
|
||||
regex (.*) over-captures the outer brace, producing '{}}'.
|
||||
|
||||
Real production data showed arguments='{}}}' for get_default_weather
|
||||
with empty arguments. This test reproduces that exact scenario.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "get_default_weather", "arguments": ', # start + name + args key
|
||||
"{}}", # empty args + outer close brace in same token
|
||||
"</tool_call>", # end token
|
||||
],
|
||||
)
|
||||
# Step 1: name sent
|
||||
self.assertIsNotNone(results[0])
|
||||
self.assertEqual(results[0].tool_calls[0].function.name, "get_default_weather")
|
||||
# Step 2: first-args branch, tool_call_portion is complete JSON
|
||||
# regex (.*) captures '{}}' but fix strips outer '}' -> '{}'
|
||||
self.assertIsNotNone(results[1])
|
||||
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
|
||||
# Step 3: end token, close branch
|
||||
# diff = prev_arguments = {} (not None), delta_text = '' (empty after split)
|
||||
# '}' not in '' -> returns None
|
||||
self.assertIsNone(results[2])
|
||||
|
||||
def test_streaming_close_with_number_ending_arguments(self):
|
||||
"""Regression: close branch must flush remaining args ending with number.
|
||||
|
||||
Before fix, '"}' not in delta was True for numbers, causing return None.
|
||||
After fix, rindex('}') correctly finds the closing brace.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": {"count": ', # start + name + args key
|
||||
"123", # number value
|
||||
"}}</tool_call>", # close braces + end token
|
||||
],
|
||||
)
|
||||
# Step 1: name sent
|
||||
# Step 2: first-args, streams {"count": 123
|
||||
# Step 3: close branch flushes remaining "}"
|
||||
streamed_args = [
|
||||
r.tool_calls[0].function.arguments
|
||||
for r in results
|
||||
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
|
||||
]
|
||||
combined = "".join(streamed_args)
|
||||
self.assertEqual(combined, '{"count": 123}')
|
||||
|
||||
def test_streaming_close_with_boolean_ending_arguments(self):
|
||||
"""Regression: close branch must flush remaining args ending with boolean."""
|
||||
parser = self._new_parser()
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": {"flag": ', # start + args key
|
||||
"true", # boolean value
|
||||
"}}</tool_call>", # close + end token
|
||||
],
|
||||
)
|
||||
streamed_args = [
|
||||
r.tool_calls[0].function.arguments
|
||||
for r in results
|
||||
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
|
||||
]
|
||||
combined = "".join(streamed_args)
|
||||
self.assertEqual(combined, '{"flag": true}')
|
||||
|
||||
def test_streaming_close_with_nested_object_ending(self):
|
||||
"""Regression: close branch must flush remaining args ending with nested '}'."""
|
||||
parser = self._new_parser()
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": {"nested": {"a": ', # start + args key
|
||||
"1", # nested value
|
||||
"}}}</tool_call>", # close all + end token
|
||||
],
|
||||
)
|
||||
streamed_args = [
|
||||
r.tool_calls[0].function.arguments
|
||||
for r in results
|
||||
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
|
||||
]
|
||||
combined = "".join(streamed_args)
|
||||
self.assertEqual(combined, '{"nested": {"a": 1}}')
|
||||
|
||||
# --- Lines 202-206: else branch (cur_start < cur_end, edge case) ---
|
||||
|
||||
@@ -404,23 +554,21 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
def test_streaming_malformed_json(self):
|
||||
"""Cover lines 213-215: MalformedJSON from partial parser"""
|
||||
parser = self._new_parser()
|
||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
||||
# Feed badly formed content
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
"<tool_call>",
|
||||
"<tool_call>{{{",
|
||||
"{{{",
|
||||
[1],
|
||||
[1, 10],
|
||||
[10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
"<tool_call>", # start tool call
|
||||
"{{{", # badly formed content
|
||||
],
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
self.assertIsNone(results[1])
|
||||
|
||||
def test_streaming_json_decode_error(self):
|
||||
"""Cover lines 216-218: JSONDecodeError from partial parser"""
|
||||
parser = self._new_parser()
|
||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
||||
# Step 1: start tool call normally
|
||||
self._simulate_streaming(parser, ["<tool_call>"])
|
||||
# Step 2: mock partial_json_parser to throw ValueError
|
||||
with patch(
|
||||
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
|
||||
side_effect=ValueError("bad json"),
|
||||
@@ -430,8 +578,8 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
"<tool_call>bad",
|
||||
"bad",
|
||||
[1],
|
||||
[1, 10],
|
||||
[10],
|
||||
[1, 2],
|
||||
[2],
|
||||
self.dummy_request,
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
@@ -469,30 +617,17 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
def test_streaming_first_arguments_with_regex_match(self):
|
||||
"""Cover lines 243-244, 257-286: first arguments appear, regex matches"""
|
||||
parser = self._new_parser()
|
||||
# Start tool call and send name
|
||||
parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'<tool_call>{"name": "get_weather"',
|
||||
'<tool_call>{"name": "get_weather"',
|
||||
[],
|
||||
[1, 10],
|
||||
[1, 10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "get_weather", "arguments": {"location": "', # start + name + args key
|
||||
"bei", # args value
|
||||
],
|
||||
)
|
||||
# Now stream arguments (first time)
|
||||
# Key must be complete (closing quote) so partial_json_parser returns truthy arguments.
|
||||
# delta must be a substring of the regex-extracted arguments portion (after "arguments":).
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "get_weather"',
|
||||
'<tool_call>{"name": "get_weather", "arguments": {"location": "bei',
|
||||
'"bei',
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
self.dummy_request,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIsNotNone(result.tool_calls)
|
||||
# Step 1: name sent
|
||||
# Step 2: first-args, regex finds "bei" in '{"location": "bei'
|
||||
self.assertIsNotNone(results[1])
|
||||
self.assertEqual(results[1].tool_calls[0].function.arguments, '{"location": "bei')
|
||||
|
||||
def test_streaming_first_arguments_no_regex_match(self):
|
||||
"""Cover lines 266-267: regex doesn't match, fallback to json.dumps"""
|
||||
@@ -522,67 +657,119 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
self.assertIsNotNone(result.tool_calls)
|
||||
|
||||
def test_streaming_first_arguments_delta_not_in_json(self):
|
||||
"""Cover lines 271-272: delta_text not found in cur_arguments_json"""
|
||||
"""Cover lines 275-276: delta_text not found in cur_arguments_json, returns None.
|
||||
When delta contains the arguments key itself (e.g. ', "arguments": {'),
|
||||
regex extracts cur_arguments_json='{' but delta ', "arguments": {' is not in '{'.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn"',
|
||||
[],
|
||||
[1, 10],
|
||||
[1, 10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn"', # start + partial name
|
||||
', "arguments": {', # delta introduces arguments key + open brace
|
||||
],
|
||||
)
|
||||
# Delta text that doesn't appear in the arguments JSON
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
|
||||
"ZZZZZ",
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
self.dummy_request,
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
# Step 1: name sent
|
||||
self.assertIsNotNone(results[0])
|
||||
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
|
||||
# Step 2: first-args branch, regex extracts cur_arguments_json='{'
|
||||
# delta_text=', "arguments": {' is NOT in '{' -> returns None
|
||||
self.assertIsNone(results[1])
|
||||
|
||||
# --- Lines 249-251: no cur_arguments and no prev_arguments ---
|
||||
|
||||
def test_streaming_no_arguments_at_all(self):
|
||||
"""Cover lines 249-251: both cur and prev arguments are empty/None"""
|
||||
parser = self._new_parser()
|
||||
parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn"',
|
||||
[],
|
||||
[1, 10],
|
||||
[1, 10],
|
||||
self.dummy_request,
|
||||
)
|
||||
# Continue with name only, no arguments
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn"}',
|
||||
"}",
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn"', # start + name
|
||||
"}", # close JSON, no arguments
|
||||
],
|
||||
)
|
||||
# prev_arguments=None, cur_arguments=None -> delta=None
|
||||
# then prev_tool_call_arr updated and returns delta (which is None)
|
||||
self.assertIsNone(result)
|
||||
self.assertIsNone(results[1])
|
||||
|
||||
def test_streaming_empty_dict_arguments_not_skipped(self):
|
||||
"""Regression: arguments={} (empty dict) must not be treated as no arguments.
|
||||
|
||||
Empty dict is falsy in Python (`not {} == True`). Before the fix,
|
||||
this caused empty arguments to enter the no-arguments branch,
|
||||
silently dropping them during streaming.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": ', # start + name + args key
|
||||
"{}", # empty dict value
|
||||
"}", # outer close brace
|
||||
],
|
||||
)
|
||||
# Step 1: name sent
|
||||
# Step 2: cur_arguments={} (not None), prev_arguments=None
|
||||
# With fix: enters first-arguments branch -> streams "{}"
|
||||
# Without fix: not {} == True -> no-arguments branch -> delta=None
|
||||
self.assertIsNotNone(results[1])
|
||||
self.assertIsNotNone(results[1].tool_calls)
|
||||
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
|
||||
|
||||
def test_streaming_empty_dict_prev_arguments_not_reset(self):
|
||||
"""Regression: prev_arguments={} must not be treated as no arguments.
|
||||
|
||||
When prev has {} and cur has a non-empty dict, the code should enter
|
||||
the both-have-arguments branch, not the first-arguments branch.
|
||||
|
||||
This scenario (arguments growing from {} to non-empty) is hard to
|
||||
produce naturally, so we build up state through a real flow then
|
||||
verify the branch behavior with one additional call.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
# Build up state naturally: prev_tool_call_arr gets arguments={}
|
||||
self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": ', # name + args key
|
||||
"{}", # empty dict value
|
||||
"}", # outer close
|
||||
],
|
||||
)
|
||||
# Verify state is correct
|
||||
self.assertEqual(parser.prev_tool_call_arr[0].get("arguments"), {})
|
||||
|
||||
# Now test: if more argument data arrives, prev_args={} should be
|
||||
# treated as "not None" -> enters both-have-arguments branch
|
||||
# Without fix: not {} == True -> first-arguments branch (wrong)
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "val',
|
||||
"al",
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[4],
|
||||
self.dummy_request,
|
||||
)
|
||||
# both-have-arguments branch: delta_text="al" streamed as arguments
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.tool_calls[0].function.arguments, "al")
|
||||
|
||||
# --- Lines 253-255: cur_arguments reset (impossible branch) ---
|
||||
|
||||
def test_streaming_arguments_reset_mid_call(self):
|
||||
"""Cover lines 253-255: prev has arguments but cur doesn't (impossible case)"""
|
||||
"""Cover lines 253-255: prev has arguments but cur doesn't (impossible case).
|
||||
|
||||
This is an edge case that shouldn't happen in normal flow, but tests
|
||||
defensive handling when partial parser returns no arguments after
|
||||
previously having them.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
parser.current_tool_id = 0
|
||||
parser.current_tool_name_sent = True
|
||||
parser.streamed_args_for_tool = [""]
|
||||
# Simulate state where prev already had arguments
|
||||
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
||||
# Feed content where cur has no arguments but prev does
|
||||
# Mock parser to return no arguments (simulating the impossible reset)
|
||||
with patch(
|
||||
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
|
||||
return_value={"name": "fn"},
|
||||
@@ -591,9 +778,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
|
||||
'"}',
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
[1, 2],
|
||||
[1, 2, 3],
|
||||
[3],
|
||||
self.dummy_request,
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
@@ -603,110 +790,48 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
def test_streaming_incremental_arguments_incomplete(self):
|
||||
"""Cover lines 288-314: both prev and cur have arguments, JSON incomplete"""
|
||||
parser = self._new_parser()
|
||||
parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn"',
|
||||
[],
|
||||
[1, 10],
|
||||
[1, 10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v', # start + name + first args
|
||||
"a", # establishes prev_args
|
||||
"l", # incremental: both-have-args
|
||||
],
|
||||
)
|
||||
# First arguments - delta must appear in regex-extracted arguments portion
|
||||
parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
||||
'{"k": "v',
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
self.dummy_request,
|
||||
)
|
||||
# More argument tokens (both prev and cur have arguments now)
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "val',
|
||||
"al",
|
||||
[1, 10, 20],
|
||||
[1, 10, 20, 30],
|
||||
[30],
|
||||
self.dummy_request,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.tool_calls[0].function.arguments, "al")
|
||||
# Step 1: name sent
|
||||
# Step 2: first-args branch
|
||||
# Step 3: both-have-args branch, streams "l"
|
||||
self.assertIsNotNone(results[2])
|
||||
self.assertEqual(results[2].tool_calls[0].function.arguments, "l")
|
||||
|
||||
def test_streaming_incremental_arguments_complete_json(self):
|
||||
"""Cover lines 289-305: complete JSON with trailing }"""
|
||||
parser = self._new_parser()
|
||||
parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn"',
|
||||
[],
|
||||
[1, 10],
|
||||
[1, 10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v', # start + name + first args
|
||||
"a", # establishes prev_args
|
||||
'"}}', # completes JSON
|
||||
],
|
||||
)
|
||||
# First arguments - delta must appear in regex-extracted arguments portion
|
||||
parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
||||
'{"k": "v',
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
self.dummy_request,
|
||||
)
|
||||
# Complete with closing braces - both prev and cur have arguments
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
|
||||
'"}}',
|
||||
[1, 10, 20],
|
||||
[1, 10, 20, 30],
|
||||
[30],
|
||||
self.dummy_request,
|
||||
)
|
||||
# is_complete_json=True, delta ends with }, should strip trailing }
|
||||
# After strip: '"' which is not empty, so returns DeltaMessage
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
# Step 3: both-have-args, complete JSON, strips trailing } -> streams '"}'
|
||||
self.assertIsNotNone(results[2])
|
||||
self.assertIsInstance(results[2], DeltaMessage)
|
||||
|
||||
def test_streaming_incremental_arguments_complete_empty_delta(self):
|
||||
"""Cover lines 304-305: complete JSON where delta becomes empty after strip"""
|
||||
parser = self._new_parser()
|
||||
parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn"',
|
||||
[],
|
||||
[1, 10],
|
||||
[1, 10],
|
||||
self.dummy_request,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"', # start + name + first args
|
||||
"}", # inner close (establishes prev_args)
|
||||
"}", # outer close: both-have-args, complete, delta stripped to ""
|
||||
],
|
||||
)
|
||||
# First arguments with proper delta
|
||||
parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn"',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
|
||||
'{"k": "v"}',
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
self.dummy_request,
|
||||
)
|
||||
# Send just the outer closing brace
|
||||
# tool_call_portion becomes complete JSON, delta="}" stripped to "" -> return None
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
|
||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
|
||||
"}",
|
||||
[1, 10, 20],
|
||||
[1, 10, 20, 30],
|
||||
[30],
|
||||
self.dummy_request,
|
||||
)
|
||||
# is_complete_json=True, delta="}" -> stripped to "" -> return None
|
||||
self.assertIsNone(result)
|
||||
# Step 3: is_complete_json=True, delta="}" -> stripped to "" -> return None
|
||||
self.assertIsNone(results[2])
|
||||
|
||||
# --- Lines 316-319: prev_tool_call_arr update branches ---
|
||||
|
||||
@@ -759,95 +884,71 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
||||
def test_streaming_full_flow(self):
|
||||
"""Integration test: simulate a full streaming tool call flow"""
|
||||
parser = self._new_parser()
|
||||
req = self.dummy_request
|
||||
|
||||
# Step 1: text before tool call
|
||||
r = parser.extract_tool_calls_streaming("", "thinking", "thinking", [], [], [], req)
|
||||
self.assertEqual(r.content, "thinking")
|
||||
|
||||
# Step 2: tool_call start token
|
||||
r = parser.extract_tool_calls_streaming("thinking", "thinking<tool_call>", "<tool_call>", [], [1], [1], req)
|
||||
self.assertIsNone(r)
|
||||
|
||||
# Step 3: function name appears
|
||||
r = parser.extract_tool_calls_streaming(
|
||||
"thinking<tool_call>",
|
||||
'thinking<tool_call>{"name": "search"',
|
||||
'{"name": "search"',
|
||||
[1],
|
||||
[1, 10],
|
||||
[10],
|
||||
req,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
"thinking", # Step 1: text before tool call
|
||||
"<tool_call>", # Step 2: tool_call start token
|
||||
'{"name": "search", "arguments": {"query": "', # Step 3: name + args key
|
||||
"test", # Step 4: args value
|
||||
" data", # Step 5: more args
|
||||
],
|
||||
)
|
||||
self.assertIsNotNone(r)
|
||||
self.assertEqual(r.tool_calls[0].function.name, "search")
|
||||
|
||||
# Step 4: arguments start - delta must appear in regex-extracted arguments portion
|
||||
r = parser.extract_tool_calls_streaming(
|
||||
'thinking<tool_call>{"name": "search"',
|
||||
'thinking<tool_call>{"name": "search", "arguments": {"query": "test',
|
||||
'{"query": "test',
|
||||
[1, 10],
|
||||
[1, 10, 20],
|
||||
[20],
|
||||
req,
|
||||
)
|
||||
self.assertIsNotNone(r)
|
||||
|
||||
# Step 1: plain text
|
||||
self.assertEqual(results[0].content, "thinking")
|
||||
# Step 2: start token -> None
|
||||
self.assertIsNone(results[1])
|
||||
# Step 3: name sent
|
||||
self.assertIsNotNone(results[2])
|
||||
self.assertEqual(results[2].tool_calls[0].function.name, "search")
|
||||
# Step 4: first arguments
|
||||
self.assertIsNotNone(results[3])
|
||||
self.assertEqual(results[3].tool_calls[0].function.arguments, '{"query": "test')
|
||||
# Step 5: more arguments
|
||||
r = parser.extract_tool_calls_streaming(
|
||||
'thinking<tool_call>{"name": "search", "arguments": {"query": "test',
|
||||
'thinking<tool_call>{"name": "search", "arguments": {"query": "test data',
|
||||
" data",
|
||||
[1, 10, 20],
|
||||
[1, 10, 20, 30],
|
||||
[30],
|
||||
req,
|
||||
self.assertIsNotNone(results[4])
|
||||
self.assertEqual(results[4].tool_calls[0].function.arguments, " data")
|
||||
|
||||
def test_streaming_empty_arguments_full_flow(self):
|
||||
"""Integration: streaming tool call with arguments={} must not lose arguments.
|
||||
|
||||
Simulates a complete streaming flow where the tool call has empty
|
||||
arguments. Verifies the name is sent and arguments are streamed.
|
||||
"""
|
||||
parser = self._new_parser()
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn", "arguments": ', # Step 1: start + name + args key
|
||||
"{}", # Step 2: empty dict value
|
||||
"}", # Step 3: outer close
|
||||
"</tool_call>", # Step 4: end token
|
||||
],
|
||||
)
|
||||
self.assertIsNotNone(r)
|
||||
self.assertEqual(r.tool_calls[0].function.arguments, " data")
|
||||
# Step 1: name sent
|
||||
self.assertIsNotNone(results[0])
|
||||
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
|
||||
# Step 2: first-args with cur_args={}, streams "{}"
|
||||
self.assertIsNotNone(results[1])
|
||||
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
|
||||
# Step 4: close branch, delta_text="" after stripping </tool_call>
|
||||
# diff={} is not None, but "}" not in "" -> return None
|
||||
self.assertIsNone(results[2])
|
||||
self.assertIsNone(results[3])
|
||||
|
||||
def test_streaming_multiple_tool_calls(self):
|
||||
"""Integration test: two tool calls in one response"""
|
||||
parser = self._new_parser()
|
||||
req = self.dummy_request
|
||||
|
||||
# First tool call
|
||||
parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'<tool_call>{"name": "fn1"',
|
||||
'<tool_call>{"name": "fn1"',
|
||||
[],
|
||||
[1, 10],
|
||||
[1, 10],
|
||||
req,
|
||||
)
|
||||
self.assertEqual(parser.current_tool_id, 0)
|
||||
|
||||
# Close first tool
|
||||
parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn1"',
|
||||
'<tool_call>{"name": "fn1"}</tool_call>',
|
||||
"}</tool_call>",
|
||||
[1, 10],
|
||||
[1, 10, 2],
|
||||
[2],
|
||||
req,
|
||||
)
|
||||
|
||||
# Second tool call
|
||||
r = parser.extract_tool_calls_streaming(
|
||||
'<tool_call>{"name": "fn1"}</tool_call>',
|
||||
'<tool_call>{"name": "fn1"}</tool_call><tool_call>{"name": "fn2"',
|
||||
'<tool_call>{"name": "fn2"',
|
||||
[1, 10, 2],
|
||||
[1, 10, 2, 1, 20],
|
||||
[1, 20],
|
||||
req,
|
||||
results = self._simulate_streaming(
|
||||
parser,
|
||||
[
|
||||
'<tool_call>{"name": "fn1"', # First tool: start + name
|
||||
"}</tool_call>", # Close first tool
|
||||
'<tool_call>{"name": "fn2"', # Second tool: start + name
|
||||
],
|
||||
)
|
||||
self.assertEqual(parser.current_tool_id, 1)
|
||||
self.assertIsNotNone(r)
|
||||
self.assertEqual(r.tool_calls[0].function.name, "fn2")
|
||||
self.assertIsNotNone(results[2])
|
||||
self.assertEqual(results[2].tool_calls[0].function.name, "fn2")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -91,10 +91,10 @@ class TestModel1(paddle.nn.Layer):
|
||||
|
||||
return sublayer2_output
|
||||
|
||||
def clear_grpah_opt_backend(self):
|
||||
def clear_graph_opt_backend(self):
|
||||
""" """
|
||||
self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||
self.sublayer1.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
self.sublayer2.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||
|
||||
|
||||
class TestCUDAGrpahRecapture(unittest.TestCase):
|
||||
@@ -152,7 +152,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
|
||||
|
||||
# Destroy
|
||||
print_gpu_memory_use("before destroy", 0)
|
||||
self.test_model1.clear_grpah_opt_backend()
|
||||
self.test_model1.clear_graph_opt_backend()
|
||||
print_gpu_memory_use("after destroy", 0)
|
||||
|
||||
def recapture_and_replay(self, input_tensor1, forward_meta1):
|
||||
@@ -168,7 +168,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
|
||||
|
||||
# Destroy
|
||||
print_gpu_memory_use("before destroy", 0)
|
||||
self.test_model1.clear_grpah_opt_backend()
|
||||
self.test_model1.clear_graph_opt_backend()
|
||||
print_gpu_memory_use("after destroy", 0)
|
||||
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
self._fdconfig_patches = [
|
||||
patch.object(FDConfig, "read_from_config", return_value=None),
|
||||
patch.object(FDConfig, "postprocess", return_value=None),
|
||||
patch.object(FDConfig, "init_cache_info", return_value=None),
|
||||
patch.object(FDConfig, "init_pd_info", return_value=None),
|
||||
patch.object(FDConfig, "check", return_value=None),
|
||||
]
|
||||
for patcher in self._fdconfig_patches:
|
||||
|
||||
@@ -42,7 +42,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return paddle_inputs
|
||||
|
||||
|
||||
def run_kernel(paddle_inputs, inputs):
|
||||
def run_kernel(paddle_inputs):
|
||||
"""Call the CUDA kernel."""
|
||||
speculate_set_stop_value_multi_seqs(
|
||||
paddle_inputs["accept_tokens"],
|
||||
@@ -137,7 +137,18 @@ def gen_inputs(
|
||||
|
||||
|
||||
def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Python reference — must match CUDA kernel logic exactly."""
|
||||
"""Python reference — must match CUDA kernel logic exactly.
|
||||
|
||||
token_ids_all 布局 (新 step_idx 语义):
|
||||
pre_ids_now[k] = 第 k 个 output token (k >= 0, 0-indexed)
|
||||
最后一个 output token 在 pre_ids_now[step_idx - 1]
|
||||
step_idx = 历史已生成的 token 数量
|
||||
|
||||
核心设计:
|
||||
1. accept_idx 从 -1 开始,-1 表示检查 pre_ids 末尾(上一轮延迟的情况)
|
||||
2. 主循环检查 accept_idx <= accept_num - 2
|
||||
3. 匹配成功时: 保留 stop_seq 所有 token,在其后追加 eos
|
||||
"""
|
||||
accept_tokens = inputs["accept_tokens"].copy()
|
||||
accept_num = inputs["accept_num"].copy()
|
||||
stop_flags = inputs["stop_flags"].copy()
|
||||
@@ -166,27 +177,36 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
|
||||
step_idx_now = int(step_idx[bid])
|
||||
min_token_limit = int(min_tokens[bid])
|
||||
|
||||
can_stop = step_idx_now >= min_token_limit
|
||||
can_stop = step_idx_now + an >= min_token_limit
|
||||
if not can_stop:
|
||||
continue
|
||||
if stop_flags[bid]:
|
||||
continue
|
||||
|
||||
accept_idx = 0
|
||||
# CUDA kernel: accept_idx 从 -1 开始,检查 pre_ids 末尾
|
||||
accept_idx = -1
|
||||
is_end = False
|
||||
while accept_idx <= an - 1 and not is_end:
|
||||
|
||||
# loop_end = accept_num > 0 ? accept_num - 2 : -1
|
||||
loop_end = an - 2 if an > 0 else -1
|
||||
while accept_idx <= loop_end and not is_end:
|
||||
if step_idx_now + accept_idx + 1 < stop_seq_len:
|
||||
accept_idx += 1
|
||||
continue
|
||||
|
||||
# Check one stop_seq match
|
||||
# 从后向前匹配 stop_seq 的每个 token
|
||||
for i in range(stop_seq_len - 1, -1, -1):
|
||||
offset = stop_seq_len - 1 - i
|
||||
accept_tokens_idx = accept_idx - offset
|
||||
cur_token_idx = -1
|
||||
if stop_seq_len - 1 - i < accept_idx:
|
||||
cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]
|
||||
|
||||
if accept_tokens_idx >= 0:
|
||||
cur_token_idx = accept_tokens_now[accept_tokens_idx]
|
||||
else:
|
||||
pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i)
|
||||
if pre_ids_idx <= 0:
|
||||
# 新语义: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
# pre_ids_now[0] 是第 1 个 output token
|
||||
pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
if pre_ids_idx < 0:
|
||||
break
|
||||
cur_token_idx = pre_ids_now[pre_ids_idx]
|
||||
|
||||
@@ -199,9 +219,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
|
||||
accept_idx += 1
|
||||
|
||||
if is_end:
|
||||
accept_num[bid] = accept_idx
|
||||
accept_tokens[bid, accept_idx - 1] = end_ids[0]
|
||||
# stop_flags[bid] = True # kernel no longer sets stop_flags
|
||||
# accept_idx 已递增,指向 stop_seq 最后 token 的下一个位置
|
||||
# 保留 stop_seq 所有 token,在其后追加 eos
|
||||
accept_num[bid] = accept_idx + 1
|
||||
accept_tokens[bid, accept_idx] = end_ids[0]
|
||||
|
||||
return {
|
||||
"accept_tokens": accept_tokens,
|
||||
@@ -239,7 +260,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
|
||||
def _run_and_get(self, inputs):
|
||||
paddle_inputs = to_paddle_inputs(inputs)
|
||||
run_kernel(paddle_inputs, inputs)
|
||||
run_kernel(paddle_inputs)
|
||||
return get_outputs(paddle_inputs)
|
||||
|
||||
def _check_all_outputs(self, inputs, outputs):
|
||||
@@ -264,7 +285,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
self._run_full_test(test_cfg)
|
||||
|
||||
def test_match_in_accept_tokens_only(self):
|
||||
"""Stop seq found entirely within accept_tokens."""
|
||||
"""Stop seq found entirely within accept_tokens. Eos appended after stop_seq last token."""
|
||||
inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10)
|
||||
# Place stop seq [A, B, C] at accept_tokens positions [0,1,2]
|
||||
inputs["accept_num"][:] = 4
|
||||
@@ -276,9 +297,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# stop_seq [10, 20, 30] matches at accept_idx=2 (window ends at accept_tokens[2]=30)
|
||||
# After loop increment, accept_idx=3, accept_num=4, eos appended at accept_tokens[3]
|
||||
self.assertEqual(outputs["accept_num"][0], 4)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq
|
||||
|
||||
def test_match_spanning_pre_ids_and_accept(self):
|
||||
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens."""
|
||||
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens. Eos appended after stop_seq last token."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
@@ -290,12 +315,15 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 6
|
||||
inputs["accept_num"][:] = 3
|
||||
# Kernel matching at accept_idx=2 (3rd token, 0-indexed):
|
||||
# i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1]
|
||||
# i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0]
|
||||
# i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6]
|
||||
# So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]]
|
||||
inputs["token_ids_all"][0, 6] = 99
|
||||
# stop_seq = [99, 11, 22] (len=3)
|
||||
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
# pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||
# step_idx = 6 表示有 6 个历史 output token,在 pre_ids_now[0..5]
|
||||
# At accept_idx=1 (window ends at accept_tokens[1]=22):
|
||||
# i=2: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=22 vs stop_seq[2]=22 ✓
|
||||
# i=1: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=11 vs stop_seq[1]=11 ✓
|
||||
# i=0: offset=2, accept_tokens_idx=-1 -> pre_ids_idx=6+(-1)=5 -> pre_ids[5]=99 vs stop_seq[0]=99 ✓
|
||||
inputs["token_ids_all"][0, 5] = 99 # pre_ids_now[5] = 第 6 个 output token (0-indexed)
|
||||
inputs["accept_tokens"][0, :3] = [11, 22, 33]
|
||||
inputs["stop_seqs"][0, 0, :3] = [99, 11, 22]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
@@ -303,12 +331,14 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# Match at accept_idx=2, loop increments to 3
|
||||
# Match at accept_idx=1, loop increments to 2 -> accept_num=3, eos at accept_tokens[2]
|
||||
self.assertEqual(outputs["accept_num"][0], 3)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 2], -1)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq
|
||||
|
||||
def test_match_in_pre_ids_only(self):
|
||||
"""Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
|
||||
def test_match_in_pre_ids_only_not_detected(self):
|
||||
"""Stop seq ending purely in pre_ids history but NOT at the end position.
|
||||
The kernel only detects stop_seq at the very end of pre_ids via accept_idx=-1 check.
|
||||
Stop seq placed earlier in pre_ids should not be detected."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
@@ -320,15 +350,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 8
|
||||
inputs["accept_num"][:] = 3
|
||||
# pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70
|
||||
# stop_seq = [50, 60, 70], all 3 tokens are in pre_ids
|
||||
# For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check
|
||||
# i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70
|
||||
# i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60
|
||||
# i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50
|
||||
inputs["token_ids_all"][0, 6] = 50
|
||||
inputs["token_ids_all"][0, 7] = 60
|
||||
inputs["token_ids_all"][0, 8] = 70
|
||||
# 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||
# step_idx = 8 表示有 8 个历史 output token,在 pre_ids_now[0..7]
|
||||
# accept_idx=-1 会检查 pre_ids_now[7] 开始的 stop_seq
|
||||
# 把 stop_seq 放在 pre_ids_now[2,3,4] - 不会被检测到
|
||||
inputs["token_ids_all"][0, 2] = 50
|
||||
inputs["token_ids_all"][0, 3] = 60
|
||||
inputs["token_ids_all"][0, 4] = 70
|
||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
@@ -336,7 +364,8 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
self.assertEqual(outputs["accept_num"][0], 1)
|
||||
# No match: stop_seq is in pre_ids but not at the end, accept_num unchanged
|
||||
self.assertEqual(outputs["accept_num"][0], 3)
|
||||
|
||||
def test_already_stopped(self):
|
||||
"""Kernel skips sequences with stop_flags=True."""
|
||||
@@ -351,7 +380,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"])
|
||||
|
||||
def test_min_tokens_blocks_stop(self):
|
||||
"""Kernel skips stop check when step_idx < min_tokens."""
|
||||
"""Kernel skips stop check when step_idx + accept_num < min_tokens."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
@@ -363,20 +392,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 8
|
||||
inputs["accept_num"][:] = 3
|
||||
# Same setup that would match (like test_match_in_pre_ids_only)
|
||||
inputs["token_ids_all"][0, 6] = 50
|
||||
inputs["token_ids_all"][0, 7] = 60
|
||||
inputs["token_ids_all"][0, 8] = 70
|
||||
# Place stop_seq in pre_ids at end position (would be detected by accept_idx=-1)
|
||||
# pre_ids_now[0..7] = 8 个历史 output token
|
||||
# accept_idx=-1 检查 pre_ids_now[5,6,7] 对应 stop_seq[0,1,2]
|
||||
inputs["token_ids_all"][0, 5] = 50
|
||||
inputs["token_ids_all"][0, 6] = 60
|
||||
inputs["token_ids_all"][0, 7] = 70
|
||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop
|
||||
inputs["min_tokens"][:] = 100 # step_idx+accept_num=11 < 100, should NOT stop
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# min_tokens prevents stop, accept_num unchanged
|
||||
self.assertEqual(outputs["accept_num"][0], 3)
|
||||
|
||||
def test_min_tokens_allows_stop(self):
|
||||
"""Kernel allows stop when step_idx >= min_tokens."""
|
||||
"""Kernel allows stop when step_idx + accept_num >= min_tokens."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
@@ -388,15 +421,17 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 8
|
||||
inputs["accept_num"][:] = 3
|
||||
# Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only)
|
||||
inputs["token_ids_all"][0, 6] = 50
|
||||
inputs["token_ids_all"][0, 7] = 60
|
||||
inputs["token_ids_all"][0, 8] = 70
|
||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
# stop_seq [X, 50] spans pre_ids and accept_tokens[0].
|
||||
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
# At accept_idx=0 (window ends at accept_tokens[0]=50):
|
||||
# i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=50 vs stop_seq[1]=50 ✓
|
||||
# i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=8+(-1)=7 -> pre_ids[7]
|
||||
pre_val = int(inputs["token_ids_all"][0, 7]) # pre_ids_now[7]
|
||||
inputs["accept_tokens"][0, :3] = [50, 60, 70]
|
||||
inputs["stop_seqs"][0, 0, :2] = [pre_val, 50]
|
||||
inputs["stop_seqs_len"][0, 0] = 2
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop
|
||||
inputs["min_tokens"][:] = 5 # step_idx+accept_num=11 >= 5, should stop
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
|
||||
@@ -413,20 +448,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 8
|
||||
inputs["accept_num"][:] = 3
|
||||
# accept_tokens: stop_seq[20,30] matches at accept_idx=2:
|
||||
# i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK
|
||||
# i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK
|
||||
# accept_tokens: [20, 30, 40]
|
||||
# Second stop seq [20, 30] matches at accept_idx=1 (window ends at accept_tokens[1]=30):
|
||||
# i=1: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=30 vs stop_seq[1]=30 ✓
|
||||
# i=0: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=20 vs stop_seq[0]=20 ✓
|
||||
inputs["accept_tokens"][0, :3] = [20, 30, 40]
|
||||
# First stop seq doesn't match
|
||||
inputs["stop_seqs"][0, 0, :3] = [99, 98, 97]
|
||||
inputs["stop_seqs_len"][0, 0] = 3
|
||||
# Second stop seq matches
|
||||
# Second stop seq [20, 30] matches
|
||||
inputs["stop_seqs"][0, 1, :2] = [20, 30]
|
||||
inputs["stop_seqs_len"][0, 1] = 2
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# Match at accept_idx=1 -> accept_num=3, eos at accept_tokens[2]
|
||||
self.assertEqual(outputs["accept_num"][0], 3)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq
|
||||
|
||||
def test_nonzero_prompt_lens(self):
|
||||
"""Verify prompt_lens offset is applied correctly."""
|
||||
@@ -444,19 +483,104 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
||||
inputs["accept_num"][:] = 2
|
||||
inputs["accept_tokens"][0, :2] = [55, 66]
|
||||
# pre_ids_now starts at token_ids_all[0, prompt_len:]
|
||||
# stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx]
|
||||
# For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4
|
||||
# -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4]
|
||||
# For accept_idx=1 (second token is accept_tokens[0,0]=55):
|
||||
# i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55
|
||||
# i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5]
|
||||
target_val = int(inputs["token_ids_all"][0, prompt_len + 5])
|
||||
# pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||
# stop_seq = [X, 55] where X = pre_ids_now[5 + (-1)] = pre_ids_now[4]
|
||||
# At accept_idx=0 (window ends at accept_tokens[0]=55):
|
||||
# i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=55 vs stop_seq[1]=55 ✓
|
||||
# i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=5+(-1)=4 -> pre_ids[4]=token_ids_all[0, prompt_len+4]
|
||||
target_val = int(inputs["token_ids_all"][0, prompt_len + 4])
|
||||
inputs["stop_seqs"][0, 0, :2] = [target_val, 55]
|
||||
inputs["stop_seqs_len"][0, 0] = 2
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# Match at accept_idx=0 -> accept_num=2, eos at accept_tokens[1]
|
||||
self.assertEqual(outputs["accept_num"][0], 2)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 1], -1) # eos appended after stop_seq
|
||||
|
||||
def test_single_token_stop_seq_preserved(self):
|
||||
"""Single token stop_seq (like <|im_end|>) with eos appended after it."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
max_model_len=32,
|
||||
stop_seqs_bs=1,
|
||||
stop_seqs_max_len=1,
|
||||
seed=90,
|
||||
)
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 10
|
||||
inputs["accept_num"][:] = 4
|
||||
# accept_tokens: [a, b, <|im_end|>, d] where <|im_end|> has token id 999
|
||||
inputs["accept_tokens"][0, :4] = [100, 200, 999, 300]
|
||||
# stop_seq = [<|im_end|>] (single token)
|
||||
inputs["stop_seqs"][0, 0, 0] = 999
|
||||
inputs["stop_seqs_len"][0, 0] = 1
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# Match at accept_idx=2 (window ends at accept_tokens[2]=999)
|
||||
# After loop increment, accept_idx=3, accept_num=4, eos at accept_tokens[3]
|
||||
self.assertEqual(outputs["accept_num"][0], 4)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq
|
||||
|
||||
def test_stop_seq_at_last_position_not_detected(self):
|
||||
"""Stop seq at the last position of accept_tokens is NOT detected (deferred to next round)."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
max_model_len=32,
|
||||
stop_seqs_bs=1,
|
||||
stop_seqs_max_len=1,
|
||||
seed=100,
|
||||
)
|
||||
inputs["prompt_lens"][:] = 0
|
||||
inputs["step_idx"][:] = 10
|
||||
inputs["accept_num"][:] = 4
|
||||
# stop_seq [999] is at accept_tokens[3] (last valid position)
|
||||
# Since we only check up to accept_num - 2 = 2, this won't be detected
|
||||
inputs["accept_tokens"][0, :4] = [100, 200, 300, 999]
|
||||
inputs["stop_seqs"][0, 0, 0] = 999
|
||||
inputs["stop_seqs_len"][0, 0] = 1
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# No match because accept_idx only goes up to 2, and 999 is at position 3
|
||||
# accept_num unchanged
|
||||
self.assertEqual(outputs["accept_num"][0], 4)
|
||||
|
||||
def test_stop_seq_detected_from_previous_round(self):
|
||||
"""Stop seq at the end of pre_ids (from previous round) is detected via accept_idx=-1."""
|
||||
inputs = gen_inputs(
|
||||
real_bsz=1,
|
||||
accept_tokens_len=5,
|
||||
max_model_len=32,
|
||||
stop_seqs_bs=1,
|
||||
stop_seqs_max_len=1,
|
||||
seed=110,
|
||||
)
|
||||
inputs["prompt_lens"][:] = 0
|
||||
# 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||
# step_idx = 10 表示有 10 个历史 output token,在 pre_ids_now[0..9]
|
||||
# accept_idx=-1 检查 pre_ids_now[9] (最后一个历史 token)
|
||||
inputs["step_idx"][:] = 10
|
||||
inputs["token_ids_all"][0, 9] = 999 # pre_ids_now[9] = 第 10 个 output token (0-indexed)
|
||||
inputs["accept_num"][:] = 3
|
||||
inputs["accept_tokens"][0, :3] = [100, 200, 300]
|
||||
inputs["stop_seqs"][0, 0, 0] = 999
|
||||
inputs["stop_seqs_len"][0, 0] = 1
|
||||
inputs["stop_flags"][:] = False
|
||||
inputs["min_tokens"][:] = 0
|
||||
outputs = self._run_and_get(inputs)
|
||||
self._check_all_outputs(inputs, outputs)
|
||||
# stop_seq [999] was in pre_ids at end, accept_idx=-1 matches
|
||||
# After loop increment, accept_idx=0, accept_num=1, eos at accept_tokens[0]
|
||||
self.assertEqual(outputs["accept_num"][0], 1)
|
||||
self.assertEqual(outputs["accept_tokens"][0, 0], -1) # replaced with eos
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -261,7 +261,9 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Write history to token_ids_all (forward loop, mirrors kernel step 5)
|
||||
if output_len > 0:
|
||||
base_addr = int(prompt_lens[batch_id])
|
||||
base = cur_step_idx - output_len + 1
|
||||
# 新语义: step_idx 入口 = 历史数量,处理后 cur_step_idx = 历史 + output_len
|
||||
# 第一个 output token 写入位置 = cur_step_idx - output_len
|
||||
base = cur_step_idx - output_len
|
||||
for i in range(output_len):
|
||||
write_idx = base_addr + base + i
|
||||
if 0 <= write_idx < max_model_len:
|
||||
|
||||
@@ -411,6 +411,32 @@ class TestDPLocalScheduler(unittest.TestCase):
|
||||
self.assertEqual(scheduler.ids, ["fresh_req"])
|
||||
self.assertEqual(scheduler.ids_read_cursor, 1)
|
||||
|
||||
def test_get_requests_insufficient_resources(self):
|
||||
"""Test getting requests when resources are insufficient."""
|
||||
mock_logger.reset_mock()
|
||||
|
||||
# Test with insufficient blocks - mock the condition variable to avoid threading issues
|
||||
with patch.object(self.scheduler, "requests_not_empty"):
|
||||
requests = self.scheduler.get_requests(
|
||||
available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
|
||||
)
|
||||
|
||||
self.assertEqual(requests, [])
|
||||
# The logger should have been called for insufficient resources
|
||||
self.assertTrue(mock_logger.debug.called)
|
||||
# Check the message contains expected content
|
||||
call_args = mock_logger.debug.call_args[0][0]
|
||||
self.assertIn("insufficient", call_args.lower())
|
||||
|
||||
def test_get_requests_insufficient_batch(self):
|
||||
"""Test getting requests when batch size is insufficient."""
|
||||
with patch.object(self.scheduler, "requests_not_empty"):
|
||||
requests = self.scheduler.get_requests(
|
||||
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0
|
||||
)
|
||||
|
||||
self.assertEqual(requests, [])
|
||||
|
||||
@patch("time.time")
|
||||
@patch.object(dp_scheduler_module, "envs")
|
||||
def test_get_requests_no_requests_available(self, mock_envs, mock_time):
|
||||
|
||||
@@ -25,9 +25,6 @@ class DummyEngine:
|
||||
"""Dummy Engine class to simulate the actual Engine for testing."""
|
||||
|
||||
class ResourceManager:
|
||||
def __init__(self):
|
||||
self.waiting = []
|
||||
|
||||
def available_batch(self):
|
||||
return 4
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ class TestConfig(unittest.TestCase):
|
||||
model_config=model_config,
|
||||
test_mode=True,
|
||||
)
|
||||
fd_config.init_cache_info()
|
||||
fd_config.init_pd_info()
|
||||
assert fd_config.register_info is not None
|
||||
|
||||
def test_fdconfig_postprocess_ports(self):
|
||||
|
||||
@@ -487,7 +487,7 @@ class TestSleepWakeupBehavior(unittest.TestCase):
|
||||
runner.local_rank = 0
|
||||
runner.device_id = 1
|
||||
runner.num_gpu_blocks = 8
|
||||
runner.model = Mock(clear_grpah_opt_backend=Mock())
|
||||
runner.model = Mock(clear_graph_opt_backend=Mock())
|
||||
runner.clear_cache = Mock()
|
||||
runner.initialize_kv_cache = Mock()
|
||||
runner.capture_model = Mock()
|
||||
@@ -523,7 +523,7 @@ class TestSleepWakeupBehavior(unittest.TestCase):
|
||||
|
||||
runner.sleep("weight,kv_cache")
|
||||
|
||||
runner.model.clear_grpah_opt_backend.assert_called_once()
|
||||
runner.model.clear_graph_opt_backend.assert_called_once()
|
||||
runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once()
|
||||
runner.dynamic_weight_manager.clear_model_weight.assert_called_once()
|
||||
runner.dynamic_weight_manager.clear_communication_group.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user