[CI] Add ep4_mtp e2e test (#6153)

* [CI] Add ep4_mtp e2e test
This commit is contained in:
YuBaoku
2026-01-22 14:54:18 +08:00
committed by GitHub
parent 1e3c35496c
commit 1cfb042045
3 changed files with 604 additions and 20 deletions
+1 -20
View File
@@ -190,24 +190,5 @@ jobs:
export PYTHONPATH=/workspace/FastDeploy/
export CUDA_VISIBLE_DEVICES=0,1,2,3
echo "============================================================"
echo "Running pytest for 4-GPU end-to-end cases"
python -m pytest -sv --tb=short tests/e2e/4cards_cases/
exit_code=$?
if [ $exit_code -ne 0 ]; then
if [ -f "./log/log_0/workerlog.0" ]; then
echo "---------------- log/workerlog.0 -------------------"
cat "./log/log_0/workerlog.0"
echo "----------------------------------------------------"
fi
if [ -f "./server.log" ]; then
echo "---------------- server.log ----------------"
cat "./server.log"
echo "--------------------------------------------"
fi
exit 1
fi
bash scripts/run_gpu_4cards.sh
'
+54
View File
@@ -0,0 +1,54 @@
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
E2E_CASES_DIR="${REPO_ROOT}/tests/e2e/4cards_cases"
FAILED_CASE_FILE="${REPO_ROOT}/failed_cases.txt"
FAILED_COUNT=0
rm -f "${FAILED_CASE_FILE}"
shopt -s nullglob
test_files=("${E2E_CASES_DIR}"/test_*.py)
if [ "${#test_files[@]}" -eq 0 ]; then
echo "ERROR: No test files found under: ${E2E_CASES_DIR}"
exit 1
fi
for test_file in "${test_files[@]}"; do
echo "------------------------------------------------------------"
echo "Running pytest: ${test_file}"
echo "------------------------------------------------------------"
if ! python -m pytest -sv --tb=short "${test_file}"; then
echo "Pytest failed for: ${test_file}"
echo "${test_file}" >> "${FAILED_CASE_FILE}"
FAILED_COUNT=$((FAILED_COUNT + 1))
if [ -f "${REPO_ROOT}/log/log_0/workerlog.0" ]; then
echo "---------------- workerlog.0 (last 200 lines) -------------"
tail -n 200 "${REPO_ROOT}/log/log_0/workerlog.0"
echo "------------------------------------------------------------"
fi
if [ -f "${REPO_ROOT}/server.log" ]; then
echo "---------------- server.log (last 200 lines) ---------------"
tail -n 200 "${REPO_ROOT}/server.log"
echo "------------------------------------------------------------"
fi
fi
done
shopt -u nullglob
if [ "${FAILED_COUNT}" -ne 0 ]; then
echo "${FAILED_COUNT} test file(s) failed:"
cat "${FAILED_CASE_FILE}"
exit 1
else
echo "All 4-GPU end-to-end tests passed"
exit 0
fi
@@ -0,0 +1,549 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import signal
import subprocess
import sys
import time
import pytest
import requests
tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, tests_dir)
from e2e.utils.serving_utils import (
FD_API_PORT,
FD_CACHE_QUEUE_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
PORTS_TO_CLEAN,
clean_ports,
is_port_open,
)
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server(api_url):
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
ports_to_add = [
FD_API_PORT + 1,
FD_METRICS_PORT + 1,
FD_CACHE_QUEUE_PORT + 1,
FD_ENGINE_QUEUE_PORT + 1,
]
for port in ports_to_add:
if port not in PORTS_TO_CLEAN:
PORTS_TO_CLEAN.append(port)
clean_ports(PORTS_TO_CLEAN)
print("log dir clean ")
if os.path.exists("log") and os.path.isdir("log"):
shutil.rmtree("log")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle")
else:
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
mtp_model_path = os.path.join(model_path, "mtp")
speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path}
log_path = "server.log"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.multi_api_server",
"--num-servers",
"2",
"--ports",
f"{FD_API_PORT},{FD_API_PORT + 1}",
"--metrics-ports",
f"{FD_METRICS_PORT},{FD_METRICS_PORT + 1}",
"--args",
"--model",
model_path,
"--engine-worker-queue-port",
f"{FD_ENGINE_QUEUE_PORT},{FD_ENGINE_QUEUE_PORT + 1}",
"--cache-queue-port",
f"{FD_CACHE_QUEUE_PORT},{FD_CACHE_QUEUE_PORT + 1}",
"--tensor-parallel-size",
"2",
"--data-parallel-size",
"2",
"--max-model-len",
"65536",
"--max-num-seqs",
"32",
"--quantization",
"block_wise_fp8",
"--enable-logprob",
"--speculative-config",
json.dumps(speculative_config),
"--graph-optimization-config",
'{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}',
]
# Start subprocess in new process group
if os.path.exists("log"):
shutil.rmtree("log")
with open(log_path, "w") as logfile:
process = subprocess.Popen(
cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
)
# Wait up to 300 seconds for API server to be ready
for _ in range(300):
if is_port_open("127.0.0.1", FD_API_PORT):
print(f"Server is up on port {FD_API_PORT}")
break
time.sleep(1)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process.pid, signal.SIGTERM)
# clean()
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process.pid, signal.SIGTERM)
# clean()
time.sleep(10)
print(f"server (pid={process.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the API endpoint URL for chat completions.
"""
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"
@pytest.fixture(scope="session")
def metrics_url(request):
"""
Returns the metrics endpoint URL.
"""
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
@pytest.fixture
def headers():
"""
Returns common HTTP request headers.
"""
return {"Content-Type": "application/json"}
def send_request(url, payload, timeout=60):
"""
Send a POST request to the specified URL with the given payload.
"""
headers = {
"Content-Type": "application/json",
}
try:
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
print("🟢 Receiving response...\n")
return res
except requests.exceptions.Timeout:
print(f"❌ Request timed out (>{timeout} seconds)")
return None
except requests.exceptions.RequestException as e:
print(f"❌ Request failed: {e}")
return None
def get_stream_chunks(response):
"""
Parse a streaming HTTP response into a list of JSON chunks.
This helper processes Server-Sent Events (SSE) style responses,
strips the 'data:' prefix, ignores the '[DONE]' marker, and
decodes each chunk into a Python dict.
Args:
response: HTTP response returned by send_request().
Returns:
List[dict]: Parsed stream chunks in arrival order.
"""
chunks = []
if response.status_code == 200:
for line in response.iter_lines(decode_unicode=True):
if not line:
continue
if line.startswith("data: "):
line = line[len("data: ") :]
if line.strip() == "[DONE]":
break
try:
chunk = json.loads(line)
chunks.append(chunk)
except Exception as e:
print(f"Failed to parse chunk: {e}, raw line: {line}")
else:
print(f"Request failed, status code: {response.status_code}")
print("Response body:", response.text)
return chunks
def get_token_list(response):
"""
Extract generated token strings from a non-streaming response.
This function reads token-level information from
`choices[0].logprobs.content` and returns the generated token list
in order. It is mainly used for stop-sequence validation.
Args:
response (dict): JSON-decoded inference response.
Returns:
List[str]: Generated token strings.
"""
token_list = []
try:
content_logprobs = response["choices"][0]["logprobs"]["content"]
except (KeyError, IndexError, TypeError) as e:
print(f"Failed to extract logprobs: {e}")
return []
for token_info in content_logprobs:
token = token_info.get("token")
if token is not None:
token_list.append(token)
print(f"Token List: {token_list}")
return token_list
def extract_logprobs(chunks):
"""
Extract token-level logprobs from streaming response chunks.
This helper skips chunks without choices, usage-only chunks,
and chunks without logprobs, and aggregates token-level
logprob information in generation order.
Args:
chunks (List[dict]): Parsed streaming chunks.
Returns:
List[List[dict]]: Structured logprobs for each generated token.
"""
results = []
for chunk in chunks:
choices = chunk.get("choices")
if not choices:
continue
choice = choices[0]
logprobs = choice.get("logprobs")
if not logprobs or not logprobs.get("content"):
continue
token_infos = []
for item in logprobs["content"]:
token_infos.append(
{
"token": item["token"],
"logprob": item["logprob"],
"top_logprobs": [
{
"token": tlp["token"],
"logprob": tlp["logprob"],
}
for tlp in item.get("top_logprobs", [])
],
}
)
results.append(token_infos)
return results
def test_text_diff(api_url):
"""
Validate deterministic streaming output against a fixed text baseline.
The test uses fixed decoding parameters and seed, concatenates
all streamed content, and performs a strict byte-level comparison
with a stored baseline file.
"""
payload = {
"stream": True,
"seed": 21,
"top_p": 0,
"stop": ["</s>", "<eos>", "<|endoftext|>", "<|im_end|>"],
"chat_template_kwargs": {
"options": {"thinking_mode": "close"},
},
"bad_words_token_ids": [101031, 101032, 101027, 101028, 101023, 101024],
"messages": [{"role": "user", "content": "解释一下温故而知新"}],
}
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join(x["choices"][0]["delta"]["content"] for x in chunks)
print(result)
base_path = os.getenv("MODEL_PATH")
if base_path:
base_file = os.path.join(base_path, "21b_ep4_mtp_text_baseline.txt")
else:
base_file = "21b_ep4_mtp_text_baseline.txt"
with open(base_file, "r", encoding="utf-8") as f:
baseline = f.read()
assert result == baseline, f"Text mismatch with baseline\nresult: {result}\nbaseline: {baseline}"
def test_chat_usage_stream(api_url):
"""
Verify token usage statistics for chat completion in streaming mode.
The test ensures:
- Generated content is non-empty
- completion_tokens respects min/max constraints
- total_tokens equals prompt_tokens + completion_tokens
"""
payload = {
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
},
"messages": [{"role": "user", "content": "解释一下温故而知新"}],
"min_tokens": 10,
"max_tokens": 50,
}
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join(x["choices"][0]["delta"]["content"] for x in chunks[:-1])
assert result != "", "Empty generation result"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"]
assert payload["min_tokens"] <= usage["completion_tokens"]
assert usage["total_tokens"] == total_tokens
def test_chat_usage_non_stream(api_url):
"""
Verify token usage statistics for chat completion in non-streaming mode.
"""
payload = {
"stream": False,
"messages": [{"role": "user", "content": "解释一下温故而知新"}],
"temperature": 1.0,
"seed": 21,
"top_p": 0,
"stop": ["</s>", "<eos>", "<|endoftext|>", "<|im_end|>"],
"min_tokens": 10,
"max_tokens": 50,
"chat_template_kwargs": {
"options": {"thinking_mode": "close"},
},
"bad_words_token_ids": [101031, 101032, 101027, 101028, 101023, 101024],
}
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["message"]["content"]
assert result != "", "Empty generation result"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"]
assert payload["min_tokens"] <= usage["completion_tokens"]
assert usage["total_tokens"] == total_tokens
def test_non_chat_usage_stream(api_url):
"""
Verify usage statistics for completions API in streaming mode.
"""
payload = {
"model": "null",
"prompt": "你好,你是谁?",
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
},
"min_tokens": 10,
"max_tokens": 50,
"seed": 566,
"chat_template_kwargs": {
"options": {"thinking_mode": "close"},
},
"bad_words_token_ids": [101031, 101032, 101027, 101028, 101023, 101024],
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"]
assert payload["min_tokens"] <= usage["completion_tokens"]
assert usage["total_tokens"] == total_tokens
def test_non_chat_usage_non_stream(api_url):
"""
Verify usage statistics for completions API in non-streaming mode.
"""
payload = {
"model": "null",
"prompt": "你好,你是谁?",
"stream": False,
"min_tokens": 10,
"max_tokens": 50,
"seed": 566,
"chat_template_kwargs": {
"options": {"thinking_mode": "close"},
},
"bad_words_token_ids": [101031, 101032, 101027, 101028, 101023, 101024],
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["text"]
assert result != "", "Empty generation result"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"]
assert payload["min_tokens"] <= usage["completion_tokens"]
assert usage["total_tokens"] == total_tokens
def test_non_stream_with_logprobs(api_url):
"""
Verify deterministic logprobs output in non-streaming mode.
"""
payload = {
"stream": False,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 3,
"logprobs": True,
"top_logprobs": 5,
"seed": 21,
"min_tokens": 1,
"chat_template_kwargs": {
"options": {"thinking_mode": "close"},
},
"bad_words_token_ids": [101031, 101032, 101027, 101028, 101023, 101024],
}
resp_json = send_request(url=api_url, payload=payload).json()
logprobs = resp_json["choices"][0]["logprobs"]
base_path = os.getenv("MODEL_PATH")
if base_path:
base_file = os.path.join(base_path, "21b_ep4_mtp_logprobs_non_stream_static_baseline.txt")
else:
base_file = "21b_ep4_mtp_logprobs_non_stream_static_baseline.txt"
with open(base_file, "r", encoding="utf-8") as f:
baseline = json.load(f)
assert logprobs == baseline
def test_stream_with_logprobs(api_url):
"""
Verify deterministic logprobs output in streaming mode.
"""
payload = {
"stream": True,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 3,
"logprobs": True,
"top_logprobs": 5,
"min_tokens": 1,
"seed": 21,
}
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
logprobs = extract_logprobs(chunks)
base_path = os.getenv("MODEL_PATH")
if base_path:
base_file = os.path.join(base_path, "21b_ep4_mtp_logprobs_stream_static_baseline.txt")
else:
base_file = "21b_ep4_mtp_logprobs_stream_static_baseline.txt"
with open(base_file, "r", encoding="utf-8") as f:
baseline = json.load(f)
assert logprobs == baseline