Files
FastDeploy/tests/engine/test_request.py
T
kevin 7707be8384 [Feature][KVCache] Implement Cache Manager V1 with GPU + CPU Cache Support (1/n) (#7097)
* [Feature][KVCache] Support cache manager v1 architecture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update cache manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: update cache_manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add node to evictable set in complete_swap_to_device

When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: update cache manager v1 and related modules

- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat(cache): add cache controller v1 implementation

- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling

* feat(cache_manager): update cache manager v1

* fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(cache_manager): refactor cache manager v1 and optimize swap ops

## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项

* [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash

## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```

* fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager

## Motivation

修复两处问题:
1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错
2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑

## Modifications

- `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811
- `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑
- `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix pre-commit code style issues

## Motivation

修复 CI pre-commit 代码风格检查失败问题。

## Modifications

- `fastdeploy/engine/common_engine.py`: black 格式化
- `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复
- `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复
- `fastdeploy/worker/gpu_worker.py`: isort 修复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] update cache_manager_v1 modules

## Motivation

更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。

## Modifications

- `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构
- `fastdeploy/config.py`:配置项更新
- `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing

## Motivation

将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。

## Modifications

- `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求
- `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests

## Motivation

优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟;
同时跳过 swap cache ops 测试(当前环境不支持)。

## Modifications

- `fastdeploy/cache_manager/v1/cache_controller.py`:
  - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点
  - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点
  - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能
- `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix unittest failures for cache_manager_v1

三个单测因接口变更或 Mock 方式问题导致失败,需修复。

- tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError`
- tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager`
- tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] remove debug logging code

## Modifications

- fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志
- fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix cupy device id caching and pickle for _match_result

## Motivation

修复两个 bug:
1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。
2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。

## Modifications

- `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。
- `request.py`:
  - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。
  - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。

## Usage or Command

* fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP

- Use BatchRequest instead of list in test_trigger_preempt_records_tasks
- Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs

## Motivation

`gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`,
将 `available_blocks()` 的返回值当作整数传给 `range()`,
导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。

## Modifications

- `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError

## Motivation

gpu_free_block_list 属性中调用 BlockPool.available_blocks(),
该方法返回 int(空闲块数量),用 list() 包装 int 会触发
TypeError: 'int' object is not iterable。

## Modifications

将 list(self._device_pool.available_blocks()) 改为
list(self._device_pool._free_blocks),直接返回空闲块索引列表。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作

## Motivation

V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配,
同时 free_cache 需要支持可选的 clear_storage 参数。

## Modifications

- cache_controller.py: free_cache 新增 clear_storage 参数(默认 False),
  仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空
- common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时
  使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑
- gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除

## Usage or Command

# 启动服务(V1 CacheManager)
python -m fastdeploy.entrypoints.openai.api_server \
  --enable-v1-kvcache-manager \
  ...

* [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup

- Remove unused `select_blocks_for_backup` method from radix_tree.py
- Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py
- Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix)
- Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path

Remove redundant prefix_caching branch in else path; always call
recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request

## Motivation

修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。

## Modifications

1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作)
2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃
3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用
4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]`
5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音
6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入

## Usage or Command

单元测试:
```bash
source .venv/py310/bin/activate
cd baidu/FastDeploy
python -m pytest tests/cache_manager/v1/test_cache_manager.py -v
python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v
python -m pytest tests/engine/test_request.py -v
```

* [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0

## Motivation

当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误:
`-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`,
会返回并清空整个空闲块列表,而非返回空列表。

## Modifications

在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`,
避免触发 Python 负索引陷阱。

## Usage or Command

```bash
# 运行相关单元测试验证修复
python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s
```

* [KVCache][Test] add unit tests for cache_manager v1 modules

## Motivation

补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。

## Modifications

新增/补充以下测试文件,全部 326 个用例通过:

- tests/cache_manager/v1/test_block_pool.py(新建)
  覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool
- tests/cache_manager/v1/test_metadata.py(新建)
  覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler
- tests/cache_manager/v1/test_cache_utils.py(补充)
  新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法
- tests/cache_manager/v1/test_radix_tree.py(补充)
  新增 TestCompleteSwapToDevice 专项测试类(6 个用例)
- tests/cache_manager/v1/test_cache_manager.py(补充)
  新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata
- tests/cache_manager/v1/test_transfer_manager.py(补充)
  新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event

## Usage or Command

```bash
# 运行所有新增单测
source .venv/py310/bin/activate
python -m pytest tests/cache_manager/v1/test_block_pool.py \
  tests/cache_manager/v1/test_metadata.py \
  tests/cache_manager/v1/test_cache_utils.py \
  tests/cache_manager/v1/test_radix_tree.py \
  tests/cache_manager/v1/test_cache_manager.py \
  tests/cache_manager/v1/test_transfer_manager.py -v
# 期望结果:326 passed
```

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2026-04-21 14:39:00 +08:00

