mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
Merge origin/release/2.6 and resolve worker_process conflict
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel(
|
|||||||
const int max_batch_size) {
|
const int max_batch_size) {
|
||||||
constexpr int VecSize = 4;
|
constexpr int VecSize = 4;
|
||||||
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
|
const uint32_t tid = threadIdx.x, bid = blockIdx.x;
|
||||||
int startend_row_vec[4];
|
int startend_row_vec[2];
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
cudaGridDependencySynchronize();
|
cudaGridDependencySynchronize();
|
||||||
#endif
|
#endif
|
||||||
@@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel(
|
|||||||
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
|
const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start;
|
||||||
|
|
||||||
startend_row_vec[0] = this_batch_q_end;
|
startend_row_vec[0] = this_batch_q_end;
|
||||||
startend_row_vec[1] = cu_seqlens_q[max_batch_size];
|
// startend_row_vec[1] = cu_seqlens_q[max_batch_size];
|
||||||
startend_row_vec[2] = 0;
|
// startend_row_vec[2] = 0;
|
||||||
startend_row_vec[3] = this_batch_q_end;
|
startend_row_vec[1] = this_batch_q_end;
|
||||||
for (int this_batch_q_idx = this_batch_q_start;
|
for (int this_batch_q_idx = this_batch_q_start;
|
||||||
this_batch_q_idx < this_batch_q_end;
|
this_batch_q_idx < this_batch_q_end;
|
||||||
++this_batch_q_idx) {
|
++this_batch_q_idx) {
|
||||||
@@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel(
|
|||||||
: this_batch_q_idx - this_batch_q_start + kv_len -
|
: this_batch_q_idx - this_batch_q_start + kv_len -
|
||||||
(this_batch_q_len);
|
(this_batch_q_len);
|
||||||
if (cache_k_idx <= append_mask_k_end) {
|
if (cache_k_idx <= append_mask_k_end) {
|
||||||
startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx);
|
startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx);
|
||||||
// 可提前跳出循环
|
// 可提前跳出循环
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
reinterpret_cast<int4*>(startend_row_indices_ptr +
|
reinterpret_cast<int2*>(startend_row_indices_ptr +
|
||||||
cu_seqlens_k_idx * 4)[0] =
|
cu_seqlens_k_idx * 2)[0] =
|
||||||
reinterpret_cast<int4*>(startend_row_vec)[0];
|
reinterpret_cast<int2*>(startend_row_vec)[0];
|
||||||
}
|
}
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
cudaTriggerProgrammaticLaunchCompletion();
|
cudaTriggerProgrammaticLaunchCompletion();
|
||||||
@@ -82,7 +82,7 @@ std::vector<paddle::Tensor> get_attn_mask_q(
|
|||||||
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
||||||
const int kv_token_num) {
|
const int kv_token_num) {
|
||||||
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
|
paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor(
|
||||||
{1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place());
|
{1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place());
|
||||||
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
|
const int max_batch_size = cu_seqlens_k.dims()[0] - 1;
|
||||||
constexpr int block_size = 512;
|
constexpr int block_size = 512;
|
||||||
int grid_size = div_up(kv_token_num, block_size);
|
int grid_size = div_up(kv_token_num, block_size);
|
||||||
@@ -123,7 +123,7 @@ std::vector<std::vector<int64_t>> GetAttnMaskQInferShape(
|
|||||||
const std::vector<int64_t>& cu_seqlens_k_shape,
|
const std::vector<int64_t>& cu_seqlens_k_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
|
const paddle::optional<std::vector<int64_t>>& attn_mask_kv_shape,
|
||||||
const int kv_token_num) {
|
const int kv_token_num) {
|
||||||
return {{1, 1, kv_token_num, 4}};
|
return {{1, 1, kv_token_num, 2}};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_attn_mask_q)
|
PD_BUILD_STATIC_OP(get_attn_mask_q)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
|||||||
int64_t* next_tokens, // [bs, tokens_per_step]
|
int64_t* next_tokens, // [bs, tokens_per_step]
|
||||||
const int* max_think_lens, // [bs]
|
const int* max_think_lens, // [bs]
|
||||||
int* max_reply_lens, // [bs]
|
int* max_reply_lens, // [bs]
|
||||||
int64_t* step_idx, // [bs]
|
const int64_t* step_idx, // [bs]
|
||||||
const int64_t* eos_token_ids, // [eos_len]
|
const int64_t* eos_token_ids, // [eos_len]
|
||||||
int* limit_status, // [bs]
|
int* limit_status, // [bs]
|
||||||
int* accept_num, // [bs]
|
int* accept_num, // [bs]
|
||||||
@@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
|||||||
int new_accept_num = original_accept_num;
|
int new_accept_num = original_accept_num;
|
||||||
|
|
||||||
// 本 step 的 token offset 对应的绝对 step
|
// 本 step 的 token offset 对应的绝对 step
|
||||||
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
|
const int64_t current_base_step = step_idx[bid] + 1;
|
||||||
|
|
||||||
for (int token_offset = 0; token_offset < original_accept_num;
|
for (int token_offset = 0; token_offset < original_accept_num;
|
||||||
token_offset++) {
|
token_offset++) {
|
||||||
@@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
|||||||
// inject_token_ids[0])
|
// inject_token_ids[0])
|
||||||
if (status == 0 &&
|
if (status == 0 &&
|
||||||
(current_step - 1) ==
|
(current_step - 1) ==
|
||||||
max_think_len) { // current_step - 1 是因为 speculate_verify 里
|
max_think_len) { // current_step - 1 : 已输出 current_step-1
|
||||||
// step_idx + 1 了
|
// 个thinking token
|
||||||
status = (inject_len > 0) ? 1 : done_status;
|
status = (inject_len > 0) ? 1 : done_status;
|
||||||
}
|
}
|
||||||
} else if (max_think_len == 0) {
|
} else if (max_think_len == 0) {
|
||||||
@@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新 step_idx / accept_num(被截断的 token 需要回退
|
|
||||||
// step_idx)
|
|
||||||
const int discarded_tokens = original_accept_num - new_accept_num;
|
|
||||||
if (discarded_tokens > 0) {
|
|
||||||
step_idx[bid] -= discarded_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
accept_num[bid] = new_accept_num;
|
accept_num[bid] = new_accept_num;
|
||||||
limit_status[bid] = status;
|
limit_status[bid] = status;
|
||||||
max_reply_lens[bid] = max_reply_len;
|
max_reply_lens[bid] = max_reply_len;
|
||||||
@@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength(
|
|||||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||||
max_think_lens.data<int>(),
|
max_think_lens.data<int>(),
|
||||||
const_cast<int*>(max_reply_lens.data<int>()),
|
const_cast<int*>(max_reply_lens.data<int>()),
|
||||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
step_idx.data<int64_t>(),
|
||||||
eos_token_ids.data<int64_t>(),
|
eos_token_ids.data<int64_t>(),
|
||||||
const_cast<int*>(limit_status.data<int>()),
|
const_cast<int*>(limit_status.data<int>()),
|
||||||
const_cast<int*>(accept_num.data<int>()),
|
const_cast<int*>(accept_num.data<int>()),
|
||||||
|
|||||||
@@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
|||||||
const int64_t step_idx_now = step_idx[bid];
|
const int64_t step_idx_now = step_idx[bid];
|
||||||
const int64_t min_token_limit = min_tokens[bid];
|
const int64_t min_token_limit = min_tokens[bid];
|
||||||
|
|
||||||
const bool can_stop = (step_idx_now >= min_token_limit);
|
const bool can_stop = (step_idx_now + accept_num >= min_token_limit);
|
||||||
if (!can_stop) return;
|
if (!can_stop) return;
|
||||||
if (!stop_flags[bid]) {
|
if (!stop_flags[bid]) {
|
||||||
int accept_idx = 0;
|
/*
|
||||||
|
accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based)
|
||||||
|
accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾
|
||||||
|
(pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。
|
||||||
|
为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1]
|
||||||
|
(当前轮最后一个 token),该 token 延迟到下一轮匹配。
|
||||||
|
循环范围:accept_num > 0 时为 [-1, accept_num-2];
|
||||||
|
accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。
|
||||||
|
*/
|
||||||
|
int accept_idx = -1;
|
||||||
bool is_end = false;
|
bool is_end = false;
|
||||||
// 遍历起始位置
|
|
||||||
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {
|
// 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾
|
||||||
|
// 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens
|
||||||
|
// 中的匹配。两者共享同一套从后向前匹配逻辑。
|
||||||
|
int loop_end = (accept_num > 0) ? accept_num - 2 : -1;
|
||||||
|
for (; accept_idx <= loop_end && !is_end; accept_idx++) {
|
||||||
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
|
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
|
||||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||||
printf("num %d < stop_seq_len %d\n",
|
printf("num %d < stop_seq_len %d\n",
|
||||||
step_idx_now - accept_num + accept_idx + 1,
|
step_idx_now + accept_idx + 1,
|
||||||
stop_seq_len);
|
stop_seq_len);
|
||||||
#endif
|
#endif
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// 遍历一个 stop_seqs
|
// 从后向前匹配 stop_seq 的每个 token
|
||||||
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
for (int i = stop_seq_len - 1; i >= 0; --i) {
|
||||||
int64_t cur_token_idx = -1;
|
int64_t cur_token_idx = -1;
|
||||||
|
|
||||||
// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
|
int offset = stop_seq_len - 1 - i;
|
||||||
if (stop_seq_len - 1 - i < accept_idx) {
|
int accept_tokens_idx = accept_idx - offset;
|
||||||
|
|
||||||
|
if (accept_tokens_idx >= 0) {
|
||||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||||
printf(
|
printf(
|
||||||
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
|
||||||
"accept_token_idx: "
|
"accept_token_idx: %d\n",
|
||||||
"%d\n",
|
|
||||||
bid,
|
bid,
|
||||||
tid,
|
tid,
|
||||||
accept_idx,
|
accept_idx,
|
||||||
accept_idx - (stop_seq_len - 1 - i) - 1);
|
accept_tokens_idx);
|
||||||
#endif
|
#endif
|
||||||
cur_token_idx =
|
cur_token_idx = accept_tokens_now[accept_tokens_idx];
|
||||||
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
|
|
||||||
} else {
|
} else {
|
||||||
|
int pre_ids_idx = step_idx_now + accept_tokens_idx;
|
||||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||||
printf(
|
printf(
|
||||||
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
|
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
|
||||||
"accept_idx:%d. "
|
"accept_idx:%d. pre_id_idx: %d\n",
|
||||||
"pre_id_idx: %ld\n",
|
|
||||||
bid,
|
bid,
|
||||||
tid,
|
tid,
|
||||||
step_idx_now,
|
step_idx_now,
|
||||||
accept_idx,
|
accept_idx,
|
||||||
step_idx_now - accept_num + accept_idx -
|
pre_ids_idx);
|
||||||
(stop_seq_len - 1 - i));
|
|
||||||
#endif
|
#endif
|
||||||
int pre_ids_idx =
|
if (pre_ids_idx < 0) break;
|
||||||
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
|
|
||||||
// EC3
|
|
||||||
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
|
|
||||||
// 导致异常结束
|
|
||||||
if (pre_ids_idx <= 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
cur_token_idx = pre_ids_now[pre_ids_idx];
|
cur_token_idx = pre_ids_now[pre_ids_idx];
|
||||||
}
|
}
|
||||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||||
@@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
|||||||
}
|
}
|
||||||
if (is_end) {
|
if (is_end) {
|
||||||
#ifdef DEBUG_SPEC_STOP_SEQS
|
#ifdef DEBUG_SPEC_STOP_SEQS
|
||||||
printf("bid:%d end with accept_idx %d", bid, accept_idx);
|
printf("bid:%d end with accept_idx %d\n", bid, accept_idx);
|
||||||
#endif
|
#endif
|
||||||
|
// accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置
|
||||||
accept_nums[bid] = accept_idx;
|
accept_nums[bid] = accept_idx + 1;
|
||||||
accept_tokens_now[accept_idx - 1] = end_ids[0];
|
accept_tokens_now[accept_idx] = end_ids[0];
|
||||||
// stop_flags[bid] = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
|||||||
int64_t *token_ids_all_now =
|
int64_t *token_ids_all_now =
|
||||||
&token_ids_all[batch_id * max_model_len + prompt_len];
|
&token_ids_all[batch_id * max_model_len + prompt_len];
|
||||||
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||||
int64_t base = cur_step_idx - output_len + 1;
|
int64_t base = cur_step_idx - output_len;
|
||||||
for (int i = 0; i < output_len; i++) {
|
for (int i = 0; i < output_len; i++) {
|
||||||
token_ids_all_now[base + i] = output_ids[i];
|
token_ids_all_now[base + i] = output_ids[i];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Whether to enable the decode caches requests for preallocating resource
|
# Whether to enable the decode caches requests for preallocating resource
|
||||||
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
||||||
|
|
||||||
|
# Batched token timeout in EP
|
||||||
|
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||||
|
|
||||||
# Max pre-fetch requests number in PD
|
# Max pre-fetch requests number in PD
|
||||||
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
||||||
|
|
||||||
|
|||||||
@@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# 是否启用 decode 缓存请求以预分配资源
|
# 是否启用 decode 缓存请求以预分配资源
|
||||||
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
||||||
|
|
||||||
|
# EP 中批处理 token 的超时时间
|
||||||
|
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||||
|
|
||||||
# PD 中最大预取请求数量
|
# PD 中最大预取请求数量
|
||||||
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
||||||
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
# Test splitwise deployment
|
|
||||||
# There are two methods for splitwise deployment:
|
|
||||||
# v0: using splitwise_scheduler or dp_scheduler (deprecated)
|
|
||||||
# v1: using local_scheduler + router
|
|
||||||
|
|
||||||
# prepare environment
|
|
||||||
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
|
||||||
export FD_DEBUG=1
|
|
||||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
|
||||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
|
||||||
|
|
||||||
SCRIPT_PATH=$(readlink -f "$0")
|
|
||||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
|
||||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
|
||||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
|
||||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
|
||||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
unset http_proxy && unset https_proxy
|
|
||||||
source ${SCRIPT_DIR}/utils.sh
|
|
||||||
|
|
||||||
P_PORT=52400
|
|
||||||
D_PORT=52500
|
|
||||||
REDIS_PORT="${REDIS_PORT:-6379}"
|
|
||||||
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
|
||||||
|
|
||||||
ports=(
|
|
||||||
$P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5))
|
|
||||||
$D_PORT $((D_PORT + 1)) $((D_PORT + 2)) $((D_PORT + 3)) $((D_PORT + 4)) $((D_PORT + 5))
|
|
||||||
$REDIS_PORT
|
|
||||||
)
|
|
||||||
check_ports "${ports[@]}" || {
|
|
||||||
echo "❌ Some ports are in use. Please release them."
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
# start redis
|
|
||||||
if ! redis-cli -p ${REDIS_PORT} ping &>/dev/null; then
|
|
||||||
echo "Redis is not running. Starting redis-server..."
|
|
||||||
redis-server --daemonize yes --port ${REDIS_PORT}
|
|
||||||
sleep 1
|
|
||||||
else
|
|
||||||
echo "Redis is already running."
|
|
||||||
fi
|
|
||||||
sleep 1
|
|
||||||
|
|
||||||
# start prefill
|
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
|
||||||
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
|
||||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
|
||||||
|
|
||||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
|
||||||
--model ${MODEL_NAME} \
|
|
||||||
--port ${P_PORT} \
|
|
||||||
--metrics-port $((P_PORT + 1)) \
|
|
||||||
--engine-worker-queue-port $((P_PORT + 2)) \
|
|
||||||
--cache-queue-port $((P_PORT + 3)) \
|
|
||||||
--max-model-len 32768 \
|
|
||||||
--num-gpu-blocks-override 1000 \
|
|
||||||
--splitwise-role "prefill" \
|
|
||||||
--cache-transfer-protocol "rdma" \
|
|
||||||
--rdma-comm-ports $((P_PORT + 4)) \
|
|
||||||
--pd-comm-port $((P_PORT + 5)) \
|
|
||||||
--scheduler-name "splitwise" \
|
|
||||||
--scheduler-host "127.0.0.1" \
|
|
||||||
--scheduler-port ${REDIS_PORT} \
|
|
||||||
--scheduler-ttl 9000 \
|
|
||||||
2>&1 >${FD_LOG_DIR}/nohup &
|
|
||||||
|
|
||||||
wait_for_health ${P_PORT}
|
|
||||||
|
|
||||||
# start decode
|
|
||||||
export CUDA_VISIBLE_DEVICES=1
|
|
||||||
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
|
||||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
|
||||||
|
|
||||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
|
||||||
--model ${MODEL_NAME} \
|
|
||||||
--port ${D_PORT} \
|
|
||||||
--metrics-port $((D_PORT + 1)) \
|
|
||||||
--engine-worker-queue-port $((D_PORT + 2)) \
|
|
||||||
--cache-queue-port $((D_PORT + 3)) \
|
|
||||||
--max-model-len 32768 \
|
|
||||||
--splitwise-role "decode" \
|
|
||||||
--cache-transfer-protocol "rdma" \
|
|
||||||
--rdma-comm-ports $((D_PORT + 4)) \
|
|
||||||
--pd-comm-port $((D_PORT + 5)) \
|
|
||||||
--scheduler-name "splitwise" \
|
|
||||||
--scheduler-host "127.0.0.1" \
|
|
||||||
--scheduler-port ${REDIS_PORT} \
|
|
||||||
--scheduler-ttl 9000 \
|
|
||||||
2>&1 >${FD_LOG_DIR}/nohup &
|
|
||||||
|
|
||||||
wait_for_health ${D_PORT}
|
|
||||||
|
|
||||||
|
|
||||||
# send request
|
|
||||||
sleep 10 # make sure server is registered to router
|
|
||||||
echo "send request..."
|
|
||||||
curl -X POST "http://0.0.0.0:${D_PORT}/v1/chat/completions" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "hello"}
|
|
||||||
],
|
|
||||||
"max_tokens": 20,
|
|
||||||
"stream": false
|
|
||||||
}'
|
|
||||||
@@ -2009,13 +2009,13 @@ class FDConfig:
|
|||||||
and self.router_config
|
and self.router_config
|
||||||
and self.router_config.router
|
and self.router_config.router
|
||||||
):
|
):
|
||||||
# For RL scenario: version.yaml will be required for models in future releases.
|
# For RL scenario, version.yaml is required for models
|
||||||
# Temporarily enforce use router to be enabled.
|
# Temporarily enforce use router to be enabled.
|
||||||
self.model_config.read_model_version()
|
self.model_config.read_model_version()
|
||||||
|
|
||||||
self.read_from_config()
|
self.read_from_config()
|
||||||
self.postprocess()
|
self.postprocess()
|
||||||
self.init_cache_info()
|
self.init_pd_info()
|
||||||
if test_mode:
|
if test_mode:
|
||||||
return
|
return
|
||||||
self.check()
|
self.check()
|
||||||
@@ -2348,18 +2348,17 @@ class FDConfig:
|
|||||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||||
logger.info("=============================================================")
|
logger.info("=============================================================")
|
||||||
|
|
||||||
def init_cache_info(self):
|
def init_pd_info(self):
|
||||||
"""
|
"""
|
||||||
initialize cache info
|
initialize info for pd deployment
|
||||||
"""
|
"""
|
||||||
# TODO: group the splitiwse params
|
|
||||||
# There are two methods for splitwise deployment:
|
# There are two methods for splitwise deployment:
|
||||||
# 1. v0 splitwise_scheduler or dp_scheduler
|
# 1. v0 splitwise_scheduler or dp_scheduler
|
||||||
# 2. v1 local_scheduler + router
|
# 2. v1 local_scheduler + router (optional)
|
||||||
self.splitwise_version = None
|
self.splitwise_version = None
|
||||||
if self.scheduler_config.name in ("splitwise", "dp"):
|
if self.scheduler_config.name in ("splitwise", "dp"):
|
||||||
self.splitwise_version = "v0"
|
self.splitwise_version = "v0"
|
||||||
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
|
elif self.scheduler_config.name == "local":
|
||||||
self.splitwise_version = "v1"
|
self.splitwise_version = "v1"
|
||||||
|
|
||||||
# the information for registering this server to router or splitwise_scheduler
|
# the information for registering this server to router or splitwise_scheduler
|
||||||
|
|||||||
@@ -592,10 +592,15 @@ class EngineArgs:
|
|||||||
raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1")
|
raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1")
|
||||||
|
|
||||||
if self.splitwise_role != "mixed":
|
if self.splitwise_role != "mixed":
|
||||||
if self.scheduler_name == "local" and self.router is None:
|
if self.scheduler_name == "splitwise":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"When using {self.splitwise_role} role and the {self.scheduler_name} "
|
"Setting scheduler_name as splitwise is not supported in pd deployment, "
|
||||||
f"scheduler, please provide --router argument."
|
"please use router as scheduler."
|
||||||
|
)
|
||||||
|
if self.scheduler_name == "local" and self.router is None:
|
||||||
|
console_logger.warning(
|
||||||
|
f"Running {self.splitwise_role} role with {self.scheduler_name} "
|
||||||
|
f"scheduler without --router. Router registration and request routing will be disabled."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
|
|||||||
@@ -367,15 +367,6 @@ class EngineService:
|
|||||||
create=True,
|
create=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
|
|
||||||
self.engine_forward_signal = IPCSignal(
|
|
||||||
name="engine_forward_signal",
|
|
||||||
array=engine_forward_signal_data,
|
|
||||||
dtype=np.int32,
|
|
||||||
suffix=current_suffix,
|
|
||||||
create=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
|
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
|
||||||
worker_healthy_live_recorded_time_array = np.zeros(
|
worker_healthy_live_recorded_time_array = np.zeros(
|
||||||
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
|
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
|
||||||
@@ -1037,29 +1028,26 @@ class EngineService:
|
|||||||
with self._pause_cond:
|
with self._pause_cond:
|
||||||
self._pause_cond.wait_for(lambda: not self.is_paused)
|
self._pause_cond.wait_for(lambda: not self.is_paused)
|
||||||
try:
|
try:
|
||||||
if not is_fetching:
|
if self.engine_worker_queue.exist_tasks():
|
||||||
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
|
time.sleep(0.001)
|
||||||
try:
|
continue
|
||||||
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||||
|
if not is_fetching:
|
||||||
is_fetching = True
|
is_fetching = True
|
||||||
get_request_pool.submit(_fetch_request)
|
get_request_pool.submit(_fetch_request)
|
||||||
except RuntimeError as e:
|
|
||||||
if "shutdown" in str(e):
|
|
||||||
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
||||||
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
|
|
||||||
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
|
|
||||||
# more efficient batches.
|
|
||||||
if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0:
|
|
||||||
time.sleep(0.001)
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
# In mixed, todo: optimze cache swap, to decouple swap from scheduler
|
if len(self.resource_manager.waiting) == 0 and (not is_fetching):
|
||||||
if self.engine_worker_queue.exist_tasks():
|
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
|
||||||
time.sleep(0.001)
|
try:
|
||||||
continue
|
is_fetching = True
|
||||||
|
get_request_pool.submit(_fetch_request)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "shutdown" in str(e):
|
||||||
|
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
|
if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
|
||||||
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()
|
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()
|
||||||
@@ -1120,13 +1108,6 @@ class EngineService:
|
|||||||
elif not task.has_been_preempted_before:
|
elif not task.has_been_preempted_before:
|
||||||
task.metrics.inference_start_time = time.time()
|
task.metrics.inference_start_time = time.time()
|
||||||
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
||||||
else:
|
|
||||||
# When there are no actual tasks to schedule, send an empty task batch to EP workers.
|
|
||||||
# This helps EP workers barrier for syncing tasks not hang.
|
|
||||||
if self.cfg.parallel_config.enable_expert_parallel:
|
|
||||||
self.engine_worker_queue.put_tasks(
|
|
||||||
([], self.resource_manager.real_bsz)
|
|
||||||
) # Empty (as idle tasks for ep)
|
|
||||||
|
|
||||||
# 4. Response error tasks
|
# 4. Response error tasks
|
||||||
if error_tasks:
|
if error_tasks:
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ class ExpertService:
|
|||||||
if envs.FD_ENABLE_RETURN_TEXT:
|
if envs.FD_ENABLE_RETURN_TEXT:
|
||||||
self.engine.create_data_processor()
|
self.engine.create_data_processor()
|
||||||
if self.cfg.scheduler_config.name == "dp":
|
if self.cfg.scheduler_config.name == "dp":
|
||||||
self.cfg.init_cache_info()
|
self.cfg.init_pd_info()
|
||||||
self.engine.scheduler.start(local_data_parallel_id)
|
self.engine.scheduler.start(local_data_parallel_id)
|
||||||
|
|
||||||
if ipc_signal_suffix is not None:
|
if ipc_signal_suffix is not None:
|
||||||
@@ -122,7 +122,7 @@ class ExpertService:
|
|||||||
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
|
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||||
|
|
||||||
if self.cfg.scheduler_config.name == "splitwise":
|
if self.cfg.scheduler_config.name == "splitwise":
|
||||||
self.cfg.init_cache_info()
|
self.cfg.init_pd_info()
|
||||||
role = self.cfg.scheduler_config.splitwise_role
|
role = self.cfg.scheduler_config.splitwise_role
|
||||||
host_ip = self.cfg.host_ip
|
host_ip = self.cfg.host_ip
|
||||||
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
|
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
|
||||||
|
|||||||
@@ -927,6 +927,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
if (
|
if (
|
||||||
self.config.cache_config.enable_prefix_caching
|
self.config.cache_config.enable_prefix_caching
|
||||||
and self.config.scheduler_config.splitwise_role != "decode"
|
and self.config.scheduler_config.splitwise_role != "decode"
|
||||||
|
and self.config.scheduler_config.splitwise_role != "prefill"
|
||||||
):
|
):
|
||||||
self.cache_manager.update_cache_blocks(
|
self.cache_manager.update_cache_blocks(
|
||||||
request, self.config.cache_config.block_size, request.num_computed_tokens
|
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||||
@@ -1374,6 +1375,11 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
self.stop_flags[request.idx] = False
|
self.stop_flags[request.idx] = False
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
self.req_dict[request.request_id] = allocated_position
|
self.req_dict[request.request_id] = allocated_position
|
||||||
|
|
||||||
|
self.cache_manager.update_cache_blocks(
|
||||||
|
request, self.config.cache_config.block_size, request.need_prefill_tokens
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
self._free_blocks(request)
|
self._free_blocks(request)
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return ExtractedToolCallInformation(
|
return ExtractedToolCallInformation(
|
||||||
tools_called=True,
|
tools_called=len(tool_calls) > 0,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -182,11 +182,13 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
logger.debug("attempting to close tool call, but no tool call")
|
logger.debug("attempting to close tool call, but no tool call")
|
||||||
return None
|
return None
|
||||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||||
if diff:
|
if diff is not None:
|
||||||
if '"}' not in delta_text:
|
if "}" not in delta_text:
|
||||||
|
return None
|
||||||
|
end_loc = delta_text.rindex("}")
|
||||||
|
diff = delta_text[:end_loc]
|
||||||
|
if not diff:
|
||||||
return None
|
return None
|
||||||
end_loc = delta_text.rindex('"}')
|
|
||||||
diff = delta_text[:end_loc] + '"}'
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Finishing tool and found diff that had not " "been streamed yet: %s",
|
"Finishing tool and found diff that had not " "been streamed yet: %s",
|
||||||
diff,
|
diff,
|
||||||
@@ -248,15 +250,15 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||||
cur_arguments = current_tool_call.get("arguments")
|
cur_arguments = current_tool_call.get("arguments")
|
||||||
|
|
||||||
if not cur_arguments and not prev_arguments:
|
if cur_arguments is None and prev_arguments is None:
|
||||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||||
delta = None
|
delta = None
|
||||||
|
|
||||||
elif not cur_arguments and prev_arguments:
|
elif cur_arguments is None and prev_arguments is not None:
|
||||||
logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.")
|
logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.")
|
||||||
delta = None
|
delta = None
|
||||||
|
|
||||||
elif cur_arguments and not prev_arguments:
|
elif cur_arguments is not None and prev_arguments is None:
|
||||||
function_name = current_tool_call.get("name")
|
function_name = current_tool_call.get("name")
|
||||||
match = re.search(
|
match = re.search(
|
||||||
r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
|
r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
|
||||||
@@ -265,6 +267,19 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
)
|
)
|
||||||
if match:
|
if match:
|
||||||
cur_arguments_json = match.group(1)
|
cur_arguments_json = match.group(1)
|
||||||
|
# When tool_call_portion is complete JSON, the regex
|
||||||
|
# (.*) over-captures the outer closing brace of the
|
||||||
|
# tool call object. Strip it from both
|
||||||
|
# cur_arguments_json and delta_text, consistent with
|
||||||
|
# the both-have-arguments branch handling.
|
||||||
|
try:
|
||||||
|
json.loads(tool_call_portion)
|
||||||
|
if cur_arguments_json.endswith("}"):
|
||||||
|
cur_arguments_json = cur_arguments_json[:-1]
|
||||||
|
if delta_text.rstrip().endswith("}"):
|
||||||
|
delta_text = delta_text.rstrip()[:-1]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
|
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||||
|
|
||||||
@@ -287,7 +302,7 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
)
|
)
|
||||||
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
||||||
|
|
||||||
elif cur_arguments and prev_arguments:
|
elif cur_arguments is not None and prev_arguments is not None:
|
||||||
try:
|
try:
|
||||||
json.loads(tool_call_portion)
|
json.loads(tool_call_portion)
|
||||||
is_complete_json = True
|
is_complete_json = True
|
||||||
|
|||||||
+16
-11
@@ -145,6 +145,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||||
# Whether to enable the decode caches requests for preallocating resource
|
# Whether to enable the decode caches requests for preallocating resource
|
||||||
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"),
|
||||||
|
# Batched token timeout in EP
|
||||||
|
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
|
||||||
# Max pre-fetch requests number in PD
|
# Max pre-fetch requests number in PD
|
||||||
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
"FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")),
|
||||||
# Enable or disable model caching.
|
# Enable or disable model caching.
|
||||||
@@ -210,17 +212,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""),
|
"FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""),
|
||||||
# Whether to enable low latency in mixed scenario
|
# Whether to enable low latency in mixed scenario
|
||||||
"FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))),
|
"FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))),
|
||||||
# Whether to use phi FP8 quantization,if 1,use paddle default.
|
|
||||||
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
|
|
||||||
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
|
|
||||||
# intended for training alignment. Defaults to 0 (disabled).
|
|
||||||
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
|
|
||||||
# Whether to use phi MOE permute,if 1,use paddle op.
|
|
||||||
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
|
|
||||||
# Whether to use phi rms_norm,if 1,use paddle op.
|
|
||||||
"FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))),
|
|
||||||
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
|
|
||||||
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
|
|
||||||
# Reserve output blocks for decoding requests when schedule new prefill requests
|
# Reserve output blocks for decoding requests when schedule new prefill requests
|
||||||
"FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int(
|
"FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int(
|
||||||
os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16")
|
os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16")
|
||||||
@@ -262,8 +253,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
|
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
|
||||||
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
|
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
|
||||||
),
|
),
|
||||||
|
# train-infer consistency, used in RL
|
||||||
# Whether to align RoPE and moe gate precision with training
|
# Whether to align RoPE and moe gate precision with training
|
||||||
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
|
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
|
||||||
|
# Whether to use phi FP8 quantization,if 1,use paddle default.
|
||||||
|
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
|
||||||
|
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
|
||||||
|
# intended for training alignment. Defaults to 0 (disabled).
|
||||||
|
"FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))),
|
||||||
|
# Whether to use phi MOE permute,if 1,use paddle op.
|
||||||
|
"FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))),
|
||||||
|
# Whether to use phi rms_norm,if 1,use paddle op.
|
||||||
|
"FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))),
|
||||||
|
# Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function
|
||||||
|
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
|
||||||
|
# Whether to enable FP8 quantization with pow2scale.
|
||||||
|
"FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class GraphOptWrapper:
|
|||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
return self.graph_opt_backend(**kwargs)
|
return self.graph_opt_backend(**kwargs)
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self, fd_config):
|
def clear_graph_opt_backend(self, fd_config):
|
||||||
""" """
|
""" """
|
||||||
# TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights
|
# TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ def init_flash_attn_version():
|
|||||||
logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.")
|
logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.")
|
||||||
|
|
||||||
if FLASH_ATTN_VERSION is None:
|
if FLASH_ATTN_VERSION is None:
|
||||||
if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()):
|
if sm_version == 90 and 90 in paddle.version.cuda_archs():
|
||||||
FLASH_ATTN_VERSION = 3
|
FLASH_ATTN_VERSION = 3
|
||||||
logger.info("The current platform supports Flash Attention V3.")
|
logger.info("The current platform supports Flash Attention V3.")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
|
|||||||
else:
|
else:
|
||||||
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||||
ffn_out,
|
ffn_out,
|
||||||
using_pow2_scale=not disable_ue8m0_cast,
|
using_pow2_scale=not disable_ue8m0_cast or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||||
using_ue8m0_scale=not disable_ue8m0_cast,
|
using_ue8m0_scale=not disable_ue8m0_cast,
|
||||||
)
|
)
|
||||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
||||||
@@ -355,7 +355,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
else:
|
else:
|
||||||
x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||||
x,
|
x,
|
||||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||||
)
|
)
|
||||||
@@ -581,7 +581,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
else:
|
else:
|
||||||
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||||
ffn_out,
|
ffn_out,
|
||||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0
|
||||||
|
or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||||
)
|
)
|
||||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]]
|
||||||
@@ -773,7 +774,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
else:
|
else:
|
||||||
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||||
x,
|
x,
|
||||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||||
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0,
|
||||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1247,7 +1247,7 @@ def python_op_fused_moe_kernel_paddle(
|
|||||||
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
|
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
|
||||||
else:
|
else:
|
||||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||||
x, using_pow2_scale=False, output_scale_transpose=False
|
x, using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=False
|
||||||
)
|
)
|
||||||
x_scale = x_scale[: x.shape[0]]
|
x_scale = x_scale[: x.shape[0]]
|
||||||
|
|
||||||
@@ -1305,7 +1305,9 @@ def python_op_fused_moe_kernel_paddle(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||||
intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False
|
intermediate_cache2,
|
||||||
|
using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||||
|
output_scale_transpose=False,
|
||||||
)
|
)
|
||||||
x_scale = x_scale[: x_q.shape[0]]
|
x_scale = x_scale[: x_q.shape[0]]
|
||||||
|
|
||||||
|
|||||||
@@ -343,7 +343,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
|||||||
else:
|
else:
|
||||||
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise(
|
||||||
x,
|
x,
|
||||||
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
|
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE,
|
||||||
output_scale_transpose=True,
|
output_scale_transpose=True,
|
||||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1306,9 +1306,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekV3PretrainedModel(PretrainedModel):
|
class DeepSeekV3PretrainedModel(PretrainedModel):
|
||||||
|
|||||||
@@ -701,9 +701,9 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.ernie.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
@ModelRegistry.register_model_class(
|
@ModelRegistry.register_model_class(
|
||||||
|
|||||||
@@ -829,9 +829,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.ernie.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
class Ernie4_5_VLPretrainedModel(PretrainedModel):
|
||||||
|
|||||||
@@ -563,9 +563,9 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class Glm4MoePretrainedModel(PretrainedModel):
|
class Glm4MoePretrainedModel(PretrainedModel):
|
||||||
|
|||||||
@@ -369,3 +369,7 @@ class Glm4MTPForCausalLM(ModelForCasualLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def clear_graph_opt_backend(self):
|
||||||
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
|
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|||||||
@@ -417,9 +417,9 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.qwen2.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2PretrainedModel(PretrainedModel):
|
class Qwen2PretrainedModel(PretrainedModel):
|
||||||
|
|||||||
@@ -341,9 +341,9 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3PretrainedModel(PretrainedModel):
|
class Qwen3PretrainedModel(PretrainedModel):
|
||||||
|
|||||||
@@ -382,9 +382,9 @@ class Qwen3VLForConditionalGeneration(ModelForCasualLM):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3VLPretrainedModel(PretrainedModel):
|
class Qwen3VLPretrainedModel(PretrainedModel):
|
||||||
|
|||||||
@@ -453,9 +453,9 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MoePretrainedModel(PretrainedModel):
|
class Qwen3MoePretrainedModel(PretrainedModel):
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ class RolloutModelConfig:
|
|||||||
routing_replay_config: str = None,
|
routing_replay_config: str = None,
|
||||||
load_choices: str = "default_v1",
|
load_choices: str = "default_v1",
|
||||||
lm_head_fp32: bool = False,
|
lm_head_fp32: bool = False,
|
||||||
|
moe_gate_fp32: bool = True,
|
||||||
):
|
):
|
||||||
# Required parameters
|
# Required parameters
|
||||||
self.model = model_name_or_path
|
self.model = model_name_or_path
|
||||||
@@ -121,6 +122,7 @@ class RolloutModelConfig:
|
|||||||
self.routing_replay_config = routing_replay_config
|
self.routing_replay_config = routing_replay_config
|
||||||
self.load_choices = load_choices
|
self.load_choices = load_choices
|
||||||
self.lm_head_fp32 = lm_head_fp32
|
self.lm_head_fp32 = lm_head_fp32
|
||||||
|
self.moe_gate_fp32 = moe_gate_fp32
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from typing import Dict, List, Optional
|
|||||||
from fastdeploy.engine.request import Request, RequestOutput
|
from fastdeploy.engine.request import Request, RequestOutput
|
||||||
from fastdeploy.scheduler.data import ScheduledResponse
|
from fastdeploy.scheduler.data import ScheduledResponse
|
||||||
from fastdeploy.scheduler.local_scheduler import LocalScheduler
|
from fastdeploy.scheduler.local_scheduler import LocalScheduler
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import envs, get_logger
|
||||||
|
|
||||||
|
|
||||||
class DPLocalScheduler(LocalScheduler):
|
class DPLocalScheduler(LocalScheduler):
|
||||||
@@ -131,19 +131,52 @@ class DPLocalScheduler(LocalScheduler):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Request objects ready for processing
|
List of Request objects ready for processing
|
||||||
"""
|
"""
|
||||||
# DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that.
|
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||||
|
self.scheduler_logger.debug(
|
||||||
|
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||||
|
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||||
|
f"max_num_batched_tokens={max_num_batched_tokens}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
required_total_blocks = 0
|
||||||
|
current_prefill_tokens = 0
|
||||||
|
start_batch_time = time.time()
|
||||||
requests: List[Request] = []
|
requests: List[Request] = []
|
||||||
|
|
||||||
with self.requests_not_empty:
|
with self.requests_not_empty:
|
||||||
batch_ids = self.requests_not_empty.wait_for(
|
while True:
|
||||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1],
|
batch_ids = self.requests_not_empty.wait_for(
|
||||||
0.005,
|
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||||
)
|
0.005,
|
||||||
if batch_ids:
|
)
|
||||||
for request_id in batch_ids:
|
if batch_ids:
|
||||||
request = self.requests[request_id]
|
for request_id in batch_ids:
|
||||||
requests.append(request.raw)
|
request = self.requests[request_id]
|
||||||
self.ids_read_cursor += 1
|
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||||
|
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||||
|
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||||
|
if required_total_blocks > available_blocks:
|
||||||
|
break
|
||||||
|
|
||||||
|
requests.append(request.raw)
|
||||||
|
self.ids_read_cursor += 1
|
||||||
|
start_batch_time = time.time()
|
||||||
|
if current_prefill_tokens > max_num_batched_tokens:
|
||||||
|
break
|
||||||
|
if len(requests) >= batch:
|
||||||
|
break
|
||||||
|
if (
|
||||||
|
(current_prefill_tokens > max_num_batched_tokens)
|
||||||
|
or (len(requests) >= batch)
|
||||||
|
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
if batch_ids:
|
||||||
|
if len(batch_ids) > 0 and len(requests) == 0:
|
||||||
|
self.scheduler_logger.debug(
|
||||||
|
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
|
||||||
|
)
|
||||||
|
|
||||||
if len(requests) > 0:
|
if len(requests) > 0:
|
||||||
self.scheduler_logger.info(
|
self.scheduler_logger.info(
|
||||||
|
|||||||
@@ -53,9 +53,6 @@ class InternalAdapter:
|
|||||||
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
|
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
|
||||||
|
|
||||||
available_block_num = self.engine.resource_manager.available_block_num()
|
available_block_num = self.engine.resource_manager.available_block_num()
|
||||||
unhandled_request_num = self.engine.scheduler.get_unhandled_request_num()
|
|
||||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
||||||
unhandled_request_num = max(unhandled_request_num, len(self.engine.resource_manager.waiting))
|
|
||||||
server_info = {
|
server_info = {
|
||||||
"splitwise_role": self.cfg.scheduler_config.splitwise_role,
|
"splitwise_role": self.cfg.scheduler_config.splitwise_role,
|
||||||
"block_size": int(self.cfg.cache_config.block_size),
|
"block_size": int(self.cfg.cache_config.block_size),
|
||||||
@@ -65,7 +62,7 @@ class InternalAdapter:
|
|||||||
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
|
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
|
||||||
"max_batch_size": int(available_batch_size),
|
"max_batch_size": int(available_batch_size),
|
||||||
"max_input_token_num": self.cfg.model_config.max_model_len,
|
"max_input_token_num": self.cfg.model_config.max_model_len,
|
||||||
"unhandled_request_num": unhandled_request_num,
|
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
|
||||||
"available_batch": int(self.engine.resource_manager.available_batch()),
|
"available_batch": int(self.engine.resource_manager.available_batch()),
|
||||||
}
|
}
|
||||||
return server_info
|
return server_info
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig
|
||||||
from fastdeploy.engine.pooling_params import PoolingParams
|
from fastdeploy.engine.pooling_params import PoolingParams
|
||||||
from fastdeploy.engine.request import ImagePosition, Request, RequestType
|
from fastdeploy.engine.request import ImagePosition, Request, RequestType
|
||||||
from fastdeploy.model_executor.graph_optimization.utils import (
|
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||||
@@ -2110,6 +2110,12 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self._cached_sampler_output = sampler_output
|
self._cached_sampler_output = sampler_output
|
||||||
self._cached_post_process_event = post_process_event
|
self._cached_post_process_event = post_process_event
|
||||||
else:
|
else:
|
||||||
|
if (
|
||||||
|
self.fd_config.speculative_config.method == SpecMethod.MTP
|
||||||
|
and hasattr(self.proposer.model, "empty_input_forward")
|
||||||
|
and self.parallel_config.use_ep
|
||||||
|
):
|
||||||
|
self._execute_empty_mtp_input(self.forward_meta)
|
||||||
self._cached_model_output_data = None
|
self._cached_model_output_data = None
|
||||||
self._cached_sampler_output = None
|
self._cached_sampler_output = None
|
||||||
self._cached_post_process_event = None
|
self._cached_post_process_event = None
|
||||||
@@ -2403,6 +2409,16 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
# 5.1. Async cpy
|
# 5.1. Async cpy
|
||||||
post_process_event = paddle.device.cuda.create_event()
|
post_process_event = paddle.device.cuda.create_event()
|
||||||
|
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||||
|
# If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished.
|
||||||
|
paddle.assign(
|
||||||
|
paddle.where(
|
||||||
|
self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1,
|
||||||
|
PREEMPTED_TOKEN_ID,
|
||||||
|
sampler_output.sampled_token_ids,
|
||||||
|
),
|
||||||
|
sampler_output.sampled_token_ids,
|
||||||
|
)
|
||||||
# if not self.speculative_decoding:
|
# if not self.speculative_decoding:
|
||||||
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
|
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
@@ -2676,13 +2692,13 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
"""Dynamic model loader use to clear parameters use for RL"""
|
"""Dynamic model loader use to clear parameters use for RL"""
|
||||||
# Clear CUDAGraph
|
# Clear CUDAGraph
|
||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
self.model.clear_grpah_opt_backend()
|
self.model.clear_graph_opt_backend()
|
||||||
# Clear parameters and Send single
|
# Clear parameters and Send single
|
||||||
self.dynamic_weight_manager.clear_parameters(
|
self.dynamic_weight_manager.clear_parameters(
|
||||||
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||||
)
|
)
|
||||||
if self.spec_method == SpecMethod.MTP:
|
if self.spec_method == SpecMethod.MTP:
|
||||||
self.proposer.model.clear_grpah_opt_backend()
|
self.proposer.model.clear_graph_opt_backend()
|
||||||
self.proposer.clear_mtp_cache()
|
self.proposer.clear_mtp_cache()
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
@@ -2736,7 +2752,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
logger.info("GPU model runner's weight is already sleeping, no need to sleep again!")
|
logger.info("GPU model runner's weight is already sleeping, no need to sleep again!")
|
||||||
return
|
return
|
||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
self.model.clear_grpah_opt_backend()
|
self.model.clear_graph_opt_backend()
|
||||||
if self.fd_config.parallel_config.enable_expert_parallel:
|
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||||
self.dynamic_weight_manager.clear_deepep_buffer()
|
self.dynamic_weight_manager.clear_deepep_buffer()
|
||||||
self.dynamic_weight_manager.clear_model_weight()
|
self.dynamic_weight_manager.clear_model_weight()
|
||||||
|
|||||||
@@ -2511,7 +2511,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
|||||||
"""Dynamic model loader use to clear parameters use for RL"""
|
"""Dynamic model loader use to clear parameters use for RL"""
|
||||||
# Clear CUDAGraph
|
# Clear CUDAGraph
|
||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
self.model.clear_grpah_opt_backend()
|
self.model.clear_graph_opt_backend()
|
||||||
# Clear parameters and Send single
|
# Clear parameters and Send single
|
||||||
self.dynamic_weight_manager.clear_parameters(
|
self.dynamic_weight_manager.clear_parameters(
|
||||||
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||||
|
|||||||
@@ -457,6 +457,9 @@ class PaddleDisWorkerProc:
|
|||||||
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
|
# TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one
|
||||||
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
||||||
while True:
|
while True:
|
||||||
|
# run eplb
|
||||||
|
self._run_eplb(tp_rank)
|
||||||
|
|
||||||
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||||
if self.ranks > 1:
|
if self.ranks > 1:
|
||||||
@@ -534,7 +537,7 @@ class PaddleDisWorkerProc:
|
|||||||
|
|
||||||
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
||||||
logger.debug(f"Rank: {self.local_rank} Detected new requests.")
|
logger.debug(f"Rank: {self.local_rank} Detected new requests.")
|
||||||
self.engine_forward_signal.value[0] = 1
|
|
||||||
tasks, read_finish = self.task_queue.get_tasks()
|
tasks, read_finish = self.task_queue.get_tasks()
|
||||||
# Only one of all tp_size client will get read_finish == True.
|
# Only one of all tp_size client will get read_finish == True.
|
||||||
if read_finish:
|
if read_finish:
|
||||||
@@ -543,39 +546,25 @@ class PaddleDisWorkerProc:
|
|||||||
self.task_queue.read_finish_flag.set(0)
|
self.task_queue.read_finish_flag.set(0)
|
||||||
else:
|
else:
|
||||||
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
||||||
# In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival.
|
|
||||||
# Only EP + DP prefill should barrier for data arrival.
|
|
||||||
# In mixed mode and decoder in D, we should not barrier to influence decoding.
|
|
||||||
if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill":
|
|
||||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
|
||||||
|
|
||||||
req_dicts, control_reqs = [], []
|
req_dicts, control_reqs = [], []
|
||||||
assert (
|
for req_dict, bsz in tasks:
|
||||||
len(tasks) > 0
|
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
|
||||||
), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}"
|
control_reqs.append(req_dict[0])
|
||||||
# In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue.
|
else:
|
||||||
# tasks[0] contains two part, ([req1, ...] ,real_bsz)
|
max_occupied_batch_index = int(bsz)
|
||||||
# tasks[0][0] is [req1, ...]
|
req_dicts.extend(req_dict)
|
||||||
# if empty batch is delived, eval(tasks[0][0]) should be False ([]),
|
|
||||||
# if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below.
|
|
||||||
if tasks[0][0]:
|
|
||||||
for req_dict, bsz in tasks:
|
|
||||||
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
|
|
||||||
control_reqs.append(req_dict[0])
|
|
||||||
else:
|
|
||||||
max_occupied_batch_index = int(bsz)
|
|
||||||
req_dicts.extend(req_dict)
|
|
||||||
|
|
||||||
# todo: run control request async
|
# todo: run control request async
|
||||||
if len(control_reqs) > 0:
|
if len(control_reqs) > 0:
|
||||||
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
|
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
|
||||||
for control_req in control_reqs:
|
for control_req in control_reqs:
|
||||||
if self.parallel_config.use_ep:
|
if self.parallel_config.use_ep:
|
||||||
self.cached_control_reqs.append(control_req)
|
self.cached_control_reqs.append(control_req)
|
||||||
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
|
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
|
||||||
else:
|
else:
|
||||||
self.run_control_method(control_req)
|
self.run_control_method(control_req)
|
||||||
self._tp_barrier_wait() if tp_size > 1 else None
|
self._tp_barrier_wait() if tp_size > 1 else None
|
||||||
|
|
||||||
if len(req_dicts) > 0:
|
if len(req_dicts) > 0:
|
||||||
# Count prefill requests in current batch
|
# Count prefill requests in current batch
|
||||||
@@ -591,12 +580,6 @@ class PaddleDisWorkerProc:
|
|||||||
|
|
||||||
# Process prefill inputs
|
# Process prefill inputs
|
||||||
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
|
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
|
||||||
else:
|
|
||||||
if self.scheduler_config.splitwise_role == "prefill":
|
|
||||||
if tp_size > 1:
|
|
||||||
# Synchronize the signal for other workers
|
|
||||||
self._tp_barrier_wait()
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Let the ep group run control method synchronically
|
# Let the ep group run control method synchronically
|
||||||
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep:
|
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep:
|
||||||
@@ -611,7 +594,6 @@ class PaddleDisWorkerProc:
|
|||||||
and not self.worker.model_runner.not_need_stop()
|
and not self.worker.model_runner.not_need_stop()
|
||||||
):
|
):
|
||||||
self._tp_barrier_wait() if tp_size > 1 else None
|
self._tp_barrier_wait() if tp_size > 1 else None
|
||||||
self.engine_forward_signal.value[0] = 0
|
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -634,9 +616,6 @@ class PaddleDisWorkerProc:
|
|||||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
|
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
|
||||||
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
|
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")
|
||||||
# run eplb
|
|
||||||
self._run_eplb(tp_rank)
|
|
||||||
self.engine_forward_signal.value[0] = 0
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not self.parallel_config.use_ep
|
not self.parallel_config.use_ep
|
||||||
|
|||||||
+1
-1
@@ -47,5 +47,5 @@ aistudio_sdk
|
|||||||
p2pstore
|
p2pstore
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
flashinfer-python-paddle
|
flashinfer-python-paddle
|
||||||
flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl
|
flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl
|
||||||
transformers>=4.55.1,<5.0.0
|
transformers>=4.55.1,<5.0.0
|
||||||
|
|||||||
@@ -214,29 +214,28 @@ def test_metrics_with_clear_and_reset():
|
|||||||
"""
|
"""
|
||||||
Test the metrics monitoring endpoint.
|
Test the metrics monitoring endpoint.
|
||||||
"""
|
"""
|
||||||
pass # not stable, uncomment after bug fix
|
metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||||
# metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
|
||||||
|
|
||||||
# async_concurrency(n=10)
|
async_concurrency(n=10)
|
||||||
|
|
||||||
# time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
|
|
||||||
# ===== clear_load_weight =====
|
# ===== clear_load_weight =====
|
||||||
# clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight"
|
clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight"
|
||||||
# print("Calling clear_load_weight...")
|
print("Calling clear_load_weight...")
|
||||||
# r = requests.get(clear_url, timeout=30)
|
r = requests.get(clear_url, timeout=30)
|
||||||
# assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}"
|
assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}"
|
||||||
|
|
||||||
# metrics = get_metrics_dict(metrics_url)
|
metrics = get_metrics_dict(metrics_url)
|
||||||
# running = metrics["fastdeploy:num_requests_running"]
|
running = metrics["fastdeploy:num_requests_running"]
|
||||||
# waiting = metrics["fastdeploy:num_requests_waiting"]
|
waiting = metrics["fastdeploy:num_requests_waiting"]
|
||||||
|
|
||||||
# print(
|
print(
|
||||||
# "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
|
"ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
|
||||||
# running,
|
running,
|
||||||
# "waiting:",
|
"waiting:",
|
||||||
# waiting,
|
waiting,
|
||||||
# )
|
)
|
||||||
# assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight"
|
# assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,455 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Test splitwise deployment WITHOUT Router:
|
||||||
|
# use local_scheduler, manually construct disaggregate_info,
|
||||||
|
# send requests to both Prefill and Decode concurrently.
|
||||||
|
# ENABLE_V1_KVCACHE_SCHEDULER=1, use rdma to transfer cache.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from utils.serving_utils import (
|
||||||
|
FD_API_PORT,
|
||||||
|
FD_CACHE_QUEUE_PORT,
|
||||||
|
FD_ENGINE_QUEUE_PORT,
|
||||||
|
FD_METRICS_PORT,
|
||||||
|
check_service_health,
|
||||||
|
clean,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ports for PD disaggregation (no router port needed)
|
||||||
|
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
|
||||||
|
FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623))
|
||||||
|
|
||||||
|
# Prefill uses base ports, Decode uses base+1
|
||||||
|
PORTS_TO_CLEAN = [
|
||||||
|
FD_API_PORT,
|
||||||
|
FD_ENGINE_QUEUE_PORT,
|
||||||
|
FD_METRICS_PORT,
|
||||||
|
FD_CACHE_QUEUE_PORT,
|
||||||
|
FD_CONNECTOR_PORT,
|
||||||
|
FD_RDMA_PORT,
|
||||||
|
FD_API_PORT + 1,
|
||||||
|
FD_ENGINE_QUEUE_PORT + 1,
|
||||||
|
FD_METRICS_PORT + 1,
|
||||||
|
FD_CACHE_QUEUE_PORT + 1,
|
||||||
|
FD_CONNECTOR_PORT + 1,
|
||||||
|
FD_RDMA_PORT + 1,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_disaggregate_info() -> dict:
|
||||||
|
"""Build disaggregate_info manually, replicating Router's handle_splitwise_request logic."""
|
||||||
|
host_ip = os.getenv("FD_HOST_IP", "127.0.0.1")
|
||||||
|
return {
|
||||||
|
"prefill_ip": host_ip,
|
||||||
|
"decode_ip": host_ip,
|
||||||
|
"prefill_connector_port": FD_CONNECTOR_PORT,
|
||||||
|
"decode_connector_port": FD_CONNECTOR_PORT + 1,
|
||||||
|
"decode_device_ids": ["1"],
|
||||||
|
"decode_rdma_ports": [FD_RDMA_PORT + 1],
|
||||||
|
"transfer_protocol": "rdma",
|
||||||
|
"decode_tp_size": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _send_pd_request(payload: dict, timeout: int = 120):
|
||||||
|
"""
|
||||||
|
Send request to both Prefill and Decode concurrently,
|
||||||
|
replicate Router's fan-out forwarding behavior.
|
||||||
|
Returns the Decode response (same as Router's return_result_url_index=-1).
|
||||||
|
"""
|
||||||
|
disaggregate_info = _build_disaggregate_info()
|
||||||
|
|
||||||
|
# Inject disaggregate_info and request_id (same as Router)
|
||||||
|
payload = payload.copy()
|
||||||
|
payload["disaggregate_info"] = disaggregate_info
|
||||||
|
if "request_id" not in payload:
|
||||||
|
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||||
|
|
||||||
|
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/chat/completions"
|
||||||
|
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions"
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
# For streaming, use requests with stream=True for decode response
|
||||||
|
if payload.get("stream", False):
|
||||||
|
# Send to both concurrently (same as Router's fan-out), stream from decode
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def _post_stream(url):
|
||||||
|
return requests.post(url, headers=headers, json=payload, timeout=timeout, stream=True)
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
prefill_future = executor.submit(_post_stream, prefill_url)
|
||||||
|
decode_future = executor.submit(_post_stream, decode_url)
|
||||||
|
# Return decode streaming response immediately
|
||||||
|
decode_resp = decode_future.result()
|
||||||
|
# Consume prefill response in background (don't block)
|
||||||
|
try:
|
||||||
|
prefill_future.result(timeout=timeout)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return decode_resp
|
||||||
|
else:
|
||||||
|
# Non-streaming: send to both, return decode response
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def _post(url):
|
||||||
|
return requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
prefill_future = executor.submit(_post, prefill_url)
|
||||||
|
decode_future = executor.submit(_post, decode_url)
|
||||||
|
# Wait for both, return decode response
|
||||||
|
decode_resp = decode_future.result()
|
||||||
|
# Also check prefill didn't error (but don't block on it)
|
||||||
|
try:
|
||||||
|
prefill_future.result(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return decode_resp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def setup_and_run_server():
|
||||||
|
"""
|
||||||
|
Pytest fixture that runs once per test session:
|
||||||
|
- Cleans ports before tests
|
||||||
|
- Starts Prefill and Decode instances WITHOUT Router
|
||||||
|
- Waits for both to be healthy
|
||||||
|
- Tears down after all tests finish
|
||||||
|
"""
|
||||||
|
print("Pre-test port cleanup...")
|
||||||
|
clean(PORTS_TO_CLEAN)
|
||||||
|
|
||||||
|
print("log dir clean")
|
||||||
|
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
|
||||||
|
shutil.rmtree("log_prefill")
|
||||||
|
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
|
||||||
|
shutil.rmtree("log_decode")
|
||||||
|
|
||||||
|
base_path = os.getenv("MODEL_PATH")
|
||||||
|
if base_path:
|
||||||
|
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||||
|
else:
|
||||||
|
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
|
||||||
|
print(f"model_path: {model_path}")
|
||||||
|
|
||||||
|
base_log_dir = os.getenv("FD_LOG_DIR", "log")
|
||||||
|
|
||||||
|
# Prefill instance
|
||||||
|
print("start prefill...")
|
||||||
|
env_prefill = os.environ.copy()
|
||||||
|
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
env_prefill["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_prefill")
|
||||||
|
|
||||||
|
prefill_log_path = "prefill.log"
|
||||||
|
prefill_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"fastdeploy.entrypoints.openai.api_server",
|
||||||
|
"--model",
|
||||||
|
model_path,
|
||||||
|
"--port",
|
||||||
|
str(FD_API_PORT),
|
||||||
|
"--engine-worker-queue-port",
|
||||||
|
str(FD_ENGINE_QUEUE_PORT),
|
||||||
|
"--metrics-port",
|
||||||
|
str(FD_METRICS_PORT),
|
||||||
|
"--cache-queue-port",
|
||||||
|
str(FD_CACHE_QUEUE_PORT),
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--splitwise-role",
|
||||||
|
"prefill",
|
||||||
|
"--cache-transfer-protocol",
|
||||||
|
"rdma",
|
||||||
|
"--rdma-comm-ports",
|
||||||
|
str(FD_RDMA_PORT),
|
||||||
|
"--pd-comm-port",
|
||||||
|
str(FD_CONNECTOR_PORT),
|
||||||
|
# No --router flag
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(prefill_log_path, "w") as logfile:
|
||||||
|
process_prefill = subprocess.Popen(
|
||||||
|
prefill_cmd,
|
||||||
|
stdout=logfile,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
start_new_session=True,
|
||||||
|
env=env_prefill,
|
||||||
|
)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Decode instance
|
||||||
|
print("start decode...")
|
||||||
|
env_decode = os.environ.copy()
|
||||||
|
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
|
||||||
|
env_decode["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_decode")
|
||||||
|
|
||||||
|
decode_log_path = "decode.log"
|
||||||
|
decode_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"fastdeploy.entrypoints.openai.api_server",
|
||||||
|
"--model",
|
||||||
|
model_path,
|
||||||
|
"--port",
|
||||||
|
str(FD_API_PORT + 1),
|
||||||
|
"--engine-worker-queue-port",
|
||||||
|
str(FD_ENGINE_QUEUE_PORT + 1),
|
||||||
|
"--metrics-port",
|
||||||
|
str(FD_METRICS_PORT + 1),
|
||||||
|
"--cache-queue-port",
|
||||||
|
str(FD_CACHE_QUEUE_PORT + 1),
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--splitwise-role",
|
||||||
|
"decode",
|
||||||
|
"--cache-transfer-protocol",
|
||||||
|
"rdma",
|
||||||
|
"--rdma-comm-ports",
|
||||||
|
str(FD_RDMA_PORT + 1),
|
||||||
|
"--pd-comm-port",
|
||||||
|
str(FD_CONNECTOR_PORT + 1),
|
||||||
|
# No --router flag
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(decode_log_path, "w") as logfile:
|
||||||
|
process_decode = subprocess.Popen(
|
||||||
|
decode_cmd,
|
||||||
|
stdout=logfile,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
start_new_session=True,
|
||||||
|
env=env_decode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait up to 300 seconds for both instances to be healthy
|
||||||
|
for _ in range(60):
|
||||||
|
prefill_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT}")
|
||||||
|
decode_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT + 1}")
|
||||||
|
if prefill_healthy and decode_healthy:
|
||||||
|
print("Prefill and decode servers are both online")
|
||||||
|
break
|
||||||
|
time.sleep(5)
|
||||||
|
else:
|
||||||
|
print("[TIMEOUT] Servers failed to start in 5 minutes. Cleaning up...")
|
||||||
|
try:
|
||||||
|
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||||
|
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||||
|
clean(PORTS_TO_CLEAN)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to kill process group: {e}")
|
||||||
|
raise RuntimeError("Prefill or decode server did not start")
|
||||||
|
|
||||||
|
yield # Run tests
|
||||||
|
|
||||||
|
print("\n===== Post-test server cleanup... =====")
|
||||||
|
try:
|
||||||
|
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||||
|
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||||
|
clean(PORTS_TO_CLEAN)
|
||||||
|
print(f"Prefill server (pid={process_prefill.pid}) terminated")
|
||||||
|
print(f"Decode server (pid={process_decode.pid}) terminated")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to terminate server: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def api_url(request):
|
||||||
|
"""
|
||||||
|
Returns the Decode API endpoint URL (where final responses come from).
|
||||||
|
"""
|
||||||
|
return f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def headers():
|
||||||
|
return {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_chunks(response):
|
||||||
|
"""Parse streaming response into chunk list."""
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
for line in response.iter_lines(decode_unicode=True):
|
||||||
|
if line:
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line = line[len("data: ") :]
|
||||||
|
|
||||||
|
if line.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = json.loads(line)
|
||||||
|
chunks.append(chunk)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Parse failed: {e}, line: {line}")
|
||||||
|
else:
|
||||||
|
print(f"Request failed, status: {response.status_code}")
|
||||||
|
print("Response:", response.text)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_usage_stream(api_url):
|
||||||
|
"""Test streaming chat with usage"""
|
||||||
|
payload = {
|
||||||
|
"model": "default",
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 0,
|
||||||
|
"seed": 33,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||||
|
],
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stream": True,
|
||||||
|
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||||
|
"metadata": {"min_tokens": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = _send_pd_request(payload)
|
||||||
|
chunks = get_stream_chunks(response)
|
||||||
|
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
|
||||||
|
print("Decode Response:", result)
|
||||||
|
assert result != "", "结果为空"
|
||||||
|
usage = chunks[-1]["usage"]
|
||||||
|
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||||
|
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||||
|
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||||
|
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_usage_non_stream(api_url):
|
||||||
|
"""Test non-streaming chat with usage"""
|
||||||
|
payload = {
|
||||||
|
"model": "default",
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 0,
|
||||||
|
"seed": 33,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||||
|
],
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stream": False,
|
||||||
|
"metadata": {"min_tokens": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = _send_pd_request(payload).json()
|
||||||
|
usage = response["usage"]
|
||||||
|
result = response["choices"][0]["message"]["content"]
|
||||||
|
assert result != "", "结果为空"
|
||||||
|
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||||
|
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||||
|
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||||
|
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_chat_usage_stream(api_url):
|
||||||
|
"""Test streaming completion (non-chat) with usage"""
|
||||||
|
payload = {
|
||||||
|
"model": "default",
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 0,
|
||||||
|
"seed": 33,
|
||||||
|
"prompt": "牛顿的三大运动定律是什么?",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stream": True,
|
||||||
|
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||||
|
"metadata": {"min_tokens": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send to /v1/completions endpoints
|
||||||
|
disaggregate_info = _build_disaggregate_info()
|
||||||
|
payload = payload.copy()
|
||||||
|
payload["disaggregate_info"] = disaggregate_info
|
||||||
|
if "request_id" not in payload:
|
||||||
|
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||||
|
|
||||||
|
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions"
|
||||||
|
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120)
|
||||||
|
decode_future = executor.submit(
|
||||||
|
requests.post, decode_url, json=payload, headers=headers, timeout=120, stream=True
|
||||||
|
)
|
||||||
|
response = decode_future.result()
|
||||||
|
|
||||||
|
chunks = get_stream_chunks(response)
|
||||||
|
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
|
||||||
|
print("Decode Response:", result)
|
||||||
|
assert result != "", "结果为空"
|
||||||
|
usage = chunks[-1]["usage"]
|
||||||
|
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||||
|
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||||
|
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||||
|
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_chat_usage_non_stream(api_url):
|
||||||
|
"""Test non-streaming completion (non-chat) with usage"""
|
||||||
|
payload = {
|
||||||
|
"model": "default",
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 0,
|
||||||
|
"seed": 33,
|
||||||
|
"prompt": "牛顿的三大运动定律是什么?",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stream": False,
|
||||||
|
"metadata": {"min_tokens": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send to /v1/completions endpoints
|
||||||
|
disaggregate_info = _build_disaggregate_info()
|
||||||
|
payload = payload.copy()
|
||||||
|
payload["disaggregate_info"] = disaggregate_info
|
||||||
|
if "request_id" not in payload:
|
||||||
|
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||||
|
|
||||||
|
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions"
|
||||||
|
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120)
|
||||||
|
decode_future = executor.submit(requests.post, decode_url, json=payload, headers=headers, timeout=120)
|
||||||
|
response = decode_future.result().json()
|
||||||
|
|
||||||
|
usage = response["usage"]
|
||||||
|
result = response["choices"][0]["text"]
|
||||||
|
print("Decode Response:", result)
|
||||||
|
assert result != "", "结果为空"
|
||||||
|
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||||
|
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||||
|
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||||
|
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||||
@@ -1457,9 +1457,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
|||||||
task.metrics.scheduler_recv_req_time = time.time()
|
task.metrics.scheduler_recv_req_time = time.time()
|
||||||
|
|
||||||
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
||||||
eng.engine_worker_queue = Mock(
|
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
|
||||||
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
|
|
||||||
)
|
|
||||||
eng._send_error_response = Mock()
|
eng._send_error_response = Mock()
|
||||||
|
|
||||||
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")]))
|
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")]))
|
||||||
@@ -1493,9 +1491,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
|||||||
task.metrics.scheduler_recv_req_time = time.time()
|
task.metrics.scheduler_recv_req_time = time.time()
|
||||||
|
|
||||||
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
||||||
eng.engine_worker_queue = Mock(
|
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
|
||||||
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], []))
|
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], []))
|
||||||
|
|
||||||
@@ -1526,9 +1522,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
|||||||
task.metrics.scheduler_recv_req_time = time.time()
|
task.metrics.scheduler_recv_req_time = time.time()
|
||||||
|
|
||||||
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock())
|
||||||
eng.engine_worker_queue = Mock(
|
eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock())
|
||||||
exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0)
|
|
||||||
)
|
|
||||||
eng._send_error_response = Mock()
|
eng._send_error_response = Mock()
|
||||||
|
|
||||||
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)]))
|
eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)]))
|
||||||
|
|||||||
@@ -60,6 +60,50 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
|
|
||||||
return ErnieX1ToolParser(tokenizer=DummyTokenizer())
|
return ErnieX1ToolParser(tokenizer=DummyTokenizer())
|
||||||
|
|
||||||
|
def _simulate_streaming(self, parser, deltas):
|
||||||
|
"""Simulate a multi-step streaming flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: ErnieX1ToolParser instance
|
||||||
|
deltas: list of delta text strings, each representing one streaming step
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of results from each extract_tool_calls_streaming call
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
previous_text = ""
|
||||||
|
token_id = 0
|
||||||
|
previous_token_ids = []
|
||||||
|
|
||||||
|
for delta in deltas:
|
||||||
|
current_text = previous_text + delta
|
||||||
|
# When delta contains <tool_call> plus more content, use 2 tokens
|
||||||
|
# so that the parser extracts tool_call_portion (line 163-164)
|
||||||
|
if "<tool_call>" in delta and delta != "<tool_call>":
|
||||||
|
n_tokens = 2
|
||||||
|
else:
|
||||||
|
n_tokens = 1
|
||||||
|
|
||||||
|
delta_token_ids = list(range(token_id + 1, token_id + 1 + n_tokens))
|
||||||
|
token_id += n_tokens
|
||||||
|
current_token_ids = previous_token_ids + delta_token_ids
|
||||||
|
|
||||||
|
result = parser.extract_tool_calls_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta,
|
||||||
|
previous_token_ids,
|
||||||
|
current_token_ids,
|
||||||
|
delta_token_ids,
|
||||||
|
self.dummy_request,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
previous_text = current_text
|
||||||
|
previous_token_ids = list(current_token_ids)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
# ==================== __init__ tests (lines 60-81) ====================
|
# ==================== __init__ tests (lines 60-81) ====================
|
||||||
|
|
||||||
def test_init_sets_tokens_and_ids(self):
|
def test_init_sets_tokens_and_ids(self):
|
||||||
@@ -116,6 +160,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
self.assertTrue(result.tools_called)
|
self.assertTrue(result.tools_called)
|
||||||
self.assertEqual(result.tool_calls[0].function.arguments, "{}")
|
self.assertEqual(result.tool_calls[0].function.arguments, "{}")
|
||||||
|
|
||||||
|
def test_extract_tool_calls_empty_arguments(self):
|
||||||
|
"""Cover: tool call with explicit empty arguments {}"""
|
||||||
|
output = '<tool_call>{"name": "fn", "arguments": {}}</tool_call>'
|
||||||
|
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||||
|
self.assertTrue(result.tools_called)
|
||||||
|
self.assertEqual(result.tool_calls[0].function.name, "fn")
|
||||||
|
self.assertEqual(result.tool_calls[0].function.arguments, "{}")
|
||||||
|
|
||||||
def test_extract_tool_calls_nested_arguments(self):
|
def test_extract_tool_calls_nested_arguments(self):
|
||||||
"""Cover regex with nested braces in arguments"""
|
"""Cover regex with nested braces in arguments"""
|
||||||
output = '<tool_call>{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}</tool_call>'
|
output = '<tool_call>{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}</tool_call>'
|
||||||
@@ -182,38 +234,24 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
def test_streaming_end_token_in_delta(self):
|
def test_streaming_end_token_in_delta(self):
|
||||||
"""Cover lines 149-156: </tool_call> appears in delta"""
|
"""Cover lines 149-156: </tool_call> appears in delta"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
# First, start a tool call
|
results = self._simulate_streaming(
|
||||||
parser.extract_tool_calls_streaming(
|
parser,
|
||||||
"",
|
[
|
||||||
'<tool_call>{"name": "fn"',
|
'<tool_call>{"name": "fn", "arguments": {"k": "', # start + name + args key
|
||||||
'<tool_call>{"name": "fn"',
|
"v", # args value
|
||||||
[],
|
'"}}</tool_call>', # close with end token in delta
|
||||||
[1, 10],
|
],
|
||||||
[1, 10],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
# Now stream arguments
|
# Step 1: name sent
|
||||||
parser.extract_tool_calls_streaming(
|
self.assertIsNotNone(results[0])
|
||||||
'<tool_call>{"name": "fn"',
|
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
# Step 2: first-args branch, regex extracts '{"k": "v' as arguments_delta
|
||||||
', "arguments": {"k": "v',
|
self.assertIsNotNone(results[1])
|
||||||
[1, 10],
|
self.assertEqual(results[1].tool_calls[0].function.arguments, '{"k": "v')
|
||||||
[1, 10, 20],
|
# Step 3: end token in delta triggers close handling
|
||||||
[20],
|
# delta before </tool_call> is '"}}', close branch: rindex('}')=2, diff='"}'
|
||||||
self.dummy_request,
|
self.assertIsNotNone(results[2])
|
||||||
)
|
self.assertEqual(results[2].tool_calls[0].function.arguments, '"}')
|
||||||
# Close with end token in delta
|
|
||||||
result = parser.extract_tool_calls_streaming(
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}</tool_call>',
|
|
||||||
'"}}</tool_call>',
|
|
||||||
[1, 10, 20],
|
|
||||||
[1, 10, 20, 2],
|
|
||||||
[2],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
# Should handle end token
|
|
||||||
self.assertTrue(result is None or isinstance(result, DeltaMessage))
|
|
||||||
|
|
||||||
# --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) ---
|
# --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) ---
|
||||||
|
|
||||||
@@ -255,37 +293,29 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
def test_streaming_continue_tool_call_no_name_yet(self):
|
def test_streaming_continue_tool_call_no_name_yet(self):
|
||||||
"""Cover lines 174-176, 220-222: partial JSON without name yet"""
|
"""Cover lines 174-176, 220-222: partial JSON without name yet"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
# Start tool call
|
results = self._simulate_streaming(
|
||||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
parser,
|
||||||
# Continue with partial content, no name parseable yet
|
[
|
||||||
result = parser.extract_tool_calls_streaming(
|
"<tool_call>", # start tool call
|
||||||
"<tool_call>",
|
'{"na', # partial content, no name yet
|
||||||
'<tool_call>{"na',
|
],
|
||||||
'{"na',
|
|
||||||
[1],
|
|
||||||
[1, 10],
|
|
||||||
[10],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(results[0])
|
||||||
|
self.assertIsNone(results[1])
|
||||||
|
|
||||||
def test_streaming_continue_tool_call_with_name(self):
|
def test_streaming_continue_tool_call_with_name(self):
|
||||||
"""Cover lines 174-176, 223-235: name becomes available"""
|
"""Cover lines 174-176, 223-235: name becomes available"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
# Start tool call
|
results = self._simulate_streaming(
|
||||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
parser,
|
||||||
# Name appears
|
[
|
||||||
result = parser.extract_tool_calls_streaming(
|
"<tool_call>", # start tool call
|
||||||
"<tool_call>",
|
'{"name": "get_weather"', # name appears
|
||||||
'<tool_call>{"name": "get_weather"',
|
],
|
||||||
'{"name": "get_weather"',
|
|
||||||
[1],
|
|
||||||
[1, 10],
|
|
||||||
[10],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNone(results[0])
|
||||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
self.assertIsNotNone(results[1])
|
||||||
|
self.assertEqual(results[1].tool_calls[0].function.name, "get_weather")
|
||||||
self.assertTrue(parser.current_tool_name_sent)
|
self.assertTrue(parser.current_tool_name_sent)
|
||||||
|
|
||||||
# --- Lines 236-237: name not sent and function_name is None ---
|
# --- Lines 236-237: name not sent and function_name is None ---
|
||||||
@@ -293,18 +323,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
def test_streaming_no_function_name(self):
|
def test_streaming_no_function_name(self):
|
||||||
"""Cover lines 236-237: parsed JSON has no 'name' field"""
|
"""Cover lines 236-237: parsed JSON has no 'name' field"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
results = self._simulate_streaming(
|
||||||
# Send JSON without name field
|
parser,
|
||||||
result = parser.extract_tool_calls_streaming(
|
[
|
||||||
"<tool_call>",
|
"<tool_call>", # start tool call
|
||||||
'<tool_call>{"arguments": {"k": "v"}}',
|
'{"arguments": {"k": "v"}}', # JSON without name field
|
||||||
'{"arguments": {"k": "v"}}',
|
],
|
||||||
[1],
|
|
||||||
[1, 10],
|
|
||||||
[10],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(results[1])
|
||||||
|
|
||||||
# --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) ---
|
# --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) ---
|
||||||
|
|
||||||
@@ -333,9 +359,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
parser.streamed_args_for_tool = [""]
|
parser.streamed_args_for_tool = [""]
|
||||||
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
||||||
result = parser.extract_tool_calls_streaming(
|
result = parser.extract_tool_calls_streaming(
|
||||||
'<tool_call>{"name":"fn","arguments":{"k":"v"}}',
|
'<tool_call>{"name":"fn","arguments":{"k":"v"',
|
||||||
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call>',
|
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call>',
|
||||||
'"}}</tool_call>',
|
"}}</tool_call>",
|
||||||
[1, 10],
|
[1, 10],
|
||||||
[1, 10, 2],
|
[1, 10, 2],
|
||||||
[2],
|
[2],
|
||||||
@@ -343,9 +369,14 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
self.assertIsNotNone(result.tool_calls)
|
self.assertIsNotNone(result.tool_calls)
|
||||||
|
self.assertEqual(result.tool_calls[0].function.arguments, "}")
|
||||||
|
|
||||||
def test_streaming_close_with_diff_no_end_marker(self):
|
def test_streaming_text_after_completed_tool_call(self):
|
||||||
"""Cover lines 184-185: close with arguments but no '"}' in delta_text"""
|
"""Cover lines 143-147: text content after a completed tool call.
|
||||||
|
|
||||||
|
When start==end counts, prev_end==cur_end, and end_token not in delta,
|
||||||
|
the parser treats delta as regular text content.
|
||||||
|
"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.current_tool_id = 0
|
parser.current_tool_id = 0
|
||||||
parser.current_tool_name_sent = True
|
parser.current_tool_name_sent = True
|
||||||
@@ -353,7 +384,7 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
||||||
# Simulate end token in delta but without '"}' pattern
|
# Simulate end token in delta but without '"}' pattern
|
||||||
# We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta
|
# We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta
|
||||||
# so that we enter the elif at 178
|
# so that we enter the text-content branch at line 143-147
|
||||||
result = parser.extract_tool_calls_streaming(
|
result = parser.extract_tool_calls_streaming(
|
||||||
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call>',
|
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call>',
|
||||||
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call> text',
|
'<tool_call>{"name":"fn","arguments":{"k":"v"}}</tool_call> text',
|
||||||
@@ -363,8 +394,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
[30],
|
[30],
|
||||||
self.dummy_request,
|
self.dummy_request,
|
||||||
)
|
)
|
||||||
# balanced counts, prev_end==cur_end, end not in delta -> returns content (line 147)
|
# balanced counts, prev_end==cur_end, end not in delta -> returns content (line 149)
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(result.content, " text")
|
||||||
|
|
||||||
def test_streaming_close_no_arguments(self):
|
def test_streaming_close_no_arguments(self):
|
||||||
"""Cover lines 182-183: close branch where prev arguments is None/empty"""
|
"""Cover lines 182-183: close branch where prev arguments is None/empty"""
|
||||||
@@ -382,8 +414,126 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
[2],
|
[2],
|
||||||
self.dummy_request,
|
self.dummy_request,
|
||||||
)
|
)
|
||||||
# diff is None (no arguments), so falls through to partial_json_parser
|
# diff is None (no arguments key in prev), falls through to partial_json_parser
|
||||||
self.assertTrue(result is None or isinstance(result, DeltaMessage))
|
# parses complete JSON, cur_args=None, prev_args=None -> no-args -> delta=None
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_streaming_close_with_empty_dict_arguments(self):
|
||||||
|
"""Regression: close branch must handle arguments={} (empty dict).
|
||||||
|
|
||||||
|
Before fix, `if diff:` was False for empty dict {}, so the close
|
||||||
|
logic was skipped. After fix, `if diff is not None:` correctly
|
||||||
|
enters the branch.
|
||||||
|
"""
|
||||||
|
parser = self._new_parser()
|
||||||
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
|
[
|
||||||
|
'<tool_call>{"name": "fn", "arguments": ', # start + name + args key
|
||||||
|
"{}", # empty dict value
|
||||||
|
"}", # outer close brace
|
||||||
|
"</tool_call>", # end token
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Step 1: name sent
|
||||||
|
# Step 2: first-args, cur_args={} is not None, prev_args=None
|
||||||
|
# Without fix: not {} == True -> no-args branch -> returns None
|
||||||
|
# With fix: enters first-args -> streams "{}" -> DeltaMessage
|
||||||
|
self.assertIsNotNone(results[1])
|
||||||
|
self.assertIsNotNone(results[1].tool_calls)
|
||||||
|
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
|
||||||
|
|
||||||
|
def test_streaming_empty_arguments_with_outer_brace_in_same_token(self):
|
||||||
|
"""Regression: when arguments={} and outer } arrive in the same token '{}}',
|
||||||
|
regex (.*) over-captures the outer brace, producing '{}}'.
|
||||||
|
|
||||||
|
Real production data showed arguments='{}}}' for get_default_weather
|
||||||
|
with empty arguments. This test reproduces that exact scenario.
|
||||||
|
"""
|
||||||
|
parser = self._new_parser()
|
||||||
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
|
[
|
||||||
|
'<tool_call>{"name": "get_default_weather", "arguments": ', # start + name + args key
|
||||||
|
"{}}", # empty args + outer close brace in same token
|
||||||
|
"</tool_call>", # end token
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Step 1: name sent
|
||||||
|
self.assertIsNotNone(results[0])
|
||||||
|
self.assertEqual(results[0].tool_calls[0].function.name, "get_default_weather")
|
||||||
|
# Step 2: first-args branch, tool_call_portion is complete JSON
|
||||||
|
# regex (.*) captures '{}}' but fix strips outer '}' -> '{}'
|
||||||
|
self.assertIsNotNone(results[1])
|
||||||
|
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
|
||||||
|
# Step 3: end token, close branch
|
||||||
|
# diff = prev_arguments = {} (not None), delta_text = '' (empty after split)
|
||||||
|
# '}' not in '' -> returns None
|
||||||
|
self.assertIsNone(results[2])
|
||||||
|
|
||||||
|
def test_streaming_close_with_number_ending_arguments(self):
|
||||||
|
"""Regression: close branch must flush remaining args ending with number.
|
||||||
|
|
||||||
|
Before fix, '"}' not in delta was True for numbers, causing return None.
|
||||||
|
After fix, rindex('}') correctly finds the closing brace.
|
||||||
|
"""
|
||||||
|
parser = self._new_parser()
|
||||||
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
|
[
|
||||||
|
'<tool_call>{"name": "fn", "arguments": {"count": ', # start + name + args key
|
||||||
|
"123", # number value
|
||||||
|
"}}</tool_call>", # close braces + end token
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Step 1: name sent
|
||||||
|
# Step 2: first-args, streams {"count": 123
|
||||||
|
# Step 3: close branch flushes remaining "}"
|
||||||
|
streamed_args = [
|
||||||
|
r.tool_calls[0].function.arguments
|
||||||
|
for r in results
|
||||||
|
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
|
||||||
|
]
|
||||||
|
combined = "".join(streamed_args)
|
||||||
|
self.assertEqual(combined, '{"count": 123}')
|
||||||
|
|
||||||
|
def test_streaming_close_with_boolean_ending_arguments(self):
|
||||||
|
"""Regression: close branch must flush remaining args ending with boolean."""
|
||||||
|
parser = self._new_parser()
|
||||||
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
|
[
|
||||||
|
'<tool_call>{"name": "fn", "arguments": {"flag": ', # start + args key
|
||||||
|
"true", # boolean value
|
||||||
|
"}}</tool_call>", # close + end token
|
||||||
|
],
|
||||||
|
)
|
||||||
|
streamed_args = [
|
||||||
|
r.tool_calls[0].function.arguments
|
||||||
|
for r in results
|
||||||
|
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
|
||||||
|
]
|
||||||
|
combined = "".join(streamed_args)
|
||||||
|
self.assertEqual(combined, '{"flag": true}')
|
||||||
|
|
||||||
|
def test_streaming_close_with_nested_object_ending(self):
|
||||||
|
"""Regression: close branch must flush remaining args ending with nested '}'."""
|
||||||
|
parser = self._new_parser()
|
||||||
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
|
[
|
||||||
|
'<tool_call>{"name": "fn", "arguments": {"nested": {"a": ', # start + args key
|
||||||
|
"1", # nested value
|
||||||
|
"}}}</tool_call>", # close all + end token
|
||||||
|
],
|
||||||
|
)
|
||||||
|
streamed_args = [
|
||||||
|
r.tool_calls[0].function.arguments
|
||||||
|
for r in results
|
||||||
|
if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None
|
||||||
|
]
|
||||||
|
combined = "".join(streamed_args)
|
||||||
|
self.assertEqual(combined, '{"nested": {"a": 1}}')
|
||||||
|
|
||||||
# --- Lines 202-206: else branch (cur_start < cur_end, edge case) ---
|
# --- Lines 202-206: else branch (cur_start < cur_end, edge case) ---
|
||||||
|
|
||||||
@@ -404,23 +554,21 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
def test_streaming_malformed_json(self):
|
def test_streaming_malformed_json(self):
|
||||||
"""Cover lines 213-215: MalformedJSON from partial parser"""
|
"""Cover lines 213-215: MalformedJSON from partial parser"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
results = self._simulate_streaming(
|
||||||
# Feed badly formed content
|
parser,
|
||||||
result = parser.extract_tool_calls_streaming(
|
[
|
||||||
"<tool_call>",
|
"<tool_call>", # start tool call
|
||||||
"<tool_call>{{{",
|
"{{{", # badly formed content
|
||||||
"{{{",
|
],
|
||||||
[1],
|
|
||||||
[1, 10],
|
|
||||||
[10],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(results[1])
|
||||||
|
|
||||||
def test_streaming_json_decode_error(self):
|
def test_streaming_json_decode_error(self):
|
||||||
"""Cover lines 216-218: JSONDecodeError from partial parser"""
|
"""Cover lines 216-218: JSONDecodeError from partial parser"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.extract_tool_calls_streaming("", "<tool_call>", "<tool_call>", [], [1], [1], self.dummy_request)
|
# Step 1: start tool call normally
|
||||||
|
self._simulate_streaming(parser, ["<tool_call>"])
|
||||||
|
# Step 2: mock partial_json_parser to throw ValueError
|
||||||
with patch(
|
with patch(
|
||||||
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
|
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
|
||||||
side_effect=ValueError("bad json"),
|
side_effect=ValueError("bad json"),
|
||||||
@@ -430,8 +578,8 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
"<tool_call>bad",
|
"<tool_call>bad",
|
||||||
"bad",
|
"bad",
|
||||||
[1],
|
[1],
|
||||||
[1, 10],
|
[1, 2],
|
||||||
[10],
|
[2],
|
||||||
self.dummy_request,
|
self.dummy_request,
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
@@ -469,30 +617,17 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
def test_streaming_first_arguments_with_regex_match(self):
|
def test_streaming_first_arguments_with_regex_match(self):
|
||||||
"""Cover lines 243-244, 257-286: first arguments appear, regex matches"""
|
"""Cover lines 243-244, 257-286: first arguments appear, regex matches"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
# Start tool call and send name
|
results = self._simulate_streaming(
|
||||||
parser.extract_tool_calls_streaming(
|
parser,
|
||||||
"",
|
[
|
||||||
'<tool_call>{"name": "get_weather"',
|
'<tool_call>{"name": "get_weather", "arguments": {"location": "', # start + name + args key
|
||||||
'<tool_call>{"name": "get_weather"',
|
"bei", # args value
|
||||||
[],
|
],
|
||||||
[1, 10],
|
|
||||||
[1, 10],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
# Now stream arguments (first time)
|
# Step 1: name sent
|
||||||
# Key must be complete (closing quote) so partial_json_parser returns truthy arguments.
|
# Step 2: first-args, regex finds "bei" in '{"location": "bei'
|
||||||
# delta must be a substring of the regex-extracted arguments portion (after "arguments":).
|
self.assertIsNotNone(results[1])
|
||||||
result = parser.extract_tool_calls_streaming(
|
self.assertEqual(results[1].tool_calls[0].function.arguments, '{"location": "bei')
|
||||||
'<tool_call>{"name": "get_weather"',
|
|
||||||
'<tool_call>{"name": "get_weather", "arguments": {"location": "bei',
|
|
||||||
'"bei',
|
|
||||||
[1, 10],
|
|
||||||
[1, 10, 20],
|
|
||||||
[20],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(result)
|
|
||||||
self.assertIsNotNone(result.tool_calls)
|
|
||||||
|
|
||||||
def test_streaming_first_arguments_no_regex_match(self):
|
def test_streaming_first_arguments_no_regex_match(self):
|
||||||
"""Cover lines 266-267: regex doesn't match, fallback to json.dumps"""
|
"""Cover lines 266-267: regex doesn't match, fallback to json.dumps"""
|
||||||
@@ -522,67 +657,119 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
self.assertIsNotNone(result.tool_calls)
|
self.assertIsNotNone(result.tool_calls)
|
||||||
|
|
||||||
def test_streaming_first_arguments_delta_not_in_json(self):
|
def test_streaming_first_arguments_delta_not_in_json(self):
|
||||||
"""Cover lines 271-272: delta_text not found in cur_arguments_json"""
|
"""Cover lines 275-276: delta_text not found in cur_arguments_json, returns None.
|
||||||
|
When delta contains the arguments key itself (e.g. ', "arguments": {'),
|
||||||
|
regex extracts cur_arguments_json='{' but delta ', "arguments": {' is not in '{'.
|
||||||
|
"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.extract_tool_calls_streaming(
|
results = self._simulate_streaming(
|
||||||
"",
|
parser,
|
||||||
'<tool_call>{"name": "fn"',
|
[
|
||||||
'<tool_call>{"name": "fn"',
|
'<tool_call>{"name": "fn"', # start + partial name
|
||||||
[],
|
', "arguments": {', # delta introduces arguments key + open brace
|
||||||
[1, 10],
|
],
|
||||||
[1, 10],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
# Delta text that doesn't appear in the arguments JSON
|
# Step 1: name sent
|
||||||
result = parser.extract_tool_calls_streaming(
|
self.assertIsNotNone(results[0])
|
||||||
'<tool_call>{"name": "fn"',
|
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
|
# Step 2: first-args branch, regex extracts cur_arguments_json='{'
|
||||||
"ZZZZZ",
|
# delta_text=', "arguments": {' is NOT in '{' -> returns None
|
||||||
[1, 10],
|
self.assertIsNone(results[1])
|
||||||
[1, 10, 20],
|
|
||||||
[20],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
self.assertIsNone(result)
|
|
||||||
|
|
||||||
# --- Lines 249-251: no cur_arguments and no prev_arguments ---
|
# --- Lines 249-251: no cur_arguments and no prev_arguments ---
|
||||||
|
|
||||||
def test_streaming_no_arguments_at_all(self):
|
def test_streaming_no_arguments_at_all(self):
|
||||||
"""Cover lines 249-251: both cur and prev arguments are empty/None"""
|
"""Cover lines 249-251: both cur and prev arguments are empty/None"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.extract_tool_calls_streaming(
|
results = self._simulate_streaming(
|
||||||
"",
|
parser,
|
||||||
'<tool_call>{"name": "fn"',
|
[
|
||||||
'<tool_call>{"name": "fn"',
|
'<tool_call>{"name": "fn"', # start + name
|
||||||
[],
|
"}", # close JSON, no arguments
|
||||||
[1, 10],
|
],
|
||||||
[1, 10],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
# Continue with name only, no arguments
|
|
||||||
result = parser.extract_tool_calls_streaming(
|
|
||||||
'<tool_call>{"name": "fn"',
|
|
||||||
'<tool_call>{"name": "fn"}',
|
|
||||||
"}",
|
|
||||||
[1, 10],
|
|
||||||
[1, 10, 20],
|
|
||||||
[20],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
# prev_arguments=None, cur_arguments=None -> delta=None
|
# prev_arguments=None, cur_arguments=None -> delta=None
|
||||||
# then prev_tool_call_arr updated and returns delta (which is None)
|
self.assertIsNone(results[1])
|
||||||
self.assertIsNone(result)
|
|
||||||
|
def test_streaming_empty_dict_arguments_not_skipped(self):
|
||||||
|
"""Regression: arguments={} (empty dict) must not be treated as no arguments.
|
||||||
|
|
||||||
|
Empty dict is falsy in Python (`not {} == True`). Before the fix,
|
||||||
|
this caused empty arguments to enter the no-arguments branch,
|
||||||
|
silently dropping them during streaming.
|
||||||
|
"""
|
||||||
|
parser = self._new_parser()
|
||||||
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
|
[
|
||||||
|
'<tool_call>{"name": "fn", "arguments": ', # start + name + args key
|
||||||
|
"{}", # empty dict value
|
||||||
|
"}", # outer close brace
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Step 1: name sent
|
||||||
|
# Step 2: cur_arguments={} (not None), prev_arguments=None
|
||||||
|
# With fix: enters first-arguments branch -> streams "{}"
|
||||||
|
# Without fix: not {} == True -> no-arguments branch -> delta=None
|
||||||
|
self.assertIsNotNone(results[1])
|
||||||
|
self.assertIsNotNone(results[1].tool_calls)
|
||||||
|
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
|
||||||
|
|
||||||
|
def test_streaming_empty_dict_prev_arguments_not_reset(self):
|
||||||
|
"""Regression: prev_arguments={} must not be treated as no arguments.
|
||||||
|
|
||||||
|
When prev has {} and cur has a non-empty dict, the code should enter
|
||||||
|
the both-have-arguments branch, not the first-arguments branch.
|
||||||
|
|
||||||
|
This scenario (arguments growing from {} to non-empty) is hard to
|
||||||
|
produce naturally, so we build up state through a real flow then
|
||||||
|
verify the branch behavior with one additional call.
|
||||||
|
"""
|
||||||
|
parser = self._new_parser()
|
||||||
|
# Build up state naturally: prev_tool_call_arr gets arguments={}
|
||||||
|
self._simulate_streaming(
|
||||||
|
parser,
|
||||||
|
[
|
||||||
|
'<tool_call>{"name": "fn", "arguments": ', # name + args key
|
||||||
|
"{}", # empty dict value
|
||||||
|
"}", # outer close
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Verify state is correct
|
||||||
|
self.assertEqual(parser.prev_tool_call_arr[0].get("arguments"), {})
|
||||||
|
|
||||||
|
# Now test: if more argument data arrives, prev_args={} should be
|
||||||
|
# treated as "not None" -> enters both-have-arguments branch
|
||||||
|
# Without fix: not {} == True -> first-arguments branch (wrong)
|
||||||
|
result = parser.extract_tool_calls_streaming(
|
||||||
|
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
||||||
|
'<tool_call>{"name": "fn", "arguments": {"k": "val',
|
||||||
|
"al",
|
||||||
|
[1, 2, 3],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[4],
|
||||||
|
self.dummy_request,
|
||||||
|
)
|
||||||
|
# both-have-arguments branch: delta_text="al" streamed as arguments
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(result.tool_calls[0].function.arguments, "al")
|
||||||
|
|
||||||
# --- Lines 253-255: cur_arguments reset (impossible branch) ---
|
# --- Lines 253-255: cur_arguments reset (impossible branch) ---
|
||||||
|
|
||||||
def test_streaming_arguments_reset_mid_call(self):
|
def test_streaming_arguments_reset_mid_call(self):
|
||||||
"""Cover lines 253-255: prev has arguments but cur doesn't (impossible case)"""
|
"""Cover lines 253-255: prev has arguments but cur doesn't (impossible case).
|
||||||
|
|
||||||
|
This is an edge case that shouldn't happen in normal flow, but tests
|
||||||
|
defensive handling when partial parser returns no arguments after
|
||||||
|
previously having them.
|
||||||
|
"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.current_tool_id = 0
|
parser.current_tool_id = 0
|
||||||
parser.current_tool_name_sent = True
|
parser.current_tool_name_sent = True
|
||||||
parser.streamed_args_for_tool = [""]
|
parser.streamed_args_for_tool = [""]
|
||||||
|
# Simulate state where prev already had arguments
|
||||||
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}]
|
||||||
# Feed content where cur has no arguments but prev does
|
# Mock parser to return no arguments (simulating the impossible reset)
|
||||||
with patch(
|
with patch(
|
||||||
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
|
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads",
|
||||||
return_value={"name": "fn"},
|
return_value={"name": "fn"},
|
||||||
@@ -591,9 +778,9 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"',
|
'<tool_call>{"name": "fn", "arguments": {"k": "v"',
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
|
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
|
||||||
'"}',
|
'"}',
|
||||||
[1, 10],
|
[1, 2],
|
||||||
[1, 10, 20],
|
[1, 2, 3],
|
||||||
[20],
|
[3],
|
||||||
self.dummy_request,
|
self.dummy_request,
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
@@ -603,110 +790,48 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
def test_streaming_incremental_arguments_incomplete(self):
|
def test_streaming_incremental_arguments_incomplete(self):
|
||||||
"""Cover lines 288-314: both prev and cur have arguments, JSON incomplete"""
|
"""Cover lines 288-314: both prev and cur have arguments, JSON incomplete"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.extract_tool_calls_streaming(
|
results = self._simulate_streaming(
|
||||||
"",
|
parser,
|
||||||
'<tool_call>{"name": "fn"',
|
[
|
||||||
'<tool_call>{"name": "fn"',
|
'<tool_call>{"name": "fn", "arguments": {"k": "v', # start + name + first args
|
||||||
[],
|
"a", # establishes prev_args
|
||||||
[1, 10],
|
"l", # incremental: both-have-args
|
||||||
[1, 10],
|
],
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
# First arguments - delta must appear in regex-extracted arguments portion
|
# Step 1: name sent
|
||||||
parser.extract_tool_calls_streaming(
|
# Step 2: first-args branch
|
||||||
'<tool_call>{"name": "fn"',
|
# Step 3: both-have-args branch, streams "l"
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
self.assertIsNotNone(results[2])
|
||||||
'{"k": "v',
|
self.assertEqual(results[2].tool_calls[0].function.arguments, "l")
|
||||||
[1, 10],
|
|
||||||
[1, 10, 20],
|
|
||||||
[20],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
# More argument tokens (both prev and cur have arguments now)
|
|
||||||
result = parser.extract_tool_calls_streaming(
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "val',
|
|
||||||
"al",
|
|
||||||
[1, 10, 20],
|
|
||||||
[1, 10, 20, 30],
|
|
||||||
[30],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(result)
|
|
||||||
self.assertEqual(result.tool_calls[0].function.arguments, "al")
|
|
||||||
|
|
||||||
def test_streaming_incremental_arguments_complete_json(self):
|
def test_streaming_incremental_arguments_complete_json(self):
|
||||||
"""Cover lines 289-305: complete JSON with trailing }"""
|
"""Cover lines 289-305: complete JSON with trailing }"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.extract_tool_calls_streaming(
|
results = self._simulate_streaming(
|
||||||
"",
|
parser,
|
||||||
'<tool_call>{"name": "fn"',
|
[
|
||||||
'<tool_call>{"name": "fn"',
|
'<tool_call>{"name": "fn", "arguments": {"k": "v', # start + name + first args
|
||||||
[],
|
"a", # establishes prev_args
|
||||||
[1, 10],
|
'"}}', # completes JSON
|
||||||
[1, 10],
|
],
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
# First arguments - delta must appear in regex-extracted arguments portion
|
# Step 3: both-have-args, complete JSON, strips trailing } -> streams '"}'
|
||||||
parser.extract_tool_calls_streaming(
|
self.assertIsNotNone(results[2])
|
||||||
'<tool_call>{"name": "fn"',
|
self.assertIsInstance(results[2], DeltaMessage)
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
|
||||||
'{"k": "v',
|
|
||||||
[1, 10],
|
|
||||||
[1, 10, 20],
|
|
||||||
[20],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
# Complete with closing braces - both prev and cur have arguments
|
|
||||||
result = parser.extract_tool_calls_streaming(
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v',
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
|
|
||||||
'"}}',
|
|
||||||
[1, 10, 20],
|
|
||||||
[1, 10, 20, 30],
|
|
||||||
[30],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
# is_complete_json=True, delta ends with }, should strip trailing }
|
|
||||||
# After strip: '"' which is not empty, so returns DeltaMessage
|
|
||||||
self.assertIsNotNone(result)
|
|
||||||
self.assertIsInstance(result, DeltaMessage)
|
|
||||||
|
|
||||||
def test_streaming_incremental_arguments_complete_empty_delta(self):
|
def test_streaming_incremental_arguments_complete_empty_delta(self):
|
||||||
"""Cover lines 304-305: complete JSON where delta becomes empty after strip"""
|
"""Cover lines 304-305: complete JSON where delta becomes empty after strip"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
parser.extract_tool_calls_streaming(
|
results = self._simulate_streaming(
|
||||||
"",
|
parser,
|
||||||
'<tool_call>{"name": "fn"',
|
[
|
||||||
'<tool_call>{"name": "fn"',
|
'<tool_call>{"name": "fn", "arguments": {"k": "v"', # start + name + first args
|
||||||
[],
|
"}", # inner close (establishes prev_args)
|
||||||
[1, 10],
|
"}", # outer close: both-have-args, complete, delta stripped to ""
|
||||||
[1, 10],
|
],
|
||||||
self.dummy_request,
|
|
||||||
)
|
)
|
||||||
# First arguments with proper delta
|
# Step 3: is_complete_json=True, delta="}" -> stripped to "" -> return None
|
||||||
parser.extract_tool_calls_streaming(
|
self.assertIsNone(results[2])
|
||||||
'<tool_call>{"name": "fn"',
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
|
|
||||||
'{"k": "v"}',
|
|
||||||
[1, 10],
|
|
||||||
[1, 10, 20],
|
|
||||||
[20],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
# Send just the outer closing brace
|
|
||||||
# tool_call_portion becomes complete JSON, delta="}" stripped to "" -> return None
|
|
||||||
result = parser.extract_tool_calls_streaming(
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}',
|
|
||||||
'<tool_call>{"name": "fn", "arguments": {"k": "v"}}',
|
|
||||||
"}",
|
|
||||||
[1, 10, 20],
|
|
||||||
[1, 10, 20, 30],
|
|
||||||
[30],
|
|
||||||
self.dummy_request,
|
|
||||||
)
|
|
||||||
# is_complete_json=True, delta="}" -> stripped to "" -> return None
|
|
||||||
self.assertIsNone(result)
|
|
||||||
|
|
||||||
# --- Lines 316-319: prev_tool_call_arr update branches ---
|
# --- Lines 316-319: prev_tool_call_arr update branches ---
|
||||||
|
|
||||||
@@ -759,95 +884,71 @@ class TestErnieX1ToolParser(unittest.TestCase):
|
|||||||
def test_streaming_full_flow(self):
|
def test_streaming_full_flow(self):
|
||||||
"""Integration test: simulate a full streaming tool call flow"""
|
"""Integration test: simulate a full streaming tool call flow"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
req = self.dummy_request
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
# Step 1: text before tool call
|
[
|
||||||
r = parser.extract_tool_calls_streaming("", "thinking", "thinking", [], [], [], req)
|
"thinking", # Step 1: text before tool call
|
||||||
self.assertEqual(r.content, "thinking")
|
"<tool_call>", # Step 2: tool_call start token
|
||||||
|
'{"name": "search", "arguments": {"query": "', # Step 3: name + args key
|
||||||
# Step 2: tool_call start token
|
"test", # Step 4: args value
|
||||||
r = parser.extract_tool_calls_streaming("thinking", "thinking<tool_call>", "<tool_call>", [], [1], [1], req)
|
" data", # Step 5: more args
|
||||||
self.assertIsNone(r)
|
],
|
||||||
|
|
||||||
# Step 3: function name appears
|
|
||||||
r = parser.extract_tool_calls_streaming(
|
|
||||||
"thinking<tool_call>",
|
|
||||||
'thinking<tool_call>{"name": "search"',
|
|
||||||
'{"name": "search"',
|
|
||||||
[1],
|
|
||||||
[1, 10],
|
|
||||||
[10],
|
|
||||||
req,
|
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(r)
|
# Step 1: plain text
|
||||||
self.assertEqual(r.tool_calls[0].function.name, "search")
|
self.assertEqual(results[0].content, "thinking")
|
||||||
|
# Step 2: start token -> None
|
||||||
# Step 4: arguments start - delta must appear in regex-extracted arguments portion
|
self.assertIsNone(results[1])
|
||||||
r = parser.extract_tool_calls_streaming(
|
# Step 3: name sent
|
||||||
'thinking<tool_call>{"name": "search"',
|
self.assertIsNotNone(results[2])
|
||||||
'thinking<tool_call>{"name": "search", "arguments": {"query": "test',
|
self.assertEqual(results[2].tool_calls[0].function.name, "search")
|
||||||
'{"query": "test',
|
# Step 4: first arguments
|
||||||
[1, 10],
|
self.assertIsNotNone(results[3])
|
||||||
[1, 10, 20],
|
self.assertEqual(results[3].tool_calls[0].function.arguments, '{"query": "test')
|
||||||
[20],
|
|
||||||
req,
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(r)
|
|
||||||
|
|
||||||
# Step 5: more arguments
|
# Step 5: more arguments
|
||||||
r = parser.extract_tool_calls_streaming(
|
self.assertIsNotNone(results[4])
|
||||||
'thinking<tool_call>{"name": "search", "arguments": {"query": "test',
|
self.assertEqual(results[4].tool_calls[0].function.arguments, " data")
|
||||||
'thinking<tool_call>{"name": "search", "arguments": {"query": "test data',
|
|
||||||
" data",
|
def test_streaming_empty_arguments_full_flow(self):
|
||||||
[1, 10, 20],
|
"""Integration: streaming tool call with arguments={} must not lose arguments.
|
||||||
[1, 10, 20, 30],
|
|
||||||
[30],
|
Simulates a complete streaming flow where the tool call has empty
|
||||||
req,
|
arguments. Verifies the name is sent and arguments are streamed.
|
||||||
|
"""
|
||||||
|
parser = self._new_parser()
|
||||||
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
|
[
|
||||||
|
'<tool_call>{"name": "fn", "arguments": ', # Step 1: start + name + args key
|
||||||
|
"{}", # Step 2: empty dict value
|
||||||
|
"}", # Step 3: outer close
|
||||||
|
"</tool_call>", # Step 4: end token
|
||||||
|
],
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(r)
|
# Step 1: name sent
|
||||||
self.assertEqual(r.tool_calls[0].function.arguments, " data")
|
self.assertIsNotNone(results[0])
|
||||||
|
self.assertEqual(results[0].tool_calls[0].function.name, "fn")
|
||||||
|
# Step 2: first-args with cur_args={}, streams "{}"
|
||||||
|
self.assertIsNotNone(results[1])
|
||||||
|
self.assertEqual(results[1].tool_calls[0].function.arguments, "{}")
|
||||||
|
# Step 4: close branch, delta_text="" after stripping </tool_call>
|
||||||
|
# diff={} is not None, but "}" not in "" -> return None
|
||||||
|
self.assertIsNone(results[2])
|
||||||
|
self.assertIsNone(results[3])
|
||||||
|
|
||||||
def test_streaming_multiple_tool_calls(self):
|
def test_streaming_multiple_tool_calls(self):
|
||||||
"""Integration test: two tool calls in one response"""
|
"""Integration test: two tool calls in one response"""
|
||||||
parser = self._new_parser()
|
parser = self._new_parser()
|
||||||
req = self.dummy_request
|
results = self._simulate_streaming(
|
||||||
|
parser,
|
||||||
# First tool call
|
[
|
||||||
parser.extract_tool_calls_streaming(
|
'<tool_call>{"name": "fn1"', # First tool: start + name
|
||||||
"",
|
"}</tool_call>", # Close first tool
|
||||||
'<tool_call>{"name": "fn1"',
|
'<tool_call>{"name": "fn2"', # Second tool: start + name
|
||||||
'<tool_call>{"name": "fn1"',
|
],
|
||||||
[],
|
|
||||||
[1, 10],
|
|
||||||
[1, 10],
|
|
||||||
req,
|
|
||||||
)
|
|
||||||
self.assertEqual(parser.current_tool_id, 0)
|
|
||||||
|
|
||||||
# Close first tool
|
|
||||||
parser.extract_tool_calls_streaming(
|
|
||||||
'<tool_call>{"name": "fn1"',
|
|
||||||
'<tool_call>{"name": "fn1"}</tool_call>',
|
|
||||||
"}</tool_call>",
|
|
||||||
[1, 10],
|
|
||||||
[1, 10, 2],
|
|
||||||
[2],
|
|
||||||
req,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Second tool call
|
|
||||||
r = parser.extract_tool_calls_streaming(
|
|
||||||
'<tool_call>{"name": "fn1"}</tool_call>',
|
|
||||||
'<tool_call>{"name": "fn1"}</tool_call><tool_call>{"name": "fn2"',
|
|
||||||
'<tool_call>{"name": "fn2"',
|
|
||||||
[1, 10, 2],
|
|
||||||
[1, 10, 2, 1, 20],
|
|
||||||
[1, 20],
|
|
||||||
req,
|
|
||||||
)
|
)
|
||||||
self.assertEqual(parser.current_tool_id, 1)
|
self.assertEqual(parser.current_tool_id, 1)
|
||||||
self.assertIsNotNone(r)
|
self.assertIsNotNone(results[2])
|
||||||
self.assertEqual(r.tool_calls[0].function.name, "fn2")
|
self.assertEqual(results[2].tool_calls[0].function.name, "fn2")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -91,10 +91,10 @@ class TestModel1(paddle.nn.Layer):
|
|||||||
|
|
||||||
return sublayer2_output
|
return sublayer2_output
|
||||||
|
|
||||||
def clear_grpah_opt_backend(self):
|
def clear_graph_opt_backend(self):
|
||||||
""" """
|
""" """
|
||||||
self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.sublayer1.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.sublayer2.clear_graph_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class TestCUDAGrpahRecapture(unittest.TestCase):
|
class TestCUDAGrpahRecapture(unittest.TestCase):
|
||||||
@@ -152,7 +152,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
|
|||||||
|
|
||||||
# Destroy
|
# Destroy
|
||||||
print_gpu_memory_use("before destroy", 0)
|
print_gpu_memory_use("before destroy", 0)
|
||||||
self.test_model1.clear_grpah_opt_backend()
|
self.test_model1.clear_graph_opt_backend()
|
||||||
print_gpu_memory_use("after destroy", 0)
|
print_gpu_memory_use("after destroy", 0)
|
||||||
|
|
||||||
def recapture_and_replay(self, input_tensor1, forward_meta1):
|
def recapture_and_replay(self, input_tensor1, forward_meta1):
|
||||||
@@ -168,7 +168,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
|
|||||||
|
|
||||||
# Destroy
|
# Destroy
|
||||||
print_gpu_memory_use("before destroy", 0)
|
print_gpu_memory_use("before destroy", 0)
|
||||||
self.test_model1.clear_grpah_opt_backend()
|
self.test_model1.clear_graph_opt_backend()
|
||||||
print_gpu_memory_use("after destroy", 0)
|
print_gpu_memory_use("after destroy", 0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
|||||||
self._fdconfig_patches = [
|
self._fdconfig_patches = [
|
||||||
patch.object(FDConfig, "read_from_config", return_value=None),
|
patch.object(FDConfig, "read_from_config", return_value=None),
|
||||||
patch.object(FDConfig, "postprocess", return_value=None),
|
patch.object(FDConfig, "postprocess", return_value=None),
|
||||||
patch.object(FDConfig, "init_cache_info", return_value=None),
|
patch.object(FDConfig, "init_pd_info", return_value=None),
|
||||||
patch.object(FDConfig, "check", return_value=None),
|
patch.object(FDConfig, "check", return_value=None),
|
||||||
]
|
]
|
||||||
for patcher in self._fdconfig_patches:
|
for patcher in self._fdconfig_patches:
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
return paddle_inputs
|
return paddle_inputs
|
||||||
|
|
||||||
|
|
||||||
def run_kernel(paddle_inputs, inputs):
|
def run_kernel(paddle_inputs):
|
||||||
"""Call the CUDA kernel."""
|
"""Call the CUDA kernel."""
|
||||||
speculate_set_stop_value_multi_seqs(
|
speculate_set_stop_value_multi_seqs(
|
||||||
paddle_inputs["accept_tokens"],
|
paddle_inputs["accept_tokens"],
|
||||||
@@ -137,7 +137,18 @@ def gen_inputs(
|
|||||||
|
|
||||||
|
|
||||||
def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Python reference — must match CUDA kernel logic exactly."""
|
"""Python reference — must match CUDA kernel logic exactly.
|
||||||
|
|
||||||
|
token_ids_all 布局 (新 step_idx 语义):
|
||||||
|
pre_ids_now[k] = 第 k 个 output token (k >= 0, 0-indexed)
|
||||||
|
最后一个 output token 在 pre_ids_now[step_idx - 1]
|
||||||
|
step_idx = 历史已生成的 token 数量
|
||||||
|
|
||||||
|
核心设计:
|
||||||
|
1. accept_idx 从 -1 开始,-1 表示检查 pre_ids 末尾(上一轮延迟的情况)
|
||||||
|
2. 主循环检查 accept_idx <= accept_num - 2
|
||||||
|
3. 匹配成功时: 保留 stop_seq 所有 token,在其后追加 eos
|
||||||
|
"""
|
||||||
accept_tokens = inputs["accept_tokens"].copy()
|
accept_tokens = inputs["accept_tokens"].copy()
|
||||||
accept_num = inputs["accept_num"].copy()
|
accept_num = inputs["accept_num"].copy()
|
||||||
stop_flags = inputs["stop_flags"].copy()
|
stop_flags = inputs["stop_flags"].copy()
|
||||||
@@ -166,27 +177,36 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
|
|||||||
step_idx_now = int(step_idx[bid])
|
step_idx_now = int(step_idx[bid])
|
||||||
min_token_limit = int(min_tokens[bid])
|
min_token_limit = int(min_tokens[bid])
|
||||||
|
|
||||||
can_stop = step_idx_now >= min_token_limit
|
can_stop = step_idx_now + an >= min_token_limit
|
||||||
if not can_stop:
|
if not can_stop:
|
||||||
continue
|
continue
|
||||||
if stop_flags[bid]:
|
if stop_flags[bid]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
accept_idx = 0
|
# CUDA kernel: accept_idx 从 -1 开始,检查 pre_ids 末尾
|
||||||
|
accept_idx = -1
|
||||||
is_end = False
|
is_end = False
|
||||||
while accept_idx <= an - 1 and not is_end:
|
|
||||||
|
# loop_end = accept_num > 0 ? accept_num - 2 : -1
|
||||||
|
loop_end = an - 2 if an > 0 else -1
|
||||||
|
while accept_idx <= loop_end and not is_end:
|
||||||
if step_idx_now + accept_idx + 1 < stop_seq_len:
|
if step_idx_now + accept_idx + 1 < stop_seq_len:
|
||||||
accept_idx += 1
|
accept_idx += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check one stop_seq match
|
# 从后向前匹配 stop_seq 的每个 token
|
||||||
for i in range(stop_seq_len - 1, -1, -1):
|
for i in range(stop_seq_len - 1, -1, -1):
|
||||||
|
offset = stop_seq_len - 1 - i
|
||||||
|
accept_tokens_idx = accept_idx - offset
|
||||||
cur_token_idx = -1
|
cur_token_idx = -1
|
||||||
if stop_seq_len - 1 - i < accept_idx:
|
|
||||||
cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]
|
if accept_tokens_idx >= 0:
|
||||||
|
cur_token_idx = accept_tokens_now[accept_tokens_idx]
|
||||||
else:
|
else:
|
||||||
pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i)
|
# 新语义: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||||
if pre_ids_idx <= 0:
|
# pre_ids_now[0] 是第 1 个 output token
|
||||||
|
pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||||
|
if pre_ids_idx < 0:
|
||||||
break
|
break
|
||||||
cur_token_idx = pre_ids_now[pre_ids_idx]
|
cur_token_idx = pre_ids_now[pre_ids_idx]
|
||||||
|
|
||||||
@@ -199,9 +219,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str
|
|||||||
accept_idx += 1
|
accept_idx += 1
|
||||||
|
|
||||||
if is_end:
|
if is_end:
|
||||||
accept_num[bid] = accept_idx
|
# accept_idx 已递增,指向 stop_seq 最后 token 的下一个位置
|
||||||
accept_tokens[bid, accept_idx - 1] = end_ids[0]
|
# 保留 stop_seq 所有 token,在其后追加 eos
|
||||||
# stop_flags[bid] = True # kernel no longer sets stop_flags
|
accept_num[bid] = accept_idx + 1
|
||||||
|
accept_tokens[bid, accept_idx] = end_ids[0]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"accept_tokens": accept_tokens,
|
"accept_tokens": accept_tokens,
|
||||||
@@ -239,7 +260,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
|
|
||||||
def _run_and_get(self, inputs):
|
def _run_and_get(self, inputs):
|
||||||
paddle_inputs = to_paddle_inputs(inputs)
|
paddle_inputs = to_paddle_inputs(inputs)
|
||||||
run_kernel(paddle_inputs, inputs)
|
run_kernel(paddle_inputs)
|
||||||
return get_outputs(paddle_inputs)
|
return get_outputs(paddle_inputs)
|
||||||
|
|
||||||
def _check_all_outputs(self, inputs, outputs):
|
def _check_all_outputs(self, inputs, outputs):
|
||||||
@@ -264,7 +285,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
self._run_full_test(test_cfg)
|
self._run_full_test(test_cfg)
|
||||||
|
|
||||||
def test_match_in_accept_tokens_only(self):
|
def test_match_in_accept_tokens_only(self):
|
||||||
"""Stop seq found entirely within accept_tokens."""
|
"""Stop seq found entirely within accept_tokens. Eos appended after stop_seq last token."""
|
||||||
inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10)
|
inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10)
|
||||||
# Place stop seq [A, B, C] at accept_tokens positions [0,1,2]
|
# Place stop seq [A, B, C] at accept_tokens positions [0,1,2]
|
||||||
inputs["accept_num"][:] = 4
|
inputs["accept_num"][:] = 4
|
||||||
@@ -276,9 +297,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["min_tokens"][:] = 0
|
inputs["min_tokens"][:] = 0
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
# stop_seq [10, 20, 30] matches at accept_idx=2 (window ends at accept_tokens[2]=30)
|
||||||
|
# After loop increment, accept_idx=3, accept_num=4, eos appended at accept_tokens[3]
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 4)
|
||||||
|
self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq
|
||||||
|
|
||||||
def test_match_spanning_pre_ids_and_accept(self):
|
def test_match_spanning_pre_ids_and_accept(self):
|
||||||
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens."""
|
"""Stop seq spans token_ids_all (pre_ids) and accept_tokens. Eos appended after stop_seq last token."""
|
||||||
inputs = gen_inputs(
|
inputs = gen_inputs(
|
||||||
real_bsz=1,
|
real_bsz=1,
|
||||||
accept_tokens_len=5,
|
accept_tokens_len=5,
|
||||||
@@ -290,12 +315,15 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["prompt_lens"][:] = 0
|
inputs["prompt_lens"][:] = 0
|
||||||
inputs["step_idx"][:] = 6
|
inputs["step_idx"][:] = 6
|
||||||
inputs["accept_num"][:] = 3
|
inputs["accept_num"][:] = 3
|
||||||
# Kernel matching at accept_idx=2 (3rd token, 0-indexed):
|
# stop_seq = [99, 11, 22] (len=3)
|
||||||
# i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1]
|
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||||
# i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0]
|
# pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||||
# i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6]
|
# step_idx = 6 表示有 6 个历史 output token,在 pre_ids_now[0..5]
|
||||||
# So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]]
|
# At accept_idx=1 (window ends at accept_tokens[1]=22):
|
||||||
inputs["token_ids_all"][0, 6] = 99
|
# i=2: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=22 vs stop_seq[2]=22 ✓
|
||||||
|
# i=1: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=11 vs stop_seq[1]=11 ✓
|
||||||
|
# i=0: offset=2, accept_tokens_idx=-1 -> pre_ids_idx=6+(-1)=5 -> pre_ids[5]=99 vs stop_seq[0]=99 ✓
|
||||||
|
inputs["token_ids_all"][0, 5] = 99 # pre_ids_now[5] = 第 6 个 output token (0-indexed)
|
||||||
inputs["accept_tokens"][0, :3] = [11, 22, 33]
|
inputs["accept_tokens"][0, :3] = [11, 22, 33]
|
||||||
inputs["stop_seqs"][0, 0, :3] = [99, 11, 22]
|
inputs["stop_seqs"][0, 0, :3] = [99, 11, 22]
|
||||||
inputs["stop_seqs_len"][0, 0] = 3
|
inputs["stop_seqs_len"][0, 0] = 3
|
||||||
@@ -303,12 +331,14 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["min_tokens"][:] = 0
|
inputs["min_tokens"][:] = 0
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
# Match at accept_idx=2, loop increments to 3
|
# Match at accept_idx=1, loop increments to 2 -> accept_num=3, eos at accept_tokens[2]
|
||||||
self.assertEqual(outputs["accept_num"][0], 3)
|
self.assertEqual(outputs["accept_num"][0], 3)
|
||||||
self.assertEqual(outputs["accept_tokens"][0, 2], -1)
|
self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq
|
||||||
|
|
||||||
def test_match_in_pre_ids_only(self):
|
def test_match_in_pre_ids_only_not_detected(self):
|
||||||
"""Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0."""
|
"""Stop seq ending purely in pre_ids history but NOT at the end position.
|
||||||
|
The kernel only detects stop_seq at the very end of pre_ids via accept_idx=-1 check.
|
||||||
|
Stop seq placed earlier in pre_ids should not be detected."""
|
||||||
inputs = gen_inputs(
|
inputs = gen_inputs(
|
||||||
real_bsz=1,
|
real_bsz=1,
|
||||||
accept_tokens_len=5,
|
accept_tokens_len=5,
|
||||||
@@ -320,15 +350,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["prompt_lens"][:] = 0
|
inputs["prompt_lens"][:] = 0
|
||||||
inputs["step_idx"][:] = 8
|
inputs["step_idx"][:] = 8
|
||||||
inputs["accept_num"][:] = 3
|
inputs["accept_num"][:] = 3
|
||||||
# pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70
|
# 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||||
# stop_seq = [50, 60, 70], all 3 tokens are in pre_ids
|
# step_idx = 8 表示有 8 个历史 output token,在 pre_ids_now[0..7]
|
||||||
# For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check
|
# accept_idx=-1 会检查 pre_ids_now[7] 开始的 stop_seq
|
||||||
# i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70
|
# 把 stop_seq 放在 pre_ids_now[2,3,4] - 不会被检测到
|
||||||
# i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60
|
inputs["token_ids_all"][0, 2] = 50
|
||||||
# i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50
|
inputs["token_ids_all"][0, 3] = 60
|
||||||
inputs["token_ids_all"][0, 6] = 50
|
inputs["token_ids_all"][0, 4] = 70
|
||||||
inputs["token_ids_all"][0, 7] = 60
|
|
||||||
inputs["token_ids_all"][0, 8] = 70
|
|
||||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
||||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
||||||
inputs["stop_seqs_len"][0, 0] = 3
|
inputs["stop_seqs_len"][0, 0] = 3
|
||||||
@@ -336,7 +364,8 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["min_tokens"][:] = 0
|
inputs["min_tokens"][:] = 0
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
self.assertEqual(outputs["accept_num"][0], 1)
|
# No match: stop_seq is in pre_ids but not at the end, accept_num unchanged
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 3)
|
||||||
|
|
||||||
def test_already_stopped(self):
|
def test_already_stopped(self):
|
||||||
"""Kernel skips sequences with stop_flags=True."""
|
"""Kernel skips sequences with stop_flags=True."""
|
||||||
@@ -351,7 +380,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"])
|
np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"])
|
||||||
|
|
||||||
def test_min_tokens_blocks_stop(self):
|
def test_min_tokens_blocks_stop(self):
|
||||||
"""Kernel skips stop check when step_idx < min_tokens."""
|
"""Kernel skips stop check when step_idx + accept_num < min_tokens."""
|
||||||
inputs = gen_inputs(
|
inputs = gen_inputs(
|
||||||
real_bsz=1,
|
real_bsz=1,
|
||||||
accept_tokens_len=5,
|
accept_tokens_len=5,
|
||||||
@@ -363,20 +392,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["prompt_lens"][:] = 0
|
inputs["prompt_lens"][:] = 0
|
||||||
inputs["step_idx"][:] = 8
|
inputs["step_idx"][:] = 8
|
||||||
inputs["accept_num"][:] = 3
|
inputs["accept_num"][:] = 3
|
||||||
# Same setup that would match (like test_match_in_pre_ids_only)
|
# Place stop_seq in pre_ids at end position (would be detected by accept_idx=-1)
|
||||||
inputs["token_ids_all"][0, 6] = 50
|
# pre_ids_now[0..7] = 8 个历史 output token
|
||||||
inputs["token_ids_all"][0, 7] = 60
|
# accept_idx=-1 检查 pre_ids_now[5,6,7] 对应 stop_seq[0,1,2]
|
||||||
inputs["token_ids_all"][0, 8] = 70
|
inputs["token_ids_all"][0, 5] = 50
|
||||||
|
inputs["token_ids_all"][0, 6] = 60
|
||||||
|
inputs["token_ids_all"][0, 7] = 70
|
||||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
||||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
||||||
inputs["stop_seqs_len"][0, 0] = 3
|
inputs["stop_seqs_len"][0, 0] = 3
|
||||||
inputs["stop_flags"][:] = False
|
inputs["stop_flags"][:] = False
|
||||||
inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop
|
inputs["min_tokens"][:] = 100 # step_idx+accept_num=11 < 100, should NOT stop
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
# min_tokens prevents stop, accept_num unchanged
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 3)
|
||||||
|
|
||||||
def test_min_tokens_allows_stop(self):
|
def test_min_tokens_allows_stop(self):
|
||||||
"""Kernel allows stop when step_idx >= min_tokens."""
|
"""Kernel allows stop when step_idx + accept_num >= min_tokens."""
|
||||||
inputs = gen_inputs(
|
inputs = gen_inputs(
|
||||||
real_bsz=1,
|
real_bsz=1,
|
||||||
accept_tokens_len=5,
|
accept_tokens_len=5,
|
||||||
@@ -388,15 +421,17 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["prompt_lens"][:] = 0
|
inputs["prompt_lens"][:] = 0
|
||||||
inputs["step_idx"][:] = 8
|
inputs["step_idx"][:] = 8
|
||||||
inputs["accept_num"][:] = 3
|
inputs["accept_num"][:] = 3
|
||||||
# Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only)
|
# stop_seq [X, 50] spans pre_ids and accept_tokens[0].
|
||||||
inputs["token_ids_all"][0, 6] = 50
|
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||||
inputs["token_ids_all"][0, 7] = 60
|
# At accept_idx=0 (window ends at accept_tokens[0]=50):
|
||||||
inputs["token_ids_all"][0, 8] = 70
|
# i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=50 vs stop_seq[1]=50 ✓
|
||||||
inputs["accept_tokens"][0, :3] = [1, 2, 3]
|
# i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=8+(-1)=7 -> pre_ids[7]
|
||||||
inputs["stop_seqs"][0, 0, :3] = [50, 60, 70]
|
pre_val = int(inputs["token_ids_all"][0, 7]) # pre_ids_now[7]
|
||||||
inputs["stop_seqs_len"][0, 0] = 3
|
inputs["accept_tokens"][0, :3] = [50, 60, 70]
|
||||||
|
inputs["stop_seqs"][0, 0, :2] = [pre_val, 50]
|
||||||
|
inputs["stop_seqs_len"][0, 0] = 2
|
||||||
inputs["stop_flags"][:] = False
|
inputs["stop_flags"][:] = False
|
||||||
inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop
|
inputs["min_tokens"][:] = 5 # step_idx+accept_num=11 >= 5, should stop
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
@@ -413,20 +448,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["prompt_lens"][:] = 0
|
inputs["prompt_lens"][:] = 0
|
||||||
inputs["step_idx"][:] = 8
|
inputs["step_idx"][:] = 8
|
||||||
inputs["accept_num"][:] = 3
|
inputs["accept_num"][:] = 3
|
||||||
# accept_tokens: stop_seq[20,30] matches at accept_idx=2:
|
# accept_tokens: [20, 30, 40]
|
||||||
# i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK
|
# Second stop seq [20, 30] matches at accept_idx=1 (window ends at accept_tokens[1]=30):
|
||||||
# i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK
|
# i=1: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=30 vs stop_seq[1]=30 ✓
|
||||||
|
# i=0: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=20 vs stop_seq[0]=20 ✓
|
||||||
inputs["accept_tokens"][0, :3] = [20, 30, 40]
|
inputs["accept_tokens"][0, :3] = [20, 30, 40]
|
||||||
# First stop seq doesn't match
|
# First stop seq doesn't match
|
||||||
inputs["stop_seqs"][0, 0, :3] = [99, 98, 97]
|
inputs["stop_seqs"][0, 0, :3] = [99, 98, 97]
|
||||||
inputs["stop_seqs_len"][0, 0] = 3
|
inputs["stop_seqs_len"][0, 0] = 3
|
||||||
# Second stop seq matches
|
# Second stop seq [20, 30] matches
|
||||||
inputs["stop_seqs"][0, 1, :2] = [20, 30]
|
inputs["stop_seqs"][0, 1, :2] = [20, 30]
|
||||||
inputs["stop_seqs_len"][0, 1] = 2
|
inputs["stop_seqs_len"][0, 1] = 2
|
||||||
inputs["stop_flags"][:] = False
|
inputs["stop_flags"][:] = False
|
||||||
inputs["min_tokens"][:] = 0
|
inputs["min_tokens"][:] = 0
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
# Match at accept_idx=1 -> accept_num=3, eos at accept_tokens[2]
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 3)
|
||||||
|
self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq
|
||||||
|
|
||||||
def test_nonzero_prompt_lens(self):
|
def test_nonzero_prompt_lens(self):
|
||||||
"""Verify prompt_lens offset is applied correctly."""
|
"""Verify prompt_lens offset is applied correctly."""
|
||||||
@@ -444,19 +483,104 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase):
|
|||||||
inputs["accept_num"][:] = 2
|
inputs["accept_num"][:] = 2
|
||||||
inputs["accept_tokens"][0, :2] = [55, 66]
|
inputs["accept_tokens"][0, :2] = [55, 66]
|
||||||
# pre_ids_now starts at token_ids_all[0, prompt_len:]
|
# pre_ids_now starts at token_ids_all[0, prompt_len:]
|
||||||
# stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx]
|
# pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||||
# For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4
|
# 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx
|
||||||
# -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4]
|
# stop_seq = [X, 55] where X = pre_ids_now[5 + (-1)] = pre_ids_now[4]
|
||||||
# For accept_idx=1 (second token is accept_tokens[0,0]=55):
|
# At accept_idx=0 (window ends at accept_tokens[0]=55):
|
||||||
# i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55
|
# i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=55 vs stop_seq[1]=55 ✓
|
||||||
# i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5]
|
# i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=5+(-1)=4 -> pre_ids[4]=token_ids_all[0, prompt_len+4]
|
||||||
target_val = int(inputs["token_ids_all"][0, prompt_len + 5])
|
target_val = int(inputs["token_ids_all"][0, prompt_len + 4])
|
||||||
inputs["stop_seqs"][0, 0, :2] = [target_val, 55]
|
inputs["stop_seqs"][0, 0, :2] = [target_val, 55]
|
||||||
inputs["stop_seqs_len"][0, 0] = 2
|
inputs["stop_seqs_len"][0, 0] = 2
|
||||||
inputs["stop_flags"][:] = False
|
inputs["stop_flags"][:] = False
|
||||||
inputs["min_tokens"][:] = 0
|
inputs["min_tokens"][:] = 0
|
||||||
outputs = self._run_and_get(inputs)
|
outputs = self._run_and_get(inputs)
|
||||||
self._check_all_outputs(inputs, outputs)
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
# Match at accept_idx=0 -> accept_num=2, eos at accept_tokens[1]
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 2)
|
||||||
|
self.assertEqual(outputs["accept_tokens"][0, 1], -1) # eos appended after stop_seq
|
||||||
|
|
||||||
|
def test_single_token_stop_seq_preserved(self):
|
||||||
|
"""Single token stop_seq (like <|im_end|>) with eos appended after it."""
|
||||||
|
inputs = gen_inputs(
|
||||||
|
real_bsz=1,
|
||||||
|
accept_tokens_len=5,
|
||||||
|
max_model_len=32,
|
||||||
|
stop_seqs_bs=1,
|
||||||
|
stop_seqs_max_len=1,
|
||||||
|
seed=90,
|
||||||
|
)
|
||||||
|
inputs["prompt_lens"][:] = 0
|
||||||
|
inputs["step_idx"][:] = 10
|
||||||
|
inputs["accept_num"][:] = 4
|
||||||
|
# accept_tokens: [a, b, <|im_end|>, d] where <|im_end|> has token id 999
|
||||||
|
inputs["accept_tokens"][0, :4] = [100, 200, 999, 300]
|
||||||
|
# stop_seq = [<|im_end|>] (single token)
|
||||||
|
inputs["stop_seqs"][0, 0, 0] = 999
|
||||||
|
inputs["stop_seqs_len"][0, 0] = 1
|
||||||
|
inputs["stop_flags"][:] = False
|
||||||
|
inputs["min_tokens"][:] = 0
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
# Match at accept_idx=2 (window ends at accept_tokens[2]=999)
|
||||||
|
# After loop increment, accept_idx=3, accept_num=4, eos at accept_tokens[3]
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 4)
|
||||||
|
self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq
|
||||||
|
|
||||||
|
def test_stop_seq_at_last_position_not_detected(self):
|
||||||
|
"""Stop seq at the last position of accept_tokens is NOT detected (deferred to next round)."""
|
||||||
|
inputs = gen_inputs(
|
||||||
|
real_bsz=1,
|
||||||
|
accept_tokens_len=5,
|
||||||
|
max_model_len=32,
|
||||||
|
stop_seqs_bs=1,
|
||||||
|
stop_seqs_max_len=1,
|
||||||
|
seed=100,
|
||||||
|
)
|
||||||
|
inputs["prompt_lens"][:] = 0
|
||||||
|
inputs["step_idx"][:] = 10
|
||||||
|
inputs["accept_num"][:] = 4
|
||||||
|
# stop_seq [999] is at accept_tokens[3] (last valid position)
|
||||||
|
# Since we only check up to accept_num - 2 = 2, this won't be detected
|
||||||
|
inputs["accept_tokens"][0, :4] = [100, 200, 300, 999]
|
||||||
|
inputs["stop_seqs"][0, 0, 0] = 999
|
||||||
|
inputs["stop_seqs_len"][0, 0] = 1
|
||||||
|
inputs["stop_flags"][:] = False
|
||||||
|
inputs["min_tokens"][:] = 0
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
# No match because accept_idx only goes up to 2, and 999 is at position 3
|
||||||
|
# accept_num unchanged
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 4)
|
||||||
|
|
||||||
|
def test_stop_seq_detected_from_previous_round(self):
|
||||||
|
"""Stop seq at the end of pre_ids (from previous round) is detected via accept_idx=-1."""
|
||||||
|
inputs = gen_inputs(
|
||||||
|
real_bsz=1,
|
||||||
|
accept_tokens_len=5,
|
||||||
|
max_model_len=32,
|
||||||
|
stop_seqs_bs=1,
|
||||||
|
stop_seqs_max_len=1,
|
||||||
|
seed=110,
|
||||||
|
)
|
||||||
|
inputs["prompt_lens"][:] = 0
|
||||||
|
# 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0)
|
||||||
|
# step_idx = 10 表示有 10 个历史 output token,在 pre_ids_now[0..9]
|
||||||
|
# accept_idx=-1 检查 pre_ids_now[9] (最后一个历史 token)
|
||||||
|
inputs["step_idx"][:] = 10
|
||||||
|
inputs["token_ids_all"][0, 9] = 999 # pre_ids_now[9] = 第 10 个 output token (0-indexed)
|
||||||
|
inputs["accept_num"][:] = 3
|
||||||
|
inputs["accept_tokens"][0, :3] = [100, 200, 300]
|
||||||
|
inputs["stop_seqs"][0, 0, 0] = 999
|
||||||
|
inputs["stop_seqs_len"][0, 0] = 1
|
||||||
|
inputs["stop_flags"][:] = False
|
||||||
|
inputs["min_tokens"][:] = 0
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
# stop_seq [999] was in pre_ids at end, accept_idx=-1 matches
|
||||||
|
# After loop increment, accept_idx=0, accept_num=1, eos at accept_tokens[0]
|
||||||
|
self.assertEqual(outputs["accept_num"][0], 1)
|
||||||
|
self.assertEqual(outputs["accept_tokens"][0, 0], -1) # replaced with eos
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -261,7 +261,9 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
# Write history to token_ids_all (forward loop, mirrors kernel step 5)
|
# Write history to token_ids_all (forward loop, mirrors kernel step 5)
|
||||||
if output_len > 0:
|
if output_len > 0:
|
||||||
base_addr = int(prompt_lens[batch_id])
|
base_addr = int(prompt_lens[batch_id])
|
||||||
base = cur_step_idx - output_len + 1
|
# 新语义: step_idx 入口 = 历史数量,处理后 cur_step_idx = 历史 + output_len
|
||||||
|
# 第一个 output token 写入位置 = cur_step_idx - output_len
|
||||||
|
base = cur_step_idx - output_len
|
||||||
for i in range(output_len):
|
for i in range(output_len):
|
||||||
write_idx = base_addr + base + i
|
write_idx = base_addr + base + i
|
||||||
if 0 <= write_idx < max_model_len:
|
if 0 <= write_idx < max_model_len:
|
||||||
|
|||||||
@@ -411,6 +411,32 @@ class TestDPLocalScheduler(unittest.TestCase):
|
|||||||
self.assertEqual(scheduler.ids, ["fresh_req"])
|
self.assertEqual(scheduler.ids, ["fresh_req"])
|
||||||
self.assertEqual(scheduler.ids_read_cursor, 1)
|
self.assertEqual(scheduler.ids_read_cursor, 1)
|
||||||
|
|
||||||
|
def test_get_requests_insufficient_resources(self):
|
||||||
|
"""Test getting requests when resources are insufficient."""
|
||||||
|
mock_logger.reset_mock()
|
||||||
|
|
||||||
|
# Test with insufficient blocks - mock the condition variable to avoid threading issues
|
||||||
|
with patch.object(self.scheduler, "requests_not_empty"):
|
||||||
|
requests = self.scheduler.get_requests(
|
||||||
|
available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(requests, [])
|
||||||
|
# The logger should have been called for insufficient resources
|
||||||
|
self.assertTrue(mock_logger.debug.called)
|
||||||
|
# Check the message contains expected content
|
||||||
|
call_args = mock_logger.debug.call_args[0][0]
|
||||||
|
self.assertIn("insufficient", call_args.lower())
|
||||||
|
|
||||||
|
def test_get_requests_insufficient_batch(self):
|
||||||
|
"""Test getting requests when batch size is insufficient."""
|
||||||
|
with patch.object(self.scheduler, "requests_not_empty"):
|
||||||
|
requests = self.scheduler.get_requests(
|
||||||
|
available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(requests, [])
|
||||||
|
|
||||||
@patch("time.time")
|
@patch("time.time")
|
||||||
@patch.object(dp_scheduler_module, "envs")
|
@patch.object(dp_scheduler_module, "envs")
|
||||||
def test_get_requests_no_requests_available(self, mock_envs, mock_time):
|
def test_get_requests_no_requests_available(self, mock_envs, mock_time):
|
||||||
|
|||||||
@@ -25,9 +25,6 @@ class DummyEngine:
|
|||||||
"""Dummy Engine class to simulate the actual Engine for testing."""
|
"""Dummy Engine class to simulate the actual Engine for testing."""
|
||||||
|
|
||||||
class ResourceManager:
|
class ResourceManager:
|
||||||
def __init__(self):
|
|
||||||
self.waiting = []
|
|
||||||
|
|
||||||
def available_batch(self):
|
def available_batch(self):
|
||||||
return 4
|
return 4
|
||||||
|
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
test_mode=True,
|
test_mode=True,
|
||||||
)
|
)
|
||||||
fd_config.init_cache_info()
|
fd_config.init_pd_info()
|
||||||
assert fd_config.register_info is not None
|
assert fd_config.register_info is not None
|
||||||
|
|
||||||
def test_fdconfig_postprocess_ports(self):
|
def test_fdconfig_postprocess_ports(self):
|
||||||
|
|||||||
@@ -487,7 +487,7 @@ class TestSleepWakeupBehavior(unittest.TestCase):
|
|||||||
runner.local_rank = 0
|
runner.local_rank = 0
|
||||||
runner.device_id = 1
|
runner.device_id = 1
|
||||||
runner.num_gpu_blocks = 8
|
runner.num_gpu_blocks = 8
|
||||||
runner.model = Mock(clear_grpah_opt_backend=Mock())
|
runner.model = Mock(clear_graph_opt_backend=Mock())
|
||||||
runner.clear_cache = Mock()
|
runner.clear_cache = Mock()
|
||||||
runner.initialize_kv_cache = Mock()
|
runner.initialize_kv_cache = Mock()
|
||||||
runner.capture_model = Mock()
|
runner.capture_model = Mock()
|
||||||
@@ -523,7 +523,7 @@ class TestSleepWakeupBehavior(unittest.TestCase):
|
|||||||
|
|
||||||
runner.sleep("weight,kv_cache")
|
runner.sleep("weight,kv_cache")
|
||||||
|
|
||||||
runner.model.clear_grpah_opt_backend.assert_called_once()
|
runner.model.clear_graph_opt_backend.assert_called_once()
|
||||||
runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once()
|
runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once()
|
||||||
runner.dynamic_weight_manager.clear_model_weight.assert_called_once()
|
runner.dynamic_weight_manager.clear_model_weight.assert_called_once()
|
||||||
runner.dynamic_weight_manager.clear_communication_group.assert_called_once()
|
runner.dynamic_weight_manager.clear_communication_group.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user