Merge origin/release/2.6 and resolve worker_process conflict

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-04-16 11:01:28 +00:00
committed by GitHub
47 changed files with 1386 additions and 753 deletions
+10 -10
View File
@@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel(
const int max_batch_size) {
constexpr int VecSize = 4;
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
int startend_row_vec[4];
int startend_row_vec[2];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
@@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel(
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
startend_row_vec[0] = this_batch_q_end;
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
startend_row_vec[2] = 0;
startend_row_vec[3] = this_batch_q_end;
// startend_row_vec[1] = cu_seqlens_q[max_batch_size];
// startend_row_vec[2] = 0;
startend_row_vec[1] = this_batch_q_end;
for (int this_batch_q_idx = this_batch_q_start;
this_batch_q_idx < this_batch_q_end;
++this_batch_q_idx) {
@@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel(
: this_batch_q_idx - this_batch_q_start + kv_len -
(this_batch_q_len);
if (cache_k_idx <= append_mask_k_end) {
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx);
// 可提前跳出循环
break;
}
}
reinterpret_cast<int4*>(startend_row_indices_ptr +
cu_seqlens_k_idx * 4)[0] =
reinterpret_cast<int4*>(startend_row_vec)[0];
reinterpret_cast<int2*>(startend_row_indices_ptr +
cu_seqlens_k_idx * 2)[0] =
reinterpret_cast<int2*>(startend_row_vec)[0];
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::optional<paddle::Tensor>& attn_mask_kv,
const int kv_token_num) {
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
{1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place());
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
constexpr int block_size = 512;
int grid_size = div_up(kv_token_num, block_size);
@@ -123,7 +123,7 @@ std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
const std::vector<int64_t>& cu_seqlens_k_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
const int kv_token_num) {
return {{1, 1, kv_token_num, 4}};
return {{1, 1, kv_token_num, 2}};
}
PD_BUILD_STATIC_OP(get_attn_mask_q)
@@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
int64_t* next_tokens, // [bs, tokens_per_step]
const int* max_think_lens, // [bs]
int* max_reply_lens, // [bs]
int64_t* step_idx, // [bs]
const int64_t* step_idx, // [bs]
const int64_t* eos_token_ids, // [eos_len]
int* limit_status, // [bs]
int* accept_num, // [bs]
@@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
int new_accept_num = original_accept_num;
// 本 step 的 token offset 对应的绝对 step
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
const int64_t current_base_step = step_idx[bid] + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
@@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel(
// inject_token_ids[0]
if (status == 0 &&
(current_step - 1) ==
max_think_len) { // current_step - 1 是因为 speculate_verify 里
// step_idx + 1 了
max_think_len) { // current_step - 1 : 已输出 current_step-1
// 个thinking token
status = (inject_len > 0) ? 1 : done_status;
}
} else if (max_think_len == 0) {
@@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel(
}
}
// 更新 step_idx / accept_num(被截断的 token 需要回退
// step_idx
const int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_status[bid] = status;
max_reply_lens[bid] = max_reply_len;
@@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int*>(max_reply_lens.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
step_idx.data<int64_t>(),
eos_token_ids.data<int64_t>(),
const_cast<int*>(limit_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
@@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
const int64_t step_idx_now = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
const bool can_stop = (step_idx_now >= min_token_limit);
const bool can_stop = (step_idx_now + accept_num >= min_token_limit);
if (!can_stop) return;
if (!stop_flags[bid]) {
int accept_idx = 0;
/*
accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based)
accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾
(pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。
为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1]
(当前轮最后一个 token),该 token 延迟到下一轮匹配。
循环范围:accept_num > 0 时为 [-1, accept_num-2]
accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。
*/
int accept_idx = -1;
bool is_end = false;
// 遍历起始位置
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
// 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾
// 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens
// 中的匹配。两者共享同一套从后向前匹配逻辑。
int loop_end = (accept_num > 0) ? accept_num - 2 : -1;
for (; accept_idx <= loop_end && !is_end; accept_idx++) {
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("num %d < stop_seq_len %d\n",
step_idx_now - accept_num + accept_idx + 1,
step_idx_now + accept_idx + 1,
stop_seq_len);
#endif
continue;
}
// 遍历一个 stop_seqs
// 从后向前匹配 stop_seq 的每个 token
for (int i = stop_seq_len - 1; i >= 0; --i) {
int64_t cur_token_idx = -1;
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
if (stop_seq_len - 1 - i < accept_idx) {
int offset = stop_seq_len - 1 - i;
int accept_tokens_idx = accept_idx - offset;
if (accept_tokens_idx >= 0) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
"accept_token_idx: "
"%d\n",
"accept_token_idx: %d\n",
bid,
tid,
accept_idx,
accept_idx - (stop_seq_len - 1 - i) - 1);
accept_tokens_idx);
#endif
cur_token_idx =
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
cur_token_idx = accept_tokens_now[accept_tokens_idx];
} else {
int pre_ids_idx = step_idx_now + accept_tokens_idx;
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
"accept_idx:%d. "
"pre_id_idx: %ld\n",
"accept_idx:%d. pre_id_idx: %d\n",
bid,
tid,
step_idx_now,
accept_idx,
step_idx_now - accept_num + accept_idx -
(stop_seq_len - 1 - i));
pre_ids_idx);
#endif
int pre_ids_idx =
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
// EC3
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
// 导致异常结束
if (pre_ids_idx <= 0) {
break;
}
if (pre_ids_idx < 0) break;
cur_token_idx = pre_ids_now[pre_ids_idx];
}
#ifdef DEBUG_SPEC_STOP_SEQS
@@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
}
if (is_end) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("bid:%d end with accept_idx %d", bid, accept_idx);
printf("bid:%d end with accept_idx %d\n", bid, accept_idx);
#endif
accept_nums[bid] = accept_idx;
accept_tokens_now[accept_idx - 1] = end_ids[0];
// stop_flags[bid] = true;
// accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置
accept_nums[bid] = accept_idx + 1;
accept_tokens_now[accept_idx] = end_ids[0];
}
}
}
@@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
int64_t *token_ids_all_now =
&token_ids_all[batch_id * max_model_len + prompt_len];
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
int64_t base = cur_step_idx - output_len + 1;
int64_t base = cur_step_idx - output_len;
for (int i = 0; i < output_len; i++) {
token_ids_all_now[base + i] = output_ids[i];
}
+3
View File
@@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to enable the decode caches requests for preallocating resource
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
# Batched token timeout in EP
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# Max pre-fetch requests number in PD
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
+3
View File
@@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# 是否启用 decode 缓存请求以预分配资源
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
# EP 中批处理 token 的超时时间
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# PD 中最大预取请求数量
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
-113
View File
@@ -1,113 +0,0 @@
#!/bin/bash
set -e
# Test splitwise deployment
# There are two methods for splitwise deployment:
# v0: using splitwise_scheduler or dp_scheduler (deprecated)
# v1: using local_scheduler + router
# prepare environment
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=1
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
SCRIPT_PATH=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
exit 1
fi
unset http_proxy && unset https_proxy
source ${SCRIPT_DIR}/utils.sh
P_PORT=52400
D_PORT=52500
REDIS_PORT="${REDIS_PORT:-6379}"
LOG_DATE=$(date +%Y%m%d_%H%M%S)
ports=(
$P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5))
$D_PORT $((D_PORT + 1)) $((D_PORT + 2)) $((D_PORT + 3)) $((D_PORT + 4)) $((D_PORT + 5))
$REDIS_PORT
)
check_ports "${ports[@]}" || {
echo "❌ Some ports are in use. Please release them."
exit 1
}
# start redis
if ! redis-cli -p ${REDIS_PORT} ping &>/dev/null; then
echo "Redis is not running. Starting redis-server..."
redis-server --daemonize yes --port ${REDIS_PORT}
sleep 1
else
echo "Redis is already running."
fi
sleep 1
# start prefill
export CUDA_VISIBLE_DEVICES=0
export FD_LOG_DIR="log/$LOG_DATE/prefill"
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port ${P_PORT} \
--metrics-port $((P_PORT + 1)) \
--engine-worker-queue-port $((P_PORT + 2)) \
--cache-queue-port $((P_PORT + 3)) \
--max-model-len 32768 \
--num-gpu-blocks-override 1000 \
--splitwise-role "prefill" \
--cache-transfer-protocol "rdma" \
--rdma-comm-ports $((P_PORT + 4)) \
--pd-comm-port $((P_PORT + 5)) \
--scheduler-name "splitwise" \
--scheduler-host "127.0.0.1" \
--scheduler-port ${REDIS_PORT} \
--scheduler-ttl 9000 \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health ${P_PORT}
# start decode
export CUDA_VISIBLE_DEVICES=1
export FD_LOG_DIR="log/$LOG_DATE/decode"
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port ${D_PORT} \
--metrics-port $((D_PORT + 1)) \
--engine-worker-queue-port $((D_PORT + 2)) \
--cache-queue-port $((D_PORT + 3)) \
--max-model-len 32768 \
--splitwise-role "decode" \
--cache-transfer-protocol "rdma" \
--rdma-comm-ports $((D_PORT + 4)) \
--pd-comm-port $((D_PORT + 5)) \
--scheduler-name "splitwise" \
--scheduler-host "127.0.0.1" \
--scheduler-port ${REDIS_PORT} \
--scheduler-ttl 9000 \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health ${D_PORT}
# send request
sleep 10 # make sure server is registered to router
echo "send request..."
curl -X POST "http://0.0.0.0:${D_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "hello"}
],
"max_tokens": 20,
"stream": false
}'
+6 -7
View File
@@ -2009,13 +2009,13 @@ class FDConfig:
and self.router_config
and self.router_config.router
):
# For RL scenario: version.yaml will be required for models in future releases.
# For RL scenario, version.yaml is required for models
# Temporarily enforce use router to be enabled.
self.model_config.read_model_version()
self.read_from_config()
self.postprocess()
self.init_cache_info()
self.init_pd_info()
if test_mode:
return
self.check()
@@ -2348,18 +2348,17 @@ class FDConfig:
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
def init_cache_info(self):
def init_pd_info(self):
"""
initialize cache info
initialize info for pd deployment
"""
# TODO: group the splitiwse params
# There are two methods for splitwise deployment:
# 1. v0 splitwise_scheduler or dp_scheduler
# 2. v1 local_scheduler + router
# 2. v1 local_scheduler + router (optional)
self.splitwise_version = None
if self.scheduler_config.name in ("splitwise", "dp"):
self.splitwise_version = "v0"
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
elif self.scheduler_config.name == "local":
self.splitwise_version = "v1"
# the information for registering this server to router or splitwise_scheduler
+8 -3
View File
@@ -592,10 +592,15 @@ class EngineArgs:
raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1")
if self.splitwise_role != "mixed":
if self.scheduler_name == "local" and self.router is None:
if self.scheduler_name == "splitwise":
raise ValueError(
f"When using {self.splitwise_role} role and the {self.scheduler_name} "
f"scheduler, please provide --router argument."
"Setting scheduler_name as splitwise is not supported in pd deployment, "
"please use router as scheduler."
)
if self.scheduler_name == "local" and self.router is None:
console_logger.warning(
f"Running {self.splitwise_role} role with {self.scheduler_name} "
f"scheduler without --router. Router registration and request routing will be disabled."
)
if not (
+9 -28
View File
@@ -367,15 +367,6 @@ class EngineService:
create=True,
)
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
self.engine_forward_signal = IPCSignal(
name="engine_forward_signal",
array=engine_forward_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
@@ -1037,7 +1028,16 @@ class EngineService:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
try:
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role != "mixed":
if not is_fetching:
is_fetching = True
get_request_pool.submit(_fetch_request)
else:
if len(self.resource_manager.waiting) == 0 and (not is_fetching):
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
try:
is_fetching = True
@@ -1048,18 +1048,6 @@ class EngineService:
break
else:
raise
if self.cfg.scheduler_config.splitwise_role != "mixed":
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
# more efficient batches.
if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0:
time.sleep(0.001)
continue
else:
# In mixed, todo: optimze cache swap, to decouple swap from scheduler
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
continue
if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()
@@ -1120,13 +1108,6 @@ class EngineService:
elif not task.has_been_preempted_before:
task.metrics.inference_start_time = time.time()
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
else:
# When there are no actual tasks to schedule, send an empty task batch to EP workers.
# This helps EP workers barrier for syncing tasks not hang.
if self.cfg.parallel_config.enable_expert_parallel:
self.engine_worker_queue.put_tasks(
([], self.resource_manager.real_bsz)
) # Empty (as idle tasks for ep)
# 4. Response error tasks
if error_tasks:
+2 -2
View File
@@ -109,7 +109,7 @@ class ExpertService:
if envs.FD_ENABLE_RETURN_TEXT:
self.engine.create_data_processor()
if self.cfg.scheduler_config.name == "dp":
self.cfg.init_cache_info()
self.cfg.init_pd_info()
self.engine.scheduler.start(local_data_parallel_id)
if ipc_signal_suffix is not None:
@@ -122,7 +122,7 @@ class ExpertService:
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.scheduler_config.name == "splitwise":
self.cfg.init_cache_info()
self.cfg.init_pd_info()
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
@@ -927,6 +927,7 @@ class ResourceManagerV1(ResourceManager):
if (
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
and self.config.scheduler_config.splitwise_role != "prefill"
):
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
@@ -1374,6 +1375,11 @@ class ResourceManagerV1(ResourceManager):
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.need_prefill_tokens
)
return True
else:
self._free_blocks(request)
@@ -111,7 +111,7 @@ class ErnieX1ToolParser(ToolParser):
)
)
return ExtractedToolCallInformation(
tools_called=True,
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
)
except Exception:
@@ -182,11 +182,13 @@ class ErnieX1ToolParser(ToolParser):
logger.debug("attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
if '"}' not in delta_text:
if diff is not None:
if "}" not in delta_text:
return None
end_loc = delta_text.rindex("}")
diff = delta_text[:end_loc]
if not diff:
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not " "been streamed yet: %s",
diff,
@@ -248,15 +250,15 @@ class ErnieX1ToolParser(ToolParser):
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
cur_arguments = current_tool_call.get("arguments")
if not cur_arguments and not prev_arguments:
if cur_arguments is None and prev_arguments is None:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
elif not cur_arguments and prev_arguments:
elif cur_arguments is None and prev_arguments is not None:
logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.")
delta = None
elif cur_arguments and not prev_arguments:
elif cur_arguments is not None and prev_arguments is None:
function_name = current_tool_call.get("name")
match = re.search(
r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
@@ -265,6 +267,19 @@ class ErnieX1ToolParser(ToolParser):
)
if match:
cur_arguments_json = match.group(1)
# When tool_call_portion is complete JSON, the regex
# (.*) over-captures the outer closing brace of the
# tool call object. Strip it from both
# cur_arguments_json and delta_text, consistent with
# the both-have-arguments branch handling.
try:
json.loads(tool_call_portion)
if cur_arguments_json.endswith("}"):
cur_arguments_json = cur_arguments_json[:-1]
if delta_text.rstrip().endswith("}"):
delta_text = delta_text.rstrip()[:-1]
except Exception:
pass
else:
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
@@ -287,7 +302,7 @@ class ErnieX1ToolParser(ToolParser):
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
elif cur_arguments is not None and prev_arguments is not None:
try:
json.loads(tool_call_portion)
is_complete_json = True
+16 -11
View File
@@ -145,6 +145,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Whether to enable the decode caches requests for preallocating resource
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
# Batched token timeout in EP
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# Max pre-fetch requests number in PD
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
# Enable or disable model caching.
@@ -210,17 +212,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""),
# Whether to enable low latency in mixed scenario
"FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))),
# Whether to use phi FP8 quantization,if 1,use paddle default.
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
# intended for training alignment. Defaults to 0 (disabled).
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
# Whether to use phi MOE permute,if 1,use paddle op.
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
# Whether to use phi rms_norm,if 1,use paddle op.
"FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))),
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
# Reserve output blocks for decoding requests when schedule new prefill requests
"FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int(
os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16")
@@ -262,8 +253,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
),
# train-infer consistency, used in RL
# Whether to align RoPE and moe gate precision with training
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
# Whether to use phi FP8 quantization,if 1,use paddle default.
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
# intended for training alignment. Defaults to 0 (disabled).
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
# Whether to use phi MOE permute,if 1,use paddle op.
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
# Whether to use phi rms_norm,if 1,use paddle op.
"FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))),
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
# Whether to enable FP8 quantization with pow2scale.
"FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))),
}
@@ -92,7 +92,7 @@ class GraphOptWrapper:
def __call__(self, **kwargs):
return self.graph_opt_backend(**kwargs)
def clear_grpah_opt_backend(self, fd_config):
def clear_graph_opt_backend(self, fd_config):
""" """
# TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights
assert (
@@ -95,7 +95,7 @@ def init_flash_attn_version():
logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.")
if FLASH_ATTN_VERSION is None:
if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()):
if sm_version == 90 and 90 in paddle.version.cuda_archs():
FLASH_ATTN_VERSION = 3
logger.info("The current platform supports Flash Attention V3.")
else:
@@ -188,7 +188,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
else:
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=not disable_ue8m0_cast,
using_pow2_scale=not disable_ue8m0_cast or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
using_ue8m0_scale=not disable_ue8m0_cast,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
@@ -355,7 +355,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
else:
x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
@@ -581,7 +581,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
else:
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
ffn_out,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0
or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
@@ -773,7 +774,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
else:
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
@@ -1247,7 +1247,7 @@ def python_op_fused_moe_kernel_paddle(
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
else:
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x, using_pow2_scale=False, output_scale_transpose=False
x, using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=False
)
x_scale = x_scale[: x.shape[0]]
@@ -1305,7 +1305,9 @@ def python_op_fused_moe_kernel_paddle(
)
else:
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False
intermediate_cache2,
using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
output_scale_transpose=False,
)
x_scale = x_scale[: x_q.shape[0]]
@@ -343,7 +343,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
else:
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
output_scale_transpose=True,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
@@ -1306,9 +1306,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
)
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class DeepSeekV3PretrainedModel(PretrainedModel):
@@ -701,9 +701,9 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
self.ernie.clear_graph_opt_backend(fd_config=self.fd_config)
@ModelRegistry.register_model_class(
@@ -829,9 +829,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
self.ernie.clear_graph_opt_backend(fd_config=self.fd_config)
class Ernie4_5_VLPretrainedModel(PretrainedModel):
+2 -2
View File
@@ -563,9 +563,9 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Glm4MoePretrainedModel(PretrainedModel):
@@ -369,3 +369,7 @@ class Glm4MTPForCausalLM(ModelForCasualLM):
)
return hidden_states
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
+2 -2
View File
@@ -417,9 +417,9 @@ class Qwen2ForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config)
self.qwen2.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen2PretrainedModel(PretrainedModel):
+2 -2
View File
@@ -341,9 +341,9 @@ class Qwen3ForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen3PretrainedModel(PretrainedModel):
@@ -382,9 +382,9 @@ class Qwen3VLForConditionalGeneration(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen3VLPretrainedModel(PretrainedModel):
+2 -2
View File
@@ -453,9 +453,9 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen3MoePretrainedModel(PretrainedModel):
+2
View File
@@ -68,6 +68,7 @@ class RolloutModelConfig:
routing_replay_config: str = None,
load_choices: str = "default_v1",
lm_head_fp32: bool = False,
moe_gate_fp32: bool = True,
):
# Required parameters
self.model = model_name_or_path
@@ -121,6 +122,7 @@ class RolloutModelConfig:
self.routing_replay_config = routing_replay_config
self.load_choices = load_choices
self.lm_head_fp32 = lm_head_fp32
self.moe_gate_fp32 = moe_gate_fp32
def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
+36 -3
View File
@@ -23,7 +23,7 @@ from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledResponse
from fastdeploy.scheduler.local_scheduler import LocalScheduler
from fastdeploy.utils import get_logger
from fastdeploy.utils import envs, get_logger
class DPLocalScheduler(LocalScheduler):
@@ -131,19 +131,52 @@ class DPLocalScheduler(LocalScheduler):
Returns:
List of Request objects ready for processing
"""
# DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that.
if available_blocks <= reserved_output_blocks or batch < 1:
self.scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []
required_total_blocks = 0
current_prefill_tokens = 0
start_batch_time = time.time()
requests: List[Request] = []
with self.requests_not_empty:
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1],
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
if batch_ids:
if len(batch_ids) > 0 and len(requests) == 0:
self.scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
)
if len(requests) > 0:
self.scheduler_logger.info(
@@ -53,9 +53,6 @@ class InternalAdapter:
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
available_block_num = self.engine.resource_manager.available_block_num()
unhandled_request_num = self.engine.scheduler.get_unhandled_request_num()
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
unhandled_request_num = max(unhandled_request_num, len(self.engine.resource_manager.waiting))
server_info = {
"splitwise_role": self.cfg.scheduler_config.splitwise_role,
"block_size": int(self.cfg.cache_config.block_size),
@@ -65,7 +62,7 @@ class InternalAdapter:
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
"max_batch_size": int(available_batch_size),
"max_input_token_num": self.cfg.model_config.max_model_len,
"unhandled_request_num": unhandled_request_num,
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
"available_batch": int(self.engine.resource_manager.available_batch()),
}
return server_info
+20 -4
View File
@@ -27,7 +27,7 @@ import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.request import ImagePosition, Request, RequestType
from fastdeploy.model_executor.graph_optimization.utils import (
@@ -2110,6 +2110,12 @@ class GPUModelRunner(ModelRunnerBase):
self._cached_sampler_output = sampler_output
self._cached_post_process_event = post_process_event
else:
if (
self.fd_config.speculative_config.method == SpecMethod.MTP
and hasattr(self.proposer.model, "empty_input_forward")
and self.parallel_config.use_ep
):
self._execute_empty_mtp_input(self.forward_meta)
self._cached_model_output_data = None
self._cached_sampler_output = None
self._cached_post_process_event = None
@@ -2403,6 +2409,16 @@ class GPUModelRunner(ModelRunnerBase):
# 5.1. Async cpy
post_process_event = paddle.device.cuda.create_event()
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
# If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished.
paddle.assign(
paddle.where(
self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1,
PREEMPTED_TOKEN_ID,
sampler_output.sampled_token_ids,
),
sampler_output.sampled_token_ids,
)
# if not self.speculative_decoding:
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
if self.speculative_decoding:
@@ -2676,13 +2692,13 @@ class GPUModelRunner(ModelRunnerBase):
"""Dynamic model loader use to clear parameters use for RL"""
# Clear CUDAGraph
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
self.model.clear_graph_opt_backend()
# Clear parameters and Send single
self.dynamic_weight_manager.clear_parameters(
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
)
if self.spec_method == SpecMethod.MTP:
self.proposer.model.clear_grpah_opt_backend()
self.proposer.model.clear_graph_opt_backend()
self.proposer.clear_mtp_cache()
self.clear_cache()
paddle.device.cuda.empty_cache()
@@ -2736,7 +2752,7 @@ class GPUModelRunner(ModelRunnerBase):
logger.info("GPU model runner's weight is already sleeping, no need to sleep again!")
return
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
self.model.clear_graph_opt_backend()
if self.fd_config.parallel_config.enable_expert_parallel:
self.dynamic_weight_manager.clear_deepep_buffer()
self.dynamic_weight_manager.clear_model_weight()
+1 -1
View File
@@ -2511,7 +2511,7 @@ class MetaxModelRunner(ModelRunnerBase):
"""Dynamic model loader use to clear parameters use for RL"""
# Clear CUDAGraph
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
self.model.clear_graph_opt_backend()
# Clear parameters and Send single
self.dynamic_weight_manager.clear_parameters(
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
+4 -25
View File
@@ -457,6 +457,9 @@ class PaddleDisWorkerProc:
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True:
# run eplb
self._run_eplb(tp_rank)
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.ranks > 1:
@@ -534,7 +537,7 @@ class PaddleDisWorkerProc:
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.debug(f"Rank: {self.local_rank} Detected new requests.")
self.engine_forward_signal.value[0] = 1
tasks, read_finish = self.task_queue.get_tasks()
# Only one of all tp_size client will get read_finish == True.
if read_finish:
@@ -543,22 +546,8 @@ class PaddleDisWorkerProc:
self.task_queue.read_finish_flag.set(0)
else:
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
# In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival.
# Only EP + DP prefill should barrier for data arrival.
# In mixed mode and decoder in D, we should not barrier to influence decoding.
if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill":
paddle.distributed.barrier(self.parallel_config.ep_group)
req_dicts, control_reqs = [], []
assert (
len(tasks) > 0
), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}"
# In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue.
# tasks[0] contains two part, ([req1, ...] ,real_bsz)
# tasks[0][0] is [req1, ...]
# if empty batch is delived, eval(tasks[0][0]) should be False ([]),
# if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below.
if tasks[0][0]:
for req_dict, bsz in tasks:
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
control_reqs.append(req_dict[0])
@@ -591,12 +580,6 @@ class PaddleDisWorkerProc:
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
else:
if self.scheduler_config.splitwise_role == "prefill":
if tp_size > 1:
# Synchronize the signal for other workers
self._tp_barrier_wait()
continue
# Let the ep group run control method synchronically
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep:
@@ -611,7 +594,6 @@ class PaddleDisWorkerProc:
and not self.worker.model_runner.not_need_stop()
):
self._tp_barrier_wait() if tp_size > 1 else None
self.engine_forward_signal.value[0] = 0
time.sleep(0.001)
continue
@@ -634,9 +616,6 @@ class PaddleDisWorkerProc:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
# run eplb
self._run_eplb(tp_rank)
self.engine_forward_signal.value[0] = 0
if (
not self.parallel_config.use_ep
+1 -1
View File
@@ -47,5 +47,5 @@ aistudio_sdk
p2pstore
py-cpuinfo
flashinfer-python-paddle
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl
transformers>=4.55.1,<5.0.0
+16 -17
View File
@@ -214,29 +214,28 @@ def test_metrics_with_clear_and_reset():
"""
Test the metrics monitoring endpoint.
"""
pass # not stable, uncomment after bug fix
# metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
# async_concurrency(n=10)
async_concurrency(n=10)
# time.sleep(0.3)
time.sleep(0.3)
# ===== clear_load_weight =====
# clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight"
# print("Calling clear_load_weight...")
# r = requests.get(clear_url, timeout=30)
# assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}"
clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight"
print("Calling clear_load_weight...")
r = requests.get(clear_url, timeout=30)
assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}"
# metrics = get_metrics_dict(metrics_url)
# running = metrics["fastdeploy:num_requests_running"]
# waiting = metrics["fastdeploy:num_requests_waiting"]
metrics = get_metrics_dict(metrics_url)
running = metrics["fastdeploy:num_requests_running"]
waiting = metrics["fastdeploy:num_requests_waiting"]
# print(
# "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
# running,
# "waiting:",
# waiting,
# )
print(
"ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
running,
"waiting:",
waiting,
)
# assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight"
@@ -0,0 +1,455 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Test splitwise deployment WITHOUT Router:
# use local_scheduler, manually construct disaggregate_info,
# send requests to both Prefill and Decode concurrently.
# ENABLE_V1_KVCACHE_SCHEDULER=1, use rdma to transfer cache.
import json
import os
import shutil
import signal
import subprocess
import sys
import time
import uuid
import pytest
import requests
from utils.serving_utils import (
FD_API_PORT,
FD_CACHE_QUEUE_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
check_service_health,
clean,
)
# Ports for PD disaggregation (no router port needed)
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623))
# Prefill uses base ports, Decode uses base+1
PORTS_TO_CLEAN = [
FD_API_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
FD_CACHE_QUEUE_PORT,
FD_CONNECTOR_PORT,
FD_RDMA_PORT,
FD_API_PORT + 1,
FD_ENGINE_QUEUE_PORT + 1,
FD_METRICS_PORT + 1,
FD_CACHE_QUEUE_PORT + 1,
FD_CONNECTOR_PORT + 1,
FD_RDMA_PORT + 1,
]
def _build_disaggregate_info() -> dict:
"""Build disaggregate_info manually, replicating Router's handle_splitwise_request logic."""
host_ip = os.getenv("FD_HOST_IP", "127.0.0.1")
return {
"prefill_ip": host_ip,
"decode_ip": host_ip,
"prefill_connector_port": FD_CONNECTOR_PORT,
"decode_connector_port": FD_CONNECTOR_PORT + 1,
"decode_device_ids": ["1"],
"decode_rdma_ports": [FD_RDMA_PORT + 1],
"transfer_protocol": "rdma",
"decode_tp_size": 1,
}
def _send_pd_request(payload: dict, timeout: int = 120):
"""
Send request to both Prefill and Decode concurrently,
replicate Router's fan-out forwarding behavior.
Returns the Decode response (same as Router's return_result_url_index=-1).
"""
disaggregate_info = _build_disaggregate_info()
# Inject disaggregate_info and request_id (same as Router)
payload = payload.copy()
payload["disaggregate_info"] = disaggregate_info
if "request_id" not in payload:
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/chat/completions"
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
# For streaming, use requests with stream=True for decode response
if payload.get("stream", False):
# Send to both concurrently (same as Router's fan-out), stream from decode
import concurrent.futures
def _post_stream(url):
return requests.post(url, headers=headers, json=payload, timeout=timeout, stream=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
prefill_future = executor.submit(_post_stream, prefill_url)
decode_future = executor.submit(_post_stream, decode_url)
# Return decode streaming response immediately
decode_resp = decode_future.result()
# Consume prefill response in background (don't block)
try:
prefill_future.result(timeout=timeout)
except Exception:
pass
return decode_resp
else:
# Non-streaming: send to both, return decode response
import concurrent.futures
def _post(url):
return requests.post(url, headers=headers, json=payload, timeout=timeout)
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
prefill_future = executor.submit(_post, prefill_url)
decode_future = executor.submit(_post, decode_url)
# Wait for both, return decode response
decode_resp = decode_future.result()
# Also check prefill didn't error (but don't block on it)
try:
prefill_future.result(timeout=5)
except Exception:
pass
return decode_resp
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts Prefill and Decode instances WITHOUT Router
- Waits for both to be healthy
- Tears down after all tests finish
"""
print("Pre-test port cleanup...")
clean(PORTS_TO_CLEAN)
print("log dir clean")
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
shutil.rmtree("log_prefill")
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
shutil.rmtree("log_decode")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
print(f"model_path: {model_path}")
base_log_dir = os.getenv("FD_LOG_DIR", "log")
# Prefill instance
print("start prefill...")
env_prefill = os.environ.copy()
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
env_prefill["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_prefill")
prefill_log_path = "prefill.log"
prefill_cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"8192",
"--splitwise-role",
"prefill",
"--cache-transfer-protocol",
"rdma",
"--rdma-comm-ports",
str(FD_RDMA_PORT),
"--pd-comm-port",
str(FD_CONNECTOR_PORT),
# No --router flag
]
with open(prefill_log_path, "w") as logfile:
process_prefill = subprocess.Popen(
prefill_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True,
env=env_prefill,
)
time.sleep(1)
# Decode instance
print("start decode...")
env_decode = os.environ.copy()
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
env_decode["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_decode")
decode_log_path = "decode.log"
decode_cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT + 1),
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT + 1),
"--metrics-port",
str(FD_METRICS_PORT + 1),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT + 1),
"--max-model-len",
"8192",
"--splitwise-role",
"decode",
"--cache-transfer-protocol",
"rdma",
"--rdma-comm-ports",
str(FD_RDMA_PORT + 1),
"--pd-comm-port",
str(FD_CONNECTOR_PORT + 1),
# No --router flag
]
with open(decode_log_path, "w") as logfile:
process_decode = subprocess.Popen(
decode_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True,
env=env_decode,
)
# Wait up to 300 seconds for both instances to be healthy
for _ in range(60):
prefill_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT}")
decode_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT + 1}")
if prefill_healthy and decode_healthy:
print("Prefill and decode servers are both online")
break
time.sleep(5)
else:
print("[TIMEOUT] Servers failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process_prefill.pid, signal.SIGTERM)
os.killpg(process_decode.pid, signal.SIGTERM)
clean(PORTS_TO_CLEAN)
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError("Prefill or decode server did not start")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process_prefill.pid, signal.SIGTERM)
os.killpg(process_decode.pid, signal.SIGTERM)
clean(PORTS_TO_CLEAN)
print(f"Prefill server (pid={process_prefill.pid}) terminated")
print(f"Decode server (pid={process_decode.pid}) terminated")
except Exception as e:
print(f"Failed to terminate server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the Decode API endpoint URL (where final responses come from).
"""
return f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions"
@pytest.fixture
def headers():
return {"Content-Type": "application/json"}
def get_stream_chunks(response):
"""Parse streaming response into chunk list."""
chunks = []
if response.status_code == 200:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
line = line[len("data: ") :]
if line.strip() == "[DONE]":
break
try:
chunk = json.loads(line)
chunks.append(chunk)
except Exception as e:
print(f"Parse failed: {e}, line: {line}")
else:
print(f"Request failed, status: {response.status_code}")
print("Response:", response.text)
return chunks
def test_chat_usage_stream(api_url):
"""Test streaming chat with usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
response = _send_pd_request(payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_chat_usage_non_stream(api_url):
"""Test non-streaming chat with usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
response = _send_pd_request(payload).json()
usage = response["usage"]
result = response["choices"][0]["message"]["content"]
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_stream(api_url):
"""Test streaming completion (non-chat) with usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
# Send to /v1/completions endpoints
disaggregate_info = _build_disaggregate_info()
payload = payload.copy()
payload["disaggregate_info"] = disaggregate_info
if "request_id" not in payload:
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions"
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions"
headers = {"Content-Type": "application/json"}
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120)
decode_future = executor.submit(
requests.post, decode_url, json=payload, headers=headers, timeout=120, stream=True
)
response = decode_future.result()
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_non_stream(api_url):
"""Test non-streaming completion (non-chat) with usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
# Send to /v1/completions endpoints
disaggregate_info = _build_disaggregate_info()
payload = payload.copy()
payload["disaggregate_info"] = disaggregate_info
if "request_id" not in payload:
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions"
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions"
headers = {"Content-Type": "application/json"}
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120)
decode_future = executor.submit(requests.post, decode_url, json=payload, headers=headers, timeout=120)
response = decode_future.result().json()
usage = response["usage"]
result = response["choices"][0]["text"]
print("Decode Response:", result)
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
+3 -9
View File
@@ -1457,9 +1457,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
task.metrics.scheduler_recv_req_time = time.time()
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
eng.engine_worker_queue = Mock(
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
)
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
eng._send_error_response = Mock()
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")]))
@@ -1493,9 +1491,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
task.metrics.scheduler_recv_req_time = time.time()
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
eng.engine_worker_queue = Mock(
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
)
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], []))
@@ -1526,9 +1522,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
task.metrics.scheduler_recv_req_time = time.time()
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
eng.engine_worker_queue = Mock(
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
)
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
eng._send_error_response = Mock()
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)]))
@@ -60,6 +60,50 @@ class TestErnieX1ToolParser(unittest.TestCase):
return ErnieX1ToolParser(tokenizer=DummyTokenizer())
def _simulate_streaming(self, parser, deltas):
"""Simulate a multi-step streaming flow.
Args:
parser: ErnieX1ToolParser instance
deltas: list of delta text strings, each representing one streaming step
Returns:
list of results from each extract_tool_calls_streaming call
"""
results = []
previous_text = ""
token_id = 0
previous_token_ids = []
for delta in deltas:
current_text = previous_text + delta
# When delta contains <tool_call> plus more content, use 2 tokens
# so that the parser extracts tool_call_portion (line 163-164)
if "<tool_call>" in delta and delta != "<tool_call>":
n_tokens = 2
else:
n_tokens = 1
delta_token_ids = list(range(token_id + 1, token_id + 1 + n_tokens))
token_id += n_tokens
current_token_ids = previous_token_ids + delta_token_ids
result = parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta,
previous_token_ids,
current_token_ids,
delta_token_ids,
self.dummy_request,
)
results.append(result)
previous_text = current_text
previous_token_ids = list(current_token_ids)
return results
# ==================== __init__ tests (lines 60-81) ====================
def test_init_sets_tokens_and_ids(self):
@@ -116,6 +160,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
self.assertTrue(result.tools_called)
self.assertEqual(result.tool_calls[0].function.arguments, "{}")
def test_extract_tool_calls_empty_arguments(self):
"""Cover: tool call with explicit empty arguments {}"""
output = '<tool_call>{"name": "fn", "arguments": {}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertTrue(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "fn")
self.assertEqual(result.tool_calls[0].function.arguments, "{}")
def test_extract_tool_calls_nested_arguments(self):
"""Cover regex with nested braces in arguments"""
output = '<tool_call>{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}</tool_call>'
@@ -182,38 +234,24 @@ class TestErnieX1ToolParser(unittest.TestCase):
def test_streaming_end_token_in_delta(self):
"""Cover lines 149-156: </tool_call> appears in delta"""
parser = self._new_parser()
# First, start a tool call
parser.extract_tool_calls_streaming(
"",
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn"',
[],
[1, 10],
[1, 10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": {"k": "', # start + name + args key
"v", # args value
'"}}</tool_call>', # close with end token in delta
],
)
# Now stream arguments
parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn", "arguments": {"k": "v',
', "arguments": {"k": "v',
[1, 10],
[1, 10, 20],
[20],
self.dummy_request,
)
# Close with end token in delta
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn", "arguments": {"k": "v',
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}</tool_call>',
'"}}</tool_call>',
[1, 10, 20],
[1, 10, 20, 2],
[2],
self.dummy_request,
)
# Should handle end token
self.assertTrue(result is None or isinstance(result, DeltaMessage))
# Step 1: name sent
self.assertIsNotNone(results[0])
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
# Step 2: first-args branch, regex extracts '{"k": "v' as arguments_delta
self.assertIsNotNone(results[1])
self.assertEqual(results[1].tool_calls[0].function.arguments, '{"k": "v')
# Step 3: end token in delta triggers close handling
# delta before </tool_call> is '"}}', close branch: rindex('}')=2, diff='"}'
self.assertIsNotNone(results[2])
self.assertEqual(results[2].tool_calls[0].function.arguments, '"}')
# --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) ---
@@ -255,37 +293,29 @@ class TestErnieX1ToolParser(unittest.TestCase):
def test_streaming_continue_tool_call_no_name_yet(self):
"""Cover lines 174-176, 220-222: partial JSON without name yet"""
parser = self._new_parser()
# Start tool call
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
# Continue with partial content, no name parseable yet
result = parser.extract_tool_calls_streaming(
"<tool_call>",
'<tool_call>{"na',
'{"na',
[1],
[1, 10],
[10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
"<tool_call>", # start tool call
'{"na', # partial content, no name yet
],
)
self.assertIsNone(result)
self.assertIsNone(results[0])
self.assertIsNone(results[1])
def test_streaming_continue_tool_call_with_name(self):
"""Cover lines 174-176, 223-235: name becomes available"""
parser = self._new_parser()
# Start tool call
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
# Name appears
result = parser.extract_tool_calls_streaming(
"<tool_call>",
'<tool_call>{"name": "get_weather"',
'{"name": "get_weather"',
[1],
[1, 10],
[10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
"<tool_call>", # start tool call
'{"name": "get_weather"', # name appears
],
)
self.assertIsNotNone(result)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
self.assertIsNone(results[0])
self.assertIsNotNone(results[1])
self.assertEqual(results[1].tool_calls[0].function.name, "get_weather")
self.assertTrue(parser.current_tool_name_sent)
# --- Lines 236-237: name not sent and function_name is None ---
@@ -293,18 +323,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
def test_streaming_no_function_name(self):
"""Cover lines 236-237: parsed JSON has no 'name' field"""
parser = self._new_parser()
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
# Send JSON without name field
result = parser.extract_tool_calls_streaming(
"<tool_call>",
'<tool_call>{"arguments": {"k": "v"}}',
'{"arguments": {"k": "v"}}',
[1],
[1, 10],
[10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
"<tool_call>", # start tool call
'{"arguments": {"k": "v"}}', # JSON without name field
],
)
self.assertIsNone(result)
self.assertIsNone(results[1])
# --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) ---
@@ -333,9 +359,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
parser.streamed_args_for_tool = [""]
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name":"fn","arguments":{"k":"v"}}',
'<tool_call>{"name":"fn","arguments":{"k":"v"',
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call>',
'"}}</tool_call>',
"}}</tool_call>",
[1, 10],
[1, 10, 2],
[2],
@@ -343,9 +369,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
)
self.assertIsNotNone(result)
self.assertIsNotNone(result.tool_calls)
self.assertEqual(result.tool_calls[0].function.arguments, "}")
def test_streaming_close_with_diff_no_end_marker(self):
"""Cover lines 184-185: close with arguments but no '"}' in delta_text"""
def test_streaming_text_after_completed_tool_call(self):
"""Cover lines 143-147: text content after a completed tool call.
When start==end counts, prev_end==cur_end, and end_token not in delta,
the parser treats delta as regular text content.
"""
parser = self._new_parser()
parser.current_tool_id = 0
parser.current_tool_name_sent = True
@@ -353,7 +384,7 @@ class TestErnieX1ToolParser(unittest.TestCase):
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
# Simulate end token in delta but without '"}' pattern
# We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta
# so that we enter the elif at 178
# so that we enter the text-content branch at line 143-147
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call>',
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call> text',
@@ -363,8 +394,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
[30],
self.dummy_request,
)
# balanced counts, prev_end==cur_end, end not in delta -> returns content (line 147)
self.assertIsInstance(result, DeltaMessage)
# balanced counts, prev_end==cur_end, end not in delta -> returns content (line 149)
self.assertIsNotNone(result)
self.assertEqual(result.content, " text")
def test_streaming_close_no_arguments(self):
"""Cover lines 182-183: close branch where prev arguments is None/empty"""
@@ -382,8 +414,126 @@ class TestErnieX1ToolParser(unittest.TestCase):
[2],
self.dummy_request,
)
# diff is None (no arguments), so falls through to partial_json_parser
self.assertTrue(result is None or isinstance(result, DeltaMessage))
# diff is None (no arguments key in prev), falls through to partial_json_parser
# parses complete JSON, cur_args=None, prev_args=None -> no-args -> delta=None
self.assertIsNone(result)
def test_streaming_close_with_empty_dict_arguments(self):
"""Regression: close branch must handle arguments={} (empty dict).
Before fix, `if diff:` was False for empty dict {}, so the close
logic was skipped. After fix, `if diff is not None:` correctly
enters the branch.
"""
parser = self._new_parser()
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": ', # start + name + args key
"{}", # empty dict value
"}", # outer close brace
"</tool_call>", # end token
],
)
# Step 1: name sent
# Step 2: first-args, cur_args={} is not None, prev_args=None
# Without fix: not {} == True -> no-args branch -> returns None
# With fix: enters first-args -> streams "{}" -> DeltaMessage
self.assertIsNotNone(results[1])
self.assertIsNotNone(results[1].tool_calls)
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
def test_streaming_empty_arguments_with_outer_brace_in_same_token(self):
"""Regression: when arguments={} and outer } arrive in the same token '{}}',
regex (.*) over-captures the outer brace, producing '{}}'.
Real production data showed arguments='{}}}' for get_default_weather
with empty arguments. This test reproduces that exact scenario.
"""
parser = self._new_parser()
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "get_default_weather", "arguments": ', # start + name + args key
"{}}", # empty args + outer close brace in same token
"</tool_call>", # end token
],
)
# Step 1: name sent
self.assertIsNotNone(results[0])
self.assertEqual(results[0].tool_calls[0].function.name, "get_default_weather")
# Step 2: first-args branch, tool_call_portion is complete JSON
# regex (.*) captures '{}}' but fix strips outer '}' -> '{}'
self.assertIsNotNone(results[1])
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
# Step 3: end token, close branch
# diff = prev_arguments = {} (not None), delta_text = '' (empty after split)
# '}' not in '' -> returns None
self.assertIsNone(results[2])
def test_streaming_close_with_number_ending_arguments(self):
"""Regression: close branch must flush remaining args ending with number.
Before fix, '"}' not in delta was True for numbers, causing return None.
After fix, rindex('}') correctly finds the closing brace.
"""
parser = self._new_parser()
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": {"count": ', # start + name + args key
"123", # number value
"}}</tool_call>", # close braces + end token
],
)
# Step 1: name sent
# Step 2: first-args, streams {"count": 123
# Step 3: close branch flushes remaining "}"
streamed_args = [
r.tool_calls[0].function.arguments
for r in results
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
]
combined = "".join(streamed_args)
self.assertEqual(combined, '{"count": 123}')
def test_streaming_close_with_boolean_ending_arguments(self):
"""Regression: close branch must flush remaining args ending with boolean."""
parser = self._new_parser()
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": {"flag": ', # start + args key
"true", # boolean value
"}}</tool_call>", # close + end token
],
)
streamed_args = [
r.tool_calls[0].function.arguments
for r in results
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
]
combined = "".join(streamed_args)
self.assertEqual(combined, '{"flag": true}')
def test_streaming_close_with_nested_object_ending(self):
"""Regression: close branch must flush remaining args ending with nested '}'."""
parser = self._new_parser()
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": {"nested": {"a": ', # start + args key
"1", # nested value
"}}}</tool_call>", # close all + end token
],
)
streamed_args = [
r.tool_calls[0].function.arguments
for r in results
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
]
combined = "".join(streamed_args)
self.assertEqual(combined, '{"nested": {"a": 1}}')
# --- Lines 202-206: else branch (cur_start < cur_end, edge case) ---
@@ -404,23 +554,21 @@ class TestErnieX1ToolParser(unittest.TestCase):
def test_streaming_malformed_json(self):
"""Cover lines 213-215: MalformedJSON from partial parser"""
parser = self._new_parser()
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
# Feed badly formed content
result = parser.extract_tool_calls_streaming(
"<tool_call>",
"<tool_call>{{{",
"{{{",
[1],
[1, 10],
[10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
"<tool_call>", # start tool call
"{{{", # badly formed content
],
)
self.assertIsNone(result)
self.assertIsNone(results[1])
def test_streaming_json_decode_error(self):
"""Cover lines 216-218: JSONDecodeError from partial parser"""
parser = self._new_parser()
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
# Step 1: start tool call normally
self._simulate_streaming(parser, ["<tool_call>"])
# Step 2: mock partial_json_parser to throw ValueError
with patch(
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
side_effect=ValueError("bad json"),
@@ -430,8 +578,8 @@ class TestErnieX1ToolParser(unittest.TestCase):
"<tool_call>bad",
"bad",
[1],
[1, 10],
[10],
[1, 2],
[2],
self.dummy_request,
)
self.assertIsNone(result)
@@ -469,30 +617,17 @@ class TestErnieX1ToolParser(unittest.TestCase):
def test_streaming_first_arguments_with_regex_match(self):
"""Cover lines 243-244, 257-286: first arguments appear, regex matches"""
parser = self._new_parser()
# Start tool call and send name
parser.extract_tool_calls_streaming(
"",
'<tool_call>{"name": "get_weather"',
'<tool_call>{"name": "get_weather"',
[],
[1, 10],
[1, 10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "get_weather", "arguments": {"location": "', # start + name + args key
"bei", # args value
],
)
# Now stream arguments (first time)
# Key must be complete (closing quote) so partial_json_parser returns truthy arguments.
# delta must be a substring of the regex-extracted arguments portion (after "arguments":).
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "get_weather"',
'<tool_call>{"name": "get_weather", "arguments": {"location": "bei',
'"bei',
[1, 10],
[1, 10, 20],
[20],
self.dummy_request,
)
self.assertIsNotNone(result)
self.assertIsNotNone(result.tool_calls)
# Step 1: name sent
# Step 2: first-args, regex finds "bei" in '{"location": "bei'
self.assertIsNotNone(results[1])
self.assertEqual(results[1].tool_calls[0].function.arguments, '{"location": "bei')
def test_streaming_first_arguments_no_regex_match(self):
"""Cover lines 266-267: regex doesn't match, fallback to json.dumps"""
@@ -522,67 +657,119 @@ class TestErnieX1ToolParser(unittest.TestCase):
self.assertIsNotNone(result.tool_calls)
def test_streaming_first_arguments_delta_not_in_json(self):
"""Cover lines 271-272: delta_text not found in cur_arguments_json"""
"""Cover lines 275-276: delta_text not found in cur_arguments_json, returns None.
When delta contains the arguments key itself (e.g. ', "arguments": {'),
regex extracts cur_arguments_json='{' but delta ', "arguments": {' is not in '{'.
"""
parser = self._new_parser()
parser.extract_tool_calls_streaming(
"",
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn"',
[],
[1, 10],
[1, 10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn"', # start + partial name
', "arguments": {', # delta introduces arguments key + open brace
],
)
# Delta text that doesn't appear in the arguments JSON
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
"ZZZZZ",
[1, 10],
[1, 10, 20],
[20],
self.dummy_request,
)
self.assertIsNone(result)
# Step 1: name sent
self.assertIsNotNone(results[0])
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
# Step 2: first-args branch, regex extracts cur_arguments_json='{'
# delta_text=', "arguments": {' is NOT in '{' -> returns None
self.assertIsNone(results[1])
# --- Lines 249-251: no cur_arguments and no prev_arguments ---
def test_streaming_no_arguments_at_all(self):
"""Cover lines 249-251: both cur and prev arguments are empty/None"""
parser = self._new_parser()
parser.extract_tool_calls_streaming(
"",
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn"',
[],
[1, 10],
[1, 10],
self.dummy_request,
)
# Continue with name only, no arguments
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn"}',
"}",
[1, 10],
[1, 10, 20],
[20],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn"', # start + name
"}", # close JSON, no arguments
],
)
# prev_arguments=None, cur_arguments=None -> delta=None
# then prev_tool_call_arr updated and returns delta (which is None)
self.assertIsNone(result)
self.assertIsNone(results[1])
def test_streaming_empty_dict_arguments_not_skipped(self):
"""Regression: arguments={} (empty dict) must not be treated as no arguments.
Empty dict is falsy in Python (`not {} == True`). Before the fix,
this caused empty arguments to enter the no-arguments branch,
silently dropping them during streaming.
"""
parser = self._new_parser()
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": ', # start + name + args key
"{}", # empty dict value
"}", # outer close brace
],
)
# Step 1: name sent
# Step 2: cur_arguments={} (not None), prev_arguments=None
# With fix: enters first-arguments branch -> streams "{}"
# Without fix: not {} == True -> no-arguments branch -> delta=None
self.assertIsNotNone(results[1])
self.assertIsNotNone(results[1].tool_calls)
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
def test_streaming_empty_dict_prev_arguments_not_reset(self):
"""Regression: prev_arguments={} must not be treated as no arguments.
When prev has {} and cur has a non-empty dict, the code should enter
the both-have-arguments branch, not the first-arguments branch.
This scenario (arguments growing from {} to non-empty) is hard to
produce naturally, so we build up state through a real flow then
verify the branch behavior with one additional call.
"""
parser = self._new_parser()
# Build up state naturally: prev_tool_call_arr gets arguments={}
self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": ', # name + args key
"{}", # empty dict value
"}", # outer close
],
)
# Verify state is correct
self.assertEqual(parser.prev_tool_call_arr[0].get("arguments"), {})
# Now test: if more argument data arrives, prev_args={} should be
# treated as "not None" -> enters both-have-arguments branch
# Without fix: not {} == True -> first-arguments branch (wrong)
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn", "arguments": {"k": "v',
'<tool_call>{"name": "fn", "arguments": {"k": "val',
"al",
[1, 2, 3],
[1, 2, 3, 4],
[4],
self.dummy_request,
)
# both-have-arguments branch: delta_text="al" streamed as arguments
self.assertIsNotNone(result)
self.assertEqual(result.tool_calls[0].function.arguments, "al")
# --- Lines 253-255: cur_arguments reset (impossible branch) ---
def test_streaming_arguments_reset_mid_call(self):
"""Cover lines 253-255: prev has arguments but cur doesn't (impossible case)"""
"""Cover lines 253-255: prev has arguments but cur doesn't (impossible case).
This is an edge case that shouldn't happen in normal flow, but tests
defensive handling when partial parser returns no arguments after
previously having them.
"""
parser = self._new_parser()
parser.current_tool_id = 0
parser.current_tool_name_sent = True
parser.streamed_args_for_tool = [""]
# Simulate state where prev already had arguments
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
# Feed content where cur has no arguments but prev does
# Mock parser to return no arguments (simulating the impossible reset)
with patch(
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
return_value={"name": "fn"},
@@ -591,9 +778,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
'<tool_call>{"name": "fn", "arguments": {"k": "v"',
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
'"}',
[1, 10],
[1, 10, 20],
[20],
[1, 2],
[1, 2, 3],
[3],
self.dummy_request,
)
self.assertIsNone(result)
@@ -603,110 +790,48 @@ class TestErnieX1ToolParser(unittest.TestCase):
def test_streaming_incremental_arguments_incomplete(self):
"""Cover lines 288-314: both prev and cur have arguments, JSON incomplete"""
parser = self._new_parser()
parser.extract_tool_calls_streaming(
"",
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn"',
[],
[1, 10],
[1, 10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": {"k": "v', # start + name + first args
"a", # establishes prev_args
"l", # incremental: both-have-args
],
)
# First arguments - delta must appear in regex-extracted arguments portion
parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn", "arguments": {"k": "v',
'{"k": "v',
[1, 10],
[1, 10, 20],
[20],
self.dummy_request,
)
# More argument tokens (both prev and cur have arguments now)
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn", "arguments": {"k": "v',
'<tool_call>{"name": "fn", "arguments": {"k": "val',
"al",
[1, 10, 20],
[1, 10, 20, 30],
[30],
self.dummy_request,
)
self.assertIsNotNone(result)
self.assertEqual(result.tool_calls[0].function.arguments, "al")
# Step 1: name sent
# Step 2: first-args branch
# Step 3: both-have-args branch, streams "l"
self.assertIsNotNone(results[2])
self.assertEqual(results[2].tool_calls[0].function.arguments, "l")
def test_streaming_incremental_arguments_complete_json(self):
"""Cover lines 289-305: complete JSON with trailing }"""
parser = self._new_parser()
parser.extract_tool_calls_streaming(
"",
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn"',
[],
[1, 10],
[1, 10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": {"k": "v', # start + name + first args
"a", # establishes prev_args
'"}}', # completes JSON
],
)
# First arguments - delta must appear in regex-extracted arguments portion
parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn", "arguments": {"k": "v',
'{"k": "v',
[1, 10],
[1, 10, 20],
[20],
self.dummy_request,
)
# Complete with closing braces - both prev and cur have arguments
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn", "arguments": {"k": "v',
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
'"}}',
[1, 10, 20],
[1, 10, 20, 30],
[30],
self.dummy_request,
)
# is_complete_json=True, delta ends with }, should strip trailing }
# After strip: '"' which is not empty, so returns DeltaMessage
self.assertIsNotNone(result)
self.assertIsInstance(result, DeltaMessage)
# Step 3: both-have-args, complete JSON, strips trailing } -> streams '"}'
self.assertIsNotNone(results[2])
self.assertIsInstance(results[2], DeltaMessage)
def test_streaming_incremental_arguments_complete_empty_delta(self):
"""Cover lines 304-305: complete JSON where delta becomes empty after strip"""
parser = self._new_parser()
parser.extract_tool_calls_streaming(
"",
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn"',
[],
[1, 10],
[1, 10],
self.dummy_request,
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": {"k": "v"', # start + name + first args
"}", # inner close (establishes prev_args)
"}", # outer close: both-have-args, complete, delta stripped to ""
],
)
# First arguments with proper delta
parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn"',
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
'{"k": "v"}',
[1, 10],
[1, 10, 20],
[20],
self.dummy_request,
)
# Send just the outer closing brace
# tool_call_portion becomes complete JSON, delta="}" stripped to "" -> return None
result = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
"}",
[1, 10, 20],
[1, 10, 20, 30],
[30],
self.dummy_request,
)
# is_complete_json=True, delta="}" -> stripped to "" -> return None
self.assertIsNone(result)
# Step 3: is_complete_json=True, delta="}" -> stripped to "" -> return None
self.assertIsNone(results[2])
# --- Lines 316-319: prev_tool_call_arr update branches ---
@@ -759,95 +884,71 @@ class TestErnieX1ToolParser(unittest.TestCase):
def test_streaming_full_flow(self):
"""Integration test: simulate a full streaming tool call flow"""
parser = self._new_parser()
req = self.dummy_request
# Step 1: text before tool call
r = parser.extract_tool_calls_streaming("", "thinking", "thinking", [], [], [], req)
self.assertEqual(r.content, "thinking")
# Step 2: tool_call start token
r = parser.extract_tool_calls_streaming("thinking", "thinking<tool_call>", "<tool_call>", [], [1], [1], req)
self.assertIsNone(r)
# Step 3: function name appears
r = parser.extract_tool_calls_streaming(
"thinking<tool_call>",
'thinking<tool_call>{"name": "search"',
'{"name": "search"',
[1],
[1, 10],
[10],
req,
results = self._simulate_streaming(
parser,
[
"thinking", # Step 1: text before tool call
"<tool_call>", # Step 2: tool_call start token
'{"name": "search", "arguments": {"query": "', # Step 3: name + args key
"test", # Step 4: args value
" data", # Step 5: more args
],
)
self.assertIsNotNone(r)
self.assertEqual(r.tool_calls[0].function.name, "search")
# Step 4: arguments start - delta must appear in regex-extracted arguments portion
r = parser.extract_tool_calls_streaming(
'thinking<tool_call>{"name": "search"',
'thinking<tool_call>{"name": "search", "arguments": {"query": "test',
'{"query": "test',
[1, 10],
[1, 10, 20],
[20],
req,
)
self.assertIsNotNone(r)
# Step 1: plain text
self.assertEqual(results[0].content, "thinking")
# Step 2: start token -> None
self.assertIsNone(results[1])
# Step 3: name sent
self.assertIsNotNone(results[2])
self.assertEqual(results[2].tool_calls[0].function.name, "search")
# Step 4: first arguments
self.assertIsNotNone(results[3])
self.assertEqual(results[3].tool_calls[0].function.arguments, '{"query": "test')
# Step 5: more arguments
r = parser.extract_tool_calls_streaming(
'thinking<tool_call>{"name": "search", "arguments": {"query": "test',
'thinking<tool_call>{"name": "search", "arguments": {"query": "test data',
" data",
[1, 10, 20],
[1, 10, 20, 30],
[30],
req,
self.assertIsNotNone(results[4])
self.assertEqual(results[4].tool_calls[0].function.arguments, " data")
def test_streaming_empty_arguments_full_flow(self):
"""Integration: streaming tool call with arguments={} must not lose arguments.
Simulates a complete streaming flow where the tool call has empty
arguments. Verifies the name is sent and arguments are streamed.
"""
parser = self._new_parser()
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn", "arguments": ', # Step 1: start + name + args key
"{}", # Step 2: empty dict value
"}", # Step 3: outer close
"</tool_call>", # Step 4: end token
],
)
self.assertIsNotNone(r)
self.assertEqual(r.tool_calls[0].function.arguments, " data")
# Step 1: name sent
self.assertIsNotNone(results[0])
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
# Step 2: first-args with cur_args={}, streams "{}"
self.assertIsNotNone(results[1])
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
# Step 4: close branch, delta_text="" after stripping </tool_call>
# diff={} is not None, but "}" not in "" -> return None
self.assertIsNone(results[2])
self.assertIsNone(results[3])
def test_streaming_multiple_tool_calls(self):
"""Integration test: two tool calls in one response"""
parser = self._new_parser()
req = self.dummy_request
# First tool call
parser.extract_tool_calls_streaming(
"",
'<tool_call>{"name": "fn1"',
'<tool_call>{"name": "fn1"',
[],
[1, 10],
[1, 10],
req,
)
self.assertEqual(parser.current_tool_id, 0)
# Close first tool
parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn1"',
'<tool_call>{"name": "fn1"}</tool_call>',
"}</tool_call>",
[1, 10],
[1, 10, 2],
[2],
req,
)
# Second tool call
r = parser.extract_tool_calls_streaming(
'<tool_call>{"name": "fn1"}</tool_call>',
'<tool_call>{"name": "fn1"}</tool_call><tool_call>{"name": "fn2"',
'<tool_call>{"name": "fn2"',
[1, 10, 2],
[1, 10, 2, 1, 20],
[1, 20],
req,
results = self._simulate_streaming(
parser,
[
'<tool_call>{"name": "fn1"', # First tool: start + name
"}</tool_call>", # Close first tool
'<tool_call>{"name": "fn2"', # Second tool: start + name
],
)
self.assertEqual(parser.current_tool_id, 1)
self.assertIsNotNone(r)
self.assertEqual(r.tool_calls[0].function.name, "fn2")
self.assertIsNotNone(results[2])
self.assertEqual(results[2].tool_calls[0].function.name, "fn2")
if __name__ == "__main__":
@@ -91,10 +91,10 @@ class TestModel1(paddle.nn.Layer):
return sublayer2_output
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
""" """
self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config)
self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config)
self.sublayer1.clear_graph_opt_backend(fd_config=self.fd_config)
self.sublayer2.clear_graph_opt_backend(fd_config=self.fd_config)
class TestCUDAGrpahRecapture(unittest.TestCase):
@@ -152,7 +152,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
# Destroy
print_gpu_memory_use("before destroy", 0)
self.test_model1.clear_grpah_opt_backend()
self.test_model1.clear_graph_opt_backend()
print_gpu_memory_use("after destroy", 0)
def recapture_and_replay(self, input_tensor1, forward_meta1):
@@ -168,7 +168,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
# Destroy
print_gpu_memory_use("before destroy", 0)
self.test_model1.clear_grpah_opt_backend()
self.test_model1.clear_graph_opt_backend()
print_gpu_memory_use("after destroy", 0)
+1 -1
View File
@@ -111,7 +111,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
self._fdconfig_patches = [
patch.object(FDConfig, "read_from_config", return_value=None),
patch.object(FDConfig, "postprocess", return_value=None),
patch.object(FDConfig, "init_cache_info", return_value=None),
patch.object(FDConfig, "init_pd_info", return_value=None),
patch.object(FDConfig, "check", return_value=None),
]
for patcher in self._fdconfig_patches:
@@ -42,7 +42,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
return paddle_inputs
def run_kernel(paddle_inputs, inputs):
def run_kernel(paddle_inputs):
"""Call the CUDA kernel."""
speculate_set_stop_value_multi_seqs(
paddle_inputs["accept_tokens"],
@@ -137,7 +137,18 @@ def gen_inputs(
def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Python reference — must match CUDA kernel logic exactly."""
"""Python reference — must match CUDA kernel logic exactly.
token_ids_all 布局 ( step_idx 语义):
pre_ids_now[k] = k output token (k >= 0, 0-indexed)
最后一个 output token pre_ids_now[step_idx - 1]
step_idx = 历史已生成的 token 数量
核心设计:
1. accept_idx -1 开始-1 表示检查 pre_ids 末尾上一轮延迟的情况
2. 主循环检查 accept_idx <= accept_num - 2
3. 匹配成功时: 保留 stop_seq 所有 token在其后追加 eos
"""
accept_tokens = inputs["accept_tokens"].copy()
accept_num = inputs["accept_num"].copy()
stop_flags = inputs["stop_flags"].copy()
@@ -166,27 +177,36 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
step_idx_now = int(step_idx[bid])
min_token_limit = int(min_tokens[bid])
can_stop = step_idx_now >= min_token_limit
can_stop = step_idx_now + an >= min_token_limit
if not can_stop:
continue
if stop_flags[bid]:
continue
accept_idx = 0
# CUDA kernel: accept_idx 从 -1 开始,检查 pre_ids 末尾
accept_idx = -1
is_end = False
while accept_idx <= an - 1 and not is_end:
# loop_end = accept_num > 0 ? accept_num - 2 : -1
loop_end = an - 2 if an > 0 else -1
while accept_idx <= loop_end and not is_end:
if step_idx_now + accept_idx + 1 < stop_seq_len:
accept_idx += 1
continue
# Check one stop_seq match
# 从后向前匹配 stop_seq 的每个 token
for i in range(stop_seq_len - 1, -1, -1):
offset = stop_seq_len - 1 - i
accept_tokens_idx = accept_idx - offset
cur_token_idx = -1
if stop_seq_len - 1 - i < accept_idx:
cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]
if accept_tokens_idx >= 0:
cur_token_idx = accept_tokens_now[accept_tokens_idx]
else:
pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i)
if pre_ids_idx <= 0:
# 新语义: pre_ids_idx = step_idx_now + accept_tokens_idx
# pre_ids_now[0] 是第 1 个 output token
pre_ids_idx = step_idx_now + accept_tokens_idx
if pre_ids_idx < 0:
break
cur_token_idx = pre_ids_now[pre_ids_idx]
@@ -199,9 +219,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
accept_idx += 1
if is_end:
accept_num[bid] = accept_idx
accept_tokens[bid, accept_idx - 1] = end_ids[0]
# stop_flags[bid] = True # kernel no longer sets stop_flags
# accept_idx 已递增,指向 stop_seq 最后 token 的下一个位置
# 保留 stop_seq 所有 token,在其后追加 eos
accept_num[bid] = accept_idx + 1
accept_tokens[bid, accept_idx] = end_ids[0]
return {
"accept_tokens": accept_tokens,
@@ -239,7 +260,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
def _run_and_get(self, inputs):
paddle_inputs = to_paddle_inputs(inputs)
run_kernel(paddle_inputs, inputs)
run_kernel(paddle_inputs)
return get_outputs(paddle_inputs)
def _check_all_outputs(self, inputs, outputs):
@@ -264,7 +285,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
self._run_full_test(test_cfg)
def test_match_in_accept_tokens_only(self):
"""Stop seq found entirely within accept_tokens."""
"""Stop seq found entirely within accept_tokens. Eos appended after stop_seq last token."""
inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10)
# Place stop seq [A, B, C] at accept_tokens positions [0,1,2]
inputs["accept_num"][:] = 4
@@ -276,9 +297,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["min_tokens"][:] = 0
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
# stop_seq [10, 20, 30] matches at accept_idx=2 (window ends at accept_tokens[2]=30)
# After loop increment, accept_idx=3, accept_num=4, eos appended at accept_tokens[3]
self.assertEqual(outputs["accept_num"][0], 4)
self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq
def test_match_spanning_pre_ids_and_accept(self):
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens."""
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens. Eos appended after stop_seq last token."""
inputs = gen_inputs(
real_bsz=1,
accept_tokens_len=5,
@@ -290,12 +315,15 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["prompt_lens"][:] = 0
inputs["step_idx"][:] = 6
inputs["accept_num"][:] = 3
# Kernel matching at accept_idx=2 (3rd token, 0-indexed):
# i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1]
# i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0]
# i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6]
# So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]]
inputs["token_ids_all"][0, 6] = 99
# stop_seq = [99, 11, 22] (len=3)
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
# pre_ids_now[k] = 第 k 个 output token (k >= 0)
# step_idx = 6 表示有 6 个历史 output token,在 pre_ids_now[0..5]
# At accept_idx=1 (window ends at accept_tokens[1]=22):
# i=2: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=22 vs stop_seq[2]=22 ✓
# i=1: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=11 vs stop_seq[1]=11 ✓
# i=0: offset=2, accept_tokens_idx=-1 -> pre_ids_idx=6+(-1)=5 -> pre_ids[5]=99 vs stop_seq[0]=99 ✓
inputs["token_ids_all"][0, 5] = 99 # pre_ids_now[5] = 第 6 个 output token (0-indexed)
inputs["accept_tokens"][0, :3] = [11, 22, 33]
inputs["stop_seqs"][0, 0, :3] = [99, 11, 22]
inputs["stop_seqs_len"][0, 0] = 3
@@ -303,12 +331,14 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["min_tokens"][:] = 0
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
# Match at accept_idx=2, loop increments to 3
# Match at accept_idx=1, loop increments to 2 -> accept_num=3, eos at accept_tokens[2]
self.assertEqual(outputs["accept_num"][0], 3)
self.assertEqual(outputs["accept_tokens"][0, 2], -1)
self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq
def test_match_in_pre_ids_only(self):
"""Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
def test_match_in_pre_ids_only_not_detected(self):
"""Stop seq ending purely in pre_ids history but NOT at the end position.
The kernel only detects stop_seq at the very end of pre_ids via accept_idx=-1 check.
Stop seq placed earlier in pre_ids should not be detected."""
inputs = gen_inputs(
real_bsz=1,
accept_tokens_len=5,
@@ -320,15 +350,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["prompt_lens"][:] = 0
inputs["step_idx"][:] = 8
inputs["accept_num"][:] = 3
# pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70
# stop_seq = [50, 60, 70], all 3 tokens are in pre_ids
# For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check
# i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70
# i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60
# i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50
inputs["token_ids_all"][0, 6] = 50
inputs["token_ids_all"][0, 7] = 60
inputs["token_ids_all"][0, 8] = 70
# 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0)
# step_idx = 8 表示有 8 个历史 output token,在 pre_ids_now[0..7]
# accept_idx=-1 会检查 pre_ids_now[7] 开始的 stop_seq
# 把 stop_seq 放在 pre_ids_now[2,3,4] - 不会被检测到
inputs["token_ids_all"][0, 2] = 50
inputs["token_ids_all"][0, 3] = 60
inputs["token_ids_all"][0, 4] = 70
inputs["accept_tokens"][0, :3] = [1, 2, 3]
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
inputs["stop_seqs_len"][0, 0] = 3
@@ -336,7 +364,8 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["min_tokens"][:] = 0
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
self.assertEqual(outputs["accept_num"][0], 1)
# No match: stop_seq is in pre_ids but not at the end, accept_num unchanged
self.assertEqual(outputs["accept_num"][0], 3)
def test_already_stopped(self):
"""Kernel skips sequences with stop_flags=True."""
@@ -351,7 +380,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"])
def test_min_tokens_blocks_stop(self):
"""Kernel skips stop check when step_idx < min_tokens."""
"""Kernel skips stop check when step_idx + accept_num < min_tokens."""
inputs = gen_inputs(
real_bsz=1,
accept_tokens_len=5,
@@ -363,20 +392,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["prompt_lens"][:] = 0
inputs["step_idx"][:] = 8
inputs["accept_num"][:] = 3
# Same setup that would match (like test_match_in_pre_ids_only)
inputs["token_ids_all"][0, 6] = 50
inputs["token_ids_all"][0, 7] = 60
inputs["token_ids_all"][0, 8] = 70
# Place stop_seq in pre_ids at end position (would be detected by accept_idx=-1)
# pre_ids_now[0..7] = 8 个历史 output token
# accept_idx=-1 检查 pre_ids_now[5,6,7] 对应 stop_seq[0,1,2]
inputs["token_ids_all"][0, 5] = 50
inputs["token_ids_all"][0, 6] = 60
inputs["token_ids_all"][0, 7] = 70
inputs["accept_tokens"][0, :3] = [1, 2, 3]
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
inputs["stop_seqs_len"][0, 0] = 3
inputs["stop_flags"][:] = False
inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop
inputs["min_tokens"][:] = 100 # step_idx+accept_num=11 < 100, should NOT stop
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
# min_tokens prevents stop, accept_num unchanged
self.assertEqual(outputs["accept_num"][0], 3)
def test_min_tokens_allows_stop(self):
"""Kernel allows stop when step_idx >= min_tokens."""
"""Kernel allows stop when step_idx + accept_num >= min_tokens."""
inputs = gen_inputs(
real_bsz=1,
accept_tokens_len=5,
@@ -388,15 +421,17 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["prompt_lens"][:] = 0
inputs["step_idx"][:] = 8
inputs["accept_num"][:] = 3
# Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only)
inputs["token_ids_all"][0, 6] = 50
inputs["token_ids_all"][0, 7] = 60
inputs["token_ids_all"][0, 8] = 70
inputs["accept_tokens"][0, :3] = [1, 2, 3]
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
inputs["stop_seqs_len"][0, 0] = 3
# stop_seq [X, 50] spans pre_ids and accept_tokens[0].
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
# At accept_idx=0 (window ends at accept_tokens[0]=50):
# i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=50 vs stop_seq[1]=50 ✓
# i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=8+(-1)=7 -> pre_ids[7]
pre_val = int(inputs["token_ids_all"][0, 7]) # pre_ids_now[7]
inputs["accept_tokens"][0, :3] = [50, 60, 70]
inputs["stop_seqs"][0, 0, :2] = [pre_val, 50]
inputs["stop_seqs_len"][0, 0] = 2
inputs["stop_flags"][:] = False
inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop
inputs["min_tokens"][:] = 5 # step_idx+accept_num=11 >= 5, should stop
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
@@ -413,20 +448,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["prompt_lens"][:] = 0
inputs["step_idx"][:] = 8
inputs["accept_num"][:] = 3
# accept_tokens: stop_seq[20,30] matches at accept_idx=2:
# i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK
# i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK
# accept_tokens: [20, 30, 40]
# Second stop seq [20, 30] matches at accept_idx=1 (window ends at accept_tokens[1]=30):
# i=1: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=30 vs stop_seq[1]=30
# i=0: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=20 vs stop_seq[0]=20 ✓
inputs["accept_tokens"][0, :3] = [20, 30, 40]
# First stop seq doesn't match
inputs["stop_seqs"][0, 0, :3] = [99, 98, 97]
inputs["stop_seqs_len"][0, 0] = 3
# Second stop seq matches
# Second stop seq [20, 30] matches
inputs["stop_seqs"][0, 1, :2] = [20, 30]
inputs["stop_seqs_len"][0, 1] = 2
inputs["stop_flags"][:] = False
inputs["min_tokens"][:] = 0
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
# Match at accept_idx=1 -> accept_num=3, eos at accept_tokens[2]
self.assertEqual(outputs["accept_num"][0], 3)
self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq
def test_nonzero_prompt_lens(self):
"""Verify prompt_lens offset is applied correctly."""
@@ -444,19 +483,104 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
inputs["accept_num"][:] = 2
inputs["accept_tokens"][0, :2] = [55, 66]
# pre_ids_now starts at token_ids_all[0, prompt_len:]
# stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx]
# For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4
# -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4]
# For accept_idx=1 (second token is accept_tokens[0,0]=55):
# i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55
# i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5]
target_val = int(inputs["token_ids_all"][0, prompt_len + 5])
# pre_ids_now[k] = 第 k 个 output token (k >= 0)
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
# stop_seq = [X, 55] where X = pre_ids_now[5 + (-1)] = pre_ids_now[4]
# At accept_idx=0 (window ends at accept_tokens[0]=55):
# i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=55 vs stop_seq[1]=55 ✓
# i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=5+(-1)=4 -> pre_ids[4]=token_ids_all[0, prompt_len+4]
target_val = int(inputs["token_ids_all"][0, prompt_len + 4])
inputs["stop_seqs"][0, 0, :2] = [target_val, 55]
inputs["stop_seqs_len"][0, 0] = 2
inputs["stop_flags"][:] = False
inputs["min_tokens"][:] = 0
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
# Match at accept_idx=0 -> accept_num=2, eos at accept_tokens[1]
self.assertEqual(outputs["accept_num"][0], 2)
self.assertEqual(outputs["accept_tokens"][0, 1], -1) # eos appended after stop_seq
def test_single_token_stop_seq_preserved(self):
"""Single token stop_seq (like <|im_end|>) with eos appended after it."""
inputs = gen_inputs(
real_bsz=1,
accept_tokens_len=5,
max_model_len=32,
stop_seqs_bs=1,
stop_seqs_max_len=1,
seed=90,
)
inputs["prompt_lens"][:] = 0
inputs["step_idx"][:] = 10
inputs["accept_num"][:] = 4
# accept_tokens: [a, b, <|im_end|>, d] where <|im_end|> has token id 999
inputs["accept_tokens"][0, :4] = [100, 200, 999, 300]
# stop_seq = [<|im_end|>] (single token)
inputs["stop_seqs"][0, 0, 0] = 999
inputs["stop_seqs_len"][0, 0] = 1
inputs["stop_flags"][:] = False
inputs["min_tokens"][:] = 0
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
# Match at accept_idx=2 (window ends at accept_tokens[2]=999)
# After loop increment, accept_idx=3, accept_num=4, eos at accept_tokens[3]
self.assertEqual(outputs["accept_num"][0], 4)
self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq
def test_stop_seq_at_last_position_not_detected(self):
"""Stop seq at the last position of accept_tokens is NOT detected (deferred to next round)."""
inputs = gen_inputs(
real_bsz=1,
accept_tokens_len=5,
max_model_len=32,
stop_seqs_bs=1,
stop_seqs_max_len=1,
seed=100,
)
inputs["prompt_lens"][:] = 0
inputs["step_idx"][:] = 10
inputs["accept_num"][:] = 4
# stop_seq [999] is at accept_tokens[3] (last valid position)
# Since we only check up to accept_num - 2 = 2, this won't be detected
inputs["accept_tokens"][0, :4] = [100, 200, 300, 999]
inputs["stop_seqs"][0, 0, 0] = 999
inputs["stop_seqs_len"][0, 0] = 1
inputs["stop_flags"][:] = False
inputs["min_tokens"][:] = 0
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
# No match because accept_idx only goes up to 2, and 999 is at position 3
# accept_num unchanged
self.assertEqual(outputs["accept_num"][0], 4)
def test_stop_seq_detected_from_previous_round(self):
"""Stop seq at the end of pre_ids (from previous round) is detected via accept_idx=-1."""
inputs = gen_inputs(
real_bsz=1,
accept_tokens_len=5,
max_model_len=32,
stop_seqs_bs=1,
stop_seqs_max_len=1,
seed=110,
)
inputs["prompt_lens"][:] = 0
# 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0)
# step_idx = 10 表示有 10 个历史 output token,在 pre_ids_now[0..9]
# accept_idx=-1 检查 pre_ids_now[9] (最后一个历史 token)
inputs["step_idx"][:] = 10
inputs["token_ids_all"][0, 9] = 999 # pre_ids_now[9] = 第 10 个 output token (0-indexed)
inputs["accept_num"][:] = 3
inputs["accept_tokens"][0, :3] = [100, 200, 300]
inputs["stop_seqs"][0, 0, 0] = 999
inputs["stop_seqs_len"][0, 0] = 1
inputs["stop_flags"][:] = False
inputs["min_tokens"][:] = 0
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
# stop_seq [999] was in pre_ids at end, accept_idx=-1 matches
# After loop increment, accept_idx=0, accept_num=1, eos at accept_tokens[0]
self.assertEqual(outputs["accept_num"][0], 1)
self.assertEqual(outputs["accept_tokens"][0, 0], -1) # replaced with eos
if __name__ == "__main__":
@@ -261,7 +261,9 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]:
# Write history to token_ids_all (forward loop, mirrors kernel step 5)
if output_len > 0:
base_addr = int(prompt_lens[batch_id])
base = cur_step_idx - output_len + 1
# 新语义: step_idx 入口 = 历史数量,处理后 cur_step_idx = 历史 + output_len
# 第一个 output token 写入位置 = cur_step_idx - output_len
base = cur_step_idx - output_len
for i in range(output_len):
write_idx = base_addr + base + i
if 0 <= write_idx < max_model_len:
+26
View File
@@ -411,6 +411,32 @@ class TestDPLocalScheduler(unittest.TestCase):
self.assertEqual(scheduler.ids, ["fresh_req"])
self.assertEqual(scheduler.ids_read_cursor, 1)
def test_get_requests_insufficient_resources(self):
"""Test getting requests when resources are insufficient."""
mock_logger.reset_mock()
# Test with insufficient blocks - mock the condition variable to avoid threading issues
with patch.object(self.scheduler, "requests_not_empty"):
requests = self.scheduler.get_requests(
available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
)
self.assertEqual(requests, [])
# The logger should have been called for insufficient resources
self.assertTrue(mock_logger.debug.called)
# Check the message contains expected content
call_args = mock_logger.debug.call_args[0][0]
self.assertIn("insufficient", call_args.lower())
def test_get_requests_insufficient_batch(self):
"""Test getting requests when batch size is insufficient."""
with patch.object(self.scheduler, "requests_not_empty"):
requests = self.scheduler.get_requests(
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0
)
self.assertEqual(requests, [])
@patch("time.time")
@patch.object(dp_scheduler_module, "envs")
def test_get_requests_no_requests_available(self, mock_envs, mock_time):
@@ -25,9 +25,6 @@ class DummyEngine:
"""Dummy Engine class to simulate the actual Engine for testing."""
class ResourceManager:
def __init__(self):
self.waiting = []
def available_batch(self):
return 4
+1 -1
View File
@@ -138,7 +138,7 @@ class TestConfig(unittest.TestCase):
model_config=model_config,
test_mode=True,
)
fd_config.init_cache_info()
fd_config.init_pd_info()
assert fd_config.register_info is not None
def test_fdconfig_postprocess_ports(self):
+2 -2
View File
@@ -487,7 +487,7 @@ class TestSleepWakeupBehavior(unittest.TestCase):
runner.local_rank = 0
runner.device_id = 1
runner.num_gpu_blocks = 8
runner.model = Mock(clear_grpah_opt_backend=Mock())
runner.model = Mock(clear_graph_opt_backend=Mock())
runner.clear_cache = Mock()
runner.initialize_kv_cache = Mock()
runner.capture_model = Mock()
@@ -523,7 +523,7 @@ class TestSleepWakeupBehavior(unittest.TestCase):
runner.sleep("weight,kv_cache")
runner.model.clear_grpah_opt_backend.assert_called_once()
runner.model.clear_graph_opt_backend.assert_called_once()
runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once()
runner.dynamic_weight_manager.clear_model_weight.assert_called_once()
runner.dynamic_weight_manager.clear_communication_group.assert_called_once()