diff --git a/custom_ops/gpu_ops/get_attn_mask_q.cu b/custom_ops/gpu_ops/get_attn_mask_q.cu index 4ee814178b..a485d04f6b 100644 --- a/custom_ops/gpu_ops/get_attn_mask_q.cu +++ b/custom_ops/gpu_ops/get_attn_mask_q.cu @@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel( const int max_batch_size) { constexpr int VecSize = 4; const uint32_t tid = threadIdx.x, bid = blockIdx.x; - int startend_row_vec[4]; + int startend_row_vec[2]; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel( const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start; startend_row_vec[0] = this_batch_q_end; - startend_row_vec[1] = cu_seqlens_q[max_batch_size]; - startend_row_vec[2] = 0; - startend_row_vec[3] = this_batch_q_end; + // startend_row_vec[1] = cu_seqlens_q[max_batch_size]; + // startend_row_vec[2] = 0; + startend_row_vec[1] = this_batch_q_end; for (int this_batch_q_idx = this_batch_q_start; this_batch_q_idx < this_batch_q_end; ++this_batch_q_idx) { @@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel( : this_batch_q_idx - this_batch_q_start + kv_len - (this_batch_q_len); if (cache_k_idx <= append_mask_k_end) { - startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx); + startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx); // 可提前跳出循环 break; } } - reinterpret_cast(startend_row_indices_ptr + - cu_seqlens_k_idx * 4)[0] = - reinterpret_cast(startend_row_vec)[0]; + reinterpret_cast(startend_row_indices_ptr + + cu_seqlens_k_idx * 2)[0] = + reinterpret_cast(startend_row_vec)[0]; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); @@ -82,7 +82,7 @@ std::vector get_attn_mask_q( const paddle::optional& attn_mask_kv, const int kv_token_num) { paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor( - {1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place()); + {1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place()); const int max_batch_size = cu_seqlens_k.dims()[0] - 1; constexpr int block_size = 512; int grid_size = div_up(kv_token_num, block_size); @@ -123,7 +123,7 @@ std::vector> GetAttnMaskQInferShape( const std::vector& cu_seqlens_k_shape, const paddle::optional>& attn_mask_kv_shape, const int kv_token_num) { - return {{1, 1, kv_token_num, 4}}; + return {{1, 1, kv_token_num, 2}}; } PD_BUILD_STATIC_OP(get_attn_mask_q) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu index 18aa5d53d2..e620e914a2 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu @@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( int64_t* next_tokens, // [bs, tokens_per_step] const int* max_think_lens, // [bs] int* max_reply_lens, // [bs] - int64_t* step_idx, // [bs] + const int64_t* step_idx, // [bs] const int64_t* eos_token_ids, // [eos_len] int* limit_status, // [bs] int* accept_num, // [bs] @@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( int new_accept_num = original_accept_num; // 本 step 的 token offset 对应的绝对 step - const int64_t current_base_step = step_idx[bid] - original_accept_num + 1; + const int64_t current_base_step = step_idx[bid] + 1; for (int token_offset = 0; token_offset < original_accept_num; token_offset++) { @@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel( // inject_token_ids[0]) if (status == 0 && (current_step - 1) == - max_think_len) { // current_step - 1 是因为 speculate_verify 里 - // step_idx + 1 了 + max_think_len) { // current_step - 1 : 已输出 current_step-1 + // 个thinking token status = (inject_len > 0) ? 1 : done_status; } } else if (max_think_len == 0) { @@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel( } } - // 更新 step_idx / accept_num(被截断的 token 需要回退 - // step_idx) - const int discarded_tokens = original_accept_num - new_accept_num; - if (discarded_tokens > 0) { - step_idx[bid] -= discarded_tokens; - } - accept_num[bid] = new_accept_num; limit_status[bid] = status; max_reply_lens[bid] = max_reply_len; @@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength( const_cast(next_tokens.data()), max_think_lens.data(), const_cast(max_reply_lens.data()), - const_cast(step_idx.data()), + step_idx.data(), eos_token_ids.data(), const_cast(limit_status.data()), const_cast(accept_num.data()), diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu index ee364884e9..c6379387ef 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu @@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, const int64_t step_idx_now = step_idx[bid]; const int64_t min_token_limit = min_tokens[bid]; - const bool can_stop = (step_idx_now >= min_token_limit); + const bool can_stop = (step_idx_now + accept_num >= min_token_limit); if (!can_stop) return; if (!stop_flags[bid]) { - int accept_idx = 0; + /* + accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based) + accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾 + (pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。 + 为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1] + (当前轮最后一个 token),该 token 延迟到下一轮匹配。 + 循环范围:accept_num > 0 时为 [-1, accept_num-2]; + accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。 + */ + int accept_idx = -1; bool is_end = false; - // 遍历起始位置 - for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { + + // 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾 + // 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens + // 中的匹配。两者共享同一套从后向前匹配逻辑。 + int loop_end = (accept_num > 0) ? accept_num - 2 : -1; + for (; accept_idx <= loop_end && !is_end; accept_idx++) { if (step_idx_now + accept_idx + 1 < stop_seq_len) { #ifdef DEBUG_SPEC_STOP_SEQS printf("num %d < stop_seq_len %d\n", - step_idx_now - accept_num + accept_idx + 1, + step_idx_now + accept_idx + 1, stop_seq_len); #endif continue; } - // 遍历一个 stop_seqs + // 从后向前匹配 stop_seq 的每个 token for (int i = stop_seq_len - 1; i >= 0; --i) { int64_t cur_token_idx = -1; - // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 - if (stop_seq_len - 1 - i < accept_idx) { + int offset = stop_seq_len - 1 - i; + int accept_tokens_idx = accept_idx - offset; + + if (accept_tokens_idx >= 0) { #ifdef DEBUG_SPEC_STOP_SEQS printf( "AcceptTokens bid:%d. tid:%d, accept_idx:%d, " - "accept_token_idx: " - "%d\n", + "accept_token_idx: %d\n", bid, tid, accept_idx, - accept_idx - (stop_seq_len - 1 - i) - 1); + accept_tokens_idx); #endif - cur_token_idx = - accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]; + cur_token_idx = accept_tokens_now[accept_tokens_idx]; } else { + int pre_ids_idx = step_idx_now + accept_tokens_idx; #ifdef DEBUG_SPEC_STOP_SEQS printf( "PreIds bid:%d. tid:%d, step_idx_now:%ld. " - "accept_idx:%d. " - "pre_id_idx: %ld\n", + "accept_idx:%d. pre_id_idx: %d\n", bid, tid, step_idx_now, accept_idx, - step_idx_now - accept_num + accept_idx - - (stop_seq_len - 1 - i)); + pre_ids_idx); #endif - int pre_ids_idx = - step_idx_now + accept_idx - (stop_seq_len - 1 - i); - // EC3 - // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, - // 导致异常结束 - if (pre_ids_idx <= 0) { - break; - } + if (pre_ids_idx < 0) break; cur_token_idx = pre_ids_now[pre_ids_idx]; } #ifdef DEBUG_SPEC_STOP_SEQS @@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, } if (is_end) { #ifdef DEBUG_SPEC_STOP_SEQS - printf("bid:%d end with accept_idx %d", bid, accept_idx); + printf("bid:%d end with accept_idx %d\n", bid, accept_idx); #endif - - accept_nums[bid] = accept_idx; - accept_tokens_now[accept_idx - 1] = end_ids[0]; - // stop_flags[bid] = true; + // accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置 + accept_nums[bid] = accept_idx + 1; + accept_tokens_now[accept_idx] = end_ids[0]; } } } diff --git a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu index 94f71d6fd0..f7ab5daece 100644 --- a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu +++ b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu @@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder, int64_t *token_ids_all_now = &token_ids_all[batch_id * max_model_len + prompt_len]; int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens]; - int64_t base = cur_step_idx - output_len + 1; + int64_t base = cur_step_idx - output_len; for (int i = 0; i < output_len; i++) { token_ids_all_now[base + i] = output_ids[i]; } diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 692ad8cd02..e54ec8f879 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to enable the decode caches requests for preallocating resource "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # Batched token timeout in EP + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), + # Max pre-fetch requests number in PD "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 0a4cfd389d..ab625bd4d2 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # 是否启用 decode 缓存请求以预分配资源 "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # EP 中批处理 token 的超时时间 + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), + # PD 中最大预取请求数量 "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), diff --git a/examples/splitwise/start_v0_tp1.sh b/examples/splitwise/start_v0_tp1.sh deleted file mode 100644 index 40c2030113..0000000000 --- a/examples/splitwise/start_v0_tp1.sh +++ /dev/null @@ -1,113 +0,0 @@ -#!/bin/bash -set -e - -# Test splitwise deployment -# There are two methods for splitwise deployment: -# v0: using splitwise_scheduler or dp_scheduler (deprecated) -# v1: using local_scheduler + router - -# prepare environment -export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" -export FD_DEBUG=1 -export ENABLE_V1_KVCACHE_SCHEDULER=1 -export KVCACHE_GDRCOPY_FLUSH_ENABLE=1 - -SCRIPT_PATH=$(readlink -f "$0") -SCRIPT_DIR=$(dirname "$SCRIPT_PATH") -export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu) -echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}" -if [ -z "${KVCACHE_RDMA_NICS}" ]; then - echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh" - exit 1 -fi - -unset http_proxy && unset https_proxy -source ${SCRIPT_DIR}/utils.sh - -P_PORT=52400 -D_PORT=52500 -REDIS_PORT="${REDIS_PORT:-6379}" -LOG_DATE=$(date +%Y%m%d_%H%M%S) - -ports=( - $P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5)) - $D_PORT $((D_PORT + 1)) $((D_PORT + 2)) $((D_PORT + 3)) $((D_PORT + 4)) $((D_PORT + 5)) - $REDIS_PORT -) -check_ports "${ports[@]}" || { - echo "❌ Some ports are in use. Please release them." - exit 1 -} - -# start redis -if ! redis-cli -p ${REDIS_PORT} ping &>/dev/null; then - echo "Redis is not running. Starting redis-server..." - redis-server --daemonize yes --port ${REDIS_PORT} - sleep 1 -else - echo "Redis is already running." -fi -sleep 1 - -# start prefill -export CUDA_VISIBLE_DEVICES=0 -export FD_LOG_DIR="log/$LOG_DATE/prefill" -rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR} - -nohup python -m fastdeploy.entrypoints.openai.api_server \ - --model ${MODEL_NAME} \ - --port ${P_PORT} \ - --metrics-port $((P_PORT + 1)) \ - --engine-worker-queue-port $((P_PORT + 2)) \ - --cache-queue-port $((P_PORT + 3)) \ - --max-model-len 32768 \ - --num-gpu-blocks-override 1000 \ - --splitwise-role "prefill" \ - --cache-transfer-protocol "rdma" \ - --rdma-comm-ports $((P_PORT + 4)) \ - --pd-comm-port $((P_PORT + 5)) \ - --scheduler-name "splitwise" \ - --scheduler-host "127.0.0.1" \ - --scheduler-port ${REDIS_PORT} \ - --scheduler-ttl 9000 \ - 2>&1 >${FD_LOG_DIR}/nohup & - -wait_for_health ${P_PORT} - -# start decode -export CUDA_VISIBLE_DEVICES=1 -export FD_LOG_DIR="log/$LOG_DATE/decode" -rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR} - -nohup python -m fastdeploy.entrypoints.openai.api_server \ - --model ${MODEL_NAME} \ - --port ${D_PORT} \ - --metrics-port $((D_PORT + 1)) \ - --engine-worker-queue-port $((D_PORT + 2)) \ - --cache-queue-port $((D_PORT + 3)) \ - --max-model-len 32768 \ - --splitwise-role "decode" \ - --cache-transfer-protocol "rdma" \ - --rdma-comm-ports $((D_PORT + 4)) \ - --pd-comm-port $((D_PORT + 5)) \ - --scheduler-name "splitwise" \ - --scheduler-host "127.0.0.1" \ - --scheduler-port ${REDIS_PORT} \ - --scheduler-ttl 9000 \ - 2>&1 >${FD_LOG_DIR}/nohup & - -wait_for_health ${D_PORT} - - -# send request -sleep 10 # make sure server is registered to router -echo "send request..." -curl -X POST "http://0.0.0.0:${D_PORT}/v1/chat/completions" \ --H "Content-Type: application/json" \ --d '{ - "messages": [ - {"role": "user", "content": "hello"} - ], - "max_tokens": 20, - "stream": false -}' diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 6e7001bc18..1b37db9611 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -2009,13 +2009,13 @@ class FDConfig: and self.router_config and self.router_config.router ): - # For RL scenario: version.yaml will be required for models in future releases. + # For RL scenario, version.yaml is required for models # Temporarily enforce use router to be enabled. self.model_config.read_model_version() self.read_from_config() self.postprocess() - self.init_cache_info() + self.init_pd_info() if test_mode: return self.check() @@ -2348,18 +2348,17 @@ class FDConfig: logger.info("{:<20}:{:<6}{}".format(k, "", v)) logger.info("=============================================================") - def init_cache_info(self): + def init_pd_info(self): """ - initialize cache info + initialize info for pd deployment """ - # TODO: group the splitiwse params # There are two methods for splitwise deployment: # 1. v0 splitwise_scheduler or dp_scheduler - # 2. v1 local_scheduler + router + # 2. v1 local_scheduler + router (optional) self.splitwise_version = None if self.scheduler_config.name in ("splitwise", "dp"): self.splitwise_version = "v0" - elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router: + elif self.scheduler_config.name == "local": self.splitwise_version = "v1" # the information for registering this server to router or splitwise_scheduler diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 848926f963..d350350f85 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -592,10 +592,15 @@ class EngineArgs: raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1") if self.splitwise_role != "mixed": - if self.scheduler_name == "local" and self.router is None: + if self.scheduler_name == "splitwise": raise ValueError( - f"When using {self.splitwise_role} role and the {self.scheduler_name} " - f"scheduler, please provide --router argument." + "Setting scheduler_name as splitwise is not supported in pd deployment, " + "please use router as scheduler." + ) + if self.scheduler_name == "local" and self.router is None: + console_logger.warning( + f"Running {self.splitwise_role} role with {self.scheduler_name} " + f"scheduler without --router. Router registration and request routing will be disabled." ) if not ( diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 9817b7438e..9331bc3f2d 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -367,15 +367,6 @@ class EngineService: create=True, ) - engine_forward_signal_data = np.zeros([1], dtype=np.int32) - self.engine_forward_signal = IPCSignal( - name="engine_forward_signal", - array=engine_forward_signal_data, - dtype=np.int32, - suffix=current_suffix, - create=True, - ) - # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 worker_healthy_live_recorded_time_array = np.zeros( shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32 @@ -1037,29 +1028,26 @@ class EngineService: with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) try: - if not is_fetching: - # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. - try: + if self.engine_worker_queue.exist_tasks(): + time.sleep(0.001) + continue + if self.cfg.scheduler_config.splitwise_role != "mixed": + if not is_fetching: is_fetching = True get_request_pool.submit(_fetch_request) - except RuntimeError as e: - if "shutdown" in str(e): - self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") - break - else: - raise - if self.cfg.scheduler_config.splitwise_role != "mixed": - # Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished. - # Once the forward pass finishes, these accumulated requests can be scheduled in larger, - # more efficient batches. - if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0: - time.sleep(0.001) - continue + else: - # In mixed, todo: optimze cache swap, to decouple swap from scheduler - if self.engine_worker_queue.exist_tasks(): - time.sleep(0.001) - continue + if len(self.resource_manager.waiting) == 0 and (not is_fetching): + # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. + try: + is_fetching = True + get_request_pool.submit(_fetch_request) + except RuntimeError as e: + if "shutdown" in str(e): + self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") + break + else: + raise if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() @@ -1120,13 +1108,6 @@ class EngineService: elif not task.has_been_preempted_before: task.metrics.inference_start_time = time.time() self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) - else: - # When there are no actual tasks to schedule, send an empty task batch to EP workers. - # This helps EP workers barrier for syncing tasks not hang. - if self.cfg.parallel_config.enable_expert_parallel: - self.engine_worker_queue.put_tasks( - ([], self.resource_manager.real_bsz) - ) # Empty (as idle tasks for ep) # 4. Response error tasks if error_tasks: diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 81fe93e52a..5958b3d9bd 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -109,7 +109,7 @@ class ExpertService: if envs.FD_ENABLE_RETURN_TEXT: self.engine.create_data_processor() if self.cfg.scheduler_config.name == "dp": - self.cfg.init_cache_info() + self.cfg.init_pd_info() self.engine.scheduler.start(local_data_parallel_id) if ipc_signal_suffix is not None: @@ -122,7 +122,7 @@ class ExpertService: self.llm_logger.info(f"start expert service {local_data_parallel_id}") if self.cfg.scheduler_config.name == "splitwise": - self.cfg.init_cache_info() + self.cfg.init_pd_info() role = self.cfg.scheduler_config.splitwise_role host_ip = self.cfg.host_ip self.engine.scheduler.start(role, host_ip, self.cfg.register_info) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 45ec18aa1c..ae0e0c798b 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -927,6 +927,7 @@ class ResourceManagerV1(ResourceManager): if ( self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode" + and self.config.scheduler_config.splitwise_role != "prefill" ): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens @@ -1374,6 +1375,11 @@ class ResourceManagerV1(ResourceManager): self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position + + self.cache_manager.update_cache_blocks( + request, self.config.cache_config.block_size, request.need_prefill_tokens + ) + return True else: self._free_blocks(request) diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py index f4556a3679..7435dbce49 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -111,7 +111,7 @@ class ErnieX1ToolParser(ToolParser): ) ) return ExtractedToolCallInformation( - tools_called=True, + tools_called=len(tool_calls) > 0, tool_calls=tool_calls, ) except Exception: @@ -182,11 +182,13 @@ class ErnieX1ToolParser(ToolParser): logger.debug("attempting to close tool call, but no tool call") return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") - if diff: - if '"}' not in delta_text: + if diff is not None: + if "}" not in delta_text: + return None + end_loc = delta_text.rindex("}") + diff = delta_text[:end_loc] + if not diff: return None - end_loc = delta_text.rindex('"}') - diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not " "been streamed yet: %s", diff, @@ -248,15 +250,15 @@ class ErnieX1ToolParser(ToolParser): prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get("arguments") - if not cur_arguments and not prev_arguments: + if cur_arguments is None and prev_arguments is None: logger.debug("Skipping text %s - no arguments", delta_text) delta = None - elif not cur_arguments and prev_arguments: + elif cur_arguments is None and prev_arguments is not None: logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.") delta = None - elif cur_arguments and not prev_arguments: + elif cur_arguments is not None and prev_arguments is None: function_name = current_tool_call.get("name") match = re.search( r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)', @@ -265,6 +267,19 @@ class ErnieX1ToolParser(ToolParser): ) if match: cur_arguments_json = match.group(1) + # When tool_call_portion is complete JSON, the regex + # (.*) over-captures the outer closing brace of the + # tool call object. Strip it from both + # cur_arguments_json and delta_text, consistent with + # the both-have-arguments branch handling. + try: + json.loads(tool_call_portion) + if cur_arguments_json.endswith("}"): + cur_arguments_json = cur_arguments_json[:-1] + if delta_text.rstrip().endswith("}"): + delta_text = delta_text.rstrip()[:-1] + except Exception: + pass else: cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) @@ -287,7 +302,7 @@ class ErnieX1ToolParser(ToolParser): ) self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - elif cur_arguments and prev_arguments: + elif cur_arguments is not None and prev_arguments is not None: try: json.loads(tool_call_portion) is_complete_json = True diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index fa2472e3ef..ff0867fee2 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -145,6 +145,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), # Whether to enable the decode caches requests for preallocating resource "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # Batched token timeout in EP + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), # Max pre-fetch requests number in PD "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), # Enable or disable model caching. @@ -210,17 +212,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""), # Whether to enable low latency in mixed scenario "FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))), - # Whether to use phi FP8 quantization,if 1,use paddle default. - "FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))), - # Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc, - # intended for training alignment. Defaults to 0 (disabled). - "FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))), - # Whether to use phi MOE permute,if 1,use paddle op. - "FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))), - # Whether to use phi rms_norm,if 1,use paddle op. - "FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))), - # Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function - "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), # Reserve output blocks for decoding requests when schedule new prefill requests "FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16") @@ -262,8 +253,22 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # train-infer consistency, used in RL # Whether to align RoPE and moe gate precision with training "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), + # Whether to use phi FP8 quantization,if 1,use paddle default. + "FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))), + # Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc, + # intended for training alignment. Defaults to 0 (disabled). + "FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))), + # Whether to use phi MOE permute,if 1,use paddle op. + "FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))), + # Whether to use phi rms_norm,if 1,use paddle op. + "FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))), + # Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function + "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), + # Whether to enable FP8 quantization with pow2scale. + "FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))), } diff --git a/fastdeploy/model_executor/graph_optimization/decorator.py b/fastdeploy/model_executor/graph_optimization/decorator.py index 562164aae1..05ec79a495 100644 --- a/fastdeploy/model_executor/graph_optimization/decorator.py +++ b/fastdeploy/model_executor/graph_optimization/decorator.py @@ -92,7 +92,7 @@ class GraphOptWrapper: def __call__(self, **kwargs): return self.graph_opt_backend(**kwargs) - def clear_grpah_opt_backend(self, fd_config): + def clear_graph_opt_backend(self, fd_config): """ """ # TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights assert ( diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index bcffcd0bac..2549f9f5d8 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -95,7 +95,7 @@ def init_flash_attn_version(): logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.") if FLASH_ATTN_VERSION is None: - if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()): + if sm_version == 90 and 90 in paddle.version.cuda_archs(): FLASH_ATTN_VERSION = 3 logger.info("The current platform supports Flash Attention V3.") else: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 53247e2912..a16e5ccbe9 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -188,7 +188,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op( else: ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( ffn_out, - using_pow2_scale=not disable_ue8m0_cast, + using_pow2_scale=not disable_ue8m0_cast or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, using_ue8m0_scale=not disable_ue8m0_cast, ) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] @@ -355,7 +355,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): else: x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) @@ -581,7 +581,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): else: ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( ffn_out, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 + or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] @@ -773,7 +774,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): else: recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index d1db43a324..65d1d23b9b 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1247,7 +1247,7 @@ def python_op_fused_moe_kernel_paddle( x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False) else: x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, using_pow2_scale=False, output_scale_transpose=False + x, using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=False ) x_scale = x_scale[: x.shape[0]] @@ -1305,7 +1305,9 @@ def python_op_fused_moe_kernel_paddle( ) else: x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False + intermediate_cache2, + using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, + output_scale_transpose=False, ) x_scale = x_scale[: x_q.shape[0]] diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index a86170e072..ae37ca4596 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -343,7 +343,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): else: x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=True, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 4e75ba1d90..aa3f3af346 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -1306,9 +1306,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): ) return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class DeepSeekV3PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 4cc4306de5..bf8b3d9348 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -701,9 +701,9 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) + self.ernie.clear_graph_opt_backend(fd_config=self.fd_config) @ModelRegistry.register_model_class( diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index f4d70108e4..b6fa97ab0b 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -829,9 +829,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) + self.ernie.clear_graph_opt_backend(fd_config=self.fd_config) class Ernie4_5_VLPretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 7840107a04..fba36185a4 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -563,9 +563,9 @@ class Glm4MoeForCausalLM(ModelForCasualLM): return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Glm4MoePretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/glm4_mtp.py b/fastdeploy/model_executor/models/glm4_mtp.py index c28023202d..c700ea442c 100644 --- a/fastdeploy/model_executor/models/glm4_mtp.py +++ b/fastdeploy/model_executor/models/glm4_mtp.py @@ -369,3 +369,7 @@ class Glm4MTPForCausalLM(ModelForCasualLM): ) return hidden_states + + def clear_graph_opt_backend(self): + """Clear graph optimization backend, the captured cuda graph will be cleaned""" + self.model.clear_graph_opt_backend(fd_config=self.fd_config) diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 1bca09265e..1d0ce349bf 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -417,9 +417,9 @@ class Qwen2ForCausalLM(ModelForCasualLM): return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config) + self.qwen2.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen2PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index ebbf4f5aed..b0bcf9d588 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -341,9 +341,9 @@ class Qwen3ForCausalLM(ModelForCasualLM): return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py index a4d3f1579c..3f2a690424 100644 --- a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py +++ b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py @@ -382,9 +382,9 @@ class Qwen3VLForConditionalGeneration(ModelForCasualLM): return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3VLPretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 74ca37ab69..95adc7ad0e 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -453,9 +453,9 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3MoePretrainedModel(PretrainedModel): diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index cade135508..0caefd9ada 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -68,6 +68,7 @@ class RolloutModelConfig: routing_replay_config: str = None, load_choices: str = "default_v1", lm_head_fp32: bool = False, + moe_gate_fp32: bool = True, ): # Required parameters self.model = model_name_or_path @@ -121,6 +122,7 @@ class RolloutModelConfig: self.routing_replay_config = routing_replay_config self.load_choices = load_choices self.lm_head_fp32 = lm_head_fp32 + self.moe_gate_fp32 = moe_gate_fp32 def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py index 2339a077c9..f5b03eba30 100644 --- a/fastdeploy/scheduler/dp_scheduler.py +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -23,7 +23,7 @@ from typing import Dict, List, Optional from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledResponse from fastdeploy.scheduler.local_scheduler import LocalScheduler -from fastdeploy.utils import get_logger +from fastdeploy.utils import envs, get_logger class DPLocalScheduler(LocalScheduler): @@ -131,19 +131,52 @@ class DPLocalScheduler(LocalScheduler): Returns: List of Request objects ready for processing """ - # DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that. + if available_blocks <= reserved_output_blocks or batch < 1: + self.scheduler_logger.debug( + f"Scheduler's resource are insufficient: available_blocks={available_blocks} " + f"reserved_output_blocks={reserved_output_blocks} batch={batch} " + f"max_num_batched_tokens={max_num_batched_tokens}" + ) + return [] + required_total_blocks = 0 + current_prefill_tokens = 0 + start_batch_time = time.time() requests: List[Request] = [] with self.requests_not_empty: - batch_ids = self.requests_not_empty.wait_for( - lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1], - 0.005, - ) - if batch_ids: - for request_id in batch_ids: - request = self.requests[request_id] - requests.append(request.raw) - self.ids_read_cursor += 1 + while True: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], + 0.005, + ) + if batch_ids: + for request_id in batch_ids: + request = self.requests[request_id] + required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks + if required_total_blocks > available_blocks: + break + + requests.append(request.raw) + self.ids_read_cursor += 1 + start_batch_time = time.time() + if current_prefill_tokens > max_num_batched_tokens: + break + if len(requests) >= batch: + break + if ( + (current_prefill_tokens > max_num_batched_tokens) + or (len(requests) >= batch) + or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT) + ): + break + + if batch_ids: + if len(batch_ids) > 0 and len(requests) == 0: + self.scheduler_logger.debug( + f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}" + ) if len(requests) > 0: self.scheduler_logger.info( diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index e64e468b18..5c2f793fdb 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -53,9 +53,6 @@ class InternalAdapter: available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch()) available_block_num = self.engine.resource_manager.available_block_num() - unhandled_request_num = self.engine.scheduler.get_unhandled_request_num() - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - unhandled_request_num = max(unhandled_request_num, len(self.engine.resource_manager.waiting)) server_info = { "splitwise_role": self.cfg.scheduler_config.splitwise_role, "block_size": int(self.cfg.cache_config.block_size), @@ -65,7 +62,7 @@ class InternalAdapter: "available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num), "max_batch_size": int(available_batch_size), "max_input_token_num": self.cfg.model_config.max_model_len, - "unhandled_request_num": unhandled_request_num, + "unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(), "available_batch": int(self.engine.resource_manager.available_batch()), } return server_info diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 6218e58687..2bdbdb345b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -27,7 +27,7 @@ import paddle from paddle import nn from paddleformers.utils.log import logger -from fastdeploy.config import FDConfig +from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.request import ImagePosition, Request, RequestType from fastdeploy.model_executor.graph_optimization.utils import ( @@ -2110,6 +2110,12 @@ class GPUModelRunner(ModelRunnerBase): self._cached_sampler_output = sampler_output self._cached_post_process_event = post_process_event else: + if ( + self.fd_config.speculative_config.method == SpecMethod.MTP + and hasattr(self.proposer.model, "empty_input_forward") + and self.parallel_config.use_ep + ): + self._execute_empty_mtp_input(self.forward_meta) self._cached_model_output_data = None self._cached_sampler_output = None self._cached_post_process_event = None @@ -2403,6 +2409,16 @@ class GPUModelRunner(ModelRunnerBase): # 5.1. Async cpy post_process_event = paddle.device.cuda.create_event() + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + # If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished. + paddle.assign( + paddle.where( + self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1, + PREEMPTED_TOKEN_ID, + sampler_output.sampled_token_ids, + ), + sampler_output.sampled_token_ids, + ) # if not self.speculative_decoding: self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False) if self.speculative_decoding: @@ -2676,13 +2692,13 @@ class GPUModelRunner(ModelRunnerBase): """Dynamic model loader use to clear parameters use for RL""" # Clear CUDAGraph if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) if self.spec_method == SpecMethod.MTP: - self.proposer.model.clear_grpah_opt_backend() + self.proposer.model.clear_graph_opt_backend() self.proposer.clear_mtp_cache() self.clear_cache() paddle.device.cuda.empty_cache() @@ -2736,7 +2752,7 @@ class GPUModelRunner(ModelRunnerBase): logger.info("GPU model runner's weight is already sleeping, no need to sleep again!") return if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() if self.fd_config.parallel_config.enable_expert_parallel: self.dynamic_weight_manager.clear_deepep_buffer() self.dynamic_weight_manager.clear_model_weight() diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 28c769e116..d72538ba8d 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -2511,7 +2511,7 @@ class MetaxModelRunner(ModelRunnerBase): """Dynamic model loader use to clear parameters use for RL""" # Clear CUDAGraph if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index a81a0ee7d5..bad84cc96f 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -457,6 +457,9 @@ class PaddleDisWorkerProc: # TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one self.model_weights_signal = np.zeros([1], dtype=np.int32) while True: + # run eplb + self._run_eplb(tp_rank) + if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: self.model_weights_signal[0] = int(self.model_weights_status.value[0]) if self.ranks > 1: @@ -534,7 +537,7 @@ class PaddleDisWorkerProc: if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: logger.debug(f"Rank: {self.local_rank} Detected new requests.") - self.engine_forward_signal.value[0] = 1 + tasks, read_finish = self.task_queue.get_tasks() # Only one of all tp_size client will get read_finish == True. if read_finish: @@ -543,39 +546,25 @@ class PaddleDisWorkerProc: self.task_queue.read_finish_flag.set(0) else: self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY - # In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival. - # Only EP + DP prefill should barrier for data arrival. - # In mixed mode and decoder in D, we should not barrier to influence decoding. - if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill": - paddle.distributed.barrier(self.parallel_config.ep_group) req_dicts, control_reqs = [], [] - assert ( - len(tasks) > 0 - ), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}" - # In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue. - # tasks[0] contains two part, ([req1, ...] ,real_bsz) - # tasks[0][0] is [req1, ...] - # if empty batch is delived, eval(tasks[0][0]) should be False ([]), - # if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below. - if tasks[0][0]: - for req_dict, bsz in tasks: - if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest): - control_reqs.append(req_dict[0]) - else: - max_occupied_batch_index = int(bsz) - req_dicts.extend(req_dict) + for req_dict, bsz in tasks: + if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest): + control_reqs.append(req_dict[0]) + else: + max_occupied_batch_index = int(bsz) + req_dicts.extend(req_dict) - # todo: run control request async - if len(control_reqs) > 0: - logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.") - for control_req in control_reqs: - if self.parallel_config.use_ep: - self.cached_control_reqs.append(control_req) - logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}") - else: - self.run_control_method(control_req) - self._tp_barrier_wait() if tp_size > 1 else None + # todo: run control request async + if len(control_reqs) > 0: + logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.") + for control_req in control_reqs: + if self.parallel_config.use_ep: + self.cached_control_reqs.append(control_req) + logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}") + else: + self.run_control_method(control_req) + self._tp_barrier_wait() if tp_size > 1 else None if len(req_dicts) > 0: # Count prefill requests in current batch @@ -591,12 +580,6 @@ class PaddleDisWorkerProc: # Process prefill inputs self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index) - else: - if self.scheduler_config.splitwise_role == "prefill": - if tp_size > 1: - # Synchronize the signal for other workers - self._tp_barrier_wait() - continue # Let the ep group run control method synchronically if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep: @@ -611,7 +594,6 @@ class PaddleDisWorkerProc: and not self.worker.model_runner.not_need_stop() ): self._tp_barrier_wait() if tp_size > 1 else None - self.engine_forward_signal.value[0] = 0 time.sleep(0.001) continue @@ -634,9 +616,6 @@ class PaddleDisWorkerProc: if not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill() logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s") - # run eplb - self._run_eplb(tp_rank) - self.engine_forward_signal.value[0] = 0 if ( not self.parallel_config.use_ep diff --git a/requirements.txt b/requirements.txt index e662f07e97..2edef89b85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,5 +47,5 @@ aistudio_sdk p2pstore py-cpuinfo flashinfer-python-paddle -flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl +flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl transformers>=4.55.1,<5.0.0 diff --git a/tests/ci_use/metrics/test_metrics.py b/tests/ci_use/metrics/test_metrics.py index 0d5353780f..a54504c29b 100644 --- a/tests/ci_use/metrics/test_metrics.py +++ b/tests/ci_use/metrics/test_metrics.py @@ -214,29 +214,28 @@ def test_metrics_with_clear_and_reset(): """ Test the metrics monitoring endpoint. """ - pass # not stable, uncomment after bug fix - # metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" - # async_concurrency(n=10) + async_concurrency(n=10) - # time.sleep(0.3) + time.sleep(0.3) # ===== clear_load_weight ===== - # clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight" - # print("Calling clear_load_weight...") - # r = requests.get(clear_url, timeout=30) - # assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}" + clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight" + print("Calling clear_load_weight...") + r = requests.get(clear_url, timeout=30) + assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}" - # metrics = get_metrics_dict(metrics_url) - # running = metrics["fastdeploy:num_requests_running"] - # waiting = metrics["fastdeploy:num_requests_waiting"] + metrics = get_metrics_dict(metrics_url) + running = metrics["fastdeploy:num_requests_running"] + waiting = metrics["fastdeploy:num_requests_waiting"] - # print( - # "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):", - # running, - # "waiting:", - # waiting, - # ) + print( + "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):", + running, + "waiting:", + waiting, + ) # assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight" diff --git a/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py b/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py new file mode 100644 index 0000000000..efe702240e --- /dev/null +++ b/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py @@ -0,0 +1,455 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test splitwise deployment WITHOUT Router: +# use local_scheduler, manually construct disaggregate_info, +# send requests to both Prefill and Decode concurrently. +# ENABLE_V1_KVCACHE_SCHEDULER=1, use rdma to transfer cache. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time +import uuid + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + check_service_health, + clean, +) + +# Ports for PD disaggregation (no router port needed) +FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433)) +FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623)) + +# Prefill uses base ports, Decode uses base+1 +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_CONNECTOR_PORT, + FD_RDMA_PORT, + FD_API_PORT + 1, + FD_ENGINE_QUEUE_PORT + 1, + FD_METRICS_PORT + 1, + FD_CACHE_QUEUE_PORT + 1, + FD_CONNECTOR_PORT + 1, + FD_RDMA_PORT + 1, +] + + +def _build_disaggregate_info() -> dict: + """Build disaggregate_info manually, replicating Router's handle_splitwise_request logic.""" + host_ip = os.getenv("FD_HOST_IP", "127.0.0.1") + return { + "prefill_ip": host_ip, + "decode_ip": host_ip, + "prefill_connector_port": FD_CONNECTOR_PORT, + "decode_connector_port": FD_CONNECTOR_PORT + 1, + "decode_device_ids": ["1"], + "decode_rdma_ports": [FD_RDMA_PORT + 1], + "transfer_protocol": "rdma", + "decode_tp_size": 1, + } + + +def _send_pd_request(payload: dict, timeout: int = 120): + """ + Send request to both Prefill and Decode concurrently, + replicate Router's fan-out forwarding behavior. + Returns the Decode response (same as Router's return_result_url_index=-1). + """ + disaggregate_info = _build_disaggregate_info() + + # Inject disaggregate_info and request_id (same as Router) + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/chat/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions" + + headers = {"Content-Type": "application/json"} + + # For streaming, use requests with stream=True for decode response + if payload.get("stream", False): + # Send to both concurrently (same as Router's fan-out), stream from decode + import concurrent.futures + + def _post_stream(url): + return requests.post(url, headers=headers, json=payload, timeout=timeout, stream=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + prefill_future = executor.submit(_post_stream, prefill_url) + decode_future = executor.submit(_post_stream, decode_url) + # Return decode streaming response immediately + decode_resp = decode_future.result() + # Consume prefill response in background (don't block) + try: + prefill_future.result(timeout=timeout) + except Exception: + pass + return decode_resp + else: + # Non-streaming: send to both, return decode response + import concurrent.futures + + def _post(url): + return requests.post(url, headers=headers, json=payload, timeout=timeout) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + prefill_future = executor.submit(_post, prefill_url) + decode_future = executor.submit(_post, decode_url) + # Wait for both, return decode response + decode_resp = decode_future.result() + # Also check prefill didn't error (but don't block on it) + try: + prefill_future.result(timeout=5) + except Exception: + pass + return decode_resp + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts Prefill and Decode instances WITHOUT Router + - Waits for both to be healthy + - Tears down after all tests finish + """ + print("Pre-test port cleanup...") + clean(PORTS_TO_CLEAN) + + print("log dir clean") + if os.path.exists("log_prefill") and os.path.isdir("log_prefill"): + shutil.rmtree("log_prefill") + if os.path.exists("log_decode") and os.path.isdir("log_decode"): + shutil.rmtree("log_decode") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") + else: + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") + + base_log_dir = os.getenv("FD_LOG_DIR", "log") + + # Prefill instance + print("start prefill...") + env_prefill = os.environ.copy() + env_prefill["CUDA_VISIBLE_DEVICES"] = "0" + env_prefill["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_prefill") + + prefill_log_path = "prefill.log" + prefill_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "8192", + "--splitwise-role", + "prefill", + "--cache-transfer-protocol", + "rdma", + "--rdma-comm-ports", + str(FD_RDMA_PORT), + "--pd-comm-port", + str(FD_CONNECTOR_PORT), + # No --router flag + ] + + with open(prefill_log_path, "w") as logfile: + process_prefill = subprocess.Popen( + prefill_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, + env=env_prefill, + ) + time.sleep(1) + + # Decode instance + print("start decode...") + env_decode = os.environ.copy() + env_decode["CUDA_VISIBLE_DEVICES"] = "1" + env_decode["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_decode") + + decode_log_path = "decode.log" + decode_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT + 1), + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT + 1), + "--metrics-port", + str(FD_METRICS_PORT + 1), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT + 1), + "--max-model-len", + "8192", + "--splitwise-role", + "decode", + "--cache-transfer-protocol", + "rdma", + "--rdma-comm-ports", + str(FD_RDMA_PORT + 1), + "--pd-comm-port", + str(FD_CONNECTOR_PORT + 1), + # No --router flag + ] + + with open(decode_log_path, "w") as logfile: + process_decode = subprocess.Popen( + decode_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, + env=env_decode, + ) + + # Wait up to 300 seconds for both instances to be healthy + for _ in range(60): + prefill_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT}") + decode_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT + 1}") + if prefill_healthy and decode_healthy: + print("Prefill and decode servers are both online") + break + time.sleep(5) + else: + print("[TIMEOUT] Servers failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError("Prefill or decode server did not start") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + print(f"Prefill server (pid={process_prefill.pid}) terminated") + print(f"Decode server (pid={process_decode.pid}) terminated") + except Exception as e: + print(f"Failed to terminate server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the Decode API endpoint URL (where final responses come from). + """ + return f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions" + + +@pytest.fixture +def headers(): + return {"Content-Type": "application/json"} + + +def get_stream_chunks(response): + """Parse streaming response into chunk list.""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"Parse failed: {e}, line: {line}") + else: + print(f"Request failed, status: {response.status_code}") + print("Response:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """Test streaming chat with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = _send_pd_request(payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """Test non-streaming chat with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = _send_pd_request(payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """Test streaming completion (non-chat) with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + # Send to /v1/completions endpoints + disaggregate_info = _build_disaggregate_info() + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions" + headers = {"Content-Type": "application/json"} + + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120) + decode_future = executor.submit( + requests.post, decode_url, json=payload, headers=headers, timeout=120, stream=True + ) + response = decode_future.result() + + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """Test non-streaming completion (non-chat) with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + # Send to /v1/completions endpoints + disaggregate_info = _build_disaggregate_info() + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions" + headers = {"Content-Type": "application/json"} + + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120) + decode_future = executor.submit(requests.post, decode_url, json=payload, headers=headers, timeout=120) + response = decode_future.result().json() + + usage = response["usage"] + result = response["choices"][0]["text"] + print("Decode Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 551f93babd..8778e2013e 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -1457,9 +1457,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): task.metrics.scheduler_recv_req_time = time.time() eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")])) @@ -1493,9 +1491,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): task.metrics.scheduler_recv_req_time = time.time() eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [])) @@ -1526,9 +1522,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): task.metrics.scheduler_recv_req_time = time.time() eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)])) diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py index 01a68c2380..0dbda0c35e 100644 --- a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py @@ -60,6 +60,50 @@ class TestErnieX1ToolParser(unittest.TestCase): return ErnieX1ToolParser(tokenizer=DummyTokenizer()) + def _simulate_streaming(self, parser, deltas): + """Simulate a multi-step streaming flow. + + Args: + parser: ErnieX1ToolParser instance + deltas: list of delta text strings, each representing one streaming step + + Returns: + list of results from each extract_tool_calls_streaming call + """ + results = [] + previous_text = "" + token_id = 0 + previous_token_ids = [] + + for delta in deltas: + current_text = previous_text + delta + # When delta contains plus more content, use 2 tokens + # so that the parser extracts tool_call_portion (line 163-164) + if "" in delta and delta != "": + n_tokens = 2 + else: + n_tokens = 1 + + delta_token_ids = list(range(token_id + 1, token_id + 1 + n_tokens)) + token_id += n_tokens + current_token_ids = previous_token_ids + delta_token_ids + + result = parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta, + previous_token_ids, + current_token_ids, + delta_token_ids, + self.dummy_request, + ) + results.append(result) + + previous_text = current_text + previous_token_ids = list(current_token_ids) + + return results + # ==================== __init__ tests (lines 60-81) ==================== def test_init_sets_tokens_and_ids(self): @@ -116,6 +160,14 @@ class TestErnieX1ToolParser(unittest.TestCase): self.assertTrue(result.tools_called) self.assertEqual(result.tool_calls[0].function.arguments, "{}") + def test_extract_tool_calls_empty_arguments(self): + """Cover: tool call with explicit empty arguments {}""" + output = '{"name": "fn", "arguments": {}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(result.tool_calls[0].function.name, "fn") + self.assertEqual(result.tool_calls[0].function.arguments, "{}") + def test_extract_tool_calls_nested_arguments(self): """Cover regex with nested braces in arguments""" output = '{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}' @@ -182,38 +234,24 @@ class TestErnieX1ToolParser(unittest.TestCase): def test_streaming_end_token_in_delta(self): """Cover lines 149-156: appears in delta""" parser = self._new_parser() - # First, start a tool call - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "', # start + name + args key + "v", # args value + '"}}', # close with end token in delta + ], ) - # Now stream arguments - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - ', "arguments": {"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Close with end token in delta - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "v"}}', - '"}}', - [1, 10, 20], - [1, 10, 20, 2], - [2], - self.dummy_request, - ) - # Should handle end token - self.assertTrue(result is None or isinstance(result, DeltaMessage)) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args branch, regex extracts '{"k": "v' as arguments_delta + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, '{"k": "v') + # Step 3: end token in delta triggers close handling + # delta before is '"}}', close branch: rindex('}')=2, diff='"}' + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.arguments, '"}') # --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) --- @@ -255,37 +293,29 @@ class TestErnieX1ToolParser(unittest.TestCase): def test_streaming_continue_tool_call_no_name_yet(self): """Cover lines 174-176, 220-222: partial JSON without name yet""" parser = self._new_parser() - # Start tool call - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Continue with partial content, no name parseable yet - result = parser.extract_tool_calls_streaming( - "", - '{"na', - '{"na', - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"na', # partial content, no name yet + ], ) - self.assertIsNone(result) + self.assertIsNone(results[0]) + self.assertIsNone(results[1]) def test_streaming_continue_tool_call_with_name(self): """Cover lines 174-176, 223-235: name becomes available""" parser = self._new_parser() - # Start tool call - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Name appears - result = parser.extract_tool_calls_streaming( - "", - '{"name": "get_weather"', - '{"name": "get_weather"', - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"name": "get_weather"', # name appears + ], ) - self.assertIsNotNone(result) - self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertIsNone(results[0]) + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.name, "get_weather") self.assertTrue(parser.current_tool_name_sent) # --- Lines 236-237: name not sent and function_name is None --- @@ -293,18 +323,14 @@ class TestErnieX1ToolParser(unittest.TestCase): def test_streaming_no_function_name(self): """Cover lines 236-237: parsed JSON has no 'name' field""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Send JSON without name field - result = parser.extract_tool_calls_streaming( - "", - '{"arguments": {"k": "v"}}', - '{"arguments": {"k": "v"}}', - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"arguments": {"k": "v"}}', # JSON without name field + ], ) - self.assertIsNone(result) + self.assertIsNone(results[1]) # --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) --- @@ -333,9 +359,9 @@ class TestErnieX1ToolParser(unittest.TestCase): parser.streamed_args_for_tool = [""] parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] result = parser.extract_tool_calls_streaming( - '{"name":"fn","arguments":{"k":"v"}}', + '{"name":"fn","arguments":{"k":"v"', '{"name":"fn","arguments":{"k":"v"}}', - '"}}', + "}}", [1, 10], [1, 10, 2], [2], @@ -343,9 +369,14 @@ class TestErnieX1ToolParser(unittest.TestCase): ) self.assertIsNotNone(result) self.assertIsNotNone(result.tool_calls) + self.assertEqual(result.tool_calls[0].function.arguments, "}") - def test_streaming_close_with_diff_no_end_marker(self): - """Cover lines 184-185: close with arguments but no '"}' in delta_text""" + def test_streaming_text_after_completed_tool_call(self): + """Cover lines 143-147: text content after a completed tool call. + + When start==end counts, prev_end==cur_end, and end_token not in delta, + the parser treats delta as regular text content. + """ parser = self._new_parser() parser.current_tool_id = 0 parser.current_tool_name_sent = True @@ -353,7 +384,7 @@ class TestErnieX1ToolParser(unittest.TestCase): parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] # Simulate end token in delta but without '"}' pattern # We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta - # so that we enter the elif at 178 + # so that we enter the text-content branch at line 143-147 result = parser.extract_tool_calls_streaming( '{"name":"fn","arguments":{"k":"v"}}', '{"name":"fn","arguments":{"k":"v"}} text', @@ -363,8 +394,9 @@ class TestErnieX1ToolParser(unittest.TestCase): [30], self.dummy_request, ) - # balanced counts, prev_end==cur_end, end not in delta -> returns content (line 147) - self.assertIsInstance(result, DeltaMessage) + # balanced counts, prev_end==cur_end, end not in delta -> returns content (line 149) + self.assertIsNotNone(result) + self.assertEqual(result.content, " text") def test_streaming_close_no_arguments(self): """Cover lines 182-183: close branch where prev arguments is None/empty""" @@ -382,8 +414,126 @@ class TestErnieX1ToolParser(unittest.TestCase): [2], self.dummy_request, ) - # diff is None (no arguments), so falls through to partial_json_parser - self.assertTrue(result is None or isinstance(result, DeltaMessage)) + # diff is None (no arguments key in prev), falls through to partial_json_parser + # parses complete JSON, cur_args=None, prev_args=None -> no-args -> delta=None + self.assertIsNone(result) + + def test_streaming_close_with_empty_dict_arguments(self): + """Regression: close branch must handle arguments={} (empty dict). + + Before fix, `if diff:` was False for empty dict {}, so the close + logic was skipped. After fix, `if diff is not None:` correctly + enters the branch. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # start + name + args key + "{}", # empty dict value + "}", # outer close brace + "", # end token + ], + ) + # Step 1: name sent + # Step 2: first-args, cur_args={} is not None, prev_args=None + # Without fix: not {} == True -> no-args branch -> returns None + # With fix: enters first-args -> streams "{}" -> DeltaMessage + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[1].tool_calls) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + + def test_streaming_empty_arguments_with_outer_brace_in_same_token(self): + """Regression: when arguments={} and outer } arrive in the same token '{}}', + regex (.*) over-captures the outer brace, producing '{}}'. + + Real production data showed arguments='{}}}' for get_default_weather + with empty arguments. This test reproduces that exact scenario. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "get_default_weather", "arguments": ', # start + name + args key + "{}}", # empty args + outer close brace in same token + "", # end token + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "get_default_weather") + # Step 2: first-args branch, tool_call_portion is complete JSON + # regex (.*) captures '{}}' but fix strips outer '}' -> '{}' + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + # Step 3: end token, close branch + # diff = prev_arguments = {} (not None), delta_text = '' (empty after split) + # '}' not in '' -> returns None + self.assertIsNone(results[2]) + + def test_streaming_close_with_number_ending_arguments(self): + """Regression: close branch must flush remaining args ending with number. + + Before fix, '"}' not in delta was True for numbers, causing return None. + After fix, rindex('}') correctly finds the closing brace. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"count": ', # start + name + args key + "123", # number value + "}}", # close braces + end token + ], + ) + # Step 1: name sent + # Step 2: first-args, streams {"count": 123 + # Step 3: close branch flushes remaining "}" + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"count": 123}') + + def test_streaming_close_with_boolean_ending_arguments(self): + """Regression: close branch must flush remaining args ending with boolean.""" + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"flag": ', # start + args key + "true", # boolean value + "}}", # close + end token + ], + ) + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"flag": true}') + + def test_streaming_close_with_nested_object_ending(self): + """Regression: close branch must flush remaining args ending with nested '}'.""" + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"nested": {"a": ', # start + args key + "1", # nested value + "}}}", # close all + end token + ], + ) + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"nested": {"a": 1}}') # --- Lines 202-206: else branch (cur_start < cur_end, edge case) --- @@ -404,23 +554,21 @@ class TestErnieX1ToolParser(unittest.TestCase): def test_streaming_malformed_json(self): """Cover lines 213-215: MalformedJSON from partial parser""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Feed badly formed content - result = parser.extract_tool_calls_streaming( - "", - "{{{", - "{{{", - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + "{{{", # badly formed content + ], ) - self.assertIsNone(result) + self.assertIsNone(results[1]) def test_streaming_json_decode_error(self): """Cover lines 216-218: JSONDecodeError from partial parser""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) + # Step 1: start tool call normally + self._simulate_streaming(parser, [""]) + # Step 2: mock partial_json_parser to throw ValueError with patch( "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", side_effect=ValueError("bad json"), @@ -430,8 +578,8 @@ class TestErnieX1ToolParser(unittest.TestCase): "bad", "bad", [1], - [1, 10], - [10], + [1, 2], + [2], self.dummy_request, ) self.assertIsNone(result) @@ -469,30 +617,17 @@ class TestErnieX1ToolParser(unittest.TestCase): def test_streaming_first_arguments_with_regex_match(self): """Cover lines 243-244, 257-286: first arguments appear, regex matches""" parser = self._new_parser() - # Start tool call and send name - parser.extract_tool_calls_streaming( - "", - '{"name": "get_weather"', - '{"name": "get_weather"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "get_weather", "arguments": {"location": "', # start + name + args key + "bei", # args value + ], ) - # Now stream arguments (first time) - # Key must be complete (closing quote) so partial_json_parser returns truthy arguments. - # delta must be a substring of the regex-extracted arguments portion (after "arguments":). - result = parser.extract_tool_calls_streaming( - '{"name": "get_weather"', - '{"name": "get_weather", "arguments": {"location": "bei', - '"bei', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - self.assertIsNotNone(result) - self.assertIsNotNone(result.tool_calls) + # Step 1: name sent + # Step 2: first-args, regex finds "bei" in '{"location": "bei' + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, '{"location": "bei') def test_streaming_first_arguments_no_regex_match(self): """Cover lines 266-267: regex doesn't match, fallback to json.dumps""" @@ -522,67 +657,119 @@ class TestErnieX1ToolParser(unittest.TestCase): self.assertIsNotNone(result.tool_calls) def test_streaming_first_arguments_delta_not_in_json(self): - """Cover lines 271-272: delta_text not found in cur_arguments_json""" + """Cover lines 275-276: delta_text not found in cur_arguments_json, returns None. + When delta contains the arguments key itself (e.g. ', "arguments": {'), + regex extracts cur_arguments_json='{' but delta ', "arguments": {' is not in '{'. + """ parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn"', # start + partial name + ', "arguments": {', # delta introduces arguments key + open brace + ], ) - # Delta text that doesn't appear in the arguments JSON - result = parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v"}}', - "ZZZZZ", - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - self.assertIsNone(result) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args branch, regex extracts cur_arguments_json='{' + # delta_text=', "arguments": {' is NOT in '{' -> returns None + self.assertIsNone(results[1]) # --- Lines 249-251: no cur_arguments and no prev_arguments --- def test_streaming_no_arguments_at_all(self): """Cover lines 249-251: both cur and prev arguments are empty/None""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # Continue with name only, no arguments - result = parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn"}', - "}", - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn"', # start + name + "}", # close JSON, no arguments + ], ) # prev_arguments=None, cur_arguments=None -> delta=None - # then prev_tool_call_arr updated and returns delta (which is None) - self.assertIsNone(result) + self.assertIsNone(results[1]) + + def test_streaming_empty_dict_arguments_not_skipped(self): + """Regression: arguments={} (empty dict) must not be treated as no arguments. + + Empty dict is falsy in Python (`not {} == True`). Before the fix, + this caused empty arguments to enter the no-arguments branch, + silently dropping them during streaming. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # start + name + args key + "{}", # empty dict value + "}", # outer close brace + ], + ) + # Step 1: name sent + # Step 2: cur_arguments={} (not None), prev_arguments=None + # With fix: enters first-arguments branch -> streams "{}" + # Without fix: not {} == True -> no-arguments branch -> delta=None + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[1].tool_calls) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + + def test_streaming_empty_dict_prev_arguments_not_reset(self): + """Regression: prev_arguments={} must not be treated as no arguments. + + When prev has {} and cur has a non-empty dict, the code should enter + the both-have-arguments branch, not the first-arguments branch. + + This scenario (arguments growing from {} to non-empty) is hard to + produce naturally, so we build up state through a real flow then + verify the branch behavior with one additional call. + """ + parser = self._new_parser() + # Build up state naturally: prev_tool_call_arr gets arguments={} + self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # name + args key + "{}", # empty dict value + "}", # outer close + ], + ) + # Verify state is correct + self.assertEqual(parser.prev_tool_call_arr[0].get("arguments"), {}) + + # Now test: if more argument data arrives, prev_args={} should be + # treated as "not None" -> enters both-have-arguments branch + # Without fix: not {} == True -> first-arguments branch (wrong) + result = parser.extract_tool_calls_streaming( + '{"name": "fn", "arguments": {"k": "v', + '{"name": "fn", "arguments": {"k": "val', + "al", + [1, 2, 3], + [1, 2, 3, 4], + [4], + self.dummy_request, + ) + # both-have-arguments branch: delta_text="al" streamed as arguments + self.assertIsNotNone(result) + self.assertEqual(result.tool_calls[0].function.arguments, "al") # --- Lines 253-255: cur_arguments reset (impossible branch) --- def test_streaming_arguments_reset_mid_call(self): - """Cover lines 253-255: prev has arguments but cur doesn't (impossible case)""" + """Cover lines 253-255: prev has arguments but cur doesn't (impossible case). + + This is an edge case that shouldn't happen in normal flow, but tests + defensive handling when partial parser returns no arguments after + previously having them. + """ parser = self._new_parser() parser.current_tool_id = 0 parser.current_tool_name_sent = True parser.streamed_args_for_tool = [""] + # Simulate state where prev already had arguments parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] - # Feed content where cur has no arguments but prev does + # Mock parser to return no arguments (simulating the impossible reset) with patch( "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", return_value={"name": "fn"}, @@ -591,9 +778,9 @@ class TestErnieX1ToolParser(unittest.TestCase): '{"name": "fn", "arguments": {"k": "v"', '{"name": "fn", "arguments": {"k": "v"}', '"}', - [1, 10], - [1, 10, 20], - [20], + [1, 2], + [1, 2, 3], + [3], self.dummy_request, ) self.assertIsNone(result) @@ -603,110 +790,48 @@ class TestErnieX1ToolParser(unittest.TestCase): def test_streaming_incremental_arguments_incomplete(self): """Cover lines 288-314: both prev and cur have arguments, JSON incomplete""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v', # start + name + first args + "a", # establishes prev_args + "l", # incremental: both-have-args + ], ) - # First arguments - delta must appear in regex-extracted arguments portion - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - '{"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # More argument tokens (both prev and cur have arguments now) - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "val', - "al", - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - self.assertIsNotNone(result) - self.assertEqual(result.tool_calls[0].function.arguments, "al") + # Step 1: name sent + # Step 2: first-args branch + # Step 3: both-have-args branch, streams "l" + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.arguments, "l") def test_streaming_incremental_arguments_complete_json(self): """Cover lines 289-305: complete JSON with trailing }""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v', # start + name + first args + "a", # establishes prev_args + '"}}', # completes JSON + ], ) - # First arguments - delta must appear in regex-extracted arguments portion - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - '{"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Complete with closing braces - both prev and cur have arguments - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "v"}}', - '"}}', - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - # is_complete_json=True, delta ends with }, should strip trailing } - # After strip: '"' which is not empty, so returns DeltaMessage - self.assertIsNotNone(result) - self.assertIsInstance(result, DeltaMessage) + # Step 3: both-have-args, complete JSON, strips trailing } -> streams '"}' + self.assertIsNotNone(results[2]) + self.assertIsInstance(results[2], DeltaMessage) def test_streaming_incremental_arguments_complete_empty_delta(self): """Cover lines 304-305: complete JSON where delta becomes empty after strip""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v"', # start + name + first args + "}", # inner close (establishes prev_args) + "}", # outer close: both-have-args, complete, delta stripped to "" + ], ) - # First arguments with proper delta - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v"}', - '{"k": "v"}', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Send just the outer closing brace - # tool_call_portion becomes complete JSON, delta="}" stripped to "" -> return None - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v"}', - '{"name": "fn", "arguments": {"k": "v"}}', - "}", - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - # is_complete_json=True, delta="}" -> stripped to "" -> return None - self.assertIsNone(result) + # Step 3: is_complete_json=True, delta="}" -> stripped to "" -> return None + self.assertIsNone(results[2]) # --- Lines 316-319: prev_tool_call_arr update branches --- @@ -759,95 +884,71 @@ class TestErnieX1ToolParser(unittest.TestCase): def test_streaming_full_flow(self): """Integration test: simulate a full streaming tool call flow""" parser = self._new_parser() - req = self.dummy_request - - # Step 1: text before tool call - r = parser.extract_tool_calls_streaming("", "thinking", "thinking", [], [], [], req) - self.assertEqual(r.content, "thinking") - - # Step 2: tool_call start token - r = parser.extract_tool_calls_streaming("thinking", "thinking", "", [], [1], [1], req) - self.assertIsNone(r) - - # Step 3: function name appears - r = parser.extract_tool_calls_streaming( - "thinking", - 'thinking{"name": "search"', - '{"name": "search"', - [1], - [1, 10], - [10], - req, + results = self._simulate_streaming( + parser, + [ + "thinking", # Step 1: text before tool call + "", # Step 2: tool_call start token + '{"name": "search", "arguments": {"query": "', # Step 3: name + args key + "test", # Step 4: args value + " data", # Step 5: more args + ], ) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.name, "search") - - # Step 4: arguments start - delta must appear in regex-extracted arguments portion - r = parser.extract_tool_calls_streaming( - 'thinking{"name": "search"', - 'thinking{"name": "search", "arguments": {"query": "test', - '{"query": "test', - [1, 10], - [1, 10, 20], - [20], - req, - ) - self.assertIsNotNone(r) - + # Step 1: plain text + self.assertEqual(results[0].content, "thinking") + # Step 2: start token -> None + self.assertIsNone(results[1]) + # Step 3: name sent + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.name, "search") + # Step 4: first arguments + self.assertIsNotNone(results[3]) + self.assertEqual(results[3].tool_calls[0].function.arguments, '{"query": "test') # Step 5: more arguments - r = parser.extract_tool_calls_streaming( - 'thinking{"name": "search", "arguments": {"query": "test', - 'thinking{"name": "search", "arguments": {"query": "test data', - " data", - [1, 10, 20], - [1, 10, 20, 30], - [30], - req, + self.assertIsNotNone(results[4]) + self.assertEqual(results[4].tool_calls[0].function.arguments, " data") + + def test_streaming_empty_arguments_full_flow(self): + """Integration: streaming tool call with arguments={} must not lose arguments. + + Simulates a complete streaming flow where the tool call has empty + arguments. Verifies the name is sent and arguments are streamed. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # Step 1: start + name + args key + "{}", # Step 2: empty dict value + "}", # Step 3: outer close + "", # Step 4: end token + ], ) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.arguments, " data") + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args with cur_args={}, streams "{}" + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + # Step 4: close branch, delta_text="" after stripping + # diff={} is not None, but "}" not in "" -> return None + self.assertIsNone(results[2]) + self.assertIsNone(results[3]) def test_streaming_multiple_tool_calls(self): """Integration test: two tool calls in one response""" parser = self._new_parser() - req = self.dummy_request - - # First tool call - parser.extract_tool_calls_streaming( - "", - '{"name": "fn1"', - '{"name": "fn1"', - [], - [1, 10], - [1, 10], - req, - ) - self.assertEqual(parser.current_tool_id, 0) - - # Close first tool - parser.extract_tool_calls_streaming( - '{"name": "fn1"', - '{"name": "fn1"}', - "}", - [1, 10], - [1, 10, 2], - [2], - req, - ) - - # Second tool call - r = parser.extract_tool_calls_streaming( - '{"name": "fn1"}', - '{"name": "fn1"}{"name": "fn2"', - '{"name": "fn2"', - [1, 10, 2], - [1, 10, 2, 1, 20], - [1, 20], - req, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn1"', # First tool: start + name + "}", # Close first tool + '{"name": "fn2"', # Second tool: start + name + ], ) self.assertEqual(parser.current_tool_id, 1) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.name, "fn2") + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.name, "fn2") if __name__ == "__main__": diff --git a/tests/graph_optimization/test_cuda_graph_recapture.py b/tests/graph_optimization/test_cuda_graph_recapture.py index 1a28c0731b..902bcf182f 100644 --- a/tests/graph_optimization/test_cuda_graph_recapture.py +++ b/tests/graph_optimization/test_cuda_graph_recapture.py @@ -91,10 +91,10 @@ class TestModel1(paddle.nn.Layer): return sublayer2_output - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """ """ - self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config) - self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config) + self.sublayer1.clear_graph_opt_backend(fd_config=self.fd_config) + self.sublayer2.clear_graph_opt_backend(fd_config=self.fd_config) class TestCUDAGrpahRecapture(unittest.TestCase): @@ -152,7 +152,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase): # Destroy print_gpu_memory_use("before destroy", 0) - self.test_model1.clear_grpah_opt_backend() + self.test_model1.clear_graph_opt_backend() print_gpu_memory_use("after destroy", 0) def recapture_and_replay(self, input_tensor1, forward_meta1): @@ -168,7 +168,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase): # Destroy print_gpu_memory_use("before destroy", 0) - self.test_model1.clear_grpah_opt_backend() + self.test_model1.clear_graph_opt_backend() print_gpu_memory_use("after destroy", 0) diff --git a/tests/model_executor/test_thinking_budget.py b/tests/model_executor/test_thinking_budget.py index 139b685995..4cc5a1563b 100644 --- a/tests/model_executor/test_thinking_budget.py +++ b/tests/model_executor/test_thinking_budget.py @@ -111,7 +111,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase): self._fdconfig_patches = [ patch.object(FDConfig, "read_from_config", return_value=None), patch.object(FDConfig, "postprocess", return_value=None), - patch.object(FDConfig, "init_cache_info", return_value=None), + patch.object(FDConfig, "init_pd_info", return_value=None), patch.object(FDConfig, "check", return_value=None), ] for patcher in self._fdconfig_patches: diff --git a/tests/operators/test_speculate_set_stop_value_multi_seqs.py b/tests/operators/test_speculate_set_stop_value_multi_seqs.py index 45d8a0ef34..aa048560c3 100644 --- a/tests/operators/test_speculate_set_stop_value_multi_seqs.py +++ b/tests/operators/test_speculate_set_stop_value_multi_seqs.py @@ -42,7 +42,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: return paddle_inputs -def run_kernel(paddle_inputs, inputs): +def run_kernel(paddle_inputs): """Call the CUDA kernel.""" speculate_set_stop_value_multi_seqs( paddle_inputs["accept_tokens"], @@ -137,7 +137,18 @@ def gen_inputs( def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]: - """Python reference — must match CUDA kernel logic exactly.""" + """Python reference — must match CUDA kernel logic exactly. + + token_ids_all 布局 (新 step_idx 语义): + pre_ids_now[k] = 第 k 个 output token (k >= 0, 0-indexed) + 最后一个 output token 在 pre_ids_now[step_idx - 1] + step_idx = 历史已生成的 token 数量 + + 核心设计: + 1. accept_idx 从 -1 开始,-1 表示检查 pre_ids 末尾(上一轮延迟的情况) + 2. 主循环检查 accept_idx <= accept_num - 2 + 3. 匹配成功时: 保留 stop_seq 所有 token,在其后追加 eos + """ accept_tokens = inputs["accept_tokens"].copy() accept_num = inputs["accept_num"].copy() stop_flags = inputs["stop_flags"].copy() @@ -166,27 +177,36 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str step_idx_now = int(step_idx[bid]) min_token_limit = int(min_tokens[bid]) - can_stop = step_idx_now >= min_token_limit + can_stop = step_idx_now + an >= min_token_limit if not can_stop: continue if stop_flags[bid]: continue - accept_idx = 0 + # CUDA kernel: accept_idx 从 -1 开始,检查 pre_ids 末尾 + accept_idx = -1 is_end = False - while accept_idx <= an - 1 and not is_end: + + # loop_end = accept_num > 0 ? accept_num - 2 : -1 + loop_end = an - 2 if an > 0 else -1 + while accept_idx <= loop_end and not is_end: if step_idx_now + accept_idx + 1 < stop_seq_len: accept_idx += 1 continue - # Check one stop_seq match + # 从后向前匹配 stop_seq 的每个 token for i in range(stop_seq_len - 1, -1, -1): + offset = stop_seq_len - 1 - i + accept_tokens_idx = accept_idx - offset cur_token_idx = -1 - if stop_seq_len - 1 - i < accept_idx: - cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1] + + if accept_tokens_idx >= 0: + cur_token_idx = accept_tokens_now[accept_tokens_idx] else: - pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i) - if pre_ids_idx <= 0: + # 新语义: pre_ids_idx = step_idx_now + accept_tokens_idx + # pre_ids_now[0] 是第 1 个 output token + pre_ids_idx = step_idx_now + accept_tokens_idx + if pre_ids_idx < 0: break cur_token_idx = pre_ids_now[pre_ids_idx] @@ -199,9 +219,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str accept_idx += 1 if is_end: - accept_num[bid] = accept_idx - accept_tokens[bid, accept_idx - 1] = end_ids[0] - # stop_flags[bid] = True # kernel no longer sets stop_flags + # accept_idx 已递增,指向 stop_seq 最后 token 的下一个位置 + # 保留 stop_seq 所有 token,在其后追加 eos + accept_num[bid] = accept_idx + 1 + accept_tokens[bid, accept_idx] = end_ids[0] return { "accept_tokens": accept_tokens, @@ -239,7 +260,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): def _run_and_get(self, inputs): paddle_inputs = to_paddle_inputs(inputs) - run_kernel(paddle_inputs, inputs) + run_kernel(paddle_inputs) return get_outputs(paddle_inputs) def _check_all_outputs(self, inputs, outputs): @@ -264,7 +285,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): self._run_full_test(test_cfg) def test_match_in_accept_tokens_only(self): - """Stop seq found entirely within accept_tokens.""" + """Stop seq found entirely within accept_tokens. Eos appended after stop_seq last token.""" inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10) # Place stop seq [A, B, C] at accept_tokens positions [0,1,2] inputs["accept_num"][:] = 4 @@ -276,9 +297,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # stop_seq [10, 20, 30] matches at accept_idx=2 (window ends at accept_tokens[2]=30) + # After loop increment, accept_idx=3, accept_num=4, eos appended at accept_tokens[3] + self.assertEqual(outputs["accept_num"][0], 4) + self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq def test_match_spanning_pre_ids_and_accept(self): - """Stop seq spans token_ids_all (pre_ids) and accept_tokens.""" + """Stop seq spans token_ids_all (pre_ids) and accept_tokens. Eos appended after stop_seq last token.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -290,12 +315,15 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 6 inputs["accept_num"][:] = 3 - # Kernel matching at accept_idx=2 (3rd token, 0-indexed): - # i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1] - # i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0] - # i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6] - # So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]] - inputs["token_ids_all"][0, 6] = 99 + # stop_seq = [99, 11, 22] (len=3) + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 6 表示有 6 个历史 output token,在 pre_ids_now[0..5] + # At accept_idx=1 (window ends at accept_tokens[1]=22): + # i=2: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=22 vs stop_seq[2]=22 ✓ + # i=1: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=11 vs stop_seq[1]=11 ✓ + # i=0: offset=2, accept_tokens_idx=-1 -> pre_ids_idx=6+(-1)=5 -> pre_ids[5]=99 vs stop_seq[0]=99 ✓ + inputs["token_ids_all"][0, 5] = 99 # pre_ids_now[5] = 第 6 个 output token (0-indexed) inputs["accept_tokens"][0, :3] = [11, 22, 33] inputs["stop_seqs"][0, 0, :3] = [99, 11, 22] inputs["stop_seqs_len"][0, 0] = 3 @@ -303,12 +331,14 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - # Match at accept_idx=2, loop increments to 3 + # Match at accept_idx=1, loop increments to 2 -> accept_num=3, eos at accept_tokens[2] self.assertEqual(outputs["accept_num"][0], 3) - self.assertEqual(outputs["accept_tokens"][0, 2], -1) + self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq - def test_match_in_pre_ids_only(self): - """Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0.""" + def test_match_in_pre_ids_only_not_detected(self): + """Stop seq ending purely in pre_ids history but NOT at the end position. + The kernel only detects stop_seq at the very end of pre_ids via accept_idx=-1 check. + Stop seq placed earlier in pre_ids should not be detected.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -320,15 +350,13 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70 - # stop_seq = [50, 60, 70], all 3 tokens are in pre_ids - # For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check - # i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70 - # i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60 - # i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50 - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 + # 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 8 表示有 8 个历史 output token,在 pre_ids_now[0..7] + # accept_idx=-1 会检查 pre_ids_now[7] 开始的 stop_seq + # 把 stop_seq 放在 pre_ids_now[2,3,4] - 不会被检测到 + inputs["token_ids_all"][0, 2] = 50 + inputs["token_ids_all"][0, 3] = 60 + inputs["token_ids_all"][0, 4] = 70 inputs["accept_tokens"][0, :3] = [1, 2, 3] inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] inputs["stop_seqs_len"][0, 0] = 3 @@ -336,7 +364,8 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - self.assertEqual(outputs["accept_num"][0], 1) + # No match: stop_seq is in pre_ids but not at the end, accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 3) def test_already_stopped(self): """Kernel skips sequences with stop_flags=True.""" @@ -351,7 +380,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"]) def test_min_tokens_blocks_stop(self): - """Kernel skips stop check when step_idx < min_tokens.""" + """Kernel skips stop check when step_idx + accept_num < min_tokens.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -363,20 +392,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # Same setup that would match (like test_match_in_pre_ids_only) - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 + # Place stop_seq in pre_ids at end position (would be detected by accept_idx=-1) + # pre_ids_now[0..7] = 8 个历史 output token + # accept_idx=-1 检查 pre_ids_now[5,6,7] 对应 stop_seq[0,1,2] + inputs["token_ids_all"][0, 5] = 50 + inputs["token_ids_all"][0, 6] = 60 + inputs["token_ids_all"][0, 7] = 70 inputs["accept_tokens"][0, :3] = [1, 2, 3] inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] inputs["stop_seqs_len"][0, 0] = 3 inputs["stop_flags"][:] = False - inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop + inputs["min_tokens"][:] = 100 # step_idx+accept_num=11 < 100, should NOT stop outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # min_tokens prevents stop, accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 3) def test_min_tokens_allows_stop(self): - """Kernel allows stop when step_idx >= min_tokens.""" + """Kernel allows stop when step_idx + accept_num >= min_tokens.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -388,15 +421,17 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only) - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 - inputs["accept_tokens"][0, :3] = [1, 2, 3] - inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] - inputs["stop_seqs_len"][0, 0] = 3 + # stop_seq [X, 50] spans pre_ids and accept_tokens[0]. + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # At accept_idx=0 (window ends at accept_tokens[0]=50): + # i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=50 vs stop_seq[1]=50 ✓ + # i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=8+(-1)=7 -> pre_ids[7] + pre_val = int(inputs["token_ids_all"][0, 7]) # pre_ids_now[7] + inputs["accept_tokens"][0, :3] = [50, 60, 70] + inputs["stop_seqs"][0, 0, :2] = [pre_val, 50] + inputs["stop_seqs_len"][0, 0] = 2 inputs["stop_flags"][:] = False - inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop + inputs["min_tokens"][:] = 5 # step_idx+accept_num=11 >= 5, should stop outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) @@ -413,20 +448,24 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # accept_tokens: stop_seq[20,30] matches at accept_idx=2: - # i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK - # i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK + # accept_tokens: [20, 30, 40] + # Second stop seq [20, 30] matches at accept_idx=1 (window ends at accept_tokens[1]=30): + # i=1: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=30 vs stop_seq[1]=30 ✓ + # i=0: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=20 vs stop_seq[0]=20 ✓ inputs["accept_tokens"][0, :3] = [20, 30, 40] # First stop seq doesn't match inputs["stop_seqs"][0, 0, :3] = [99, 98, 97] inputs["stop_seqs_len"][0, 0] = 3 - # Second stop seq matches + # Second stop seq [20, 30] matches inputs["stop_seqs"][0, 1, :2] = [20, 30] inputs["stop_seqs_len"][0, 1] = 2 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # Match at accept_idx=1 -> accept_num=3, eos at accept_tokens[2] + self.assertEqual(outputs["accept_num"][0], 3) + self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq def test_nonzero_prompt_lens(self): """Verify prompt_lens offset is applied correctly.""" @@ -444,19 +483,104 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): inputs["accept_num"][:] = 2 inputs["accept_tokens"][0, :2] = [55, 66] # pre_ids_now starts at token_ids_all[0, prompt_len:] - # stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx] - # For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4 - # -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4] - # For accept_idx=1 (second token is accept_tokens[0,0]=55): - # i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55 - # i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5] - target_val = int(inputs["token_ids_all"][0, prompt_len + 5]) + # pre_ids_now[k] = 第 k 个 output token (k >= 0) + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # stop_seq = [X, 55] where X = pre_ids_now[5 + (-1)] = pre_ids_now[4] + # At accept_idx=0 (window ends at accept_tokens[0]=55): + # i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=55 vs stop_seq[1]=55 ✓ + # i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=5+(-1)=4 -> pre_ids[4]=token_ids_all[0, prompt_len+4] + target_val = int(inputs["token_ids_all"][0, prompt_len + 4]) inputs["stop_seqs"][0, 0, :2] = [target_val, 55] inputs["stop_seqs_len"][0, 0] = 2 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # Match at accept_idx=0 -> accept_num=2, eos at accept_tokens[1] + self.assertEqual(outputs["accept_num"][0], 2) + self.assertEqual(outputs["accept_tokens"][0, 1], -1) # eos appended after stop_seq + + def test_single_token_stop_seq_preserved(self): + """Single token stop_seq (like <|im_end|>) with eos appended after it.""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=90, + ) + inputs["prompt_lens"][:] = 0 + inputs["step_idx"][:] = 10 + inputs["accept_num"][:] = 4 + # accept_tokens: [a, b, <|im_end|>, d] where <|im_end|> has token id 999 + inputs["accept_tokens"][0, :4] = [100, 200, 999, 300] + # stop_seq = [<|im_end|>] (single token) + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # Match at accept_idx=2 (window ends at accept_tokens[2]=999) + # After loop increment, accept_idx=3, accept_num=4, eos at accept_tokens[3] + self.assertEqual(outputs["accept_num"][0], 4) + self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq + + def test_stop_seq_at_last_position_not_detected(self): + """Stop seq at the last position of accept_tokens is NOT detected (deferred to next round).""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=100, + ) + inputs["prompt_lens"][:] = 0 + inputs["step_idx"][:] = 10 + inputs["accept_num"][:] = 4 + # stop_seq [999] is at accept_tokens[3] (last valid position) + # Since we only check up to accept_num - 2 = 2, this won't be detected + inputs["accept_tokens"][0, :4] = [100, 200, 300, 999] + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # No match because accept_idx only goes up to 2, and 999 is at position 3 + # accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 4) + + def test_stop_seq_detected_from_previous_round(self): + """Stop seq at the end of pre_ids (from previous round) is detected via accept_idx=-1.""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=110, + ) + inputs["prompt_lens"][:] = 0 + # 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 10 表示有 10 个历史 output token,在 pre_ids_now[0..9] + # accept_idx=-1 检查 pre_ids_now[9] (最后一个历史 token) + inputs["step_idx"][:] = 10 + inputs["token_ids_all"][0, 9] = 999 # pre_ids_now[9] = 第 10 个 output token (0-indexed) + inputs["accept_num"][:] = 3 + inputs["accept_tokens"][0, :3] = [100, 200, 300] + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # stop_seq [999] was in pre_ids at end, accept_idx=-1 matches + # After loop increment, accept_idx=0, accept_num=1, eos at accept_tokens[0] + self.assertEqual(outputs["accept_num"][0], 1) + self.assertEqual(outputs["accept_tokens"][0, 0], -1) # replaced with eos if __name__ == "__main__": diff --git a/tests/operators/test_unified_update_model_status.py b/tests/operators/test_unified_update_model_status.py index 56656fdbe7..ed97aa8687 100644 --- a/tests/operators/test_unified_update_model_status.py +++ b/tests/operators/test_unified_update_model_status.py @@ -261,7 +261,9 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]: # Write history to token_ids_all (forward loop, mirrors kernel step 5) if output_len > 0: base_addr = int(prompt_lens[batch_id]) - base = cur_step_idx - output_len + 1 + # 新语义: step_idx 入口 = 历史数量,处理后 cur_step_idx = 历史 + output_len + # 第一个 output token 写入位置 = cur_step_idx - output_len + base = cur_step_idx - output_len for i in range(output_len): write_idx = base_addr + base + i if 0 <= write_idx < max_model_len: diff --git a/tests/scheduler/test_dp_scheduler.py b/tests/scheduler/test_dp_scheduler.py index 0e42c4491f..a5f9cfa838 100644 --- a/tests/scheduler/test_dp_scheduler.py +++ b/tests/scheduler/test_dp_scheduler.py @@ -411,6 +411,32 @@ class TestDPLocalScheduler(unittest.TestCase): self.assertEqual(scheduler.ids, ["fresh_req"]) self.assertEqual(scheduler.ids_read_cursor, 1) + def test_get_requests_insufficient_resources(self): + """Test getting requests when resources are insufficient.""" + mock_logger.reset_mock() + + # Test with insufficient blocks - mock the condition variable to avoid threading issues + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1 + ) + + self.assertEqual(requests, []) + # The logger should have been called for insufficient resources + self.assertTrue(mock_logger.debug.called) + # Check the message contains expected content + call_args = mock_logger.debug.call_args[0][0] + self.assertIn("insufficient", call_args.lower()) + + def test_get_requests_insufficient_batch(self): + """Test getting requests when batch size is insufficient.""" + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0 + ) + + self.assertEqual(requests, []) + @patch("time.time") @patch.object(dp_scheduler_module, "envs") def test_get_requests_no_requests_available(self, mock_envs, mock_time): diff --git a/tests/splitwise/test_internal_adapter_utils.py b/tests/splitwise/test_internal_adapter_utils.py index f8f22215c0..4d77278984 100644 --- a/tests/splitwise/test_internal_adapter_utils.py +++ b/tests/splitwise/test_internal_adapter_utils.py @@ -25,9 +25,6 @@ class DummyEngine: """Dummy Engine class to simulate the actual Engine for testing.""" class ResourceManager: - def __init__(self): - self.waiting = [] - def available_batch(self): return 4 diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index 240cf702ed..4f55ca4647 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -138,7 +138,7 @@ class TestConfig(unittest.TestCase): model_config=model_config, test_mode=True, ) - fd_config.init_cache_info() + fd_config.init_pd_info() assert fd_config.register_info is not None def test_fdconfig_postprocess_ports(self): diff --git a/tests/worker/test_gpu_model_runner.py b/tests/worker/test_gpu_model_runner.py index 3a02475b5a..43ab5130cd 100644 --- a/tests/worker/test_gpu_model_runner.py +++ b/tests/worker/test_gpu_model_runner.py @@ -487,7 +487,7 @@ class TestSleepWakeupBehavior(unittest.TestCase): runner.local_rank = 0 runner.device_id = 1 runner.num_gpu_blocks = 8 - runner.model = Mock(clear_grpah_opt_backend=Mock()) + runner.model = Mock(clear_graph_opt_backend=Mock()) runner.clear_cache = Mock() runner.initialize_kv_cache = Mock() runner.capture_model = Mock() @@ -523,7 +523,7 @@ class TestSleepWakeupBehavior(unittest.TestCase): runner.sleep("weight,kv_cache") - runner.model.clear_grpah_opt_backend.assert_called_once() + runner.model.clear_graph_opt_backend.assert_called_once() runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once() runner.dynamic_weight_manager.clear_model_weight.assert_called_once() runner.dynamic_weight_manager.clear_communication_group.assert_called_once()