From edd31e8849700cb2734f9dec5ba5ab9d69de5866 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Fri, 27 Feb 2026 11:31:51 +0800 Subject: [PATCH] [Feature] Add Deterministic Inference Support (#6476) * add * [tests] Add Paddle attention determinism tests and refactor resource manager Add comprehensive determinism tests for Paddle attention layer and refactor resource manager for deterministic mode support. Co-Authored-By: Claude Opus 4.6 * add * add * add * add * add more * add more * fixsome * fixsome * fix bugs * fix bugs * only in gpu * add docs * fix comments * fix some * fix some * fix comments * add more * fix potential problem * remove not need * remove not need * remove no need * fix bug * fix bugs * fix comments * fix comments * Update tests/ce/deterministic/test_determinism_verification.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/inter_communicator/test_ipc_signal.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/layers/test_paddle_attention_determinism_standalone.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix comments * fix import error * fix a bug * fix bugs * fix bugs * fix coverage * refine codes * refine code * fix comments * fix comments * fix comments * rm not need * fix allreduce large tensor bug * mv log files * mv log files * add files --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../gpu_ops/custom_all_reduce/all_reduce.cu | 4 +- fastdeploy/distributed/communication.py | 98 +++- .../custom_all_reduce/custom_all_reduce.py | 15 +- fastdeploy/engine/sampling_params.py | 6 +- .../engine/sched/resource_manager_v1.py | 59 ++- fastdeploy/envs.py | 20 + fastdeploy/logger/deterministic_logger.py | 220 ++++++++ .../batch_invariant_ops.py | 28 +- fastdeploy/worker/gpu_model_runner.py | 46 ++ fastdeploy/worker/input_batch.py | 5 + fastdeploy/worker/worker_process.py | 15 + .../test_batch_invariance_op_addmm.py | 20 + tests/ce/deterministic/start_fd.sh | 29 ++ .../test_determinism_verification.py | 470 +++++++++++++++++ tests/conftest.py | 32 +- .../deterministic/test_determinism_offline.py | 357 +++++++++++++ .../test_determinism_standalone.py | 296 +++++++++++ tests/distributed/allreduce_deterministic.py | 212 ++++++++ .../test_allreduce_deterministic_launch.py | 36 ++ .../test_sampling_params_determinism.py | 84 ++++ tests/inter_communicator/test_ipc_signal.py | 345 +++++++++++++ ...st_flash_attention_versions_determinism.py | 177 +++++++ .../test_chunked_prefill_determinism.py | 472 ++++++++++++++++++ tests/worker/test_deterministic_logger.py | 345 +++++++++++++ 24 files changed, 3364 insertions(+), 27 deletions(-) create mode 100644 fastdeploy/logger/deterministic_logger.py create mode 100644 tests/ce/deterministic/start_fd.sh create mode 100644 tests/ce/deterministic/test_determinism_verification.py create mode 100644 tests/deterministic/test_determinism_offline.py create mode 100644 tests/deterministic/test_determinism_standalone.py create mode 100644 tests/distributed/allreduce_deterministic.py create mode 100644 tests/distributed/test_allreduce_deterministic_launch.py create mode 100644 tests/engine/test_sampling_params_determinism.py create mode 100644 tests/inter_communicator/test_ipc_signal.py create mode 100644 tests/layers/test_flash_attention_versions_determinism.py create mode 100644 tests/scheduler/test_chunked_prefill_determinism.py create mode 100644 tests/worker/test_deterministic_logger.py diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu index ac7006d2c6..ca659c0074 100644 --- a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu @@ -58,7 +58,7 @@ void decode_alltoall_transpose(paddle::Tensor& inp, auto fa = reinterpret_cast(_fa); auto stream = inp.stream(); - auto input_size = inp.numel() * 2; + auto input_size = inp.numel() * phi::SizeOf(inp.dtype()); auto token_num = inp.shape()[0]; auto hidden_size = inp.shape()[1]; auto reg_buffer = reinterpret_cast(_reg_buffer); @@ -121,7 +121,7 @@ void all_reduce(paddle::Tensor& inp, auto fa = reinterpret_cast(_fa); auto stream = inp.stream(); - auto input_size = inp.numel() * 2; + auto input_size = inp.numel() * phi::SizeOf(inp.dtype()); auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { cudaMemcpyAsync( diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index d5b72eebda..7a3f0b8aea 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -20,8 +20,24 @@ import paddle import paddle.distributed as dist from paddle.distributed import fleet +import fastdeploy.envs as envs from fastdeploy.utils import register_custom_python_op +# Constants +SUPPORTED_DTYPES = (paddle.float32, paddle.float16, paddle.bfloat16) + + +def tensor_byte_size(tensor: paddle.Tensor) -> int: + """Compute tensor size in bytes from .shape to avoid numel() which + triggers cudaErrorStreamCaptureImplicit during CUDA Graph capture.""" + size = 1 + for s in tensor.shape: + size *= s + size *= tensor.element_size() + return size + + +# Global custom all-reduce instance _TP_AR = None @@ -36,8 +52,11 @@ def capture_custom_allreduce(): def use_custom_allreduce( - tp_group: paddle.distributed.communication.group.Group = None, custom_all_reduce_max_bytes: int = 8192 * 1024 -): + tp_group: paddle.distributed.communication.group.Group = None, + custom_all_reduce_max_bytes: int = None, +) -> None: + if custom_all_reduce_max_bytes is None: + custom_all_reduce_max_bytes = envs.FD_CUSTOM_AR_MAX_SIZE_MB * 1024 * 1024 if tp_group is None: hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() @@ -53,17 +72,71 @@ def custom_ar_clear_ipc_handles(): _TP_AR.clear_ipc_handles() +def _ensure_deterministic_ready(input_: paddle.Tensor) -> None: + """Validate all preconditions for deterministic all-reduce.""" + global _TP_AR + # Lazy initialization of custom all-reduce + if _TP_AR is None: + try: + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + if tp_group is not None and tp_group.nranks > 1: + use_custom_allreduce(tp_group) + except Exception as e: + raise RuntimeError( + "DETERMINISTIC_MODE is enabled but cannot auto-initialize custom all-reduce. " + "TP all-reduce would use NCCL which may produce non-deterministic results " + "due to floating-point accumulation order. " + "Ensure fleet is initialized before any TP operations, " + "or explicitly call use_custom_allreduce() beforehand." + ) from e + + if _TP_AR is None: + raise RuntimeError( + "DETERMINISTIC_MODE is enabled but custom all-reduce is not available. " + "Falling back to NCCL would produce non-deterministic results. " + "Ensure custom all-reduce is properly initialized via use_custom_allreduce()." + ) + + if input_.dtype not in SUPPORTED_DTYPES: + raise AssertionError( + f"DETERMINISTIC_MODE is enabled but input tensor dtype={input_.dtype} is not supported. " + f"Custom all-reduce only supports: {', '.join(str(d) for d in SUPPORTED_DTYPES)}. " + f"Input tensor shape: {input_.shape}, dtype: {input_.dtype}." + ) + + # Compute size from .shape to avoid numel() which triggers + # cudaErrorStreamCaptureImplicit during CUDA Graph capture + inp_size = tensor_byte_size(input_) + + if inp_size % 16 != 0: + raise RuntimeError( + f"DETERMINISTIC_MODE is enabled but input tensor size ({inp_size} bytes) " + f"is not a multiple of 16. Custom all-reduce requires 16-byte aligned tensors. " + f"Input tensor shape: {input_.shape}, element_size: {input_.element_size()} bytes, " + f"total size: {inp_size} bytes." + ) + + if inp_size > _TP_AR.max_size: + raise RuntimeError( + f"DETERMINISTIC_MODE: input tensor ({inp_size} bytes) exceeds " + f"custom all-reduce max_size ({_TP_AR.max_size} bytes). " + f"Increase buffer size via: export FD_CUSTOM_AR_MAX_SIZE_MB=" + f"{(inp_size // (1024 * 1024)) + 1}" + ) + + try: - def tensor_model_parallel_all_reduce_infer_meta(x: "paddle.static.MetaTensor", group_) -> paddle.static.MetaTensor: + def tensor_model_parallel_all_reduce_infer_meta( + x: "paddle.static.MetaTensor", group_: paddle.distributed.communication.group.Group + ) -> paddle.static.MetaTensor: return paddle.static.MetaTensor(shape=x.shape, dtype=x.dtype) @register_custom_python_op( name="tensor_model_parallel_all_reduce", infer_meta=tensor_model_parallel_all_reduce_infer_meta, - input_names=[ - "input_", - ], + input_names=["input_"], output_names=["out"], inplace_map={}, ) @@ -72,13 +145,20 @@ try: group_: paddle.distributed.communication.group.Group = None, ) -> paddle.Tensor: """All-reduce the input tensor across model parallel group.""" + global _TP_AR if input_.shape[0] == 0: return input_ - global _TP_AR + + if envs.FD_DETERMINISTIC_MODE: + _ensure_deterministic_ready(input_) + return _TP_AR.custom_all_reduce(input_) + + # for performance, use custom all-reduce if possible if _TP_AR is not None and _TP_AR.should_custom_ar(input_): # TODO: supports different_group custom allreduce - input_ = _TP_AR.custom_all_reduce(input_) - elif paddle.in_dynamic_mode(): + return _TP_AR.custom_all_reduce(input_) + + if paddle.in_dynamic_mode(): if group_ is not None: dist.all_reduce(input_, group=group_) else: diff --git a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py index 4c081271a1..9c8b65f43a 100644 --- a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py +++ b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py @@ -22,6 +22,7 @@ import paddle import paddle.distributed as dist from paddle.distributed.communication.group import Group +from fastdeploy.distributed.communication import tensor_byte_size from fastdeploy.distributed.custom_all_reduce import cuda_wrapper from fastdeploy.model_executor.ops.gpu import ( all_reduce, @@ -133,16 +134,22 @@ class CustomAllreduce: lib.cudaFree(ctypes.c_void_p(pointers[rank])) def should_custom_ar(self, inp: paddle.Tensor): - if self.capturing: - return True - inp_size = inp.shape[0] * inp.shape[1] * inp.element_size() + inp_size = tensor_byte_size(inp) + if inp_size > self.max_size: + return False + # custom allreduce requires input byte size to be multiples of 16 if inp_size % 16 != 0: return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. if self.world_size == 2 or self.full_nvlink: - return inp_size < self.max_size + return True + + if self.capturing: + return True + return False def all_reduce( diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index ad5031fb34..7fcfef642a 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -276,7 +276,11 @@ class SamplingParams: def __post_init__(self): if self.seed is None: - self.seed = random.randint(0, 922337203685477580) + # Deterministic mode: use fixed seed + if envs.FD_DETERMINISTIC_MODE: + self.seed = 42 + else: + self.seed = random.randint(0, 922337203685477580) self._verify_args() def _verify_args(self) -> None: diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index ae98fac5fb..5410f98cb7 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -453,7 +453,42 @@ class ResourceManagerV1(ResourceManager): def _get_num_new_tokens(self, request, token_budget): # TODO: set condition to new _get_num_new_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens + assert num_new_tokens > 0, ( + f"Request {request.request_id} has no remaining tokens: " + f"need_prefill={request.need_prefill_tokens}, computed={request.num_computed_tokens}" + ) num_new_tokens = min(num_new_tokens, token_budget) + + # Deterministic mode: align chunk boundaries to split_kv_size + # This ensures batch-invariant attention by making each chunk + # a multiple of the split-KV block size (default 16) + if envs.FD_DETERMINISTIC_MODE: + split_kv_size = envs.FD_DETERMINISTIC_SPLIT_KV_SIZE + current_pos = request.num_computed_tokens + remaining_tokens = request.need_prefill_tokens - current_pos + + # Case 1: Final chunk - no alignment needed + if remaining_tokens < split_kv_size: + aligned_end = current_pos + remaining_tokens + else: + # Case 2: Need to align to split_kv_size boundary + # Calculate next boundary position + next_boundary = ((current_pos + split_kv_size - 1) // split_kv_size) * split_kv_size + tokens_to_boundary = next_boundary - current_pos + + # Not enough budget to reach the next boundary: defer to next iteration + if token_budget < tokens_to_boundary: + return 0 + + # Align to as many full boundaries as budget allows + aligned_end = ((current_pos + token_budget) // split_kv_size) * split_kv_size + + num_new_tokens = aligned_end - current_pos + # Don't exceed the original budget or remaining tokens + num_new_tokens = min( + num_new_tokens, token_budget, request.need_prefill_tokens - request.num_computed_tokens + ) + if ( current_platform.is_intel_hpu() and request.need_prefill_tokens - request.num_computed_tokens > token_budget @@ -466,7 +501,11 @@ class ResourceManagerV1(ResourceManager): return num_new_tokens inputs = request.multimodal_inputs - if inputs.get("patch_idx", None) is not None and inputs.get("patch_map", None) is not None: + if ( + inputs is not None + and inputs.get("patch_idx", None) is not None + and inputs.get("patch_map", None) is not None + ): pre_end_idx = request.num_computed_tokens new_end_idx = pre_end_idx + num_new_tokens @@ -541,7 +580,8 @@ class ResourceManagerV1(ResourceManager): request.video_end = end_patch_map["video_num"] request.audio_end = _compute_audio_prefix_count(new_end_idx, end_patch_idx) elif ( - inputs.get("images", None) is not None + inputs is not None + and inputs.get("images", None) is not None and inputs.get("image_patch_id", None) is not None and inputs.get("grid_thw", None) is not None ): @@ -790,6 +830,9 @@ class ResourceManagerV1(ResourceManager): req_index += 1 continue num_new_tokens = self._get_num_new_tokens(request, token_budget) + if num_new_tokens == 0: + req_index += 1 + continue num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block): @@ -863,6 +906,12 @@ class ResourceManagerV1(ResourceManager): continue # Allocate blocks for the tokens that does not hit cache num_new_tokens = self._get_num_new_tokens(request, token_budget) + if num_new_tokens == 0: + if self.config.cache_config.enable_prefix_caching: + self._free_blocks(request) + skip_requests.append(request) + self.waiting.popleft() + continue num_new_block = self.get_new_block_nums(request, num_new_tokens) can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( request, num_new_block @@ -916,6 +965,12 @@ class ResourceManagerV1(ResourceManager): # Allocate blocks for the tokens that does not hit cache num_new_tokens = self._get_num_new_tokens(request, token_budget) + if num_new_tokens == 0: + if self.config.cache_config.enable_prefix_caching: + self._free_blocks(request) + skip_requests.append(request) + self.waiting.popleft() + continue num_new_block = self.get_new_block_nums(request, num_new_tokens) can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( request, num_new_block diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index cdbdeb5225..c6b2bffa38 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -18,6 +18,14 @@ Environment variables used by FastDeploy. import os from typing import Any, Callable + +def _validate_split_kv_size(value: int) -> int: + """Validate FD_DETERMINISTIC_SPLIT_KV_SIZE is a positive power of 2.""" + if value <= 0 or (value & (value - 1)) != 0: + raise ValueError(f"FD_DETERMINISTIC_SPLIT_KV_SIZE must be a positive power of 2, got {value}.") + return value + + environment_variables: dict[str, Callable[[], Any]] = { # Whether to use BF16 on CPU. "FD_CPU_USE_BF16": lambda: os.getenv("FD_CPU_USE_BF16", "False"), @@ -206,6 +214,18 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), # File path for file storage backend "FILE_BACKEND_STORAGE_DIR": lambda: str(os.getenv("FILE_BACKEND_STORAGE_DIR", "/tmp/fastdeploy")), + # Custom all-reduce max buffer size in MB (default 8MB). + # Increase this to avoid NCCL fallback for large tensors in deterministic mode. + # E.g. FD_CUSTOM_AR_MAX_SIZE_MB=128 for 128MB. + "FD_CUSTOM_AR_MAX_SIZE_MB": lambda: int(os.getenv("FD_CUSTOM_AR_MAX_SIZE_MB", "8")), + # Enable deterministic inference mode for chunked prefill alignment + "FD_DETERMINISTIC_MODE": lambda: bool(int(os.getenv("FD_DETERMINISTIC_MODE", "0"))), + # Split KV block size for deterministic alignment (must be power of 2 and > 0, default 16) + "FD_DETERMINISTIC_SPLIT_KV_SIZE": lambda: _validate_split_kv_size( + int(os.getenv("FD_DETERMINISTIC_SPLIT_KV_SIZE", "16")) + ), + # Enable determinism logging (print MD5 hashes and debug info) + "FD_DETERMINISTIC_LOG_MODE": lambda: bool(int(os.getenv("FD_DETERMINISTIC_LOG_MODE", "0"))), # Whether to use PD REORDER, can set 0 or 1 "FD_PD_REORDER": lambda: int(os.getenv("FD_PD_REORDER", "0")), } diff --git a/fastdeploy/logger/deterministic_logger.py b/fastdeploy/logger/deterministic_logger.py new file mode 100644 index 0000000000..cedcdd71af --- /dev/null +++ b/fastdeploy/logger/deterministic_logger.py @@ -0,0 +1,220 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import hashlib +import logging +import time + +import numpy as np + +det_logger = logging.getLogger("fastdeploy.deterministic") + + +class DeterministicLogger: + """Helper for logging tensor MD5 hashes and input details to assist determinism debugging.""" + + def __init__(self, share_inputs): + self.share_inputs = share_inputs + self._current_run_id = None + self._batch_counter = 0 + + def log_batch_start(self, model_forward_batch): + """Log batch start with run_id tracking and batch counting.""" + current_run_id = None + for req in model_forward_batch or []: + if req is not None: + parts = req.request_id.split("_") + if len(parts) > 1: + current_run_id = parts[-1] + break + if current_run_id is not None and current_run_id != self._current_run_id: + self._current_run_id = current_run_id + self._batch_counter = 0 + + self._batch_counter += 1 + + det_logger.info(f"\n{'='*80}") + det_logger.info(f"[BATCH-START] Run_{self._current_run_id} Batch_{self._batch_counter}") + det_logger.info(f"{'='*80}\n") + + @staticmethod + def _compute_tensor_md5(tensor, name="tensor", prefix=""): + """Compute MD5 hash of tensor for comparison""" + if tensor is None: + return f"{name}_md5=None" + + # Copy tensor to CPU and convert to numpy array + try: + tensor_cpu = tensor.cpu().numpy().tobytes() + except Exception: + # For data types that don't support direct tobytes (e.g., float16), convert first + tensor_cpu = tensor.cpu().numpy().astype(np.float32).tobytes() + + md5_hash = hashlib.md5(tensor_cpu).hexdigest() + return f"{prefix}{name}_md5={md5_hash[:16]}" # Print only first 16 chars to reduce log length + + def log_tensor_md5s(self, tensor_dict, forward_batch_reqs_list=None, stage="forward"): + """Log MD5 hash values for multiple tensors, including per-request MD5 + + Args: + tensor_dict: {name: tensor} dictionary + forward_batch_reqs_list: forward_batch_reqs_list list (may contain None) + stage: Stage identifier (e.g., "prefill", "decode", "forward") + """ + # Get batch size from first valid tensor + batch_size = self._get_batch_size(tensor_dict) + if batch_size is None: + return + + # Get prefill/decode counts + prefill_count, decode_count, seq_lens_encoder = self._get_stage_counts(batch_size) + + # Build stage information + stage_info = stage + if prefill_count > 0 or decode_count > 0: + stage_info += f" (prefill={prefill_count}, decode={decode_count})" + + # Compute and log batch MD5 + batch_md5_info = [ + self._compute_tensor_md5(tensor, name, prefix="batch_") + for name, tensor in tensor_dict.items() + if tensor is not None + ] + + # Log overall batch MD5 + req_id_str = self._build_req_id_str(forward_batch_reqs_list) + det_logger.info( + f"[DETERMINISM-MD5] stage={stage_info} | batch_size={batch_size} | " + + (f"requests: {req_id_str} | " if req_id_str else "") + + " | ".join(batch_md5_info) + ) + + # Log per-request MD5 for decode requests + self._log_per_request_md5s( + tensor_dict, forward_batch_reqs_list, batch_size, prefill_count, decode_count, seq_lens_encoder + ) + + @staticmethod + def _get_batch_size(tensor_dict): + """Get batch size from first tensor with a shape.""" + for name, tensor in tensor_dict.items(): + if tensor is not None and hasattr(tensor, "shape"): + return tensor.shape[0] + return None + + def _get_stage_counts(self, batch_size): + """Get prefill/decode counts and seq_lens_encoder.""" + prefill_count = 0 + decode_count = 0 + seq_lens_encoder = None + + if self.share_inputs is not None and "seq_lens_encoder" in self.share_inputs: + seq_lens_encoder = self.share_inputs["seq_lens_encoder"].cpu().numpy() + prefill_count = int((seq_lens_encoder > 0).sum()) + decode_count = int(batch_size - prefill_count) + + return prefill_count, decode_count, seq_lens_encoder + + @staticmethod + def _build_req_id_str(forward_batch_reqs_list): + """Build request ID string from forward_batch_reqs_list.""" + if forward_batch_reqs_list is None: + return "" + req_info = [f"[{i}]{req.request_id}" for i, req in enumerate(forward_batch_reqs_list) if req is not None] + return ", ".join(req_info) + + def _log_per_request_md5s( + self, tensor_dict, forward_batch_reqs_list, batch_size, prefill_count, decode_count, seq_lens_encoder + ): + """Log per-request MD5 for decode requests. + + In decode phase, tensor shape is [batch_size, hidden_dim] or [batch_size, vocab_size]. + Can split by batch dimension directly. + """ + if decode_count == 0 or forward_batch_reqs_list is None: + return + + for i, req in enumerate(forward_batch_reqs_list): + if req is None or i >= batch_size: + continue + + # Check if this is a decode request + if seq_lens_encoder is not None: + if i >= len(seq_lens_encoder) or int(seq_lens_encoder[i]) != 0: + continue # Skip prefill requests + elif prefill_count > 0: + continue # Mixed batch without seq_lens_encoder, skip all + + req_id = req.request_id + req_md5_info = [ + self._compute_tensor_md5(tensor[i : i + 1], name) + for name, tensor in tensor_dict.items() + if tensor is not None and hasattr(tensor, "shape") and len(tensor.shape) >= 2 + ] + + if req_md5_info: + det_logger.info(f"[DETERMINISM-MD5-REQ] {req_id} | decode | " + " | ".join(req_md5_info)) + + def log_prefill_input(self, request_id, idx, prefill_start_index, prefill_end_index, input_ids): + """Log prefill input details for determinism verification.""" + det_logger.info( + f"[DETERMINISM] Prefill input - request_id: {request_id}, " + f"idx: {idx}, prefill_start_index: {prefill_start_index}, " + f"prefill_end_index: {prefill_end_index}, " + f"input_ids: {input_ids}" + ) + + def log_deterministic_input(self, forward_meta): + """Log determinism inference input information, supports multiple batch requests""" + ids = forward_meta.ids_remove_padding + req_ids = self.share_inputs.get("req_ids", None) + seq_lens_this_time = self.share_inputs.get("seq_lens_this_time", None) + seq_lens_encoder = self.share_inputs.get("seq_lens_encoder", None) + seq_lens_decoder = self.share_inputs.get("seq_lens_decoder", None) + + # Get batch size + num_requests = len(seq_lens_this_time) if seq_lens_this_time is not None else 0 + + det_logger.info(f"[DETERMINISM-INPUT] time={time.time():.6f} | batch_size={num_requests}") + + if num_requests == 0 or ids is None: + det_logger.info("[DETERMINISM-INPUT] No input data") + return + + # Split ids for each request + ids_list = ids.cpu().numpy().tolist() + offset = 0 + + for i in range(num_requests): + # Get current request information + req_id = req_ids[i] if req_ids is not None and i < len(req_ids) else f"idx_{i}" + seq_len = int(seq_lens_this_time[i]) + seq_len_enc = int(seq_lens_encoder[i]) if seq_lens_encoder is not None and i < len(seq_lens_encoder) else 0 + seq_len_dec = int(seq_lens_decoder[i]) if seq_lens_decoder is not None and i < len(seq_lens_decoder) else 0 + + # Get current request's tokens + if seq_len > 0: + request_tokens = ids_list[offset : offset + seq_len] + else: + request_tokens = [] + + offset += seq_len + + # Print one line log + det_logger.info( + f"[DETERMINISM-INPUT] req_id={req_id} | tokens={request_tokens} | " + f"len={seq_len} | seq_len_enc={seq_len_enc} | seq_len_dec={seq_len_dec}" + ) diff --git a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py index 309d92f764..426985d396 100644 --- a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py +++ b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py @@ -6,6 +6,10 @@ from collections import namedtuple from collections.abc import Callable from typing import Any, Dict +from fastdeploy.utils import get_logger + +logger = get_logger("worker_process", "worker_process.log") + import paddle import triton import triton.language as tl @@ -137,13 +141,13 @@ def get_compute_units(): device_properties = paddle.cuda.get_device_properties(0) NUM_SMS = device_properties.multi_processor_count except Exception: - print("Could not get CUDA device properties. Falling back to CPU threads.") + logger.warning("Could not get CUDA device properties. Falling back to CPU threads.") # TODO(liujundong): Paddle lacks a torch.get_num_threads() equivalent for the *configured* thread count. # Using os.cpu_count() (total logical cores) as a fallback, which may not be correct. # Must check downstream logic to determine if this impacts correctness. NUM_SMS = os.cpu_count() else: - print("No CUDA device available. Using CPU.") + logger.warning("No CUDA device available. Using CPU.") # For CPU, use the number of CPU cores NUM_SMS = os.cpu_count() @@ -153,7 +157,7 @@ def get_compute_units(): def matmul_persistent(a: paddle.Tensor, b: paddle.Tensor, bias: paddle.Tensor | None = None): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.dtype == b.dtype, "Incompatible dtypes" + assert a.dtype == b.dtype, f"Incompatible dtypes: a={a.dtype}, b={b.dtype}" assert bias is None or bias.dim() == 1, "Currently assuming bias is 1D, let Horace know if you run into this" NUM_SMS = get_compute_units() @@ -210,9 +214,11 @@ def matmul_persistent(a: paddle.Tensor, b: paddle.Tensor, bias: paddle.Tensor | c.stride(0), c.stride(1), # NUM_SMS=NUM_SMS, # - A_LARGE=int(a.numel() > 2**31), - B_LARGE=int(b.numel() > 2**31), - C_LARGE=int(c.numel() > 2**31), + # Use M*K, K*N, M*N instead of numel() to avoid cudaErrorStreamCaptureImplicit + # during CUDA Graph capture + A_LARGE=int(M * K > 2**31), + B_LARGE=int(K * N > 2**31), + C_LARGE=int(M * N > 2**31), HAS_BIAS=int(bias is not None), # The Triton compiler (when used with Paddle) cannot handle these variables as booleans. Explicitly cast to int so the compiler can process them. **configs[dtype], @@ -477,6 +483,8 @@ def addmm_batch_invariant( So we use `alpha * (x @ y) + beta * input = alpha * [ (x @ y) + (beta / alpha) * input ]` to minimize the effection on performance """ + if alpha == 0: + return paddle.broadcast_to(beta * input, [x.shape[0], y.shape[1]]) matmul_result = matmul_persistent(a=x, b=y, bias=input * beta / alpha) result = alpha * matmul_result return result @@ -490,7 +498,13 @@ def mean_batch_invariant( x: paddle.Tensor, axis: list[int] = [], keepdim: bool = False, dtype: paddle.dtype | None = None, out=None ) -> paddle.Tensor: assert dtype is None or dtype == paddle.float32, f"unsupported dtype: {dtype}" - if type(axis) is int: + if axis is None: # Global mean (no axis specified) + # Avoid x.numel() to prevent cudaErrorStreamCaptureImplicit during CUDA Graph capture + n_elems = 1 + for s in x.shape: + n_elems *= s + result = paddle.sum(x, keepdim=keepdim, dtype=paddle.float32) / n_elems + elif type(axis) is int: result = mean_dim(x, axis, keepdim=keepdim) elif len(axis) == 1: # axis: int | Sequence[int] result = mean_dim(x, axis[0], keepdim=keepdim) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a635f77840..dce35516f4 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -99,6 +99,7 @@ from fastdeploy import envs from fastdeploy.engine.tasks import PoolingTask from fastdeploy.input.ernie4_5_vl_processor import DataProcessor from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient +from fastdeploy.logger.deterministic_logger import DeterministicLogger from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp @@ -211,6 +212,13 @@ class GPUModelRunner(ModelRunnerBase): self.restore_chunked_prefill_request = dict() + # Initialize deterministic logger (only when deterministic debugging is enabled) + self.deterministic_logger = ( + DeterministicLogger(self.share_inputs) + if envs.FD_DETERMINISTIC_MODE and envs.FD_DETERMINISTIC_LOG_MODE + else None + ) + # Initialize attention Backend # NOTE(gonshaotian): Currently, all attention layers share one attention backend instance. # In the future, we will expand it as a list. @@ -262,6 +270,7 @@ class GPUModelRunner(ModelRunnerBase): self.last_sampler_output = None self.last_post_process_event = None self.last_token_num = -1 + self.enable_overlap_schedule = fd_config.scheduler_config.enable_overlap_schedule and ( not self.speculative_decoding ) @@ -777,6 +786,14 @@ class GPUModelRunner(ModelRunnerBase): prompt_token_ids = request.prompt_token_ids input_ids = prompt_token_ids + request.output_token_ids prompt_len = len(prompt_token_ids) + + # Log complete input_ids for input determinism verification + # Note: Only current request info is logged here; batch info is logged during forward + if self.deterministic_logger is not None: + self.deterministic_logger.log_prefill_input( + request.request_id, idx, prefill_start_index, prefill_end_index, input_ids + ) + self.share_inputs["prompt_ids"][idx : idx + 1, :prompt_len] = np.array(prompt_token_ids, dtype="int64") logger.debug( f"Handle prefill request {request} at idx {idx}, " @@ -1653,6 +1670,10 @@ class GPUModelRunner(ModelRunnerBase): encoder_block_shape_q = 64 decoder_block_shape_q = 16 + # Deterministic mode: use deterministic_split_kv_size to ensure batch-invariant attention + if envs.FD_DETERMINISTIC_MODE: + decoder_block_shape_q = envs.FD_DETERMINISTIC_SPLIT_KV_SIZE + res_buffer = allocate_launch_related_buffer( max_batch_size=self.scheduler_config.max_num_seqs, max_model_len=self.model_config.max_model_len, @@ -2299,6 +2320,9 @@ class GPUModelRunner(ModelRunnerBase): num_running_requests: int = None, last_token_num: int = -1, ) -> None: + if self.deterministic_logger is not None: + self.deterministic_logger.log_batch_start(model_forward_batch) + # 1. Prepare inputs of model and sampler. p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests) @@ -2423,8 +2447,22 @@ class GPUModelRunner(ModelRunnerBase): ) # 4. Compute logits, Sample + if self.deterministic_logger is not None: + # Log MD5 of hidden_states (model output) + self.deterministic_logger.log_tensor_md5s( + {"hidden_states": hidden_states}, + forward_batch_reqs_list=self.forward_batch_reqs_list, + stage="hidden_states", + ) + logits = self.model.compute_logits(hidden_states) + if self.deterministic_logger is not None: + # Log MD5 of logits (before sampling) + self.deterministic_logger.log_tensor_md5s( + {"logits": logits}, forward_batch_reqs_list=self.forward_batch_reqs_list, stage="logits" + ) + if not self.speculative_decoding: set_value_by_flags_and_idx( self.share_inputs["pre_ids"], @@ -2441,6 +2479,14 @@ class GPUModelRunner(ModelRunnerBase): p_done_idxs, ) + if self.deterministic_logger is not None: + # Log MD5 of sampling results + self.deterministic_logger.log_tensor_md5s( + {"sampled_token_ids": sampler_output.sampled_token_ids}, + forward_batch_reqs_list=self.forward_batch_reqs_list, + stage="sampled_tokens", + ) + if ( self.enable_logprob and not envs.FD_USE_GET_SAVE_OUTPUT_V1 diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index eb600e28e6..caa5686b2d 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -43,6 +43,11 @@ class InputBatch: for key, value in values.items(): setattr(self, key, value) + def get(self, key, default=None): + if hasattr(self, key): + return getattr(self, key) + return default + def pop(self, key, default=None): """ Pop an attribute, similar to dict's pop method diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index ec28183c12..a41466ea71 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -1194,6 +1194,21 @@ def run_worker_proc() -> None: worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank) worker_proc.init_control() + # Enable batch-invariant mode for deterministic inference. + # This must happen AFTER worker creation but BEFORE model loading, + # because enable_batch_invariant_mode() calls paddle.compat.enable_torch_proxy() + # which makes torch appear available via proxy. If called before worker creation, + # the gpu_model_runner import chain (ernie4_5_vl_processor → paddleformers → + # transformers) will fail when transformers tries to query torch metadata. + if envs.FD_DETERMINISTIC_MODE: + from fastdeploy.model_executor.layers.batch_invariant_ops import ( + enable_batch_invariant_mode, + is_batch_invariant_mode_enabled, + ) + + if not is_batch_invariant_mode_enabled(): + enable_batch_invariant_mode() + # Initialize device and create model runner worker_proc.init_device() diff --git a/tests/batch_invariant/test_batch_invariance_op_addmm.py b/tests/batch_invariant/test_batch_invariance_op_addmm.py index 45b5caac58..bca856d3b2 100644 --- a/tests/batch_invariant/test_batch_invariance_op_addmm.py +++ b/tests/batch_invariant/test_batch_invariance_op_addmm.py @@ -7,6 +7,9 @@ import paddle from fastdeploy.model_executor.layers.batch_invariant_ops import ( set_batch_invariant_mode, ) +from fastdeploy.model_executor.layers.batch_invariant_ops.batch_invariant_ops import ( + addmm_batch_invariant, +) class TestBatchInvariantForAddmm(unittest.TestCase): @@ -45,6 +48,23 @@ class TestBatchInvariantForAddmm(unittest.TestCase): if ass: assert max(difflist) == 0 + def test_alpha_zero(self): + """alpha == 0: result should be beta * input broadcast to [M, N]""" + M, N, K = 32, 64, 128 + for dtype in [paddle.float32, paddle.bfloat16]: + x = paddle.randn([M, K], dtype=dtype) + y = paddle.randn([K, N], dtype=dtype) + bias = paddle.randn([N], dtype=dtype) + + for beta in [0.0, 1.0, 2.5]: + out = addmm_batch_invariant(bias, x, y, beta=beta, alpha=0.0) + expected = (beta * bias).expand([M, N]) + # shape must be [M, N] + assert out.shape == [M, N], f"Expected shape [{M}, {N}], got {out.shape}" + # cast to float32 for comparison (bfloat16 not supported by isclose) + diff = (out.cast(paddle.float32) - expected.cast(paddle.float32)).abs().max() + assert diff.item() == 0, f"dtype={dtype}, beta={beta}, max diff={diff.item()}" + def test_case(self): # Test with standard Paddle (likely to show differences) print("Standard Paddle:") diff --git a/tests/ce/deterministic/start_fd.sh b/tests/ce/deterministic/start_fd.sh new file mode 100644 index 0000000000..c25050a80c --- /dev/null +++ b/tests/ce/deterministic/start_fd.sh @@ -0,0 +1,29 @@ +export FD_MODEL_SOURCE=HUGGINGFACE +export FD_MODEL_CACHE=./models + +export CUDA_VISIBLE_DEVICES=0 +export ENABLE_V1_KVCACHE_SCHEDULER=1 + +# FD_DETERMINISTIC_MODE: Toggle deterministic mode +# 0: Disable deterministic mode (non-deterministic) +# 1: Enable deterministic mode (default) +# FD_DETERMINISTIC_LOG_MODE: Toggle determinism logging +# 0: Disable logging (high performance, recommended for production) +# 1: Enable logging with MD5 hashes (debug mode) +# Usage: bash start_fd.sh [deterministic_mode] [log_mode] +# Example: +# bash start_fd.sh 1 0 # Deterministic mode without logging (fast) +# bash start_fd.sh 1 1 # Deterministic mode with logging (debug) +export FD_DETERMINISTIC_MODE=${1:-1} +export FD_DETERMINISTIC_LOG_MODE=${2:-0} + + +python -m fastdeploy.entrypoints.openai.api_server \ + --model ./models/Qwen/Qwen2.5-7B \ + --port 8188 \ + --tensor-parallel-size 1 \ + --max-model-len 32768 \ + --enable-logprob \ + --graph-optimization-config '{"use_cudagraph":true}' \ + --no-enable-prefix-caching \ + --no-enable-output-caching diff --git a/tests/ce/deterministic/test_determinism_verification.py b/tests/ce/deterministic/test_determinism_verification.py new file mode 100644 index 0000000000..b665d38d9d --- /dev/null +++ b/tests/ce/deterministic/test_determinism_verification.py @@ -0,0 +1,470 @@ +""" +Determinism Feature Verification Test + +Reference: test_batch_invariant.py. Verifies whether determinism works correctly. + +Usage: + # Step 1: Start server with determinism disabled + bash ./tests/ce/deterministic/start_fd.sh 0 + + # Step 2: Run non-deterministic test (expected: results differ) + python ./tests/ce/deterministic/test_determinism_verification.py --phase non-deterministic + + # Step 3: Stop server + bash fastdeploy/stop.sh + + # Step 4: Start server with determinism enabled and logging ON + bash ./tests/ce/deterministic/start_fd.sh 1 1 + + # Step 5: Run deterministic test (expected: results consistent) + python ./tests/ce/deterministic/test_determinism_verification.py --phase deterministic + +Arguments: + --phase {deterministic,non-deterministic} + Test mode + - deterministic: determinism enabled with logging, expected MD5 consistency + - non-deterministic: determinism disabled, expected different outputs + --api-url API endpoint URL (default: http://localhost:8188/v1/chat/completions) + --model Model name (default: Qwen/Qwen2.5-7B) + --log-file Server log file path (default: log/workerlog.0) + --repeat Number of repeat rounds for non-deterministic phase (default: 3) + +Note: The deterministic test requires FD_DETERMINISTIC_LOG_MODE=1 to extract MD5 values + from logs for verification. +""" + +import argparse +import asyncio +import hashlib +import logging +import os +import random +import re +import sys + +import aiohttp + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + +# Defaults (overridable via CLI args or env vars) +DEFAULT_API_URL = "http://localhost:8188/v1/chat/completions" +DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-7B" +DEFAULT_LOG_FILE = "log/workerlog.0" +DEFAULT_NON_DET_REPEAT = 3 + +# Target prompt (we care about its determinism) +TARGET_PROMPT = "你好,请简单介绍一下自己。" + +# Distractor prompts (different content, used to create batch interference) +DISTRACTOR_PROMPTS = [ + "今天天气怎么样?", + "什么是人工智能?", + "如何学习编程?", + "什么是机器学习?", + "Python 是什么?", +] + +# Generation length for target prompt (fixed, longer) +TARGET_MAX_TOKENS = 128 + +# Generation length range for distractor prompts +DISTRACTOR_MAX_TOKENS_RANGE = (8, 32) + +# Health check settings +HEALTH_CHECK_INTERVAL = 5 +HEALTH_CHECK_TIMEOUT = 300 + + +def parse_args(): + parser = argparse.ArgumentParser(description="Determinism feature verification test") + parser.add_argument( + "--phase", + choices=["deterministic", "non-deterministic"], + required=True, + help="Test mode: deterministic (enabled) or non-deterministic (disabled)", + ) + parser.add_argument( + "--api-url", + default=os.environ.get("FD_TEST_API_URL", DEFAULT_API_URL), + help=f"API endpoint URL (default: {DEFAULT_API_URL})", + ) + parser.add_argument( + "--model", + default=os.environ.get("FD_TEST_MODEL", DEFAULT_MODEL_NAME), + help=f"Model name (default: {DEFAULT_MODEL_NAME})", + ) + parser.add_argument( + "--log-file", + default=os.environ.get("FD_TEST_LOG_FILE", DEFAULT_LOG_FILE), + help=f"Server log file path (default: {DEFAULT_LOG_FILE})", + ) + parser.add_argument( + "--repeat", + type=int, + default=int(os.environ.get("FD_TEST_REPEAT", DEFAULT_NON_DET_REPEAT)), + help=f"Number of repeat rounds for non-deterministic phase (default: {DEFAULT_NON_DET_REPEAT})", + ) + return parser.parse_args() + + +def extract_md5_from_log(log_file: str, request_id: str) -> list[str]: + """Extract all decode step MD5 values for the specified request from log file.""" + md5_values = [] + try: + with open(log_file, "r", encoding="utf-8", errors="ignore") as f: + pattern = rf"\[DETERMINISM-MD5-REQ\] {re.escape(request_id)} \| decode" + for line in f: + if re.search(pattern, line): + match = re.search(r"hidden_states_md5=([a-f0-9]+)", line) + if match: + md5_values.append(match.group(1)) + except FileNotFoundError: + logger.warning("Log file not found: %s", log_file) + return md5_values + + +async def wait_for_server(api_url: str) -> None: + """Wait for the server to be ready by polling the API endpoint.""" + base_url = api_url.rsplit("/v1/", 1)[0] + health_url = f"{base_url}/v1/models" + timeout = aiohttp.ClientTimeout(total=10) + + logger.info("Waiting for server to be ready at %s ...", base_url) + elapsed = 0 + while elapsed < HEALTH_CHECK_TIMEOUT: + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(health_url) as resp: + if resp.status == 200: + logger.info("Server is ready.") + return + except (aiohttp.ClientError, asyncio.TimeoutError): + pass + await asyncio.sleep(HEALTH_CHECK_INTERVAL) + elapsed += HEALTH_CHECK_INTERVAL + logger.info(" Still waiting... (%ds/%ds)", elapsed, HEALTH_CHECK_TIMEOUT) + + raise RuntimeError( + f"Server not ready after {HEALTH_CHECK_TIMEOUT}s. " + f"Check that the server is running and accessible at {base_url}" + ) + + +async def send_request( + session: aiohttp.ClientSession, api_url: str, prompt: str, request_id: str, max_tokens: int, model: str +) -> str: + """Send request and return response content.""" + request = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.8, + "top_p": 0.9, + "max_tokens": max_tokens, + "request_id": request_id, + } + timeout = aiohttp.ClientTimeout(total=300) + async with session.post(api_url, json=request, timeout=timeout) as response: + response.raise_for_status() + result = await response.json() + return result["choices"][0]["message"]["content"] + + +async def run_test_case( + session: aiohttp.ClientSession, + api_url: str, + test_name: str, + test_plan: list[tuple[str, str, bool]], + model: str, +) -> list[tuple[str, str]]: + """ + Run a test case. + + Args: + api_url: API endpoint URL. + test_plan: List of (request_id, prompt, is_target) tuples. + model: Model name to use for the request. + + Returns: + List of (request_id, result) tuples for target requests only. + """ + target_count = sum(1 for _, _, t in test_plan if t) + distractor_count = len(test_plan) - target_count + logger.info( + "[Test %s] %d requests (target=%d, distractor=%d)", test_name, len(test_plan), target_count, distractor_count + ) + + tasks = [] + for req_id, prompt, is_target in test_plan: + max_tokens = TARGET_MAX_TOKENS if is_target else random.randint(*DISTRACTOR_MAX_TOKENS_RANGE) + tasks.append(send_request(session, api_url, prompt, req_id, max_tokens, model)) + + results = await asyncio.gather(*tasks) + + target_outputs = [] + for (req_id, _, is_target), result in zip(test_plan, results): + marker = "[Target]" if is_target else "[Distractor]" + logger.info(" %s %s: %s...", marker, req_id, result[:50]) + if is_target: + target_outputs.append((req_id, result)) + + return target_outputs + + +def _print_section(title: str) -> None: + """Print a section banner.""" + print("\n" + "=" * 80) + print(title) + print("=" * 80) + + +def _check_consistency( + items: dict[str, list[str]], + label: str, + expect_consistent: bool, + detail_formatter=None, +) -> bool: + """ + Unified consistency check logic. + + Args: + items: Dict mapping unique_key -> list of request_ids sharing that key. + label: Description label (e.g. "Text", "MD5 Step 1"). + expect_consistent: True expects all keys identical, False expects differences. + detail_formatter: Optional callable(key) -> str for displaying details on mismatch. + + Returns: + True if result matches expectation, False otherwise. + """ + expected_desc = "consistent" if expect_consistent else "inconsistent" + _print_section(f"{label} Consistency Check (Expected: {expected_desc})") + + if not items: + logger.warning("No %s values found!", label) + return False + + is_consistent = len(items) == 1 + + print(f"\n Unique values: {len(items)}") + if is_consistent: + key = next(iter(items)) + reqs = items[key] + print(f" All {len(reqs)} requests share the same value") + else: + for i, (key, reqs) in enumerate(items.items(), 1): + detail = f" ({detail_formatter(key)})" if detail_formatter else "" + print(f" Group {i}: {', '.join(reqs)}{detail}") + + print("-" * 80) + + passed = is_consistent == expect_consistent + actual_desc = "consistent" if is_consistent else "inconsistent" + status = "PASS" if passed else "FAIL" + print(f" {status}: expected {expected_desc}, actual {actual_desc}") + print("=" * 80) + + return passed + + +def compare_text_consistency(target_results: list[tuple[str, str]], expect_consistent: bool = True) -> bool: + """Compare target request text content against expected consistency.""" + unique_texts: dict[str, list[str]] = {} + text_map: dict[str, str] = {} + for req_id, text in target_results: + text_md5 = hashlib.md5(text.encode("utf-8")).hexdigest() + unique_texts.setdefault(text_md5, []).append(req_id) + if text_md5 not in text_map: + text_map[text_md5] = text + + return _check_consistency( + unique_texts, + label="Text", + expect_consistent=expect_consistent, + detail_formatter=lambda key: repr(text_map[key][:50]), + ) + + +def compare_md5_consistency(all_md5: dict[str, list[str]], expect_consistent: bool = True) -> bool: + """ + Compare MD5 results across ALL decode steps and verify against expected consistency. + + For each decode step, checks that all target requests produced identical hidden_states_md5. + All steps must be consistent for the overall check to pass. + """ + if not all_md5: + logger.warning("No MD5 values found!") + return False + + # Find the minimum number of decode steps across all requests + min_steps = min(len(md5s) for md5s in all_md5.values()) + if min_steps == 0: + logger.warning("Some requests have no decode step MD5 values!") + return False + + req_ids = list(all_md5.keys()) + logger.info("Checking MD5 consistency across %d decode steps for %d requests", min_steps, len(req_ids)) + + failed_steps = [] + + for step in range(min_steps): + step_md5s: dict[str, list[str]] = {} + for req_id in req_ids: + md5_val = all_md5[req_id][step] + step_md5s.setdefault(md5_val, []).append(req_id) + + step_consistent = len(step_md5s) == 1 + + if not step_consistent: + failed_steps.append(step) + + # Print per-step result + if step_consistent: + md5_val = next(iter(step_md5s)) + logger.info(" Decode step %d: CONSISTENT (md5=%s)", step + 1, md5_val) + else: + logger.warning(" Decode step %d: INCONSISTENT (%d different values)", step + 1, len(step_md5s)) + for md5_val, reqs in step_md5s.items(): + logger.warning(" md5=%s: %s", md5_val, ", ".join(reqs)) + + is_consistent = len(failed_steps) == 0 + + _print_section(f"MD5 Consistency Check (all {min_steps} decode steps)") + if is_consistent: + print(f" All {min_steps} decode steps are consistent across {len(req_ids)} requests") + else: + print(f" {len(failed_steps)}/{min_steps} decode steps are INCONSISTENT") + print(f" Failed steps: {[s + 1 for s in failed_steps]}") + print("-" * 80) + + passed = is_consistent == expect_consistent + expected_desc = "consistent" if expect_consistent else "inconsistent" + actual_desc = "consistent" if is_consistent else "inconsistent" + status = "PASS" if passed else "FAIL" + print(f" {status}: expected {expected_desc}, actual {actual_desc}") + print("=" * 80) + + return passed + + +# Test cases: (name, plan) where plan is [(request_id, prompt, is_target)] +TEST_CASES = [ + ( + "case1: Single request (target only)", + [ + ("case1-target", TARGET_PROMPT, True), + ], + ), + ( + "case2: Two requests (1 target + 1 distractor)", + [ + ("case2-distract-a", DISTRACTOR_PROMPTS[0], False), + ("case2-target", TARGET_PROMPT, True), + ], + ), + ( + "case3: Four requests (1 target + 3 distractors)", + [ + ("case3-distract-a", DISTRACTOR_PROMPTS[0], False), + ("case3-distract-b", DISTRACTOR_PROMPTS[1], False), + ("case3-target", TARGET_PROMPT, True), + ("case3-distract-c", DISTRACTOR_PROMPTS[2], False), + ], + ), + ( + "case4: Six requests (1 target + 5 distractors)", + [ + ("case4-distract-a", DISTRACTOR_PROMPTS[0], False), + ("case4-distract-b", DISTRACTOR_PROMPTS[1], False), + ("case4-distract-c", DISTRACTOR_PROMPTS[2], False), + ("case4-distract-d", DISTRACTOR_PROMPTS[3], False), + ("case4-target", TARGET_PROMPT, True), + ("case4-distract-e", DISTRACTOR_PROMPTS[4], False), + ], + ), +] + + +def _build_test_plan(test_cases, repeat: int = 1): + """ + Build test plan with optional repetition. + + For repeat > 1, each test case is duplicated with round-suffixed request_ids. + This increases sample size for non-deterministic testing. + """ + if repeat <= 1: + return test_cases + + expanded = [] + for case_name, plan in test_cases: + for r in range(repeat): + round_name = f"{case_name} (round {r + 1})" + round_plan = [(f"{req_id}-r{r + 1}", prompt, is_target) for req_id, prompt, is_target in plan] + expanded.append((round_name, round_plan)) + return expanded + + +async def main() -> int: + args = parse_args() + is_deterministic = args.phase == "deterministic" + + _print_section("Determinism Feature Verification Test") + print(f"\n Test mode: {args.phase}") + print(f" API URL: {args.api_url}") + print(f" Model: {args.model}") + print(f" Log file: {args.log_file}") + if is_deterministic: + print(" Expected: All target requests have consistent MD5 values") + else: + print(f" Expected: Target requests produce different outputs (repeat={args.repeat})") + print("=" * 80) + + # Wait for server to be ready + await wait_for_server(args.api_url) + + # Build test plan (repeat for non-deterministic to reduce flaky probability) + repeat = args.repeat if not is_deterministic else 1 + test_plan = _build_test_plan(TEST_CASES, repeat=repeat) + + async with aiohttp.ClientSession() as session: + all_target_results: list[tuple[str, str]] = [] + for test_name, plan in test_plan: + target_outputs = await run_test_case(session, args.api_url, test_name, plan, args.model) + all_target_results.extend(target_outputs) + await asyncio.sleep(1) + + target_request_ids = [req_id for req_id, _ in all_target_results] + + _print_section("All tests completed, starting verification...") + + if is_deterministic: + # Deterministic mode: compare MD5 across all decode steps + all_md5 = {} + for req_id in target_request_ids: + md5_values = extract_md5_from_log(args.log_file, req_id) + if md5_values: + all_md5[req_id] = md5_values + logger.info("%s: %d decode steps found", req_id, len(md5_values)) + else: + logger.warning("%s: No MD5 logs found", req_id) + + if all_md5: + passed = compare_md5_consistency(all_md5, expect_consistent=True) + else: + logger.warning("No MD5 logs found, fallback to text consistency check") + passed = compare_text_consistency(all_target_results, expect_consistent=True) + else: + # Non-deterministic mode: compare text content + passed = compare_text_consistency(all_target_results, expect_consistent=False) + + _print_section("Final Result") + if passed: + print(f" PASS: {args.phase} mode verified successfully") + else: + print(f" FAIL: {args.phase} mode verification failed") + print("=" * 80) + + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) diff --git a/tests/conftest.py b/tests/conftest.py index 22c71ca862..057dc15aeb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,11 +11,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", "gpu: mark test as requiring GPU platform") + + +def pytest_collection_modifyitems(config, items): + """Skip GPU-marked tests when not on a GPU platform. + + IMPORTANT: Do NOT import paddle or fastdeploy here. This function runs + during pytest collection (before fork). Importing paddle initializes the + CUDA runtime, which makes forked child processes unable to re-initialize + CUDA (OSError: CUDA error(3), initialization error). + """ + import glob + + has_gpu = len(glob.glob("/dev/nvidia[0-9]*")) > 0 + + if has_gpu: + return + + skip_marker = pytest.mark.skip(reason="Test requires GPU platform, skipping on non-GPU") + for item in items: + if "gpu" in item.keywords: + item.add_marker(skip_marker) + + import time from typing import Any, Union -import pytest -from e2e.utils.serving_utils import ( +from e2e.utils.serving_utils import ( # noqa: E402 FD_API_PORT, FD_CACHE_QUEUE_PORT, FD_ENGINE_QUEUE_PORT, diff --git a/tests/deterministic/test_determinism_offline.py b/tests/deterministic/test_determinism_offline.py new file mode 100644 index 0000000000..c34d34dec1 --- /dev/null +++ b/tests/deterministic/test_determinism_offline.py @@ -0,0 +1,357 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Determinism offline inference tests using LLM.generate + +Test scenarios: +1. Same-prompt repeatability (FD_DETERMINISTIC_MODE=1) +2. Batch invariance (single vs. batch, different positions) +3. Different batch sizes consistency +4. Sampling-parameter combinations (temperature x top_p, parametrized) +5. Long sequence generation (512-1024 tokens) +6. Long input prompt handling +7. Minimal output (max_tokens=1, early stop) +8. Special characters & multi-language prompts +9. Multi-turn conversation +10. State isolation (interleaved / interference prompts) +11. Non-deterministic validation (proves tests are effective) + +Usage: + CUDA_VISIBLE_DEVICES=0 pytest tests/deterministic/test_determinism_offline.py -v +""" + +import os + +import pytest + +pytestmark = pytest.mark.gpu + +DEFAULT_MODEL_DIR = "./models" +MODEL_NAME = "Qwen2-7B-Instruct" + +_ENV_CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" +_ENV_FD_DETERMINISTIC_MODE = "FD_DETERMINISTIC_MODE" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module", autouse=True) +def _module_env(): + """Set env vars before importing fastdeploy (must happen first).""" + old_cuda = os.environ.get(_ENV_CUDA_VISIBLE_DEVICES) + old_det = os.environ.get(_ENV_FD_DETERMINISTIC_MODE) + + os.environ[_ENV_CUDA_VISIBLE_DEVICES] = os.environ.get(_ENV_CUDA_VISIBLE_DEVICES, "0") + os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1" + + global LLM, SamplingParams # noqa: PLW0603 + from fastdeploy import LLM, SamplingParams + + yield + + if old_cuda is None: + os.environ.pop(_ENV_CUDA_VISIBLE_DEVICES, None) + else: + os.environ[_ENV_CUDA_VISIBLE_DEVICES] = old_cuda + if old_det is None: + os.environ.pop(_ENV_FD_DETERMINISTIC_MODE, None) + else: + os.environ[_ENV_FD_DETERMINISTIC_MODE] = old_det + + +@pytest.fixture(autouse=True) +def _reset_deterministic_mode(): + """Ensure every test starts with deterministic mode ON.""" + os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1" + yield + os.environ[_ENV_FD_DETERMINISTIC_MODE] = "1" + + +@pytest.fixture(scope="module") +def model_path(): + model_dir = os.getenv("MODEL_PATH", DEFAULT_MODEL_DIR) + return os.path.join(model_dir, MODEL_NAME) + + +@pytest.fixture(scope="module") +def llm(model_path, _module_env): + return LLM( + model=model_path, + tensor_parallel_size=1, + max_model_len=8192, + enable_prefix_caching=False, + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _generate_text(llm, prompt, sp): + """Generate once, return (text, token_ids).""" + out = llm.generate([prompt], sp)[0] + return out.outputs.text, out.outputs.token_ids + + +def _assert_deterministic(llm, prompt, sp, runs=2): + """Run *runs* times and assert all outputs are identical.""" + results = [_generate_text(llm, prompt, sp) for _ in range(runs)] + texts = [r[0] for r in results] + token_ids = [r[1] for r in results] + assert all(t == texts[0] for t in texts), "Text outputs differ across runs" + assert all(t == token_ids[0] for t in token_ids), "Token IDs differ across runs" + return texts[0], token_ids[0] + + +# ===================== Core determinism tests ===================== + + +def test_deterministic_same_prompt(llm): + """Same prompt + same seed produces identical output across 5 runs.""" + sp = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50, seed=123) + _assert_deterministic(llm, "Please introduce artificial intelligence in one sentence.", sp, runs=5) + + +def test_deterministic_batch_invariance(llm): + """Target prompt produces identical output regardless of batch position.""" + prompt = "What kind of programming language is Python?" + sp = SamplingParams(temperature=0.5, max_tokens=40, seed=456) + + baseline, _ = _generate_text(llm, prompt, sp) + + batch_configs = [ + [prompt, "Filler question 1"], + ["Filler question 2", prompt, "Filler question 3"], + ["Filler question 4", "Filler question 5", prompt], + ["Filler 6", "Filler 7", "Filler 8", prompt], + ] + + for i, batch in enumerate(batch_configs): + outputs = llm.generate(batch, sp) + idx = batch.index(prompt) + assert ( + outputs[idx].outputs.text == baseline + ), f"Batch config {i} (pos {idx}): result differs from single-request baseline" + + +def test_deterministic_different_batch_sizes(llm): + """Same prompt is consistent across batch sizes 1 / 2 / 4 / 8.""" + prompt = "What is machine learning?" + sp = SamplingParams(temperature=0.5, max_tokens=30, seed=789) + + baseline, _ = _generate_text(llm, prompt, sp) + + for bs in [2, 4, 8]: + outputs = llm.generate([prompt] * bs, sp) + assert outputs[0].outputs.text == baseline, f"Batch size {bs} differs from bs=1" + + +# ===================== Sampling-parameter combinations ===================== + + +@pytest.mark.parametrize( + "temp,top_p,seed", + [ + (0.0, 1.0, 300), # greedy, no top_p filter + (0.0, 0.0, 301), # double-greedy + (0.3, 0.9, 302), # low temp, moderate top_p + (0.8, 0.0, 303), # medium temp, greedy top_p + (0.8, 1.0, 304), # medium temp, no top_p filter + (0.8, 0.5, 305), # medium temp, strict top_p + (1.0, 0.95, 306), # high temp + (1.5, 0.9, 307), # very high temp + ], +) +def test_deterministic_param_combos(llm, temp, top_p, seed): + """Determinism holds across various (temperature, top_p) combinations.""" + sp = SamplingParams(temperature=temp, top_p=top_p, max_tokens=30, seed=seed) + _assert_deterministic(llm, "What is a neural network?", sp) + + +# ===================== Long sequence tests ===================== + + +@pytest.mark.parametrize( + "temp,seed", + [ + (0.0, 100), + (0.3, 130), + (0.5, 150), + (0.7, 170), + ], +) +def test_deterministic_long_sequence(llm, temp, seed): + """Long generation (512+ tokens) stays deterministic at various temperatures.""" + prompt = "Please describe the history of AI in detail, including major milestones and key technical breakthroughs." + sp = SamplingParams(temperature=temp, top_p=0.95, max_tokens=512, seed=seed) + + text, token_ids = _assert_deterministic(llm, prompt, sp) + assert len(token_ids) >= 100, f"Expected >= 100 tokens, got {len(token_ids)}" + + +def test_deterministic_long_prompt(llm): + """Long input prompt (prefill-heavy) stays deterministic.""" + base = "This is a description about natural language processing. " + long_prompt = (base * 50) + "Please summarize the above." + sp = SamplingParams(temperature=0.5, max_tokens=100, seed=2024) + + _assert_deterministic(llm, long_prompt, sp) + + +# ===================== Minimal / boundary output tests ===================== + + +def test_deterministic_max_tokens_one(llm): + """Single-token output is deterministic.""" + sp = SamplingParams(temperature=0.1, max_tokens=1, seed=700) + + text, token_ids = _assert_deterministic(llm, "What color is the sky?", sp) + assert len(token_ids) == 1, f"Expected 1 token, got {len(token_ids)}" + + +def test_deterministic_early_stop(llm): + """Early stopping via stop sequences is deterministic.""" + sp = SamplingParams(temperature=0.7, max_tokens=100, stop=["\u3002", "."], seed=800) + + text, token_ids = _assert_deterministic(llm, "Please list three colors:", sp) + assert len(token_ids) < 100, f"Expected early stop, got {len(token_ids)} tokens" + + +# ===================== Special input tests ===================== + + +@pytest.mark.parametrize( + "prompt,seed", + [ + ("What is AI? \U0001f52c\U0001f9e0", 900), # emoji + ("Math: E = mc\u00b2", 901), # superscript + ("Code: def hello(): return 'world'", 902), # code + ("Symbols: @#$%^&*()", 903), # special symbols + ], +) +def test_deterministic_special_chars(llm, prompt, seed): + sp = SamplingParams(temperature=0.5, max_tokens=30, seed=seed) + _assert_deterministic(llm, prompt, sp) + + +@pytest.mark.parametrize( + "lang,prompt,seed", + [ + ("Chinese", "Please introduce artificial intelligence in one sentence.", 1000), + ("English", "What is artificial intelligence in one sentence?", 1001), + ( + "Japanese", + "\u4eba\u5de5\u77e5\u80fd\u306b\u3064\u3044\u3066\u4e00\u8a00\u3067\u8aac\u660e\u3057\u3066\u304f\u3060\u3055\u3044\u3002", + 1002, + ), + ("Spanish", "\u00bfQu\u00e9 es la inteligencia artificial en una frase?", 1003), + ], +) +def test_deterministic_multi_language(llm, lang, prompt, seed): + sp = SamplingParams(temperature=0.5, max_tokens=30, seed=seed) + _assert_deterministic(llm, prompt, sp) + + +# ===================== Multi-turn conversation test ===================== + + +def test_deterministic_multi_turn(llm): + """Multi-turn chat maintains determinism.""" + sp = SamplingParams(temperature=0.5, max_tokens=50, seed=1100) + + messages1 = [ + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi! How can I help you?"}, + {"role": "user", "content": "Please introduce yourself."}, + ] + + # First full conversation + r1_turn1 = llm.chat(messages1, sp)[0].outputs.text + msgs2 = messages1 + [ + {"role": "assistant", "content": r1_turn1}, + {"role": "user", "content": "What can you do?"}, + ] + r1_turn2 = llm.chat(msgs2, sp)[0].outputs.text + + # Second full conversation (same seed) + r2_turn1 = llm.chat(messages1, sp)[0].outputs.text + msgs2_repeat = messages1 + [ + {"role": "assistant", "content": r2_turn1}, + {"role": "user", "content": "What can you do?"}, + ] + r2_turn2 = llm.chat(msgs2_repeat, sp)[0].outputs.text + + assert r1_turn1 == r2_turn1, "Multi-turn: turn-1 outputs differ" + assert r1_turn2 == r2_turn2, "Multi-turn: turn-2 outputs differ" + + +# ===================== State isolation test ===================== + + +def test_deterministic_state_isolation(llm): + """Interference prompts and interleaving do not break determinism.""" + prompt_a = "What is Python?" + prompt_b = "What is JavaScript?" + sp_a = SamplingParams(temperature=0.5, max_tokens=30, seed=1200) + sp_b = SamplingParams(temperature=0.5, max_tokens=30, seed=1201) + + # Round 1 + a1, _ = _generate_text(llm, prompt_a, sp_a) + b1, _ = _generate_text(llm, prompt_b, sp_b) + + # Run unrelated interference + for p in ["Explain reinforcement learning.", "What is NLP?", "List 3 fruits."]: + llm.generate([p], SamplingParams(temperature=0.7, max_tokens=20, seed=999)) + + # Round 2 + a2, _ = _generate_text(llm, prompt_a, sp_a) + b2, _ = _generate_text(llm, prompt_b, sp_b) + + assert a1 == a2, "Prompt A: output changed after interference" + assert b1 == b2, "Prompt B: output changed after interference" + + +# ===================== Non-deterministic validation ===================== + + +def test_non_deterministic_validation(llm): + """ + Prove that tests are effective: + - Without seed + without mode: outputs vary + - With explicit seed: outputs are consistent + """ + prompt = "Please explain deep learning in one sentence." + + # Part 1: no mode, no seed -> outputs should differ + os.environ.pop("FD_DETERMINISTIC_MODE", None) + results_no_seed = [] + for _ in range(5): + sp = SamplingParams(temperature=0.7, max_tokens=30) + results_no_seed.append(llm.generate([prompt], sp)[0].outputs.text) + + assert len(set(results_no_seed)) > 1, "Without seed/mode: expected varied outputs, got all identical" + + # Part 2: explicit seed -> outputs must be consistent + sp_seeded = SamplingParams(temperature=0.7, max_tokens=30, seed=999) + results_seeded = [llm.generate([prompt], sp_seeded)[0].outputs.text for _ in range(5)] + assert len(set(results_seeded)) == 1, "With explicit seed: expected consistent outputs" + + +if __name__ == "__main__": + pytest.main(["-sv", __file__]) diff --git a/tests/deterministic/test_determinism_standalone.py b/tests/deterministic/test_determinism_standalone.py new file mode 100644 index 0000000000..c1d940d661 --- /dev/null +++ b/tests/deterministic/test_determinism_standalone.py @@ -0,0 +1,296 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Determinism unit tests (lightweight, no model loading required) + +Test scenarios: +1. SamplingParams seed behavior in deterministic / non-deterministic mode +2. Environment variable handling (FD_DETERMINISTIC_MODE, SPLIT_KV_SIZE, LOG_MODE) +3. Token allocation alignment logic (_get_num_new_tokens) +4. Cross-mode behavior validation + +Usage: + pytest tests/deterministic/test_determinism_standalone.py -v +""" + +import importlib +import os +from dataclasses import dataclass +from typing import Optional + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _reload_sp(): + """Reload envs + sampling_params so env-var changes take effect.""" + import fastdeploy.engine.sampling_params as sp_module + import fastdeploy.envs as envs_module + + importlib.reload(envs_module) + importlib.reload(sp_module) + return sp_module, envs_module + + +@dataclass +class _FakeRequest: + """Minimal stand-in for a scheduler request object.""" + + need_prefill_tokens: int + num_computed_tokens: int + request_id: str = "fake-0" + prompt_token_ids: Optional[list] = None + multimodal_inputs: Optional[dict] = None + with_image: bool = False + + +def _align_tokens(current_pos, remaining, budget, split_kv_size): + """ + Pure-function replica of the alignment logic in + ResourceManagerV1._get_num_new_tokens (deterministic branch). + + Returns the number of new tokens to allocate. + """ + if remaining < split_kv_size: + # Final chunk - no alignment needed + return min(remaining, budget) + + # Next split_kv_size boundary from current_pos + next_boundary = ((current_pos + split_kv_size - 1) // split_kv_size) * split_kv_size + tokens_to_boundary = next_boundary - current_pos + + if budget < tokens_to_boundary: + return 0 # defer + + aligned_end = ((current_pos + budget) // split_kv_size) * split_kv_size + num_new = aligned_end - current_pos + return min(num_new, budget, remaining) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clean_env(): + """Save and restore determinism-related env vars around every test.""" + keys = [ + "FD_DETERMINISTIC_MODE", + "FD_DETERMINISTIC_SPLIT_KV_SIZE", + "FD_DETERMINISTIC_LOG_MODE", + ] + saved = {k: os.environ.get(k) for k in keys} + yield + for k, v in saved.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + +def _set_env(key, value): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + +# ===================== SamplingParams seed tests ===================== + + +class TestSamplingParamsSeed: + """Verify seed assignment in SamplingParams under different modes.""" + + def test_non_deterministic_uses_random_seed(self): + """Without FD_DETERMINISTIC_MODE, each SamplingParams gets a random seed.""" + _set_env("FD_DETERMINISTIC_MODE", None) + sp_mod, _ = _reload_sp() + + seeds = {sp_mod.SamplingParams().seed for _ in range(10)} + assert len(seeds) > 1, "Non-deterministic mode should produce different random seeds" + + def test_deterministic_uses_fixed_seed(self): + """With FD_DETERMINISTIC_MODE=1, default seed is always 42.""" + _set_env("FD_DETERMINISTIC_MODE", "1") + sp_mod, _ = _reload_sp() + + seeds = {sp_mod.SamplingParams().seed for _ in range(10)} + assert seeds == {42}, f"Deterministic mode should always use seed=42, got {seeds}" + + def test_explicit_seed_overrides_mode(self): + """User-supplied seed takes precedence over deterministic default.""" + _set_env("FD_DETERMINISTIC_MODE", "1") + sp_mod, _ = _reload_sp() + + assert sp_mod.SamplingParams(seed=123).seed == 123 + + def test_seed_zero_is_valid(self): + """seed=0 must not be confused with 'unset'.""" + _set_env("FD_DETERMINISTIC_MODE", "1") + sp_mod, _ = _reload_sp() + + assert sp_mod.SamplingParams(seed=0).seed == 0 + + def test_seed_max_value(self): + """Upper-bound seed accepted by _verify_args.""" + _set_env("FD_DETERMINISTIC_MODE", "1") + sp_mod, _ = _reload_sp() + + max_seed = 922337203685477580 + assert sp_mod.SamplingParams(seed=max_seed).seed == max_seed + + def test_explicit_seed_works_in_both_modes(self): + """Same explicit seed yields same value regardless of mode.""" + explicit_seed = 12345 + for mode in ("0", "1"): + _set_env("FD_DETERMINISTIC_MODE", mode) + sp_mod, _ = _reload_sp() + assert sp_mod.SamplingParams(seed=explicit_seed).seed == explicit_seed + + +# ===================== Environment variable tests ===================== + + +class TestDeterminismEnvVars: + """Verify env-var parsing in fastdeploy.envs.""" + + @pytest.mark.parametrize( + "raw,expected", + [ + (None, False), + ("0", False), + ("1", True), + ], + ) + def test_deterministic_mode(self, raw, expected): + _set_env("FD_DETERMINISTIC_MODE", raw) + _, envs_mod = _reload_sp() + assert envs_mod.FD_DETERMINISTIC_MODE is expected + + def test_split_kv_size_default(self): + _set_env("FD_DETERMINISTIC_SPLIT_KV_SIZE", None) + _, envs_mod = _reload_sp() + assert envs_mod.FD_DETERMINISTIC_SPLIT_KV_SIZE == 16 + + def test_split_kv_size_custom(self): + _set_env("FD_DETERMINISTIC_SPLIT_KV_SIZE", "32") + _, envs_mod = _reload_sp() + assert envs_mod.FD_DETERMINISTIC_SPLIT_KV_SIZE == 32 + + @pytest.mark.parametrize( + "raw,expected", + [ + (None, False), + ("1", True), + ], + ) + def test_log_mode(self, raw, expected): + _set_env("FD_DETERMINISTIC_LOG_MODE", raw) + _, envs_mod = _reload_sp() + assert envs_mod.FD_DETERMINISTIC_LOG_MODE is expected + + +# ===================== Token alignment logic tests ===================== + + +class TestTokenAlignment: + """ + Verify the deterministic token-alignment algorithm. + + The alignment logic ensures chunk boundaries fall on split_kv_size + multiples so that attention computation is batch-invariant. + """ + + @pytest.mark.parametrize( + "cur,remaining,budget,kv,expected", + [ + # --- basic cases (cur=0) --- + (0, 100, 5, 16, 0), # budget < kv_size, defer + (0, 100, 16, 16, 16), # budget == kv_size + (0, 100, 32, 16, 32), # budget == 2*kv_size + (0, 100, 50, 16, 48), # round-down to 48 + # --- non-zero current_pos --- + (10, 90, 20, 16, 6), # next boundary=16, then end=16, alloc=6 + (8, 92, 20, 16, 8), # next boundary=16, aligned_end=16, alloc=8 + (16, 84, 32, 16, 32), # already on boundary + (15, 85, 1, 16, 1), # exactly 1 token to next boundary + (17, 83, 2, 16, 0), # 15 tokens to boundary=32, budget=2 => defer + # --- final-chunk (remaining < kv_size) --- + (96, 4, 10, 16, 4), # final chunk, no alignment + (96, 4, 2, 16, 2), # final chunk, budget < remaining + # --- large kv_size --- + (0, 200, 100, 64, 64), # kv=64, 100//64*64=64 + (0, 200, 128, 64, 128), # kv=64, 128//64*64=128 + ], + ) + def test_align_tokens(self, cur, remaining, budget, kv, expected): + result = _align_tokens(cur, remaining, budget, kv) + assert result == expected, ( + f"align_tokens(cur={cur}, remaining={remaining}, budget={budget}, kv={kv}): " + f"expected {expected}, got {result}" + ) + + def test_alignment_vs_non_deterministic(self): + """Deterministic mode allocates fewer tokens due to alignment.""" + budget, kv = 50, 16 + det_result = _align_tokens(0, 100, budget, kv) # 48 + non_det_result = min(100, budget) # 50 + assert det_result < non_det_result + assert det_result == 48 + assert non_det_result == 50 + + def test_result_always_on_boundary_or_final_allocation(self): + """After allocation, (current_pos + result) sits on a kv boundary + unless this allocation exhausts all remaining tokens.""" + kv = 16 + for cur in range(0, 80, 7): + for remaining in [5, 10, 30, 60, 100]: + for budget in [1, 8, 16, 32, 64]: + result = _align_tokens(cur, remaining, budget, kv) + if result == 0: + continue + end = cur + result + is_final = result == remaining + if remaining >= kv and not is_final: + assert end % kv == 0, ( + f"cur={cur} remaining={remaining} budget={budget}: " f"end={end} is not aligned to {kv}" + ) + + +# ===================== Cross-mode behavior validation ===================== + + +class TestCrossModeBehavior: + """Prove that mode switch actually changes observable behavior.""" + + def test_deterministic_mode_consistent_seeds(self): + _set_env("FD_DETERMINISTIC_MODE", "1") + sp_mod, _ = _reload_sp() + seeds = [sp_mod.SamplingParams().seed for _ in range(10)] + assert len(set(seeds)) == 1 and seeds[0] == 42 + + def test_non_deterministic_mode_varied_seeds(self): + _set_env("FD_DETERMINISTIC_MODE", "0") + sp_mod, _ = _reload_sp() + seeds = [sp_mod.SamplingParams().seed for _ in range(10)] + assert len(set(seeds)) > 1 + + +if __name__ == "__main__": + pytest.main(["-sv", __file__]) diff --git a/tests/distributed/allreduce_deterministic.py b/tests/distributed/allreduce_deterministic.py new file mode 100644 index 0000000000..d498626540 --- /dev/null +++ b/tests/distributed/allreduce_deterministic.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +All-Reduce Deterministic Test with Real Communication + +Tests: +1. Custom All-Reduce is deterministic for supported dtypes (float32, float16, bfloat16) +2. Non-16 byte aligned tensors raise RuntimeError in deterministic mode +3. Unsupported dtypes (int32) raise AssertionError in deterministic mode + +Run: + python -m paddle.distributed.launch --gpus=0,1,2,3 tests/distributed/allreduce_deterministic.py +""" + +import os + +import paddle +import paddle.distributed as dist +import pytest + +pytestmark = pytest.mark.gpu + +from fastdeploy import envs +from fastdeploy.distributed import communication +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce + +SUPPORTED_DTYPES = [paddle.float32, paddle.float16, paddle.bfloat16] +TENSOR_SIZE = 2048 +NUM_RUNS = 20 + + +def _create_tensor(size: int, dtype: paddle.dtype, rank: int) -> paddle.Tensor: + """Create a test tensor with appropriate dtype and scaling.""" + if dtype == paddle.int32: + return paddle.randint(-100, 100, shape=[size, 1], dtype=dtype) * (rank + 1) + return paddle.randn([size, 1], dtype=dtype) * (rank + 1) + + +def _check_results_identical(results: list) -> bool: + """Check if all results are identical.""" + if not results: + return True + return all((results[0] == r).all() for r in results[1:]) + + +def _init_custom_allreduce(world_size: int): + """Initialize custom all-reduce for testing.""" + mp_group = dist.new_group(ranks=list(range(world_size))) + communication.use_custom_allreduce(mp_group, 8192 * 1024) + return mp_group + + +def _enable_deterministic_mode(): + """Enable deterministic mode via environment variable.""" + os.environ["FD_DETERMINISTIC_MODE"] = "1" + assert envs.FD_DETERMINISTIC_MODE, f"FD_DETERMINISTIC_MODE should be True but got {envs.FD_DETERMINISTIC_MODE}" + + +def test_custom_allreduce_deterministic(rank, world_size, dtype): + """Custom all-reduce should be deterministic.""" + _mp_group = _init_custom_allreduce(world_size) # noqa: F841 + results = [] + + for _ in range(NUM_RUNS): + paddle.seed(42 + rank) + x = _create_tensor(TENSOR_SIZE, dtype, rank) + result = tensor_model_parallel_all_reduce(x) + results.append(result.astype("float32").numpy().copy()) + dist.barrier() + + communication.custom_ar_clear_ipc_handles() + return _check_results_identical(results) + + +def _init_large_custom_allreduce(world_size: int): + """Initialize custom all-reduce with 128MB buffer for large tensor tests.""" + _enable_deterministic_mode() + large_max_size = 128 * 1024 * 1024 # 128MB + mp_group = dist.new_group(ranks=list(range(world_size))) + # Properly close old instance to free GPU buffers and IPC handles + if communication._TP_AR is not None: + communication._TP_AR.close() + communication._TP_AR = None + communication.use_custom_allreduce(mp_group, large_max_size) + + +def test_large_tensor_correctness(rank, world_size, dtype): + """Large tensor (> default 8MB) should produce correct results with increased max_size.""" + # 2M elements * 2 bytes (bf16) = 4MB; 8M elements * 2 bytes = 16MB (> 8MB default) + large_sizes = [2 * 1024 * 1024, 8 * 1024 * 1024] + for large_size in large_sizes: + expected_val = float(world_size * (world_size + 1) // 2) + x = paddle.full([large_size, 1], float(rank + 1), dtype=dtype) + result = tensor_model_parallel_all_reduce(x) + + # Cast to float32 before numpy() since bfloat16 has no native numpy support + result_np = result.astype("float32").numpy().flatten() + max_diff = abs(result_np - expected_val).max() + if max_diff > 0.01: + raise AssertionError( + f"Large tensor AR mismatch for {dtype}, size={large_size}: " + f"expected={expected_val}, got_sample={result_np[:5]}, max_diff={max_diff}" + ) + dist.barrier() + + +def test_large_tensor_deterministic(rank, world_size, dtype): + """Multiple runs of large tensor all-reduce must produce bitwise-identical results.""" + # 8M elements * 2 bytes (bf16) = 16MB, exceeds default 8MB + large_size = 8 * 1024 * 1024 + results = [] + for _ in range(NUM_RUNS): + paddle.seed(42 + rank) + x = _create_tensor(large_size, dtype, rank) + result = tensor_model_parallel_all_reduce(x) + results.append(result.astype("float32").numpy().copy()) + dist.barrier() + + return _check_results_identical(results) + + +def test_non_16_aligned_raises_error(rank, world_size): + """Non-16 byte aligned tensors should raise RuntimeError in deterministic mode.""" + _enable_deterministic_mode() + mp_group = _init_custom_allreduce(world_size) + # 1026 * 4 = 4104 bytes (NOT multiple of 16) + x = paddle.to_tensor([1.0] * 1026, dtype=paddle.float32).reshape([1026, 1]) + + try: + with pytest.raises(RuntimeError, match="DETERMINISTIC_MODE.*multiple of 16"): + tensor_model_parallel_all_reduce(x, group_=mp_group) + finally: + communication.custom_ar_clear_ipc_handles() + + +def test_unsupported_dtype_raises_error(rank, world_size): + """Unsupported dtypes should raise AssertionError in deterministic mode.""" + _enable_deterministic_mode() + mp_group = _init_custom_allreduce(world_size) + x = _create_tensor(TENSOR_SIZE, paddle.int32, rank) + + try: + with pytest.raises(AssertionError, match="DETERMINISTIC_MODE.*not supported"): + tensor_model_parallel_all_reduce(x, group_=mp_group) + finally: + communication.custom_ar_clear_ipc_handles() + + +def main(): + if not dist.is_initialized(): + paddle.distributed.init_parallel_env() + + rank = dist.get_rank() + world_size = dist.get_world_size() + + assert world_size >= 2, f"Test requires at least 2 GPUs, got {world_size}" + + print(f"All-Reduce Deterministic Test (world_size={world_size}, runs={NUM_RUNS})") + + # Error path tests + test_non_16_aligned_raises_error(rank, world_size) + print("PASS: non-16 byte aligned tensor raises RuntimeError") + dist.barrier() + + test_unsupported_dtype_raises_error(rank, world_size) + print("PASS: unsupported dtype (int32) raises AssertionError") + dist.barrier() + + # Determinism tests for supported dtypes (small tensors) + for dtype in SUPPORTED_DTYPES: + assert test_custom_allreduce_deterministic( + rank, world_size, dtype + ), f"Custom all-reduce is NOT deterministic for {dtype}" + print(f"PASS: custom all-reduce deterministic for {dtype}") + dist.barrier() + + # Large tensor tests (> default 8MB, using increased max_size) + # Create one 128MB instance shared by all dtype tests to avoid IPC buffer leaks + _init_large_custom_allreduce(world_size) + + for dtype in SUPPORTED_DTYPES: + test_large_tensor_correctness(rank, world_size, dtype) + print(f"PASS: large tensor all-reduce correctness for {dtype}") + dist.barrier() + + for dtype in SUPPORTED_DTYPES: + assert test_large_tensor_deterministic( + rank, world_size, dtype + ), f"Large tensor all-reduce is NOT deterministic for {dtype}" + print(f"PASS: large tensor all-reduce deterministic for {dtype}") + dist.barrier() + + communication.custom_ar_clear_ipc_handles() + + print("All tests passed.") + + +if __name__ == "__main__": + main() diff --git a/tests/distributed/test_allreduce_deterministic_launch.py b/tests/distributed/test_allreduce_deterministic_launch.py new file mode 100644 index 0000000000..eee8e39662 --- /dev/null +++ b/tests/distributed/test_allreduce_deterministic_launch.py @@ -0,0 +1,36 @@ +import os +import subprocess +import sys + +import pytest + +pytestmark = pytest.mark.gpu + + +def test_rollout_model_with_distributed_launch(): + """ + test_rollout_model + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + rollout_script = os.path.join(current_dir, "allreduce_deterministic.py") + + command = [sys.executable, "-m", "paddle.distributed.launch", "--gpus", "0,1", rollout_script] + + print(f"Executing command: {' '.join(command)}") + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + try: + stdout, stderr = process.communicate(timeout=300) + return_code = process.returncode + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + return_code = -1 + + print("\n" + "=" * 50 + " STDOUT " + "=" * 50) + print(stdout) + print("\n" + "=" * 50 + " STDERR " + "=" * 50) + print(stderr) + + assert return_code == 0, f"Process exited with code {return_code}\nSTDERR: {stderr[-500:] if stderr else 'N/A'}" diff --git a/tests/engine/test_sampling_params_determinism.py b/tests/engine/test_sampling_params_determinism.py new file mode 100644 index 0000000000..b05437162d --- /dev/null +++ b/tests/engine/test_sampling_params_determinism.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from fastdeploy.engine.sampling_params import SamplingParams + +MAX_SEED = 922337203685477580 + + +class TestSamplingParamsDeterminism(unittest.TestCase): + """Test SamplingParams deterministic seed behavior""" + + _ENV_KEYS = ("FD_DETERMINISTIC_MODE",) + + def setUp(self): + """Save and clear deterministic env vars""" + self._saved_env = {k: os.environ.pop(k, None) for k in self._ENV_KEYS} + + def tearDown(self): + """Restore original env vars""" + for key, value in self._saved_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + def test_fixed_seed_in_deterministic_mode(self): + """seed=None should always resolve to 42 when FD_DETERMINISTIC_MODE=1""" + os.environ["FD_DETERMINISTIC_MODE"] = "1" + + for _ in range(5): + params = SamplingParams(seed=None) + self.assertEqual(params.seed, 42) + + def test_random_seed_in_non_deterministic_mode(self): + """seed=None should produce varying seeds when FD_DETERMINISTIC_MODE=0""" + os.environ["FD_DETERMINISTIC_MODE"] = "0" + + seeds = {SamplingParams(seed=None).seed for _ in range(10)} + self.assertGreaterEqual(len(seeds), 2) + + def test_explicit_seed_respected_in_both_modes(self): + """Explicit seed values should be kept regardless of deterministic mode""" + test_seeds = [0, 1, 100, MAX_SEED] + for mode in ("0", "1"): + os.environ["FD_DETERMINISTIC_MODE"] = mode + for seed in test_seeds: + params = SamplingParams(seed=seed) + self.assertEqual(params.seed, seed) + + def test_seed_out_of_range_rejected(self): + """Seeds outside [0, MAX_SEED] should raise ValueError""" + with self.assertRaises(ValueError): + SamplingParams(seed=-1) + + with self.assertRaises(ValueError): + SamplingParams(seed=MAX_SEED + 1) + + def test_env_switch_changes_behavior(self): + """Switching FD_DETERMINISTIC_MODE at runtime should affect subsequent SamplingParams""" + os.environ["FD_DETERMINISTIC_MODE"] = "1" + params_det = SamplingParams(seed=None) + self.assertEqual(params_det.seed, 42) + + os.environ["FD_DETERMINISTIC_MODE"] = "0" + seeds = {SamplingParams(seed=None).seed for _ in range(10)} + # At least some seeds should differ from the fixed value + self.assertGreaterEqual(len(seeds), 2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/inter_communicator/test_ipc_signal.py b/tests/inter_communicator/test_ipc_signal.py new file mode 100644 index 0000000000..60b59f4dea --- /dev/null +++ b/tests/inter_communicator/test_ipc_signal.py @@ -0,0 +1,345 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import time +import unittest +from multiprocessing.shared_memory import SharedMemory + +import numpy as np +import pytest + +from fastdeploy.inter_communicator.ipc_signal import IPCSignal, shared_memory_exists + + +class TestSharedMemoryExists(unittest.TestCase): + """Test cases for shared_memory_exists function.""" + + def test_returns_false_for_nonexistent_memory(self): + """Test that shared_memory_exists returns False for non-existent shared memory.""" + result = shared_memory_exists(f"nonexistent_shm_{time.time()}") + self.assertFalse(result) + + def test_returns_true_for_existing_memory(self): + """Test that shared_memory_exists returns True for existing shared memory.""" + name = f"test_shm_{time.time()}" + shm = SharedMemory(name=name, create=True, size=1024) + try: + result = shared_memory_exists(name) + self.assertTrue(result) + finally: + try: + shm.close() + shm.unlink() + except Exception: + pass + + +@pytest.mark.parametrize( + "dtype,shape,initial_value", + [ + (np.int32, (10,), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + (np.float32, (5,), [0.0, 1.5, 2.5, 3.5, 4.5]), + (np.int64, (3, 3), [[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + (np.uint8, (4,), [0, 127, 200, 255]), + ], +) +def test_ipc_signal_create_with_array(dtype, shape, initial_value): + """Test IPCSignal creation with numpy array.""" + name = f"test_ipc_signal_{time.time()}" + array = np.array(initial_value, dtype=dtype) + + signal = IPCSignal(name=name, array=array, dtype=dtype, create=True) + try: + # Verify value is initialized correctly + np.testing.assert_array_equal(signal.value, array) + np.testing.assert_equal(signal.value.dtype, dtype) + + # Verify shared memory exists + assert shared_memory_exists(name) + finally: + try: + signal.clear() + except Exception: + pass + + +class TestIPCSignal(unittest.TestCase): + """Test cases for IPCSignal class.""" + + def setUp(self): + """Set up test fixtures.""" + self.test_name_base = f"test_ipc_signal_{time.time()}" + self._signals_to_clean = [] + + def tearDown(self): + """Clean up all tracked signals.""" + for signal in self._signals_to_clean: + try: + signal.clear() + except Exception: + pass + + def _track(self, signal): + """Register a signal for automatic cleanup in tearDown.""" + self._signals_to_clean.append(signal) + return signal + + def test_create_with_suffix(self): + """Test IPCSignal creation with suffix.""" + name = self.test_name_base + suffix = 123 + + array = np.array([1, 2, 3], dtype=np.int32) + signal = self._track(IPCSignal(name=name, array=array, dtype=np.int32, suffix=suffix, create=True)) + + expected_name = f"{name}.{suffix}" + self.assertTrue(shared_memory_exists(expected_name)) + np.testing.assert_array_equal(signal.value, array) + + def test_attach_to_existing(self): + """Test IPCSignal attaching to existing shared memory.""" + name = f"{self.test_name_base}_attach" + array = np.array([10, 20, 30], dtype=np.int64) + + # Create shared memory + signal1 = self._track(IPCSignal(name=name, array=array, dtype=np.int64, create=True)) + signal1.value[0] = 99 # Modify value + + # Attach to existing + signal2 = IPCSignal(name=name, array=array, dtype=np.int64, create=False) + + # Verify value is shared + self.assertEqual(signal2.value[0], 99) + np.testing.assert_array_equal(signal2.value, signal1.value) + + def test_dtype_mismatch_raises_assertion(self): + """Test that dtype mismatch raises AssertionError.""" + name = f"{self.test_name_base}_mismatch" + array = np.array([1, 2, 3], dtype=np.int32) + + with self.assertRaises(AssertionError): + IPCSignal(name=name, array=array, dtype=np.float32, create=True) + + def test_non_numpy_array_raises_assertion(self): + """Test that non-numpy array raises AssertionError.""" + name = f"{self.test_name_base}_non_array" + + with self.assertRaises(AssertionError): + IPCSignal(name=name, array=[1, 2, 3], dtype=np.int32, create=True) + + def test_create_with_shm_size(self): + """Test IPCSignal creation with shm_size (no array).""" + name = f"{self.test_name_base}_size" + + signal = self._track(IPCSignal(name=name, shm_size=1024, create=True)) + + # Verify signal is created but value is None (no array template) + self.assertTrue(shared_memory_exists(name)) + self.assertIsNone(signal.value) + + def test_attach_with_shm_size(self): + """Test IPCSignal attach with shm_size (no array).""" + name = f"{self.test_name_base}_attach_size" + + # Create + self._track(IPCSignal(name=name, shm_size=512, create=True)) + + # Attach + signal2 = IPCSignal(name=name, shm_size=512, create=False) + + self.assertTrue(shared_memory_exists(name)) + self.assertIsNone(signal2.value) + + def test_shm_size_required_without_array_and_dtype(self): + """Test that shm_size is required when array and dtype are None.""" + name = f"{self.test_name_base}_no_size" + + with self.assertRaises(AssertionError): + IPCSignal(name=name, create=True) + + def test_clear_removes_shared_memory(self): + """Test that clear() properly removes shared memory.""" + name = f"{self.test_name_base}_clear" + array = np.array([1, 2, 3], dtype=np.int32) + + signal = IPCSignal(name=name, array=array, dtype=np.int32, create=True) + self.assertTrue(shared_memory_exists(name)) + + signal.clear() + self.assertFalse(shared_memory_exists(name)) + + def test_clear_idempotent(self): + """Test that clear() can be called multiple times safely.""" + name = f"{self.test_name_base}_idempotent" + array = np.array([1, 2, 3], dtype=np.int32) + + signal = IPCSignal(name=name, array=array, dtype=np.int32, create=True) + + # Should not raise exception + signal.clear() + signal.clear() # Call again + + def test_value_sharing_between_processes_mock(self): + """Test that value is shared (mocked for unit test).""" + name = f"{self.test_name_base}_shared" + array = np.array([100, 200, 300], dtype=np.int64) + + signal1 = self._track(IPCSignal(name=name, array=array, dtype=np.int64, create=True)) + signal2 = IPCSignal(name=name, array=array, dtype=np.int64, create=False) + + # Modify through signal1 + signal1.value[0] = 999 + signal1.value[1] = 888 + signal1.value[2] = 777 + + # Verify signal2 sees changes + self.assertEqual(signal2.value[0], 999) + self.assertEqual(signal2.value[1], 888) + self.assertEqual(signal2.value[2], 777) + + def test_multiple_array_creation_replaces_existing(self): + """Test that creating with same name replaces existing shared memory.""" + name = f"{self.test_name_base}_replace" + array1 = np.array([1, 2, 3], dtype=np.int32) + array2 = np.array([4, 5, 6], dtype=np.int32) + + signal1 = IPCSignal(name=name, array=array1, dtype=np.int32, create=True) + signal1.clear() + + signal2 = self._track(IPCSignal(name=name, array=array2, dtype=np.int32, create=True)) + + np.testing.assert_array_equal(signal2.value, array2) + + def test_clear_closes_and_unlinks(self): + """Test that clear() both closes and unlinks the shared memory.""" + name = f"{self.test_name_base}_unlink" + array = np.array([1, 2, 3], dtype=np.int32) + + signal = IPCSignal(name=name, array=array, dtype=np.int32, create=True) + + # After clear, the shared memory should be removed + signal.clear() + self.assertFalse(shared_memory_exists(name)) + + # Attempting to attach should fail + try: + _ = SharedMemory(name=name, create=False) + self.fail("Should have raised FileNotFoundError") + except FileNotFoundError: + pass + + def test_raw_buffer_read_write_with_shm_size(self): + """Test raw buffer read/write in shm_size mode.""" + name = f"{self.test_name_base}_raw_buf" + data = b"hello ipc signal" + + signal1 = self._track(IPCSignal(name=name, shm_size=1024, create=True)) + signal1.shm.buf[: len(data)] = data + + signal2 = IPCSignal(name=name, shm_size=1024, create=False) + self.assertEqual(bytes(signal2.shm.buf[: len(data)]), data) + + def test_create_overwrites_existing_without_clear(self): + """Test that create=True on existing name auto-unlinks and recreates.""" + name = f"{self.test_name_base}_overwrite" + array1 = np.array([1, 2, 3], dtype=np.int32) + array2 = np.array([7, 8, 9], dtype=np.int32) + + # Create first signal, do NOT clear it + IPCSignal(name=name, array=array1, dtype=np.int32, create=True) + + # Create again with same name — should auto-unlink old and recreate + signal2 = self._track(IPCSignal(name=name, array=array2, dtype=np.int32, create=True)) + np.testing.assert_array_equal(signal2.value, array2) + + def test_attach_nonexistent_raises_error(self): + """Test that create=False on non-existent shm raises FileNotFoundError.""" + name = f"nonexistent_signal_{time.time()}" + array = np.array([1, 2, 3], dtype=np.int32) + + with self.assertRaises(FileNotFoundError): + IPCSignal(name=name, array=array, dtype=np.int32, create=False) + + +class TestIPCSignalEdgeCases(unittest.TestCase): + """Test edge cases for IPCSignal.""" + + def test_empty_array_raises_error(self): + """Test IPCSignal with empty array raises ValueError due to nbytes=0.""" + name = f"test_empty_array_{time.time()}" + array = np.array([], dtype=np.int32) + + with self.assertRaises(ValueError): + IPCSignal(name=name, array=array, dtype=np.int32, create=True) + + def test_large_array(self): + """Test IPCSignal with large array.""" + name = f"test_large_array_{time.time()}" + size = 10000 + array = np.arange(size, dtype=np.int64) + + signal = IPCSignal(name=name, array=array, dtype=np.int64, create=True) + try: + np.testing.assert_array_equal(signal.value, array) + finally: + try: + signal.clear() + except Exception: + pass + + def test_multidimensional_array(self): + """Test IPCSignal with multidimensional array.""" + name = f"test_multi_array_{time.time()}" + array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) + + signal = IPCSignal(name=name, array=array, dtype=np.int32, create=True) + try: + self.assertEqual(signal.value.shape, (3, 3)) + np.testing.assert_array_equal(signal.value, array) + finally: + try: + signal.clear() + except Exception: + pass + + def test_different_numeric_types(self): + """Test IPCSignal with different numeric types.""" + name_base = f"test_types_{time.time()}" + + test_cases = [ + (np.int8, [1, 2, 3]), + (np.int16, [1000, 2000, 3000]), + (np.int32, [100000, 200000, 300000]), + (np.int64, [1000000000, 2000000000, 3000000000]), + (np.float32, [1.5, 2.5, 3.5]), + (np.float64, [1.123456789, 2.987654321, 3.5]), + ] + + for i, (dtype, values) in enumerate(test_cases): + name = f"{name_base}_{i}" + array = np.array(values, dtype=dtype) + signal = IPCSignal(name=name, array=array, dtype=dtype, create=True) + try: + np.testing.assert_array_equal(signal.value, array) + finally: + try: + signal.clear() + except Exception: + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/layers/test_flash_attention_versions_determinism.py b/tests/layers/test_flash_attention_versions_determinism.py new file mode 100644 index 0000000000..158166593e --- /dev/null +++ b/tests/layers/test_flash_attention_versions_determinism.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Flash Attention V2 / V3 determinism tests. + +Verify bitwise determinism of flash-backend SDPA when explicitly +selecting FA version via FLAGS_flash_attn_version (2 or 3). +""" + +import unittest + +import pytest + +pytestmark = pytest.mark.gpu + +import paddle +import paddle.nn.functional as F + +# --------------- constants --------------- +BATCH_SIZE = 2 +NUM_HEADS = 32 +HEAD_DIM = 64 +SEQ_LEN = 2048 +NUM_RUNS = 5 + + +# --------------- helpers --------------- +def _make_qkv(batch_size, num_heads, seq_len, head_dim, dtype="float16", seed=42): + """Create deterministic q/k/v tensors.""" + paddle.seed(seed) + shape = [batch_size, num_heads, seq_len, head_dim] + return ( + paddle.randn(shape, dtype=dtype), + paddle.randn(shape, dtype=dtype), + paddle.randn(shape, dtype=dtype), + ) + + +def _assert_deterministic(test_case, func, num_runs=NUM_RUNS): + """Run *func* multiple times and assert all results are bitwise equal.""" + results = [func().clone() for _ in range(num_runs)] + for i in range(1, num_runs): + test_case.assertTrue( + paddle.equal(results[0], results[i]).all().item(), + f"Run 0 vs Run {i} differ", + ) + + +# --------------- test class --------------- +class TestFlashAttentionVersionsDeterminism(unittest.TestCase): + """Test determinism when switching between FA2 and FA3.""" + + FA_VERSIONS = [2, 3] + + def setUp(self): + if not paddle.is_compiled_with_cuda(): + self.skipTest("Flash Attention requires CUDA") + paddle.set_device("gpu") + # Save/restore flag to avoid cross-test pollution + self._saved_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] + + def tearDown(self): + paddle.set_flags({"FLAGS_flash_attn_version": self._saved_version}) + + def _skip_if_fa3_unsupported(self): + prop = paddle.device.cuda.get_device_properties() + sm = prop.major * 10 + prop.minor + if sm < 89 or sm >= 100: + self.skipTest(f"FA3 requires SM89-SM99, current SM{sm}") + + def _set_fa_version(self, version): + if version == 3: + self._skip_if_fa3_unsupported() + paddle.set_flags({"FLAGS_flash_attn_version": version}) + + def _flash_sdpa(self, q, k, v, **kwargs): + """Thin wrapper: synchronize then call flash-backend SDPA.""" + paddle.device.synchronize() + return F.scaled_dot_product_attention(q, k, v, backend="flash", **kwargs) + + # ==================== tests ==================== + + def test_determinism(self): + """Multi-run determinism for FA2/FA3, causal and non-causal.""" + for version in self.FA_VERSIONS: + for is_causal in [False, True]: + with self.subTest(version=version, is_causal=is_causal): + self._set_fa_version(version) + q, k, v = _make_qkv(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM) + _assert_deterministic( + self, + lambda: self._flash_sdpa(q, k, v, is_causal=is_causal, enable_gqa=False), + ) + + def test_batch_invariance(self): + """First-sample result should be identical across batch sizes.""" + for version in self.FA_VERSIONS: + with self.subTest(version=version): + self._set_fa_version(version) + max_bs = 8 + q, k, v = _make_qkv(max_bs, NUM_HEADS, SEQ_LEN, HEAD_DIM) + + ref = self._flash_sdpa(q[:1], k[:1], v[:1], is_causal=False, enable_gqa=False) + for bs in [2, 4, 8]: + result = self._flash_sdpa(q[:bs], k[:bs], v[:bs], is_causal=False, enable_gqa=False) + self.assertTrue( + paddle.equal(ref, result[0:1]).all().item(), + f"FA{version} batch invariance failed at bs={bs}", + ) + + def test_seq_length_determinism(self): + """Determinism across various sequence lengths (including boundaries).""" + seq_lengths = [1, 2, 4, 8, 16, 64, 128, 256, 512, 1024, 2048, 4096] + for version in self.FA_VERSIONS: + for seq_len in seq_lengths: + with self.subTest(version=version, seq_len=seq_len): + self._set_fa_version(version) + q, k, v = _make_qkv(BATCH_SIZE, NUM_HEADS, seq_len, HEAD_DIM) + _assert_deterministic( + self, + lambda: self._flash_sdpa(q, k, v, is_causal=False, enable_gqa=False), + num_runs=2, + ) + + def test_dtype_determinism(self): + """Determinism across float16 and float32.""" + for version in self.FA_VERSIONS: + for dtype in ["float16", "float32"]: + with self.subTest(version=version, dtype=dtype): + self._set_fa_version(version) + q, k, v = _make_qkv(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=dtype) + _assert_deterministic( + self, + lambda: self._flash_sdpa(q, k, v, is_causal=False, enable_gqa=False), + num_runs=3, + ) + + def test_head_config_determinism(self): + """Determinism across different head configurations.""" + for version in self.FA_VERSIONS: + for num_heads, head_dim in [(1, 64), (7, 64), (32, 64)]: + with self.subTest(version=version, num_heads=num_heads, head_dim=head_dim): + self._set_fa_version(version) + q, k, v = _make_qkv(BATCH_SIZE, num_heads, SEQ_LEN, head_dim) + _assert_deterministic( + self, + lambda: self._flash_sdpa(q, k, v, is_causal=False, enable_gqa=False), + num_runs=2, + ) + + def test_gqa_determinism(self): + """Determinism with GQA enabled.""" + for version in self.FA_VERSIONS: + with self.subTest(version=version): + self._set_fa_version(version) + q, k, v = _make_qkv(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM) + _assert_deterministic( + self, + lambda: self._flash_sdpa(q, k, v, is_causal=False, enable_gqa=True), + num_runs=3, + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/scheduler/test_chunked_prefill_determinism.py b/tests/scheduler/test_chunked_prefill_determinism.py new file mode 100644 index 0000000000..1a0f786f3d --- /dev/null +++ b/tests/scheduler/test_chunked_prefill_determinism.py @@ -0,0 +1,472 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Chunked Prefill Determinism Tests + +Test _get_num_new_tokens alignment behavior in ResourceManagerV1: +1. Deterministic disabled (no alignment) +2. Deterministic enabled (split_kv_size boundary alignment) +3. Boundary cases +4. Continuous chunk consistency +5. Multimodal inputs (image / video / audio) +6. Real batch scheduling scenarios +7. Corner cases (empty request, invalid state, large split, dynamic switch, etc.) +""" + +import os +import unittest + +from fastdeploy.engine.request import Request +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 + +# --------------------------------------------------------------------------- +# Minimal config stubs -- only fields accessed by ResourceManagerV1.__init__ +# and _get_num_new_tokens are kept. +# --------------------------------------------------------------------------- + + +class ModelConfig: + def __init__(self): + self.enable_mm = False + self.causal = True + + +class CacheConfig: + def __init__(self): + self.block_size = 16 + self.enable_prefix_caching = False + self.kvcache_storage_backend = None + self.write_policy = None + self.num_cpu_blocks = 0 + self.total_block_num = 10000 + self.prefill_kvcache_block_num = 10000 + self.max_encoder_cache = 0 + self.max_processor_cache = 0 + self.bytes_per_token_per_layer = 32 * 32 * 128 * 2 + + +class ParallelConfig: + def __init__(self): + self.local_engine_worker_queue_port = None + self.tensor_parallel_size = 1 + + +class SpeculativeConfig: + def __init__(self): + self.method = None + self.num_speculative_tokens = 0 + self.model_type = None + + +class StubConfig: + """Assembles the minimal sub-configs needed by ResourceManagerV1.""" + + def __init__(self): + self.model_config = ModelConfig() + self.cache_config = CacheConfig() + self.parallel_config = ParallelConfig() + self.speculative_config = SpeculativeConfig() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _create_request(request_id, prompt_token_ids, num_computed_tokens=0, multimodal_inputs=None): + """Create a real Request object for testing.""" + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + prompt_token_ids_len=len(prompt_token_ids), + num_computed_tokens=num_computed_tokens, + multimodal_inputs=multimodal_inputs, + ) + + +def _build_mm_inputs(prompt_len, text_len, modal_id, extra=None): + """Build a multimodal_inputs dict for a single-modality request.""" + mm_len = prompt_len - text_len + patch_idx_val = modal_id # 1=image, 2=video, 3=audio + inputs = { + "image_patch_id": prompt_len + 1, + "image_end_id": prompt_len + 2, + "video_patch_id": prompt_len + 3, + "video_end_id": prompt_len + 4, + "audio_patch_id": prompt_len + 5, + "audio_end_id": prompt_len + 6, + "patch_idx": [0] * text_len + [patch_idx_val] * mm_len, + "patch_map": [ + {"modal_id": 0, "end_idx": text_len, "image_num": 0, "video_num": 0}, + { + "modal_id": modal_id, + "end_idx": prompt_len, + "image_num": 1 if modal_id == 1 else 0, + "video_num": 1 if modal_id == 2 else 0, + }, + ], + "tts": False, + } + if extra: + inputs.update(extra) + return inputs + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + + +class TestChunkedPrefillDeterminism(unittest.TestCase): + """Test _get_num_new_tokens alignment in deterministic mode.""" + + def setUp(self): + self._saved_env = {} + for key in ("FD_DETERMINISTIC_MODE", "FD_DETERMINISTIC_SPLIT_KV_SIZE"): + self._saved_env[key] = os.environ.get(key) + self.config = StubConfig() + self.rm = self._create_resource_manager(self.config) + + def tearDown(self): + for key, value in self._saved_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + # -- env helpers -- + + def _enable_deterministic(self, split_kv_size=16): + os.environ["FD_DETERMINISTIC_MODE"] = "1" + os.environ["FD_DETERMINISTIC_SPLIT_KV_SIZE"] = str(split_kv_size) + + def _disable_deterministic(self): + os.environ.pop("FD_DETERMINISTIC_MODE", None) + os.environ.pop("FD_DETERMINISTIC_SPLIT_KV_SIZE", None) + + def _create_resource_manager(self, config): + return ResourceManagerV1( + max_num_seqs=32, + config=config, + tensor_parallel_size=1, + splitwise_role="mixed", + local_data_parallel_id=0, + ) + + def _create_mm_resource_manager(self): + config = StubConfig() + config.model_config.enable_mm = True + return self._create_resource_manager(config) + + # ==================== 1. Deterministic disabled ==================== + + def test_get_num_new_tokens_deterministic_disabled(self): + """No alignment when deterministic mode is off; budget=0 returns 0.""" + self._disable_deterministic() + + test_cases = [ + # (prompt_tokens, num_computed, token_budget, expected) + (list(range(100)), 0, 50, 50), + (list(range(100)), 50, 30, 30), + (list(range(100)), 90, 20, 10), + (list(range(32)), 0, 15, 15), + # budget=0 -> 0 + (list(range(100)), 0, 0, 0), + ] + for prompt_ids, num_computed, budget, expected in test_cases: + with self.subTest(prompt_len=len(prompt_ids), computed=num_computed, budget=budget): + req = _create_request("req", prompt_ids, num_computed) + result = self.rm._get_num_new_tokens(req, budget) + self.assertEqual(result, expected) + + # ==================== 2. Deterministic enabled alignment ==================== + + def test_get_num_new_tokens_deterministic_enabled_alignment(self): + """Results must align to split_kv_size boundary.""" + split_kv_size = 16 + self._enable_deterministic(split_kv_size) + + test_cases = [ + # (prompt_tokens, num_computed, token_budget, expected) + (list(range(100)), 0, 20, 16), + (list(range(100)), 0, 32, 32), + (list(range(100)), 0, 40, 32), + (list(range(100)), 0, 50, 48), + (list(range(100)), 8, 20, 8), + (list(range(100)), 8, 30, 24), + (list(range(100)), 16, 20, 16), + (list(range(100)), 16, 25, 16), + ] + for prompt_ids, num_computed, budget, expected in test_cases: + with self.subTest(computed=num_computed, budget=budget): + req = _create_request("req", prompt_ids, num_computed) + result = self.rm._get_num_new_tokens(req, budget) + self.assertEqual(result, expected) + # Verify alignment + if result > 0: + final_pos = num_computed + result + self.assertEqual(final_pos % split_kv_size, 0) + + # ==================== 3. Boundary cases ==================== + + def test_get_num_new_tokens_boundary_cases(self): + """Boundary conditions including large budget.""" + split_kv_size = 16 + self._enable_deterministic(split_kv_size) + + test_cases = [ + (list(range(100)), 0, 5, "budget < split_kv_size, start at 0"), + (list(range(100)), 0, 1, "budget = 1, start at 0"), + (list(range(100)), 10, 5, "budget < split_kv_size, start at 10"), + (list(range(100)), 15, 5, "budget < split_kv_size, near boundary"), + (list(range(16)), 0, 16, "exactly split_kv_size tokens needed"), + (list(range(16)), 0, 32, "budget > needed"), + # Very large budget (overflow guard) + (list(range(100)), 0, 1000000, "very large budget"), + ] + for prompt_ids, num_computed, budget, desc in test_cases: + with self.subTest(desc=desc): + req = _create_request("req", prompt_ids, num_computed) + result = self.rm._get_num_new_tokens(req, budget) + max_possible = min(len(prompt_ids) - num_computed, budget) + self.assertGreaterEqual(result, 0) + self.assertLessEqual(result, max_possible) + + # ==================== 4. Chunk consistency ==================== + + def test_get_num_new_tokens_consistency_across_chunks(self): + """All chunk boundaries must align to split_kv_size.""" + split_kv_size = 16 + self._enable_deterministic(split_kv_size) + + prompt_ids = list(range(112)) + budget = 50 + num_computed = 0 + chunk_sizes = [] + + while num_computed < len(prompt_ids): + req = _create_request("req", prompt_ids, num_computed) + result = self.rm._get_num_new_tokens(req, budget) + if result == 0: + break + chunk_sizes.append(result) + num_computed += result + + # Every intermediate boundary must be aligned; final position may equal seq length + position = 0 + for chunk_size in chunk_sizes: + position += chunk_size + is_ok = (position % split_kv_size == 0) or (position == len(prompt_ids)) + self.assertTrue(is_ok, f"position {position} not aligned to {split_kv_size}") + + self.assertEqual(num_computed, len(prompt_ids)) + + # ==================== 5. Multimodal (parameterized) ==================== + + _MULTIMODAL_CASES = [ + {"name": "image", "prompt_len": 150, "text_len": 50, "modal_id": 1, "budget": 60, "extra": {}}, + { + "name": "video", + "prompt_len": 200, + "text_len": 80, + "modal_id": 2, + "budget": 50, + "extra": {"can_split_idx_list": [96, 112, 128, 144, 160, 176, 192]}, + }, + {"name": "audio", "prompt_len": 120, "text_len": 60, "modal_id": 3, "budget": 40, "extra": {}}, + ] + + def test_multimodal_input_single_modality(self): + """Token allocation for image / video / audio multimodal requests.""" + self._enable_deterministic(16) + rm = self._create_mm_resource_manager() + + for case in self._MULTIMODAL_CASES: + with self.subTest(modality=case["name"]): + prompt_ids = list(range(case["prompt_len"])) + mm_inputs = _build_mm_inputs(case["prompt_len"], case["text_len"], case["modal_id"], case["extra"]) + req = _create_request(f"mm_{case['name']}", prompt_ids, 0, mm_inputs) + result = rm._get_num_new_tokens(req, case["budget"]) + self.assertGreaterEqual(result, 0) + self.assertLessEqual(result, case["budget"]) + + # ==================== 6. Real batch scheduling ==================== + + def test_real_batch_scheduling_concurrent_requests(self): + """Multiple requests competing for budget, all must respect alignment.""" + split_kv_size = 16 + self._enable_deterministic(split_kv_size) + budget = 50 + + batch = [ + ("req1", list(range(27)), 0), + ("req2", list(range(63)), 0), + ("req3", list(range(128)), 0), + ("req4", list(range(60)), 10), + ("req5", list(range(47)), 7), + ] + for rid, prompt_ids, computed in batch: + with self.subTest(request=rid): + req = _create_request(rid, prompt_ids, computed) + result = self.rm._get_num_new_tokens(req, budget) + final_pos = computed + result + max_possible = min(len(prompt_ids) - computed, budget) + self.assertLessEqual(result, max_possible) + if result > 0: + is_ok = (final_pos % split_kv_size == 0) or (final_pos == len(prompt_ids)) + self.assertTrue(is_ok, f"{rid}: final_pos={final_pos} not aligned") + + def test_real_batch_scheduling_continuous_prefill(self): + """Continuous prefill: all chunks fully consume a 47-token prompt.""" + split_kv_size = 16 + self._enable_deterministic(split_kv_size) + + prompt_ids = list(range(47)) + budget = 50 + num_computed = 0 + iterations = 0 + + while num_computed < len(prompt_ids) and iterations < 10: + req = _create_request("cont", prompt_ids, num_computed) + result = self.rm._get_num_new_tokens(req, budget) + self.assertGreater(result, 0, f"stuck at {num_computed}") + final_pos = num_computed + result + is_ok = (final_pos % split_kv_size == 0) or (final_pos == len(prompt_ids)) + self.assertTrue(is_ok, f"chunk ending at {final_pos} not aligned") + num_computed += result + iterations += 1 + + self.assertEqual(num_computed, len(prompt_ids)) + + def test_real_batch_scheduling_with_multimodal_requests(self): + """Mixed batch: text-only + image requests.""" + self._enable_deterministic(16) + rm = self._create_mm_resource_manager() + budget = 30 + + # Text-only request + req_text = _create_request("text_only", list(range(100)), 0) + r1 = rm._get_num_new_tokens(req_text, budget) + self.assertGreaterEqual(r1, 0) + self.assertLessEqual(r1, budget) + + # Image request + mm_inputs = _build_mm_inputs(80, 40, modal_id=1) + req_img = _create_request("with_image", list(range(80)), 0, mm_inputs) + r2 = rm._get_num_new_tokens(req_img, budget) + self.assertGreaterEqual(r2, 0) + self.assertLessEqual(r2, budget) + + # ==================== 7. Corner cases ==================== + + def test_corner_case_invalid_request_states(self): + """Empty prompt, completed prefill, and num_computed > need_prefill must assert.""" + self._enable_deterministic(16) + + # Empty prompt + with self.subTest(case="empty prompt"): + with self.assertRaises(AssertionError): + self.rm._get_num_new_tokens(_create_request("e", [], 0), 50) + + # Already completed + with self.subTest(case="completed prefill"): + with self.assertRaises(AssertionError): + self.rm._get_num_new_tokens(_create_request("c", list(range(100)), 100), 50) + + # Inconsistent state + with self.subTest(case="num_computed > need_prefill"): + with self.assertRaises(AssertionError): + self.rm._get_num_new_tokens(_create_request("i", list(range(50)), 100), 50) + + # Zero budget (legitimate, returns 0) + with self.subTest(case="zero budget"): + result = self.rm._get_num_new_tokens(_create_request("z", list(range(100)), 0), 0) + self.assertEqual(result, 0) + + def test_corner_case_minimum_split_size(self): + """split_kv_size=1: every position is aligned, so max allocation is allowed.""" + self._enable_deterministic(1) + + for prompt_ids, computed, budget, expected in [ + (list(range(100)), 0, 20, 20), + (list(range(100)), 10, 15, 15), + (list(range(100)), 50, 10, 10), + ]: + with self.subTest(computed=computed, budget=budget): + req = _create_request("min", prompt_ids, computed) + result = self.rm._get_num_new_tokens(req, budget) + self.assertEqual(result, expected) + + def test_corner_case_large_split_size(self): + """split_kv_size >> budget or sequence length.""" + test_cases = [ + # (split_kv_size, prompt_ids, num_computed, budget, description) + (128, list(range(100)), 0, 10, "split >> budget: budget=10"), + (128, list(range(100)), 0, 1, "split >> budget: budget=1"), + (128, list(range(100)), 64, 20, "split >> budget: near boundary"), + (256, list(range(50)), 0, 100, "split >> seq_len"), + ] + for split_kv_size, prompt_ids, computed, budget, desc in test_cases: + with self.subTest(desc=desc): + self._enable_deterministic(split_kv_size) + req = _create_request("lg", prompt_ids, computed) + result = self.rm._get_num_new_tokens(req, budget) + max_possible = min(len(prompt_ids) - computed, budget) + self.assertGreaterEqual(result, 0) + self.assertLessEqual(result, max_possible) + + def test_corner_case_dynamic_config_switch(self): + """Switching from non-deterministic to deterministic mid-stream.""" + # Phase 1: non-deterministic + self._disable_deterministic() + req1 = _create_request("sw1", list(range(100)), 0) + result1 = self.rm._get_num_new_tokens(req1, 30) + + # Phase 2: enable deterministic, continue from result1 + split_kv_size = 16 + self._enable_deterministic(split_kv_size) + req2 = _create_request("sw2", list(range(100)), result1) + result2 = self.rm._get_num_new_tokens(req2, 30) + + if result2 > 0: + final_pos = result1 + result2 + is_aligned = (final_pos % split_kv_size == 0) or (final_pos == 100) + self.assertTrue(is_aligned, f"final_pos={final_pos} not aligned after switch") + + def test_deterministic_return_zero_budget_below_boundary(self): + """Returns 0 when budget cannot reach the next alignment boundary.""" + split_kv_size = 16 + self._enable_deterministic(split_kv_size) + + test_cases = [ + # (prompt_ids, num_computed, budget) + # pos=10, next_boundary=16, need 6, budget=5 + (list(range(100)), 10, 5), + # pos=1, next_boundary=16, need 15, budget=3 + (list(range(100)), 1, 3), + # pos=17, next_boundary=32, need 15, budget=14 + (list(range(100)), 17, 14), + # budget=0 (deterministic) + (list(range(100)), 0, 0), + ] + for prompt_ids, computed, budget in test_cases: + with self.subTest(computed=computed, budget=budget): + req = _create_request("det0", prompt_ids, computed) + result = self.rm._get_num_new_tokens(req, budget) + self.assertEqual(result, 0) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/worker/test_deterministic_logger.py b/tests/worker/test_deterministic_logger.py new file mode 100644 index 0000000000..82e28820ad --- /dev/null +++ b/tests/worker/test_deterministic_logger.py @@ -0,0 +1,345 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import logging +import os +import sys +import types +import unittest +from types import SimpleNamespace +from unittest.mock import Mock + +import numpy as np + +# Register fastdeploy as a bare namespace package so that +# ``from fastdeploy.worker.deterministic_logger import ...`` does NOT +# execute fastdeploy/__init__.py (which pulls in paddle, paddleformers, etc.). +_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +for _pkg, _rel_path in [ + ("fastdeploy", "fastdeploy"), + ("fastdeploy.logger", "fastdeploy/logger"), + ("fastdeploy.worker", "fastdeploy/worker"), +]: + if _pkg not in sys.modules: + _mod = types.ModuleType(_pkg) + _mod.__path__ = [os.path.join(_project_root, _rel_path)] + _mod.__package__ = _pkg + sys.modules[_pkg] = _mod + +from fastdeploy.logger.deterministic_logger import DeterministicLogger # noqa: E402 + + +def _make_tensor(array): + """Create a mock tensor that behaves like a paddle Tensor for testing.""" + arr = np.array(array) + tensor = Mock() + tensor.cpu.return_value = tensor + tensor.numpy.return_value = arr + tensor.shape = arr.shape + tensor.__len__ = lambda self: arr.shape[0] + tensor.__getitem__ = lambda self, idx: _make_tensor(arr[idx]) + return tensor + + +class TestComputeTensorMd5(unittest.TestCase): + def test_none_tensor(self): + result = DeterministicLogger._compute_tensor_md5(None, name="x") + self.assertEqual(result, "x_md5=None") + + def test_deterministic_hash(self): + t = _make_tensor([1.0, 2.0, 3.0]) + r1 = DeterministicLogger._compute_tensor_md5(t, name="a") + r2 = DeterministicLogger._compute_tensor_md5(t, name="a") + self.assertEqual(r1, r2) + self.assertIn("a_md5=", r1) + + def test_different_tensors_different_hash(self): + t1 = _make_tensor([1.0, 2.0]) + t2 = _make_tensor([3.0, 4.0]) + r1 = DeterministicLogger._compute_tensor_md5(t1, name="x") + r2 = DeterministicLogger._compute_tensor_md5(t2, name="x") + self.assertNotEqual(r1, r2) + + def test_prefix(self): + t = _make_tensor([1.0]) + result = DeterministicLogger._compute_tensor_md5(t, name="h", prefix="batch_") + self.assertTrue(result.startswith("batch_h_md5=")) + + def test_md5_truncated_to_16_chars(self): + t = _make_tensor([1.0, 2.0, 3.0]) + result = DeterministicLogger._compute_tensor_md5(t, name="x") + md5_value = result.split("=")[1] + self.assertEqual(len(md5_value), 16) + + +class TestGetBatchSize(unittest.TestCase): + def test_returns_first_tensor_batch_size(self): + t = _make_tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + result = DeterministicLogger._get_batch_size({"a": t}) + self.assertEqual(result, 3) + + def test_skips_none_tensors(self): + t = _make_tensor([[1.0], [2.0]]) + result = DeterministicLogger._get_batch_size({"a": None, "b": t}) + self.assertEqual(result, 2) + + def test_returns_none_for_empty_dict(self): + self.assertIsNone(DeterministicLogger._get_batch_size({})) + + def test_returns_none_for_all_none(self): + self.assertIsNone(DeterministicLogger._get_batch_size({"a": None})) + + +class TestBuildReqIdStr(unittest.TestCase): + def test_none_list(self): + self.assertEqual(DeterministicLogger._build_req_id_str(None), "") + + def test_single_request(self): + req = Mock(request_id="req-001") + result = DeterministicLogger._build_req_id_str([req]) + self.assertEqual(result, "[0]req-001") + + def test_multiple_requests_with_none(self): + r1 = Mock(request_id="r1") + r2 = Mock(request_id="r2") + result = DeterministicLogger._build_req_id_str([r1, None, r2]) + self.assertEqual(result, "[0]r1, [2]r2") + + +class TestGetStageCounts(unittest.TestCase): + def test_no_seq_lens_encoder(self): + logger = DeterministicLogger(share_inputs={}) + prefill, decode, enc = logger._get_stage_counts(batch_size=4) + self.assertEqual(prefill, 0) + self.assertEqual(decode, 0) + self.assertIsNone(enc) + + def test_with_seq_lens_encoder(self): + # seq_lens_encoder: [5, 0, 3, 0] -> 2 prefill, 2 decode + enc_tensor = _make_tensor([5, 0, 3, 0]) + logger = DeterministicLogger(share_inputs={"seq_lens_encoder": enc_tensor}) + prefill, decode, enc = logger._get_stage_counts(batch_size=4) + self.assertEqual(prefill, 2) + self.assertEqual(decode, 2) + np.testing.assert_array_equal(enc, np.array([5, 0, 3, 0])) + + def test_all_prefill(self): + enc_tensor = _make_tensor([10, 20]) + logger = DeterministicLogger(share_inputs={"seq_lens_encoder": enc_tensor}) + prefill, decode, _ = logger._get_stage_counts(batch_size=2) + self.assertEqual(prefill, 2) + self.assertEqual(decode, 0) + + def test_none_share_inputs(self): + logger = DeterministicLogger(share_inputs=None) + prefill, decode, enc = logger._get_stage_counts(batch_size=4) + self.assertEqual(prefill, 0) + self.assertEqual(decode, 0) + self.assertIsNone(enc) + + +class TestLogTensorMd5s(unittest.TestCase): + def test_logs_batch_md5(self): + t = _make_tensor([[1.0, 2.0], [3.0, 4.0]]) + logger = DeterministicLogger(share_inputs={}) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_tensor_md5s({"hidden": t}, stage="test_stage") + self.assertTrue(any("[DETERMINISM-MD5]" in msg for msg in cm.output)) + self.assertTrue(any("stage=test_stage" in msg for msg in cm.output)) + + def test_skips_when_no_valid_tensor(self): + logger = DeterministicLogger(share_inputs={}) + det_log = logging.getLogger("fastdeploy.deterministic") + det_log.setLevel(logging.INFO) + # Should not raise, just silently return + logger.log_tensor_md5s({"a": None}) + + def test_logs_with_request_ids(self): + t = _make_tensor([[1.0], [2.0]]) + req = Mock(request_id="req-42") + logger = DeterministicLogger(share_inputs={}) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_tensor_md5s({"x": t}, forward_batch_reqs_list=[req], stage="s") + self.assertTrue(any("req-42" in msg for msg in cm.output)) + + def test_logs_per_request_md5_for_decode(self): + # 2 requests, both decode (seq_lens_encoder = [0, 0]) + t = _make_tensor([[1.0, 2.0], [3.0, 4.0]]) + enc_tensor = _make_tensor([0, 0]) + r1 = Mock(request_id="r1") + r2 = Mock(request_id="r2") + logger = DeterministicLogger(share_inputs={"seq_lens_encoder": enc_tensor}) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_tensor_md5s({"out": t}, forward_batch_reqs_list=[r1, r2], stage="decode") + req_msgs = [msg for msg in cm.output if "[DETERMINISM-MD5-REQ]" in msg] + self.assertEqual(len(req_msgs), 2) + + +class TestLogDeterministicInput(unittest.TestCase): + def _make_forward_meta(self, ids_list): + ids_tensor = _make_tensor(ids_list) + return SimpleNamespace(ids_remove_padding=ids_tensor) + + def test_logs_input_info(self): + forward_meta = self._make_forward_meta([101, 102, 201]) + share_inputs = { + "req_ids": ["req-a", "req-b"], + "seq_lens_this_time": [2, 1], + "seq_lens_encoder": [2, 0], + "seq_lens_decoder": [0, 5], + } + logger = DeterministicLogger(share_inputs=share_inputs) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_deterministic_input(forward_meta) + output = "\n".join(cm.output) + self.assertIn("batch_size=2", output) + self.assertIn("req_id=req-a", output) + self.assertIn("req_id=req-b", output) + self.assertIn("tokens=[101, 102]", output) + self.assertIn("tokens=[201]", output) + + def test_no_input_data(self): + forward_meta = SimpleNamespace(ids_remove_padding=None) + share_inputs = { + "req_ids": None, + "seq_lens_this_time": [], + "seq_lens_encoder": None, + "seq_lens_decoder": None, + } + logger = DeterministicLogger(share_inputs=share_inputs) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_deterministic_input(forward_meta) + self.assertTrue(any("No input data" in msg for msg in cm.output)) + + def test_fallback_req_id(self): + forward_meta = self._make_forward_meta([10, 20]) + share_inputs = { + "req_ids": None, + "seq_lens_this_time": [1, 1], + "seq_lens_encoder": None, + "seq_lens_decoder": None, + } + logger = DeterministicLogger(share_inputs=share_inputs) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_deterministic_input(forward_meta) + output = "\n".join(cm.output) + self.assertIn("req_id=idx_0", output) + self.assertIn("req_id=idx_1", output) + + +class TestLogBatchStart(unittest.TestCase): + def _make_logger(self): + return DeterministicLogger(share_inputs={}) + + def _make_req(self, request_id): + return Mock(request_id=request_id) + + def test_logs_batch_start(self): + logger = self._make_logger() + batch = [self._make_req("prompt_0")] + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_batch_start(batch) + output = "\n".join(cm.output) + self.assertIn("[BATCH-START]", output) + self.assertIn("Run_0", output) + self.assertIn("Batch_1", output) + + def test_batch_counter_increments(self): + logger = self._make_logger() + batch = [self._make_req("prompt_0")] + with self.assertLogs("fastdeploy.deterministic", level="INFO"): + logger.log_batch_start(batch) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_batch_start(batch) + output = "\n".join(cm.output) + self.assertIn("Batch_2", output) + + def test_run_id_change_resets_counter(self): + logger = self._make_logger() + batch_0 = [self._make_req("prompt_0")] + batch_1 = [self._make_req("prompt_1")] + with self.assertLogs("fastdeploy.deterministic", level="INFO"): + logger.log_batch_start(batch_0) + logger.log_batch_start(batch_0) # Batch_2 + # Switch to run_id 1 => counter resets + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_batch_start(batch_1) + output = "\n".join(cm.output) + self.assertIn("Run_1", output) + self.assertIn("Batch_1", output) + + def test_skips_none_requests(self): + logger = self._make_logger() + batch = [None, self._make_req("req_5")] + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_batch_start(batch) + output = "\n".join(cm.output) + self.assertIn("Run_5", output) + + def test_empty_batch(self): + logger = self._make_logger() + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_batch_start([]) + output = "\n".join(cm.output) + self.assertIn("Run_None", output) + self.assertIn("Batch_1", output) + + def test_none_batch(self): + logger = self._make_logger() + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_batch_start(None) + output = "\n".join(cm.output) + self.assertIn("Batch_1", output) + + +class TestLogPrefillInput(unittest.TestCase): + def test_logs_prefill_input(self): + logger = DeterministicLogger(share_inputs={}) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_prefill_input( + request_id="req-001", + idx=0, + prefill_start_index=0, + prefill_end_index=5, + input_ids=[101, 102, 103, 104, 105], + ) + output = "\n".join(cm.output) + self.assertIn("[DETERMINISM] Prefill input", output) + self.assertIn("request_id: req-001", output) + self.assertIn("idx: 0", output) + self.assertIn("prefill_start_index: 0", output) + self.assertIn("prefill_end_index: 5", output) + self.assertIn("[101, 102, 103, 104, 105]", output) + + def test_logs_with_nonzero_start_index(self): + logger = DeterministicLogger(share_inputs={}) + with self.assertLogs("fastdeploy.deterministic", level="INFO") as cm: + logger.log_prefill_input( + request_id="req-002", + idx=3, + prefill_start_index=10, + prefill_end_index=20, + input_ids=list(range(20)), + ) + output = "\n".join(cm.output) + self.assertIn("request_id: req-002", output) + self.assertIn("idx: 3", output) + self.assertIn("prefill_start_index: 10", output) + self.assertIn("prefill_end_index: 20", output) + + +if __name__ == "__main__": + unittest.main()