mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
PD deployment support without router (#7412)
This commit is contained in:
@@ -1,113 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# There are two methods for splitwise deployment:
|
||||
# v0: using splitwise_scheduler or dp_scheduler (deprecated)
|
||||
# v1: using local_scheduler + router
|
||||
|
||||
# prepare environment
|
||||
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
source ${SCRIPT_DIR}/utils.sh
|
||||
|
||||
P_PORT=52400
|
||||
D_PORT=52500
|
||||
REDIS_PORT="${REDIS_PORT:-6379}"
|
||||
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
ports=(
|
||||
$P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5))
|
||||
$D_PORT $((D_PORT + 1)) $((D_PORT + 2)) $((D_PORT + 3)) $((D_PORT + 4)) $((D_PORT + 5))
|
||||
$REDIS_PORT
|
||||
)
|
||||
check_ports "${ports[@]}" || {
|
||||
echo "❌ Some ports are in use. Please release them."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# start redis
|
||||
if ! redis-cli -p ${REDIS_PORT} ping &>/dev/null; then
|
||||
echo "Redis is not running. Starting redis-server..."
|
||||
redis-server --daemonize yes --port ${REDIS_PORT}
|
||||
sleep 1
|
||||
else
|
||||
echo "Redis is already running."
|
||||
fi
|
||||
sleep 1
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port ${P_PORT} \
|
||||
--metrics-port $((P_PORT + 1)) \
|
||||
--engine-worker-queue-port $((P_PORT + 2)) \
|
||||
--cache-queue-port $((P_PORT + 3)) \
|
||||
--max-model-len 32768 \
|
||||
--num-gpu-blocks-override 1000 \
|
||||
--splitwise-role "prefill" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports $((P_PORT + 4)) \
|
||||
--pd-comm-port $((P_PORT + 5)) \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port ${REDIS_PORT} \
|
||||
--scheduler-ttl 9000 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${P_PORT}
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port ${D_PORT} \
|
||||
--metrics-port $((D_PORT + 1)) \
|
||||
--engine-worker-queue-port $((D_PORT + 2)) \
|
||||
--cache-queue-port $((D_PORT + 3)) \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "decode" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports $((D_PORT + 4)) \
|
||||
--pd-comm-port $((D_PORT + 5)) \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port ${REDIS_PORT} \
|
||||
--scheduler-ttl 9000 \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${D_PORT}
|
||||
|
||||
|
||||
# send request
|
||||
sleep 10 # make sure server is registered to router
|
||||
echo "send request..."
|
||||
curl -X POST "http://0.0.0.0:${D_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
],
|
||||
"max_tokens": 20,
|
||||
"stream": false
|
||||
}'
|
||||
@@ -2010,13 +2010,13 @@ class FDConfig:
|
||||
and self.router_config
|
||||
and self.router_config.router
|
||||
):
|
||||
# For RL scenario: version.yaml will be required for models in future releases.
|
||||
# For RL scenario, version.yaml is required for models
|
||||
# Temporarily enforce use router to be enabled.
|
||||
self.model_config.read_model_version()
|
||||
|
||||
self.read_from_config()
|
||||
self.postprocess()
|
||||
self.init_cache_info()
|
||||
self.init_pd_info()
|
||||
if test_mode:
|
||||
return
|
||||
self.check()
|
||||
@@ -2371,18 +2371,17 @@ class FDConfig:
|
||||
logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
logger.info("=============================================================")
|
||||
|
||||
def init_cache_info(self):
|
||||
def init_pd_info(self):
|
||||
"""
|
||||
initialize cache info
|
||||
initialize info for pd deployment
|
||||
"""
|
||||
# TODO: group the splitiwse params
|
||||
# There are two methods for splitwise deployment:
|
||||
# 1. v0 splitwise_scheduler or dp_scheduler
|
||||
# 2. v1 local_scheduler + router
|
||||
# 2. v1 local_scheduler + router (optional)
|
||||
self.splitwise_version = None
|
||||
if self.scheduler_config.name in ("splitwise", "dp"):
|
||||
self.splitwise_version = "v0"
|
||||
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
|
||||
elif self.scheduler_config.name == "local":
|
||||
self.splitwise_version = "v1"
|
||||
|
||||
# the information for registering this server to router or splitwise_scheduler
|
||||
|
||||
@@ -600,10 +600,15 @@ class EngineArgs:
|
||||
raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1")
|
||||
|
||||
if self.splitwise_role != "mixed":
|
||||
if self.scheduler_name == "local" and self.router is None:
|
||||
if self.scheduler_name == "splitwise":
|
||||
raise ValueError(
|
||||
f"When using {self.splitwise_role} role and the {self.scheduler_name} "
|
||||
f"scheduler, please provide --router argument."
|
||||
"Setting scheduler_name as splitwise is not supported in pd deployment, "
|
||||
"please use router as scheduler."
|
||||
)
|
||||
if self.scheduler_name == "local" and self.router is None:
|
||||
console_logger.warning(
|
||||
f"Running {self.splitwise_role} role with {self.scheduler_name} "
|
||||
f"scheduler without --router. Router registration and request routing will be disabled."
|
||||
)
|
||||
|
||||
if not (
|
||||
|
||||
@@ -109,7 +109,7 @@ class ExpertService:
|
||||
if envs.FD_ENABLE_RETURN_TEXT:
|
||||
self.engine.create_data_processor()
|
||||
if self.cfg.scheduler_config.name == "dp":
|
||||
self.cfg.init_cache_info()
|
||||
self.cfg.init_pd_info()
|
||||
self.engine.scheduler.start(local_data_parallel_id)
|
||||
|
||||
if ipc_signal_suffix is not None:
|
||||
@@ -122,7 +122,7 @@ class ExpertService:
|
||||
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.cfg.init_cache_info()
|
||||
self.cfg.init_pd_info()
|
||||
role = self.cfg.scheduler_config.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
self.engine.scheduler.start(role, host_ip, self.cfg.register_info)
|
||||
|
||||
@@ -0,0 +1,455 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Test splitwise deployment WITHOUT Router:
|
||||
# use local_scheduler, manually construct disaggregate_info,
|
||||
# send requests to both Prefill and Decode concurrently.
|
||||
# ENABLE_V1_KVCACHE_SCHEDULER=1, use rdma to transfer cache.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from utils.serving_utils import (
|
||||
FD_API_PORT,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
FD_METRICS_PORT,
|
||||
check_service_health,
|
||||
clean,
|
||||
)
|
||||
|
||||
# Ports for PD disaggregation (no router port needed)
|
||||
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
|
||||
FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623))
|
||||
|
||||
# Prefill uses base ports, Decode uses base+1
|
||||
PORTS_TO_CLEAN = [
|
||||
FD_API_PORT,
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
FD_METRICS_PORT,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
FD_CONNECTOR_PORT,
|
||||
FD_RDMA_PORT,
|
||||
FD_API_PORT + 1,
|
||||
FD_ENGINE_QUEUE_PORT + 1,
|
||||
FD_METRICS_PORT + 1,
|
||||
FD_CACHE_QUEUE_PORT + 1,
|
||||
FD_CONNECTOR_PORT + 1,
|
||||
FD_RDMA_PORT + 1,
|
||||
]
|
||||
|
||||
|
||||
def _build_disaggregate_info() -> dict:
|
||||
"""Build disaggregate_info manually, replicating Router's handle_splitwise_request logic."""
|
||||
host_ip = os.getenv("FD_HOST_IP", "127.0.0.1")
|
||||
return {
|
||||
"prefill_ip": host_ip,
|
||||
"decode_ip": host_ip,
|
||||
"prefill_connector_port": FD_CONNECTOR_PORT,
|
||||
"decode_connector_port": FD_CONNECTOR_PORT + 1,
|
||||
"decode_device_ids": ["1"],
|
||||
"decode_rdma_ports": [FD_RDMA_PORT + 1],
|
||||
"transfer_protocol": "rdma",
|
||||
"decode_tp_size": 1,
|
||||
}
|
||||
|
||||
|
||||
def _send_pd_request(payload: dict, timeout: int = 120):
|
||||
"""
|
||||
Send request to both Prefill and Decode concurrently,
|
||||
replicate Router's fan-out forwarding behavior.
|
||||
Returns the Decode response (same as Router's return_result_url_index=-1).
|
||||
"""
|
||||
disaggregate_info = _build_disaggregate_info()
|
||||
|
||||
# Inject disaggregate_info and request_id (same as Router)
|
||||
payload = payload.copy()
|
||||
payload["disaggregate_info"] = disaggregate_info
|
||||
if "request_id" not in payload:
|
||||
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||
|
||||
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/chat/completions"
|
||||
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# For streaming, use requests with stream=True for decode response
|
||||
if payload.get("stream", False):
|
||||
# Send to both concurrently (same as Router's fan-out), stream from decode
|
||||
import concurrent.futures
|
||||
|
||||
def _post_stream(url):
|
||||
return requests.post(url, headers=headers, json=payload, timeout=timeout, stream=True)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
prefill_future = executor.submit(_post_stream, prefill_url)
|
||||
decode_future = executor.submit(_post_stream, decode_url)
|
||||
# Return decode streaming response immediately
|
||||
decode_resp = decode_future.result()
|
||||
# Consume prefill response in background (don't block)
|
||||
try:
|
||||
prefill_future.result(timeout=timeout)
|
||||
except Exception:
|
||||
pass
|
||||
return decode_resp
|
||||
else:
|
||||
# Non-streaming: send to both, return decode response
|
||||
import concurrent.futures
|
||||
|
||||
def _post(url):
|
||||
return requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
prefill_future = executor.submit(_post, prefill_url)
|
||||
decode_future = executor.submit(_post, decode_url)
|
||||
# Wait for both, return decode response
|
||||
decode_resp = decode_future.result()
|
||||
# Also check prefill didn't error (but don't block on it)
|
||||
try:
|
||||
prefill_future.result(timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
return decode_resp
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_run_server():
|
||||
"""
|
||||
Pytest fixture that runs once per test session:
|
||||
- Cleans ports before tests
|
||||
- Starts Prefill and Decode instances WITHOUT Router
|
||||
- Waits for both to be healthy
|
||||
- Tears down after all tests finish
|
||||
"""
|
||||
print("Pre-test port cleanup...")
|
||||
clean(PORTS_TO_CLEAN)
|
||||
|
||||
print("log dir clean")
|
||||
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
|
||||
shutil.rmtree("log_prefill")
|
||||
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
|
||||
shutil.rmtree("log_decode")
|
||||
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||
else:
|
||||
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
|
||||
print(f"model_path: {model_path}")
|
||||
|
||||
base_log_dir = os.getenv("FD_LOG_DIR", "log")
|
||||
|
||||
# Prefill instance
|
||||
print("start prefill...")
|
||||
env_prefill = os.environ.copy()
|
||||
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
env_prefill["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_prefill")
|
||||
|
||||
prefill_log_path = "prefill.log"
|
||||
prefill_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT),
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--splitwise-role",
|
||||
"prefill",
|
||||
"--cache-transfer-protocol",
|
||||
"rdma",
|
||||
"--rdma-comm-ports",
|
||||
str(FD_RDMA_PORT),
|
||||
"--pd-comm-port",
|
||||
str(FD_CONNECTOR_PORT),
|
||||
# No --router flag
|
||||
]
|
||||
|
||||
with open(prefill_log_path, "w") as logfile:
|
||||
process_prefill = subprocess.Popen(
|
||||
prefill_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True,
|
||||
env=env_prefill,
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# Decode instance
|
||||
print("start decode...")
|
||||
env_decode = os.environ.copy()
|
||||
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
env_decode["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_decode")
|
||||
|
||||
decode_log_path = "decode.log"
|
||||
decode_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT + 1),
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT + 1),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT + 1),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT + 1),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--splitwise-role",
|
||||
"decode",
|
||||
"--cache-transfer-protocol",
|
||||
"rdma",
|
||||
"--rdma-comm-ports",
|
||||
str(FD_RDMA_PORT + 1),
|
||||
"--pd-comm-port",
|
||||
str(FD_CONNECTOR_PORT + 1),
|
||||
# No --router flag
|
||||
]
|
||||
|
||||
with open(decode_log_path, "w") as logfile:
|
||||
process_decode = subprocess.Popen(
|
||||
decode_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True,
|
||||
env=env_decode,
|
||||
)
|
||||
|
||||
# Wait up to 300 seconds for both instances to be healthy
|
||||
for _ in range(60):
|
||||
prefill_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT}")
|
||||
decode_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT + 1}")
|
||||
if prefill_healthy and decode_healthy:
|
||||
print("Prefill and decode servers are both online")
|
||||
break
|
||||
time.sleep(5)
|
||||
else:
|
||||
print("[TIMEOUT] Servers failed to start in 5 minutes. Cleaning up...")
|
||||
try:
|
||||
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||
clean(PORTS_TO_CLEAN)
|
||||
except Exception as e:
|
||||
print(f"Failed to kill process group: {e}")
|
||||
raise RuntimeError("Prefill or decode server did not start")
|
||||
|
||||
yield # Run tests
|
||||
|
||||
print("\n===== Post-test server cleanup... =====")
|
||||
try:
|
||||
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||
clean(PORTS_TO_CLEAN)
|
||||
print(f"Prefill server (pid={process_prefill.pid}) terminated")
|
||||
print(f"Decode server (pid={process_decode.pid}) terminated")
|
||||
except Exception as e:
|
||||
print(f"Failed to terminate server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_url(request):
|
||||
"""
|
||||
Returns the Decode API endpoint URL (where final responses come from).
|
||||
"""
|
||||
return f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headers():
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def get_stream_chunks(response):
|
||||
"""Parse streaming response into chunk list."""
|
||||
chunks = []
|
||||
|
||||
if response.status_code == 200:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: ") :]
|
||||
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
chunks.append(chunk)
|
||||
except Exception as e:
|
||||
print(f"Parse failed: {e}, line: {line}")
|
||||
else:
|
||||
print(f"Request failed, status: {response.status_code}")
|
||||
print("Response:", response.text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def test_chat_usage_stream(api_url):
|
||||
"""Test streaming chat with usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = _send_pd_request(payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_chat_usage_non_stream(api_url):
|
||||
"""Test non-streaming chat with usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = _send_pd_request(payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["message"]["content"]
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_stream(api_url):
|
||||
"""Test streaming completion (non-chat) with usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
# Send to /v1/completions endpoints
|
||||
disaggregate_info = _build_disaggregate_info()
|
||||
payload = payload.copy()
|
||||
payload["disaggregate_info"] = disaggregate_info
|
||||
if "request_id" not in payload:
|
||||
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||
|
||||
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions"
|
||||
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120)
|
||||
decode_future = executor.submit(
|
||||
requests.post, decode_url, json=payload, headers=headers, timeout=120, stream=True
|
||||
)
|
||||
response = decode_future.result()
|
||||
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_non_stream(api_url):
|
||||
"""Test non-streaming completion (non-chat) with usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
# Send to /v1/completions endpoints
|
||||
disaggregate_info = _build_disaggregate_info()
|
||||
payload = payload.copy()
|
||||
payload["disaggregate_info"] = disaggregate_info
|
||||
if "request_id" not in payload:
|
||||
payload["request_id"] = f"test-pd-{uuid.uuid4()}"
|
||||
|
||||
prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions"
|
||||
decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120)
|
||||
decode_future = executor.submit(requests.post, decode_url, json=payload, headers=headers, timeout=120)
|
||||
response = decode_future.result().json()
|
||||
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["text"]
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
@@ -111,7 +111,7 @@ class TestThinkingBudgetLogitsProcessor(unittest.TestCase):
|
||||
self._fdconfig_patches = [
|
||||
patch.object(FDConfig, "read_from_config", return_value=None),
|
||||
patch.object(FDConfig, "postprocess", return_value=None),
|
||||
patch.object(FDConfig, "init_cache_info", return_value=None),
|
||||
patch.object(FDConfig, "init_pd_info", return_value=None),
|
||||
patch.object(FDConfig, "check", return_value=None),
|
||||
]
|
||||
for patcher in self._fdconfig_patches:
|
||||
|
||||
Reference in New Issue
Block a user