1380 lines
55 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import pickle
import unittest
from unittest.mock import Mock
import numpy as np
from fastdeploy.cache_manager.v1.metadata import CacheLevel, CacheSwapMetadata
from fastdeploy.engine.request import (
BatchRequest,
CompletionOutput,
ImagePosition,
PoolingParams,
Request,
RequestMetrics,
RequestOutput,
RequestStatus,
RequestType,
SamplingParams,
StructuralTagResponseFormat,
)
from fastdeploy.entrypoints.openai.protocol import ResponseFormat, StructuralTag
def _make_swap_meta(src_ids, dst_ids, hash_values=None):
"""Helper: create a CacheSwapMetadata instance."""
return CacheSwapMetadata(
src_block_ids=list(src_ids),
dst_block_ids=list(dst_ids),
src_type="host",
dst_type="device",
hash_values=list(hash_values) if hash_values else [],
)
class TestRequestInit(unittest.TestCase):
"""Test cases for Request initialization"""
def test_init_default_values(self):
"""Test initialization with default values"""
request = Request(request_id="test_123")
# Test basic attributes
self.assertEqual(request.request_id, "test_123")
self.assertIsNone(request.prompt)
self.assertIsNone(request.prompt_token_ids)
self.assertIsNone(request.prompt_token_ids_len)
self.assertIsNone(request.messages)
self.assertIsNone(request.system)
self.assertIsNone(request.sampling_params)
self.assertIsNone(request.pooling_params)
self.assertIsNone(request.history)
self.assertIsNone(request.tools)
self.assertIsNone(request.eos_token_ids)
# Test default values
self.assertEqual(request.num_cached_tokens, 0)
self.assertEqual(request.num_cached_blocks, 0)
self.assertFalse(request.disable_chat_template)
self.assertIsNone(request.disaggregate_info)
# Test multi-modal defaults
self.assertIsNone(request.multimodal_inputs)
self.assertIsNone(request.multimodal_data)
self.assertIsNone(request.multimodal_img_boundaries)
# Test status and type
self.assertEqual(request.status, RequestStatus.WAITING)
self.assertEqual(request.task_type, RequestType.PREFILL)
self.assertIsNone(request.idx)
self.assertEqual(request.need_prefill_tokens, None) # prompt_token_ids_len is None
# Test internal structures
self.assertEqual(request.block_tables, [])
self.assertEqual(request.output_token_ids, [])
self.assertEqual(request.num_computed_tokens, 0)
self.assertEqual(request.prefill_start_index, 0)
self.assertEqual(request.prefill_end_index, 0)
self.assertEqual(request.async_process_futures, [])
self.assertIsNone(request.error_message)
self.assertIsNone(request.error_code)
def test_init_with_parameters(self):
"""Test initialization with various parameters"""
sampling_params = SamplingParams()
pooling_params = PoolingParams()
metrics = RequestMetrics()
request = Request(
request_id="test_full",
prompt="Hello world",
prompt_token_ids=[1, 2, 3],
prompt_token_ids_len=3,
messages=[{"role": "user", "content": "Hello"}],
system="You are helpful",
sampling_params=sampling_params,
pooling_params=pooling_params,
history=[["user", "hello"]],
tools=[{"name": "test_tool"}],
eos_token_ids=[0],
disable_chat_template=True,
disaggregate_info={"key": "value"},
draft_token_ids=[4, 5],
guided_json={"schema": "test"},
guided_regex="test.*",
guided_choice=["option1", "option2"],
guided_grammar="grammar",
structural_tag="tag",
guided_json_object=True,
enable_thinking=True,
reasoning_max_tokens=100,
trace_carrier={"trace": "carrier"},
dp_rank=0,
chat_template="template",
image_start=1,
video_start=2,
audio_start=3,
image_end=4,
video_end=5,
audio_end=6,
prefill_start_index=10,
prefill_end_index=20,
num_computed_tokens=5,
metrics=metrics,
user="test_user",
metadata={"meta": "data"},
completion_token_ids=[6, 7],
chat_template_kwargs={"kwarg": "value"},
prompt_tokens="tokens",
add_generation_prompt=True,
response_format={"type": "json_object"},
mm_hashes=["hash1", "hash2"],
suffix={"key": "suffix"},
top_logprobs=5,
add_special_tokens=True,
)
# Test parameter assignment
self.assertEqual(request.request_id, "test_full")
self.assertEqual(request.prompt, "Hello world")
self.assertEqual(request.prompt_token_ids, [1, 2, 3])
self.assertEqual(request.prompt_token_ids_len, 3)
self.assertEqual(request.messages, [{"role": "user", "content": "Hello"}])
self.assertEqual(request.system, "You are helpful")
self.assertEqual(request.sampling_params, sampling_params)
self.assertEqual(request.pooling_params, pooling_params)
self.assertEqual(request.history, [["user", "hello"]])
self.assertEqual(request.tools, [{"name": "test_tool"}])
self.assertEqual(request.eos_token_ids, [0])
# Test boolean parameters
self.assertTrue(request.disable_chat_template)
self.assertTrue(request.guided_json_object)
self.assertTrue(request.enable_thinking)
self.assertTrue(request.add_generation_prompt)
self.assertTrue(request.add_special_tokens)
# Test numerical parameters
self.assertEqual(request.reasoning_max_tokens, 100)
self.assertEqual(request.dp_rank, 0)
self.assertEqual(request.image_start, 1)
self.assertEqual(request.video_start, 2)
# Test string parameters
self.assertEqual(request.trace_carrier, {"trace": "carrier"})
self.assertEqual(request.chat_template, "template")
self.assertEqual(request.user, "test_user")
def test_init_with_multimodal_inputs(self):
"""Test initialization with multimodal inputs"""
multimodal_inputs = {
"mm_positions": [ImagePosition(offset=0, length=10)],
"input_ids": np.array([1, 2, 3]),
}
request = Request(
request_id="test_mm",
multimodal_inputs=multimodal_inputs,
multimodal_data={"images": ["img1", "img2"]},
)
self.assertEqual(request.multimodal_inputs, multimodal_inputs)
self.assertEqual(request.multimodal_data, {"images": ["img1", "img2"]})
self.assertIsNone(request.multimodal_img_boundaries)
def test_init_default_metrics(self):
"""Test that metrics are created when not provided"""
request = Request(request_id="test_metrics")
self.assertIsInstance(request.metrics, RequestMetrics)
self.assertIsNotNone(request.metrics.arrival_time)
def test_init_existing_metrics(self):
"""Test initialization with existing metrics"""
metrics = RequestMetrics()
metrics.arrival_time = 1000.0
request = Request(request_id="test_existing_metrics", metrics=metrics)
self.assertEqual(request.metrics, metrics)
self.assertEqual(request.metrics.arrival_time, 1000.0)
class TestRequestProperties(unittest.TestCase):
"""Test cases for Request properties"""
def test_num_total_tokens(self):
"""Test num_total_tokens property"""
# Test with no tokens
request = Request(request_id="test1")
request.prompt_token_ids_len = 0
self.assertEqual(request.num_total_tokens, 0)
# Test with prompt tokens only
request = Request(request_id="test2")
request.prompt_token_ids_len = 5
request.output_token_ids = []
self.assertEqual(request.num_total_tokens, 5)
# Test with output tokens only
request = Request(request_id="test3")
request.prompt_token_ids_len = 0
request.output_token_ids = [1, 2, 3]
self.assertEqual(request.num_total_tokens, 3)
# Test with both prompt and output tokens
request = Request(request_id="test4")
request.prompt_token_ids_len = 5
request.output_token_ids = [1, 2, 3]
self.assertEqual(request.num_total_tokens, 8)
class TestRequestClassMethods(unittest.TestCase):
"""Test cases for Request class methods"""
def test_process_guided_json(self):
"""Test _process_guided_json class method"""
# Test with response_format type json_object
mock_request = Request(request_id="pickle_test")
mock_request.response_format = ResponseFormat(type="json_object")
result = Request._process_guided_json(mock_request)
self.assertTrue(result)
self.assertIsNone(getattr(mock_request, "guided_json", None))
# Test with response_format type json_schema
mock_request = Mock()
mock_request.response_format = Mock()
mock_request.response_format.type = "json_schema"
mock_request.response_format.json_schema = Mock()
mock_request.response_format.json_schema.json_schema = {"type": "object"}
Request._process_guided_json(mock_request)
self.assertEqual(mock_request.guided_json, {"type": "object"})
# Test with response_format type structural_tag
mock_request = Mock()
mock_request.response_format = StructuralTagResponseFormat(
type="structural_tag",
structures=[StructuralTag(begin="<user>", end="</user>")],
triggers=["<user>", "</user>"],
)
Request._process_guided_json(mock_request)
expected_json = json.dumps(
{
"type": "structural_tag",
"structures": [{"begin": "<user>", "schema": None, "end": "</user>"}],
"triggers": ["<user>", "</user>"],
}
)
self.assertEqual(mock_request.structural_tag, expected_json)
def test_from_generic_request(self):
"""Test from_generic_request class method"""
mock_generic_request = Mock()
mock_generic_request.request_id = "generic_test"
mock_generic_request.prompt_token_ids = [1, 2, 3]
mock_generic_request.messages = [{"role": "user", "content": "Hello"}]
mock_generic_request.disable_chat_template = True
mock_generic_request.tools = [Mock()]
mock_generic_request.tools[0].model_dump.return_value = {"name": "test_tool"}
mock_generic_request.suffix = {"test": "value"}
mock_generic_request.metadata = {"key": "value"}
# Mock sampling params creation
original_from_generic = SamplingParams.from_generic_request
SamplingParams.from_generic_request = Mock(return_value=SamplingParams())
try:
request = Request.from_generic_request(
req=mock_generic_request,
request_id="override_test",
prompt="Test prompt",
)
self.assertEqual(request.request_id, "override_test")
self.assertEqual(request.prompt, "Test prompt")
self.assertEqual(request.prompt_token_ids, [1, 2, 3])
self.assertEqual(request.messages, [{"role": "user", "content": "Hello"}])
self.assertTrue(request.disable_chat_template)
self.assertEqual(request.tools, [{"name": "test_tool"}])
self.assertIsInstance(request.metrics, RequestMetrics)
finally:
SamplingParams.from_generic_request = original_from_generic
def test_from_dict(self):
"""Test from_dict class method"""
test_dict = {
"request_id": "dict_test",
"prompt": "Test prompt",
"prompt_token_ids": [1, 2, 3],
"prompt_token_ids_len": 3,
"messages": [{"role": "user", "content": "Hello"}],
"system": "Test system",
"history": [["user", "hi"]],
"tools": [{"name": "test_tool"}],
"eos_token_ids": [0],
"multimodal_inputs": {"mm_positions": [{"offset": 0, "length": 10}]},
"multimodal_data": {"images": ["img1"]},
"disable_chat_template": True,
"disaggregate_info": {"key": "value"},
"draft_token_ids": [4, 5],
"guided_json": {"schema": "test"},
"guided_regex": "test.*",
"guided_choice": ["opt1"],
"guided_grammar": "grammar",
"structural_tag": "tag",
"guided_json_object": True,
"enable_thinking": True,
"reasoning_max_tokens": 100,
"trace_carrier": {"trace": "carrier"},
"chat_template": "template",
"num_computed_tokens": 5,
"prefill_start_index": 10,
"prefill_end_index": 20,
"image_start": 1,
"video_start": 2,
"audio_start": 3,
"image_end": 4,
"video_end": 5,
"audio_end": 6,
"dp_rank": 0,
"ic_req_data": {"internal": "data"},
"metrics": {"arrival_time": 1000.0},
"max_tokens": 100,
}
request = Request.from_dict(test_dict)
# Test basic fields
self.assertEqual(request.request_id, "dict_test")
self.assertEqual(request.prompt, "Test prompt")
self.assertEqual(request.prompt_token_ids, [1, 2, 3])
self.assertEqual(request.prompt_token_ids_len, 3)
# Test multimodal inputs conversion
self.assertIsInstance(request.multimodal_inputs["mm_positions"][0], ImagePosition)
# Test sampling params creation
self.assertIsNotNone(request.sampling_params)
# Test metrics creation
self.assertIsInstance(request.metrics, RequestMetrics)
self.assertEqual(request.metrics.arrival_time, 1000.0)
class TestRequestInstanceMethods(unittest.TestCase):
"""Test cases for Request instance methods"""
def test_getstate(self):
"""Test __getstate__ method for pickle support"""
request = Request(request_id="pickle_test")
request.async_process_futures = [Mock(), Mock()] # These should be filtered
state = request.__getstate__()
# async_process_futures should be empty list after filtering
self.assertEqual(state["async_process_futures"], [])
# Other attributes should be preserved
self.assertEqual(state["request_id"], "pickle_test")
def test_eq(self):
"""Test __eq__ method"""
request1 = Request(request_id="same_id")
request2 = Request(request_id="same_id")
request3 = Request(request_id="different_id")
self.assertEqual(request1, request2)
self.assertNotEqual(request1, request3)
self.assertNotEqual(request1, "not_a_request")
def test_to_dict_basic(self):
"""Test to_dict method with basic request"""
request = Request(request_id="dict_basic")
request.prompt = "Hello"
request.prompt_token_ids = [1, 2, 3]
request.prompt_token_ids_len = 3
request.sampling_params = SamplingParams()
request.metrics = RequestMetrics()
data = request.to_dict()
self.assertEqual(data["request_id"], "dict_basic")
self.assertEqual(data["prompt"], "Hello")
self.assertEqual(data["prompt_token_ids"], [1, 2, 3])
self.assertEqual(data["prompt_token_ids_len"], 3)
def test_to_dict_with_multimodal(self):
"""Test to_dict with multimodal inputs"""
request = Request(request_id="dict_mm")
request.multimodal_inputs = {
"position_ids": [1, 2, 3],
"input_ids": np.array([4, 5, 6]),
"other_field": "should_be_filtered",
}
request.sampling_params = SamplingParams()
request.metrics = RequestMetrics()
# Test with V1 scheduler (should only allow position_ids)
data = request.to_dict()
self.assertEqual(list(data["multimodal_inputs"].keys()), ["position_ids"])
self.assertEqual(data["multimodal_inputs"]["position_ids"], [1, 2, 3])
def test_get_method(self):
"""Test get method for attribute access"""
request = Request(request_id="get_test")
request.sampling_params = SamplingParams()
request.sampling_params.temperature = 0.7
# Test getting request attribute
self.assertEqual(request.get("request_id"), "get_test")
# Test getting sampling_params attribute
self.assertEqual(request.get("temperature"), 0.7)
# Test getting non-existent attribute with default
self.assertIsNone(request.get("non_existent"))
self.assertEqual(request.get("non_existent", "default"), "default")
def test_set_method(self):
"""Test set method for attribute modification"""
request = Request(request_id="set_test")
request.sampling_params = SamplingParams()
# Test setting request attribute
request.set("prompt", "New prompt")
self.assertEqual(request.prompt, "New prompt")
# Test setting sampling_params attribute
request.set("temperature", 1.0)
self.assertEqual(request.sampling_params.temperature, 1.0)
def test_repr_debug_disabled(self):
"""Test __repr__ when debug is disabled"""
request = Request(request_id="repr_test")
repr_str = request.__repr__()
self.assertEqual(repr_str, "Request(request_id=repr_test)")
def test_repr_debug_enabled(self):
"""Test __repr__ when debug is enabled"""
request = Request(request_id="repr_debug")
request.prompt = "Hello"
request.prompt_token_ids = [1, 2, 3]
# Mock envs.FD_DEBUG to True
import fastdeploy.engine.request as request_module
original_value = getattr(request_module.envs, "FD_DEBUG", False)
request_module.envs.FD_DEBUG = True
try:
repr_str = request.__repr__()
self.assertIn("request_id='repr_debug'", repr_str)
self.assertIn("prompt='Hello'", repr_str)
self.assertIn("prompt_token_ids=[1, 2, 3]", repr_str)
finally:
request_module.envs.FD_DEBUG = original_value
def test_getitem_setitem_delitem(self):
"""Test dictionary-like access methods"""
request = Request(request_id="dict_access")
request.sampling_params = SamplingParams()
request.sampling_params.temperature = 0.7
# Test __getitem__
self.assertEqual(request["request_id"], "dict_access")
self.assertEqual(request["temperature"], 0.7)
# Test __setitem__
request["prompt"] = "New prompt"
self.assertEqual(request.prompt, "New prompt")
request["temperature"] = 1.0
self.assertEqual(request.sampling_params.temperature, 1.0)
# Test __delitem__
request.sampling_params.top_k = 10
del request["top_k"]
self.assertNotIn("top_k", request.sampling_params.__dict__)
def test_contains(self):
"""Test __contains__ method"""
request = Request(request_id="contains_test")
request.sampling_params = SamplingParams()
request.sampling_params.temperature = 0.7
self.assertTrue("request_id" in request)
self.assertTrue("temperature" in request)
self.assertFalse("non_existent" in request)
class TestRequestEdgeCases(unittest.TestCase):
"""Test edge cases and error scenarios"""
def test_init_with_none_request_id(self):
"""Test initialization with None request_id"""
request = Request(request_id=None)
self.assertIsNone(request.request_id)
def test_getitem_key_error(self):
"""Test __getitem__ with non-existent key raises KeyError"""
request = Request(request_id="key_error_test")
with self.assertRaises(KeyError):
_ = request["non_existent_key"]
def test_delitem_key_error(self):
"""Test __delitem__ with non-existent key raises KeyError"""
request = Request(request_id="del_key_error_test")
with self.assertRaises(KeyError):
del request["non_existent_key"]
def test_repr_exception_handling(self):
"""Test __repr__ handles exceptions gracefully"""
request = Request(request_id="repr_exception")
# Create an attribute that will cause an exception during repr
class ProblematicAttribute:
def __repr__(self):
raise Exception("Repr failed")
request.problematic = ProblematicAttribute()
# Mock envs.FD_DEBUG to True to trigger detailed repr
import fastdeploy.engine.request as request_module
original_value = getattr(request_module.envs, "FD_DEBUG", False)
request_module.envs.FD_DEBUG = True
try:
repr_str = request.__repr__()
self.assertTrue(repr_str.startswith("<Request repr failed:"))
finally:
request_module.envs.FD_DEBUG = original_value
def test_from_dict_error_handling(self):
"""Test from_dict handles errors in multimodal conversion"""
test_dict = {
"request_id": "error_test",
"multimodal_inputs": {"mm_positions": [{"not_valid": "data"}]}, # Missing required fields
}
# Should not raise an exception but log error
request = Request.from_dict(test_dict)
self.assertEqual(request.request_id, "error_test")
class TestRequestOutputDictAccess(unittest.TestCase):
"""Test cases for RequestOutput dictionary-style access methods"""
def setUp(self):
self.metrics = RequestMetrics()
self.metrics.arrival_time = 1000.0
self.metrics.model_forward_time = 1.5
self.outputs = CompletionOutput(
index=0, send_idx=0, token_ids=[1, 2, 3], text="test output", reasoning_content="test reasoning"
)
self.request_output = RequestOutput(
request_id="test_dict_access",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
outputs=self.outputs,
metrics=self.metrics,
)
def test_get_method(self):
"""Test get() method"""
# Test getting request_output attribute
self.assertEqual(self.request_output.get("request_id"), "test_dict_access")
# Test getting outputs attribute
self.assertEqual(self.request_output.get("text"), "test output")
# Test getting metrics attribute
self.assertEqual(self.request_output.get("arrival_time"), 1000.0)
# Test getting non-existent attribute with default
self.assertIsNone(self.request_output.get("non_existent"))
self.assertEqual(self.request_output.get("non_existent", "default"), "default")
def test_set_method(self):
"""Test set() method"""
# Test setting request_output attribute
self.request_output.set("prompt", "new prompt")
self.assertEqual(self.request_output.prompt, "new prompt")
# Test setting outputs attribute
self.request_output.set("text", "new text")
self.assertEqual(self.outputs.text, "new text")
# Test setting metrics attribute
self.request_output.set("model_forward_time", 2.0)
self.assertEqual(self.metrics.model_forward_time, 2.0)
def test_getitem_method(self):
"""Test __getitem__ method"""
# Test getting request_output attribute
self.assertEqual(self.request_output["request_id"], "test_dict_access")
# Test getting outputs attribute
self.assertEqual(self.request_output["text"], "test output")
# Test getting metrics attribute
self.assertEqual(self.request_output["arrival_time"], 1000.0)
# Test KeyError for non-existent attribute
with self.assertRaises(KeyError):
_ = self.request_output["non_existent"]
def test_setitem_method(self):
"""Test __setitem__ method"""
# Test setting request_output attribute
self.request_output["prompt"] = "new prompt"
self.assertEqual(self.request_output.prompt, "new prompt")
# Test setting outputs attribute
self.request_output["text"] = "new text"
self.assertEqual(self.outputs.text, "new text")
# Test setting metrics attribute
self.request_output["model_forward_time"] = 2.0
self.assertEqual(self.metrics.model_forward_time, 2.0)
def test_delitem_method(self):
"""Test __delitem__ method"""
# Test deleting request_output attribute (using existing attribute)
original_prompt = self.request_output.prompt
del self.request_output["prompt"]
self.assertFalse(hasattr(self.request_output, "prompt"))
# Restore for other tests
self.request_output.prompt = original_prompt
# Test deleting outputs attribute (using existing attribute)
original_text = self.outputs.text
del self.request_output["text"]
self.assertFalse(hasattr(self.outputs, "text"))
# Restore for other tests
self.outputs.text = original_text
# Test deleting metrics attribute (using existing attribute)
original_arrival_time = self.metrics.arrival_time
del self.request_output["arrival_time"]
self.assertFalse(hasattr(self.metrics, "arrival_time"))
# Restore for other tests
self.metrics.arrival_time = original_arrival_time
# Test KeyError for non-existent attribute
try:
del self.request_output["non_existent"]
self.fail("Expected KeyError but none was raised")
except KeyError:
pass # Expected behavior
def test_contains_method(self):
"""Test __contains__ method"""
# Test request_output attributes
self.assertTrue("request_id" in self.request_output)
self.assertTrue("prompt" in self.request_output)
# Test outputs attributes
self.assertTrue("text" in self.request_output)
self.assertTrue("reasoning_content" in self.request_output)
# Test metrics attributes
self.assertTrue("arrival_time" in self.request_output)
self.assertTrue("model_forward_time" in self.request_output)
# Test non-existent attribute
self.assertFalse("non_existent" in self.request_output)
class TestRequestCacheFields(unittest.TestCase):
"""Tests for _block_hasher, _prompt_hashes, cache_swap_metadata, cache_evict_metadata."""
# ------------------------------------------------------------------
# _block_hasher / _prompt_hashes initialization
# ------------------------------------------------------------------
def test_default_block_hasher_and_prompt_hashes(self):
"""Default values: _block_hasher is None, _prompt_hashes is empty list."""
req = Request(request_id="cache_defaults")
self.assertIsNone(req._block_hasher)
self.assertEqual(req._prompt_hashes, [])
def test_block_hasher_init_via_constructor(self):
"""block_hasher passed to constructor is stored in _block_hasher."""
hasher = Mock(return_value=[])
req = Request(request_id="bh_init", block_hasher=hasher)
self.assertIs(req._block_hasher, hasher)
def test_set_block_hasher(self):
"""set_block_hasher replaces _block_hasher."""
req = Request(request_id="set_bh")
self.assertIsNone(req._block_hasher)
hasher = Mock(return_value=[])
req.set_block_hasher(hasher)
self.assertIs(req._block_hasher, hasher)
# ------------------------------------------------------------------
# prompt_hashes property
# ------------------------------------------------------------------
def test_prompt_hashes_no_hasher(self):
"""prompt_hashes returns _prompt_hashes as-is when no hasher is set."""
req = Request(request_id="ph_no_hasher")
req._prompt_hashes = ["h1", "h2"]
self.assertEqual(req.prompt_hashes, ["h1", "h2"])
def test_prompt_hashes_hasher_returns_new_hashes(self):
"""prompt_hashes appends new hashes returned by _block_hasher."""
req = Request(request_id="ph_new_hashes")
req._prompt_hashes = ["h1"]
req._block_hasher = Mock(return_value=["h2", "h3"])
result = req.prompt_hashes
# hasher is called with req
req._block_hasher.assert_called_once_with(req)
self.assertEqual(result, ["h1", "h2", "h3"])
# underlying list is mutated
self.assertEqual(req._prompt_hashes, ["h1", "h2", "h3"])
def test_prompt_hashes_hasher_returns_empty(self):
"""When hasher returns empty list, _prompt_hashes is unchanged."""
req = Request(request_id="ph_empty")
req._prompt_hashes = ["h1"]
req._block_hasher = Mock(return_value=[])
result = req.prompt_hashes
self.assertEqual(result, ["h1"])
self.assertEqual(req._prompt_hashes, ["h1"])
def test_prompt_hashes_hasher_returns_none(self):
"""When hasher returns None (falsy), _prompt_hashes is unchanged."""
req = Request(request_id="ph_none")
req._prompt_hashes = ["h1"]
req._block_hasher = Mock(return_value=None)
result = req.prompt_hashes
self.assertEqual(result, ["h1"])
def test_prompt_hashes_accumulates_across_multiple_accesses(self):
"""Each access may add more hashes (simulates incremental computation)."""
call_count = {"n": 0}
def incremental_hasher(r):
call_count["n"] += 1
return [f"h{call_count['n']}"]
req = Request(request_id="ph_incremental")
req._block_hasher = incremental_hasher
_ = req.prompt_hashes # first access → adds "h1"
_ = req.prompt_hashes # second access → adds "h2"
self.assertEqual(req._prompt_hashes, ["h1", "h2"])
# ------------------------------------------------------------------
# cache_swap_metadata / cache_evict_metadata initialization
# ------------------------------------------------------------------
def test_default_cache_metadata_are_empty_lists(self):
"""cache_swap_metadata and cache_evict_metadata default to empty lists."""
req = Request(request_id="meta_defaults")
self.assertEqual(req.cache_swap_metadata, [])
self.assertEqual(req.cache_evict_metadata, [])
# ------------------------------------------------------------------
# pop_cache_swap_metadata / pop_cache_evict_metadata
# ------------------------------------------------------------------
def test_pop_cache_swap_metadata_returns_and_clears(self):
"""pop_cache_swap_metadata returns current list and resets to []."""
req = Request(request_id="pop_swap")
meta = _make_swap_meta([1], [2], ["hash_a"])
req.cache_swap_metadata = [meta]
result = req.pop_cache_swap_metadata()
self.assertEqual(result, [meta])
self.assertEqual(req.cache_swap_metadata, [])
def test_pop_cache_evict_metadata_returns_and_clears(self):
"""pop_cache_evict_metadata returns current list and resets to []."""
req = Request(request_id="pop_evict")
meta = _make_swap_meta([3], [4], ["hash_b"])
req.cache_evict_metadata = [meta]
result = req.pop_cache_evict_metadata()
self.assertEqual(result, [meta])
self.assertEqual(req.cache_evict_metadata, [])
def test_pop_empty_cache_metadata(self):
"""pop on empty list returns [] and leaves field as []."""
req = Request(request_id="pop_empty")
self.assertEqual(req.pop_cache_swap_metadata(), [])
self.assertEqual(req.pop_cache_evict_metadata(), [])
# ------------------------------------------------------------------
# __getstate__ skips _block_hasher
# ------------------------------------------------------------------
def test_getstate_excludes_block_hasher(self):
"""__getstate__ must not include _block_hasher (cannot be pickled)."""
req = Request(request_id="getstate_bh", block_hasher=lambda r: [])
state = req.__getstate__()
self.assertNotIn("_block_hasher", state)
def test_getstate_preserves_prompt_hashes(self):
"""__getstate__ preserves _prompt_hashes."""
req = Request(request_id="getstate_ph")
req._prompt_hashes = ["h1", "h2"]
state = req.__getstate__()
self.assertEqual(state["_prompt_hashes"], ["h1", "h2"])
class TestBatchRequestInit(unittest.TestCase):
"""Tests for BatchRequest initialization."""
def test_default_init(self):
"""BatchRequest starts with empty requests and no metadata."""
br = BatchRequest()
self.assertEqual(br.requests, [])
self.assertIsNone(br.cache_swap_metadata)
self.assertIsNone(br.cache_evict_metadata)
def test_len_empty(self):
self.assertEqual(len(BatchRequest()), 0)
class TestBatchRequestAddRequest(unittest.TestCase):
"""Tests for BatchRequest.add_request."""
def _make_request(self, rid):
return Request(request_id=rid)
def test_add_request_appends_to_requests(self):
"""add_request stores request in .requests list."""
br = BatchRequest()
req = self._make_request("r1")
br.add_request(req)
self.assertIn(req, br.requests)
self.assertEqual(len(br), 1)
def test_add_request_without_metadata(self):
"""When request has no pending metadata, batch metadata stays None."""
br = BatchRequest()
req = self._make_request("r_no_meta")
br.add_request(req)
self.assertIsNone(br.cache_swap_metadata)
self.assertIsNone(br.cache_evict_metadata)
def test_add_request_with_swap_metadata(self):
"""add_request moves swap metadata from request to batch."""
br = BatchRequest()
req = self._make_request("r_swap")
meta = _make_swap_meta([10, 11], [20, 21], ["hA", "hB"])
req.cache_swap_metadata = [meta]
br.add_request(req)
# Request's swap list should be cleared
self.assertEqual(req.cache_swap_metadata, [])
# Batch should aggregate the metadata
self.assertIsNotNone(br.cache_swap_metadata)
self.assertEqual(br.cache_swap_metadata.src_block_ids, [10, 11])
self.assertEqual(br.cache_swap_metadata.dst_block_ids, [20, 21])
self.assertEqual(br.cache_swap_metadata.hash_values, ["hA", "hB"])
def test_add_request_with_evict_metadata(self):
"""add_request moves evict metadata from request to batch."""
br = BatchRequest()
req = self._make_request("r_evict")
meta = _make_swap_meta([5], [6], ["hE"])
req.cache_evict_metadata = [meta]
br.add_request(req)
self.assertEqual(req.cache_evict_metadata, [])
self.assertIsNotNone(br.cache_evict_metadata)
self.assertEqual(br.cache_evict_metadata.src_block_ids, [5])
self.assertEqual(br.cache_evict_metadata.dst_block_ids, [6])
def test_add_multiple_requests_merges_swap_metadata(self):
"""Swap metadata from multiple requests is merged into one."""
br = BatchRequest()
for i, (src, dst, h) in enumerate([([1], [2], ["h1"]), ([3], [4], ["h2"])]):
req = self._make_request(f"r{i}")
req.cache_swap_metadata = [_make_swap_meta(src, dst, h)]
br.add_request(req)
self.assertEqual(br.cache_swap_metadata.src_block_ids, [1, 3])
self.assertEqual(br.cache_swap_metadata.dst_block_ids, [2, 4])
self.assertEqual(br.cache_swap_metadata.hash_values, ["h1", "h2"])
def test_add_multiple_requests_merges_evict_metadata(self):
"""Evict metadata from multiple requests is merged into one."""
br = BatchRequest()
for i, (src, dst, h) in enumerate([([7], [8], ["e1"]), ([9], [10], ["e2"])]):
req = self._make_request(f"re{i}")
req.cache_evict_metadata = [_make_swap_meta(src, dst, h)]
br.add_request(req)
self.assertEqual(br.cache_evict_metadata.src_block_ids, [7, 9])
self.assertEqual(br.cache_evict_metadata.dst_block_ids, [8, 10])
self.assertEqual(br.cache_evict_metadata.hash_values, ["e1", "e2"])
class TestBatchRequestAppendSwapEvictMetadata(unittest.TestCase):
"""Unit tests for append_swap_metadata and append_evict_metadata."""
def test_append_swap_metadata_first_time(self):
"""append_swap_metadata creates CacheSwapMetadata when None."""
br = BatchRequest()
meta = _make_swap_meta([1, 2], [3, 4], ["h1", "h2"])
br.append_swap_metadata([meta])
self.assertIsNotNone(br.cache_swap_metadata)
self.assertEqual(br.cache_swap_metadata.src_block_ids, [1, 2])
self.assertEqual(br.cache_swap_metadata.dst_block_ids, [3, 4])
self.assertEqual(br.cache_swap_metadata.hash_values, ["h1", "h2"])
self.assertEqual(br.cache_swap_metadata.src_type, CacheLevel.HOST)
self.assertEqual(br.cache_swap_metadata.dst_type, CacheLevel.DEVICE)
def test_append_swap_metadata_merges(self):
"""Subsequent append_swap_metadata extends existing lists."""
br = BatchRequest()
br.append_swap_metadata([_make_swap_meta([1], [2], ["hA"])])
br.append_swap_metadata([_make_swap_meta([3], [4], ["hB"])])
self.assertEqual(br.cache_swap_metadata.src_block_ids, [1, 3])
self.assertEqual(br.cache_swap_metadata.dst_block_ids, [2, 4])
self.assertEqual(br.cache_swap_metadata.hash_values, ["hA", "hB"])
def test_append_evict_metadata_first_time(self):
"""append_evict_metadata creates CacheSwapMetadata when None."""
br = BatchRequest()
meta = _make_swap_meta([5], [6], ["he"])
br.append_evict_metadata([meta])
self.assertIsNotNone(br.cache_evict_metadata)
self.assertEqual(br.cache_evict_metadata.src_block_ids, [5])
self.assertEqual(br.cache_evict_metadata.dst_block_ids, [6])
self.assertEqual(br.cache_evict_metadata.dst_type, CacheLevel.HOST)
def test_append_evict_metadata_merges(self):
"""Subsequent append_evict_metadata extends existing lists."""
br = BatchRequest()
br.append_evict_metadata([_make_swap_meta([1], [2], ["e1"])])
br.append_evict_metadata([_make_swap_meta([3], [4], ["e2"])])
self.assertEqual(br.cache_evict_metadata.src_block_ids, [1, 3])
self.assertEqual(br.cache_evict_metadata.dst_block_ids, [2, 4])
self.assertEqual(br.cache_evict_metadata.hash_values, ["e1", "e2"])
def test_append_empty_list_is_noop(self):
"""append_swap_metadata / append_evict_metadata with empty list is a no-op."""
br = BatchRequest()
br.append_swap_metadata([])
br.append_evict_metadata([])
self.assertIsNone(br.cache_swap_metadata)
self.assertIsNone(br.cache_evict_metadata)
class TestBatchRequestAppendAndExtend(unittest.TestCase):
"""Tests for BatchRequest.append and BatchRequest.extend."""
def _br_with_swap(self, src, dst, hashes=None):
br = BatchRequest()
br.append_swap_metadata([_make_swap_meta(src, dst, hashes or [])])
return br
def _br_with_evict(self, src, dst, hashes=None):
br = BatchRequest()
br.append_evict_metadata([_make_swap_meta(src, dst, hashes or [])])
return br
def test_append_merges_requests(self):
br1 = BatchRequest()
br1.add_request(Request(request_id="a"))
br2 = BatchRequest()
br2.add_request(Request(request_id="b"))
br1.append(br2)
self.assertEqual(len(br1), 2)
def test_append_merges_swap_metadata(self):
br1 = self._br_with_swap([1], [2], ["h1"])
br2 = self._br_with_swap([3], [4], ["h2"])
br1.append(br2)
self.assertEqual(br1.cache_swap_metadata.src_block_ids, [1, 3])
self.assertEqual(br1.cache_swap_metadata.hash_values, ["h1", "h2"])
def test_append_merges_evict_metadata(self):
br1 = self._br_with_evict([5], [6], ["e1"])
br2 = self._br_with_evict([7], [8], ["e2"])
br1.append(br2)
self.assertEqual(br1.cache_evict_metadata.src_block_ids, [5, 7])
def test_append_batch_without_metadata_does_not_create_metadata(self):
br1 = BatchRequest()
br1.add_request(Request(request_id="x"))
br2 = BatchRequest()
br2.add_request(Request(request_id="y"))
br1.append(br2)
self.assertIsNone(br1.cache_swap_metadata)
self.assertIsNone(br1.cache_evict_metadata)
def test_extend_multiple_batches(self):
br_main = BatchRequest()
sub1 = self._br_with_swap([1], [2], ["h1"])
sub1.add_request(Request(request_id="s1"))
sub2 = self._br_with_swap([3], [4], ["h2"])
sub2.add_request(Request(request_id="s2"))
br_main.extend([sub1, sub2])
self.assertEqual(len(br_main), 2)
self.assertEqual(br_main.cache_swap_metadata.src_block_ids, [1, 3])
class TestBatchRequestIterAndAccess(unittest.TestCase):
"""Tests for __iter__, __getitem__, __len__, __repr__."""
def _populated_br(self):
br = BatchRequest()
for i in range(3):
br.add_request(Request(request_id=f"r{i}"))
return br
def test_iter(self):
br = self._populated_br()
ids = [req.request_id for req in br]
self.assertEqual(ids, ["r0", "r1", "r2"])
def test_getitem(self):
br = self._populated_br()
self.assertEqual(br[0].request_id, "r0")
self.assertEqual(br[2].request_id, "r2")
def test_len(self):
br = self._populated_br()
self.assertEqual(len(br), 3)
def test_repr_contains_swap_and_evict(self):
br = BatchRequest()
br.append_swap_metadata([_make_swap_meta([1], [2], ["hR"])])
r = repr(br)
self.assertIn("BatchRequest", r)
self.assertIn("swap_metadata", r)
self.assertIn("evict_metadata", r)
class TestBatchRequestPickle(unittest.TestCase):
"""Ensure BatchRequest can be serialized / deserialized via pickle."""
def test_pickle_without_block_hasher(self):
"""BatchRequest with plain Requests (no block_hasher) round-trips via pickle."""
br = BatchRequest()
req = Request(request_id="pk1", prompt="hello")
req._prompt_hashes = ["h1"]
br.add_request(req)
br.append_swap_metadata([_make_swap_meta([10], [20], ["hP"])])
data = pickle.dumps(br)
br2 = pickle.loads(data)
self.assertEqual(len(br2), 1)
self.assertEqual(br2[0].request_id, "pk1")
self.assertEqual(br2.cache_swap_metadata.src_block_ids, [10])
def test_getstate_skips_block_hasher_in_requests(self):
"""__getstate__ of BatchRequest serializes requests without _block_hasher."""
br = BatchRequest()
req = Request(request_id="gs1", block_hasher=lambda r: ["h_new"])
br.add_request(req)
state = br.__getstate__()
# Each request dict must not contain _block_hasher
for req_state in state["requests"]:
self.assertNotIn("_block_hasher", req_state)
from fastdeploy.cache_manager.v1.cache_utils import (
get_block_hash_extra_keys as _get_block_hash_extra_keys,
)
from fastdeploy.cache_manager.v1.cache_utils import (
get_request_block_hasher as _get_request_block_hasher,
)
from fastdeploy.cache_manager.v1.cache_utils import (
hash_block_tokens as _hash_block_tokens,
)
class TestPromptHashesWithRealHasher(unittest.TestCase):
"""
Test Request.prompt_hashes together with the real get_request_block_hasher
and get_block_hash_extra_keys implementations.
These tests do NOT use mock hashers, so they exercise the full hash
computation path (hash_block_tokens → SHA-256 chained hash).
"""
BLOCK_SIZE = 4 # small block size makes tests easy to reason about
get_request_block_hasher = staticmethod(_get_request_block_hasher)
get_block_hash_extra_keys = staticmethod(_get_block_hash_extra_keys)
hash_block_tokens = staticmethod(_hash_block_tokens)
def _hasher(self):
return _get_request_block_hasher(self.BLOCK_SIZE)
# ------------------------------------------------------------------
# Basic hash computation
# ------------------------------------------------------------------
def test_no_complete_block_returns_empty(self):
"""Fewer tokens than one block → prompt_hashes returns []."""
req = Request(
request_id="real_partial", prompt_token_ids=[1, 2, 3], block_hasher=self._hasher() # < BLOCK_SIZE=4
)
self.assertEqual(req.prompt_hashes, [])
def test_exactly_one_block(self):
"""Exactly block_size tokens → one hash produced."""
tokens = [10, 20, 30, 40] # 4 tokens == BLOCK_SIZE
req = Request(request_id="real_one_block", prompt_token_ids=tokens, block_hasher=self._hasher())
hashes = req.prompt_hashes
self.assertEqual(len(hashes), 1)
# Verify hash value matches hash_block_tokens directly
expected = self.hash_block_tokens(tokens, None, None)
self.assertEqual(hashes[0], expected)
def test_two_complete_blocks(self):
"""Two full blocks → two chained hashes."""
tokens = list(range(8)) # 8 tokens = 2 blocks of 4
req = Request(request_id="real_two_blocks", prompt_token_ids=tokens, block_hasher=self._hasher())
hashes = req.prompt_hashes
self.assertEqual(len(hashes), 2)
h0 = self.hash_block_tokens(tokens[:4], None, None)
h1 = self.hash_block_tokens(tokens[4:8], h0, None)
self.assertEqual(hashes[0], h0)
self.assertEqual(hashes[1], h1)
def test_partial_tail_not_hashed(self):
"""9 tokens with block_size=4 → only 2 complete blocks hashed."""
tokens = list(range(9))
req = Request(request_id="real_tail", prompt_token_ids=tokens, block_hasher=self._hasher())
self.assertEqual(len(req.prompt_hashes), 2)
def test_hash_is_deterministic(self):
"""Same tokens always produce the same hash."""
tokens = [1, 2, 3, 4]
req1 = Request(request_id="det1", prompt_token_ids=tokens, block_hasher=self._hasher())
req2 = Request(request_id="det2", prompt_token_ids=tokens, block_hasher=self._hasher())
self.assertEqual(req1.prompt_hashes, req2.prompt_hashes)
def test_different_tokens_different_hash(self):
"""Different token sequences yield different hashes."""
req1 = Request(request_id="diff1", prompt_token_ids=[1, 2, 3, 4], block_hasher=self._hasher())
req2 = Request(request_id="diff2", prompt_token_ids=[5, 6, 7, 8], block_hasher=self._hasher())
self.assertNotEqual(req1.prompt_hashes, req2.prompt_hashes)
# ------------------------------------------------------------------
# Incremental (multi-access) behaviour
# ------------------------------------------------------------------
def test_incremental_hashing_does_not_recompute(self):
"""
If existing hashes already cover N blocks, prompt_hashes only computes
the next block not all blocks from scratch.
"""
tokens = list(range(12)) # 3 blocks of 4
req = Request(request_id="incremental", prompt_token_ids=tokens, block_hasher=self._hasher())
# First access: all three blocks computed
h_all = req.prompt_hashes[:] # copy
self.assertEqual(len(h_all), 3)
# If we artificially reset and call again, hasher sees existing 3 hashes
# and returns [] because start_token_idx = 3*4 = 12 = num_tokens → no new block
result2 = req.prompt_hashes
self.assertEqual(len(result2), 3) # no duplicates
def test_new_output_tokens_trigger_additional_hashes(self):
"""
After output tokens are appended, a second call to prompt_hashes
produces more hashes (because the combined token sequence now has
more complete blocks).
"""
# Start with exactly 1 block of prompt tokens
tokens = list(range(4))
req = Request(request_id="out_tokens", prompt_token_ids=tokens, block_hasher=self._hasher())
req.output_token_ids = []
first = req.prompt_hashes[:]
self.assertEqual(len(first), 1)
# Append 4 output tokens → now 2 complete blocks total
req.output_token_ids = list(range(4, 8))
second = req.prompt_hashes[:]
self.assertEqual(len(second), 2)
self.assertEqual(second[0], first[0]) # first hash unchanged
# ------------------------------------------------------------------
# get_block_hash_extra_keys via prompt_hashes (multimodal path)
# ------------------------------------------------------------------
def test_prompt_hashes_no_multimodal_inputs(self):
"""
With no multimodal_inputs, get_block_hash_extra_keys returns empty
extra_keys → hash equals plain hash_block_tokens with extra_keys=None.
"""
tokens = [1, 2, 3, 4]
req = Request(request_id="mm_none", prompt_token_ids=tokens, block_hasher=self._hasher())
req.multimodal_inputs = None
hashes = req.prompt_hashes
expected = self.hash_block_tokens(tokens, None, None)
self.assertEqual(hashes[0], expected)
def test_prompt_hashes_with_multimodal_fully_within_block(self):
"""
A multimodal item fully within the block contributes its hash as
extra_keys, changing the computed block hash.
"""
tokens = [1, 2, 3, 4]
mm_hash = "img_hash_abc"
# Image fully within block [0, 4)
req = Request(request_id="mm_within", prompt_token_ids=tokens, block_hasher=self._hasher())
req.multimodal_inputs = {
"mm_positions": [ImagePosition(offset=1, length=2)],
"mm_hashes": [mm_hash],
}
hashes = req.prompt_hashes
# Expected: extra_keys = (mm_hash,)
expected = self.hash_block_tokens(tokens, None, (mm_hash,))
self.assertEqual(hashes[0], expected)
def test_prompt_hashes_multimodal_outside_block_not_included(self):
"""
A multimodal item that starts after the block end must NOT be included
in extra_keys for that block.
"""
tokens = list(range(8)) # 2 blocks: [0,4) and [4,8)
mm_hash = "img_hash_xyz"
# Image sits in the second block [4, 8)
req = Request(request_id="mm_outside", prompt_token_ids=tokens, block_hasher=self._hasher())
req.multimodal_inputs = {
"mm_positions": [ImagePosition(offset=4, length=2)],
"mm_hashes": [mm_hash],
}
hashes = req.prompt_hashes
# First block has no multimodal item → extra_keys = None
h0_expected = self.hash_block_tokens(list(range(4)), None, None)
self.assertEqual(hashes[0], h0_expected)
# Second block contains the image
h1_expected = self.hash_block_tokens(list(range(4, 8)), h0_expected, (mm_hash,))
self.assertEqual(hashes[1], h1_expected)
def test_prompt_hashes_multimodal_spanning_two_blocks(self):
"""
A multimodal item spanning two blocks contributes its hash to each block.
"""
tokens = list(range(8))
mm_hash = "span_hash"
# Image [2, 6) spans both block [0,4) and [4,8)
req = Request(request_id="mm_span", prompt_token_ids=tokens, block_hasher=self._hasher())
req.multimodal_inputs = {
"mm_positions": [ImagePosition(offset=2, length=4)],
"mm_hashes": [mm_hash],
}
hashes = req.prompt_hashes
self.assertEqual(len(hashes), 2)
# Both blocks include the mm hash as extra_keys
h0_expected = self.hash_block_tokens(list(range(4)), None, (mm_hash,))
self.assertEqual(hashes[0], h0_expected)
h1_expected = self.hash_block_tokens(list(range(4, 8)), h0_expected, (mm_hash,))
self.assertEqual(hashes[1], h1_expected)
# ------------------------------------------------------------------
# get_block_hash_extra_keys direct unit tests
# ------------------------------------------------------------------
def test_extra_keys_no_multimodal(self):
"""No multimodal_inputs → empty extra keys."""
req = Request(request_id="ek_none")
req.multimodal_inputs = None
next_idx, keys = self.get_block_hash_extra_keys(req, 0, 4, 0)
self.assertEqual(keys, [])
self.assertEqual(next_idx, 0)
def test_extra_keys_item_fully_inside_block(self):
"""Multimodal item fully inside [start, end) → its hash is collected."""
req = Request(request_id="ek_inside")
req.multimodal_inputs = {
"mm_positions": [ImagePosition(offset=1, length=2)], # [1, 3)
"mm_hashes": ["hash_inside"],
}
next_idx, keys = self.get_block_hash_extra_keys(req, 0, 4, 0)
self.assertIn("hash_inside", keys)
def test_extra_keys_item_starts_after_block(self):
"""Multimodal item starts after block end → not included."""
req = Request(request_id="ek_after")
req.multimodal_inputs = {
"mm_positions": [ImagePosition(offset=5, length=2)], # after block [0,4)
"mm_hashes": ["hash_after"],
}
_, keys = self.get_block_hash_extra_keys(req, 0, 4, 0)
self.assertEqual(keys, [])
def test_extra_keys_item_ends_before_block(self):
"""Multimodal item ends before block start → fast-exit, not included."""
req = Request(request_id="ek_before")
req.multimodal_inputs = {
"mm_positions": [ImagePosition(offset=0, length=1)], # [0,1) ends before block [2,6)
"mm_hashes": ["hash_before"],
}
_, keys = self.get_block_hash_extra_keys(req, 2, 6, 0)
self.assertEqual(keys, [])
def test_extra_keys_item_spans_beyond_block(self):
"""Multimodal item spanning beyond block end → included, and mm_idx points to it."""
req = Request(request_id="ek_span")
req.multimodal_inputs = {
"mm_positions": [ImagePosition(offset=2, length=4)], # [2, 6) spans [0,4) end
"mm_hashes": ["hash_span"],
}
next_idx, keys = self.get_block_hash_extra_keys(req, 0, 4, 0)
self.assertIn("hash_span", keys)
self.assertEqual(next_idx, 0) # mm_idx points back at the spanning item
def test_extra_keys_multiple_items_only_overlapping_included(self):
"""Only multimodal items that overlap [start, end) are included."""
req = Request(request_id="ek_multi")
req.multimodal_inputs = {
"mm_positions": [
ImagePosition(offset=0, length=2), # [0,2) → in block [0,4): YES
ImagePosition(offset=2, length=2), # [2,4) → in block [0,4): YES
ImagePosition(offset=5, length=2), # [5,7) → after block [0,4): NO
],
"mm_hashes": ["hA", "hB", "hC"],
}
_, keys = self.get_block_hash_extra_keys(req, 0, 4, 0)
self.assertIn("hA", keys)
self.assertIn("hB", keys)
self.assertNotIn("hC", keys)
if __name__ == "__main__":
unittest.main()