mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
edd31e8849
* add * [tests] Add Paddle attention determinism tests and refactor resource manager Add comprehensive determinism tests for Paddle attention layer and refactor resource manager for deterministic mode support. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * add * add * add * add * add more * add more * fixsome * fixsome * fix bugs * fix bugs * only in gpu * add docs * fix comments * fix some * fix some * fix comments * add more * fix potential problem * remove not need * remove not need * remove no need * fix bug * fix bugs * fix comments * fix comments * Update tests/ce/deterministic/test_determinism_verification.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/inter_communicator/test_ipc_signal.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism_standalone.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix comments * fix import error * fix a bug * fix bugs * fix bugs * fix coverage * refine codes * refine code * fix comments * fix comments * fix comments * rm not need * fix allreduce large tensor bug * mv log files * mv log files * add files --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
471 lines
16 KiB
Python
471 lines
16 KiB
Python
"""
|
|
Determinism Feature Verification Test
|
|
|
|
Reference: test_batch_invariant.py. Verifies whether determinism works correctly.
|
|
|
|
Usage:
|
|
# Step 1: Start server with determinism disabled
|
|
bash ./tests/ce/deterministic/start_fd.sh 0
|
|
|
|
# Step 2: Run non-deterministic test (expected: results differ)
|
|
python ./tests/ce/deterministic/test_determinism_verification.py --phase non-deterministic
|
|
|
|
# Step 3: Stop server
|
|
bash fastdeploy/stop.sh
|
|
|
|
# Step 4: Start server with determinism enabled and logging ON
|
|
bash ./tests/ce/deterministic/start_fd.sh 1 1
|
|
|
|
# Step 5: Run deterministic test (expected: results consistent)
|
|
python ./tests/ce/deterministic/test_determinism_verification.py --phase deterministic
|
|
|
|
Arguments:
|
|
--phase {deterministic,non-deterministic}
|
|
Test mode
|
|
- deterministic: determinism enabled with logging, expected MD5 consistency
|
|
- non-deterministic: determinism disabled, expected different outputs
|
|
--api-url API endpoint URL (default: http://localhost:8188/v1/chat/completions)
|
|
--model Model name (default: Qwen/Qwen2.5-7B)
|
|
--log-file Server log file path (default: log/workerlog.0)
|
|
--repeat Number of repeat rounds for non-deterministic phase (default: 3)
|
|
|
|
Note: The deterministic test requires FD_DETERMINISTIC_LOG_MODE=1 to extract MD5 values
|
|
from logs for verification.
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import sys
|
|
|
|
import aiohttp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
|
|
# Defaults (overridable via CLI args or env vars)
|
|
DEFAULT_API_URL = "http://localhost:8188/v1/chat/completions"
|
|
DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-7B"
|
|
DEFAULT_LOG_FILE = "log/workerlog.0"
|
|
DEFAULT_NON_DET_REPEAT = 3
|
|
|
|
# Target prompt (we care about its determinism)
|
|
TARGET_PROMPT = "你好,请简单介绍一下自己。"
|
|
|
|
# Distractor prompts (different content, used to create batch interference)
|
|
DISTRACTOR_PROMPTS = [
|
|
"今天天气怎么样?",
|
|
"什么是人工智能?",
|
|
"如何学习编程?",
|
|
"什么是机器学习?",
|
|
"Python 是什么?",
|
|
]
|
|
|
|
# Generation length for target prompt (fixed, longer)
|
|
TARGET_MAX_TOKENS = 128
|
|
|
|
# Generation length range for distractor prompts
|
|
DISTRACTOR_MAX_TOKENS_RANGE = (8, 32)
|
|
|
|
# Health check settings
|
|
HEALTH_CHECK_INTERVAL = 5
|
|
HEALTH_CHECK_TIMEOUT = 300
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Determinism feature verification test")
|
|
parser.add_argument(
|
|
"--phase",
|
|
choices=["deterministic", "non-deterministic"],
|
|
required=True,
|
|
help="Test mode: deterministic (enabled) or non-deterministic (disabled)",
|
|
)
|
|
parser.add_argument(
|
|
"--api-url",
|
|
default=os.environ.get("FD_TEST_API_URL", DEFAULT_API_URL),
|
|
help=f"API endpoint URL (default: {DEFAULT_API_URL})",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
default=os.environ.get("FD_TEST_MODEL", DEFAULT_MODEL_NAME),
|
|
help=f"Model name (default: {DEFAULT_MODEL_NAME})",
|
|
)
|
|
parser.add_argument(
|
|
"--log-file",
|
|
default=os.environ.get("FD_TEST_LOG_FILE", DEFAULT_LOG_FILE),
|
|
help=f"Server log file path (default: {DEFAULT_LOG_FILE})",
|
|
)
|
|
parser.add_argument(
|
|
"--repeat",
|
|
type=int,
|
|
default=int(os.environ.get("FD_TEST_REPEAT", DEFAULT_NON_DET_REPEAT)),
|
|
help=f"Number of repeat rounds for non-deterministic phase (default: {DEFAULT_NON_DET_REPEAT})",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def extract_md5_from_log(log_file: str, request_id: str) -> list[str]:
|
|
"""Extract all decode step MD5 values for the specified request from log file."""
|
|
md5_values = []
|
|
try:
|
|
with open(log_file, "r", encoding="utf-8", errors="ignore") as f:
|
|
pattern = rf"\[DETERMINISM-MD5-REQ\] {re.escape(request_id)} \| decode"
|
|
for line in f:
|
|
if re.search(pattern, line):
|
|
match = re.search(r"hidden_states_md5=([a-f0-9]+)", line)
|
|
if match:
|
|
md5_values.append(match.group(1))
|
|
except FileNotFoundError:
|
|
logger.warning("Log file not found: %s", log_file)
|
|
return md5_values
|
|
|
|
|
|
async def wait_for_server(api_url: str) -> None:
|
|
"""Wait for the server to be ready by polling the API endpoint."""
|
|
base_url = api_url.rsplit("/v1/", 1)[0]
|
|
health_url = f"{base_url}/v1/models"
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
|
|
logger.info("Waiting for server to be ready at %s ...", base_url)
|
|
elapsed = 0
|
|
while elapsed < HEALTH_CHECK_TIMEOUT:
|
|
try:
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.get(health_url) as resp:
|
|
if resp.status == 200:
|
|
logger.info("Server is ready.")
|
|
return
|
|
except (aiohttp.ClientError, asyncio.TimeoutError):
|
|
pass
|
|
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
|
|
elapsed += HEALTH_CHECK_INTERVAL
|
|
logger.info(" Still waiting... (%ds/%ds)", elapsed, HEALTH_CHECK_TIMEOUT)
|
|
|
|
raise RuntimeError(
|
|
f"Server not ready after {HEALTH_CHECK_TIMEOUT}s. "
|
|
f"Check that the server is running and accessible at {base_url}"
|
|
)
|
|
|
|
|
|
async def send_request(
|
|
session: aiohttp.ClientSession, api_url: str, prompt: str, request_id: str, max_tokens: int, model: str
|
|
) -> str:
|
|
"""Send request and return response content."""
|
|
request = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"temperature": 0.8,
|
|
"top_p": 0.9,
|
|
"max_tokens": max_tokens,
|
|
"request_id": request_id,
|
|
}
|
|
timeout = aiohttp.ClientTimeout(total=300)
|
|
async with session.post(api_url, json=request, timeout=timeout) as response:
|
|
response.raise_for_status()
|
|
result = await response.json()
|
|
return result["choices"][0]["message"]["content"]
|
|
|
|
|
|
async def run_test_case(
|
|
session: aiohttp.ClientSession,
|
|
api_url: str,
|
|
test_name: str,
|
|
test_plan: list[tuple[str, str, bool]],
|
|
model: str,
|
|
) -> list[tuple[str, str]]:
|
|
"""
|
|
Run a test case.
|
|
|
|
Args:
|
|
api_url: API endpoint URL.
|
|
test_plan: List of (request_id, prompt, is_target) tuples.
|
|
model: Model name to use for the request.
|
|
|
|
Returns:
|
|
List of (request_id, result) tuples for target requests only.
|
|
"""
|
|
target_count = sum(1 for _, _, t in test_plan if t)
|
|
distractor_count = len(test_plan) - target_count
|
|
logger.info(
|
|
"[Test %s] %d requests (target=%d, distractor=%d)", test_name, len(test_plan), target_count, distractor_count
|
|
)
|
|
|
|
tasks = []
|
|
for req_id, prompt, is_target in test_plan:
|
|
max_tokens = TARGET_MAX_TOKENS if is_target else random.randint(*DISTRACTOR_MAX_TOKENS_RANGE)
|
|
tasks.append(send_request(session, api_url, prompt, req_id, max_tokens, model))
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
target_outputs = []
|
|
for (req_id, _, is_target), result in zip(test_plan, results):
|
|
marker = "[Target]" if is_target else "[Distractor]"
|
|
logger.info(" %s %s: %s...", marker, req_id, result[:50])
|
|
if is_target:
|
|
target_outputs.append((req_id, result))
|
|
|
|
return target_outputs
|
|
|
|
|
|
def _print_section(title: str) -> None:
|
|
"""Print a section banner."""
|
|
print("\n" + "=" * 80)
|
|
print(title)
|
|
print("=" * 80)
|
|
|
|
|
|
def _check_consistency(
|
|
items: dict[str, list[str]],
|
|
label: str,
|
|
expect_consistent: bool,
|
|
detail_formatter=None,
|
|
) -> bool:
|
|
"""
|
|
Unified consistency check logic.
|
|
|
|
Args:
|
|
items: Dict mapping unique_key -> list of request_ids sharing that key.
|
|
label: Description label (e.g. "Text", "MD5 Step 1").
|
|
expect_consistent: True expects all keys identical, False expects differences.
|
|
detail_formatter: Optional callable(key) -> str for displaying details on mismatch.
|
|
|
|
Returns:
|
|
True if result matches expectation, False otherwise.
|
|
"""
|
|
expected_desc = "consistent" if expect_consistent else "inconsistent"
|
|
_print_section(f"{label} Consistency Check (Expected: {expected_desc})")
|
|
|
|
if not items:
|
|
logger.warning("No %s values found!", label)
|
|
return False
|
|
|
|
is_consistent = len(items) == 1
|
|
|
|
print(f"\n Unique values: {len(items)}")
|
|
if is_consistent:
|
|
key = next(iter(items))
|
|
reqs = items[key]
|
|
print(f" All {len(reqs)} requests share the same value")
|
|
else:
|
|
for i, (key, reqs) in enumerate(items.items(), 1):
|
|
detail = f" ({detail_formatter(key)})" if detail_formatter else ""
|
|
print(f" Group {i}: {', '.join(reqs)}{detail}")
|
|
|
|
print("-" * 80)
|
|
|
|
passed = is_consistent == expect_consistent
|
|
actual_desc = "consistent" if is_consistent else "inconsistent"
|
|
status = "PASS" if passed else "FAIL"
|
|
print(f" {status}: expected {expected_desc}, actual {actual_desc}")
|
|
print("=" * 80)
|
|
|
|
return passed
|
|
|
|
|
|
def compare_text_consistency(target_results: list[tuple[str, str]], expect_consistent: bool = True) -> bool:
|
|
"""Compare target request text content against expected consistency."""
|
|
unique_texts: dict[str, list[str]] = {}
|
|
text_map: dict[str, str] = {}
|
|
for req_id, text in target_results:
|
|
text_md5 = hashlib.md5(text.encode("utf-8")).hexdigest()
|
|
unique_texts.setdefault(text_md5, []).append(req_id)
|
|
if text_md5 not in text_map:
|
|
text_map[text_md5] = text
|
|
|
|
return _check_consistency(
|
|
unique_texts,
|
|
label="Text",
|
|
expect_consistent=expect_consistent,
|
|
detail_formatter=lambda key: repr(text_map[key][:50]),
|
|
)
|
|
|
|
|
|
def compare_md5_consistency(all_md5: dict[str, list[str]], expect_consistent: bool = True) -> bool:
|
|
"""
|
|
Compare MD5 results across ALL decode steps and verify against expected consistency.
|
|
|
|
For each decode step, checks that all target requests produced identical hidden_states_md5.
|
|
All steps must be consistent for the overall check to pass.
|
|
"""
|
|
if not all_md5:
|
|
logger.warning("No MD5 values found!")
|
|
return False
|
|
|
|
# Find the minimum number of decode steps across all requests
|
|
min_steps = min(len(md5s) for md5s in all_md5.values())
|
|
if min_steps == 0:
|
|
logger.warning("Some requests have no decode step MD5 values!")
|
|
return False
|
|
|
|
req_ids = list(all_md5.keys())
|
|
logger.info("Checking MD5 consistency across %d decode steps for %d requests", min_steps, len(req_ids))
|
|
|
|
failed_steps = []
|
|
|
|
for step in range(min_steps):
|
|
step_md5s: dict[str, list[str]] = {}
|
|
for req_id in req_ids:
|
|
md5_val = all_md5[req_id][step]
|
|
step_md5s.setdefault(md5_val, []).append(req_id)
|
|
|
|
step_consistent = len(step_md5s) == 1
|
|
|
|
if not step_consistent:
|
|
failed_steps.append(step)
|
|
|
|
# Print per-step result
|
|
if step_consistent:
|
|
md5_val = next(iter(step_md5s))
|
|
logger.info(" Decode step %d: CONSISTENT (md5=%s)", step + 1, md5_val)
|
|
else:
|
|
logger.warning(" Decode step %d: INCONSISTENT (%d different values)", step + 1, len(step_md5s))
|
|
for md5_val, reqs in step_md5s.items():
|
|
logger.warning(" md5=%s: %s", md5_val, ", ".join(reqs))
|
|
|
|
is_consistent = len(failed_steps) == 0
|
|
|
|
_print_section(f"MD5 Consistency Check (all {min_steps} decode steps)")
|
|
if is_consistent:
|
|
print(f" All {min_steps} decode steps are consistent across {len(req_ids)} requests")
|
|
else:
|
|
print(f" {len(failed_steps)}/{min_steps} decode steps are INCONSISTENT")
|
|
print(f" Failed steps: {[s + 1 for s in failed_steps]}")
|
|
print("-" * 80)
|
|
|
|
passed = is_consistent == expect_consistent
|
|
expected_desc = "consistent" if expect_consistent else "inconsistent"
|
|
actual_desc = "consistent" if is_consistent else "inconsistent"
|
|
status = "PASS" if passed else "FAIL"
|
|
print(f" {status}: expected {expected_desc}, actual {actual_desc}")
|
|
print("=" * 80)
|
|
|
|
return passed
|
|
|
|
|
|
# Test cases: (name, plan) where plan is [(request_id, prompt, is_target)]
|
|
TEST_CASES = [
|
|
(
|
|
"case1: Single request (target only)",
|
|
[
|
|
("case1-target", TARGET_PROMPT, True),
|
|
],
|
|
),
|
|
(
|
|
"case2: Two requests (1 target + 1 distractor)",
|
|
[
|
|
("case2-distract-a", DISTRACTOR_PROMPTS[0], False),
|
|
("case2-target", TARGET_PROMPT, True),
|
|
],
|
|
),
|
|
(
|
|
"case3: Four requests (1 target + 3 distractors)",
|
|
[
|
|
("case3-distract-a", DISTRACTOR_PROMPTS[0], False),
|
|
("case3-distract-b", DISTRACTOR_PROMPTS[1], False),
|
|
("case3-target", TARGET_PROMPT, True),
|
|
("case3-distract-c", DISTRACTOR_PROMPTS[2], False),
|
|
],
|
|
),
|
|
(
|
|
"case4: Six requests (1 target + 5 distractors)",
|
|
[
|
|
("case4-distract-a", DISTRACTOR_PROMPTS[0], False),
|
|
("case4-distract-b", DISTRACTOR_PROMPTS[1], False),
|
|
("case4-distract-c", DISTRACTOR_PROMPTS[2], False),
|
|
("case4-distract-d", DISTRACTOR_PROMPTS[3], False),
|
|
("case4-target", TARGET_PROMPT, True),
|
|
("case4-distract-e", DISTRACTOR_PROMPTS[4], False),
|
|
],
|
|
),
|
|
]
|
|
|
|
|
|
def _build_test_plan(test_cases, repeat: int = 1):
|
|
"""
|
|
Build test plan with optional repetition.
|
|
|
|
For repeat > 1, each test case is duplicated with round-suffixed request_ids.
|
|
This increases sample size for non-deterministic testing.
|
|
"""
|
|
if repeat <= 1:
|
|
return test_cases
|
|
|
|
expanded = []
|
|
for case_name, plan in test_cases:
|
|
for r in range(repeat):
|
|
round_name = f"{case_name} (round {r + 1})"
|
|
round_plan = [(f"{req_id}-r{r + 1}", prompt, is_target) for req_id, prompt, is_target in plan]
|
|
expanded.append((round_name, round_plan))
|
|
return expanded
|
|
|
|
|
|
async def main() -> int:
|
|
args = parse_args()
|
|
is_deterministic = args.phase == "deterministic"
|
|
|
|
_print_section("Determinism Feature Verification Test")
|
|
print(f"\n Test mode: {args.phase}")
|
|
print(f" API URL: {args.api_url}")
|
|
print(f" Model: {args.model}")
|
|
print(f" Log file: {args.log_file}")
|
|
if is_deterministic:
|
|
print(" Expected: All target requests have consistent MD5 values")
|
|
else:
|
|
print(f" Expected: Target requests produce different outputs (repeat={args.repeat})")
|
|
print("=" * 80)
|
|
|
|
# Wait for server to be ready
|
|
await wait_for_server(args.api_url)
|
|
|
|
# Build test plan (repeat for non-deterministic to reduce flaky probability)
|
|
repeat = args.repeat if not is_deterministic else 1
|
|
test_plan = _build_test_plan(TEST_CASES, repeat=repeat)
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
all_target_results: list[tuple[str, str]] = []
|
|
for test_name, plan in test_plan:
|
|
target_outputs = await run_test_case(session, args.api_url, test_name, plan, args.model)
|
|
all_target_results.extend(target_outputs)
|
|
await asyncio.sleep(1)
|
|
|
|
target_request_ids = [req_id for req_id, _ in all_target_results]
|
|
|
|
_print_section("All tests completed, starting verification...")
|
|
|
|
if is_deterministic:
|
|
# Deterministic mode: compare MD5 across all decode steps
|
|
all_md5 = {}
|
|
for req_id in target_request_ids:
|
|
md5_values = extract_md5_from_log(args.log_file, req_id)
|
|
if md5_values:
|
|
all_md5[req_id] = md5_values
|
|
logger.info("%s: %d decode steps found", req_id, len(md5_values))
|
|
else:
|
|
logger.warning("%s: No MD5 logs found", req_id)
|
|
|
|
if all_md5:
|
|
passed = compare_md5_consistency(all_md5, expect_consistent=True)
|
|
else:
|
|
logger.warning("No MD5 logs found, fallback to text consistency check")
|
|
passed = compare_text_consistency(all_target_results, expect_consistent=True)
|
|
else:
|
|
# Non-deterministic mode: compare text content
|
|
passed = compare_text_consistency(all_target_results, expect_consistent=False)
|
|
|
|
_print_section("Final Result")
|
|
if passed:
|
|
print(f" PASS: {args.phase} mode verified successfully")
|
|
else:
|
|
print(f" FAIL: {args.phase} mode verification failed")
|
|
print("=" * 80)
|
|
|
|
return 0 if passed else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(asyncio.run(main()))
|