Files
FastDeploy/fastdeploy/engine/request.py
T
kevin 7707be8384 [Feature][KVCache] Implement Cache Manager V1 with GPU + CPU Cache Support (1/n) (#7097)
* [Feature][KVCache] Support cache manager v1 architecture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update cache manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: update cache_manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add node to evictable set in complete_swap_to_device

When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: update cache manager v1 and related modules

- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat(cache): add cache controller v1 implementation

- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling

* feat(cache_manager): update cache manager v1

* fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(cache_manager): refactor cache manager v1 and optimize swap ops

## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项

* [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash

## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```

* fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager

## Motivation

修复两处问题:
1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错
2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑

## Modifications

- `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811
- `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑
- `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix pre-commit code style issues

## Motivation

修复 CI pre-commit 代码风格检查失败问题。

## Modifications

- `fastdeploy/engine/common_engine.py`: black 格式化
- `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复
- `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复
- `fastdeploy/worker/gpu_worker.py`: isort 修复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] update cache_manager_v1 modules

## Motivation

更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。

## Modifications

- `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构
- `fastdeploy/config.py`:配置项更新
- `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing

## Motivation

将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。

## Modifications

- `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求
- `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests

## Motivation

优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟;
同时跳过 swap cache ops 测试(当前环境不支持)。

## Modifications

- `fastdeploy/cache_manager/v1/cache_controller.py`:
  - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点
  - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点
  - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能
- `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix unittest failures for cache_manager_v1

三个单测因接口变更或 Mock 方式问题导致失败,需修复。

- tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError`
- tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager`
- tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] remove debug logging code

## Modifications

- fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志
- fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix cupy device id caching and pickle for _match_result

## Motivation

修复两个 bug:
1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。
2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。

## Modifications

- `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。
- `request.py`:
  - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。
  - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。

## Usage or Command

* fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP

- Use BatchRequest instead of list in test_trigger_preempt_records_tasks
- Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs

## Motivation

`gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`,
将 `available_blocks()` 的返回值当作整数传给 `range()`,
导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。

## Modifications

- `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError

## Motivation

gpu_free_block_list 属性中调用 BlockPool.available_blocks(),
该方法返回 int(空闲块数量),用 list() 包装 int 会触发
TypeError: 'int' object is not iterable。

## Modifications

将 list(self._device_pool.available_blocks()) 改为
list(self._device_pool._free_blocks),直接返回空闲块索引列表。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作

## Motivation

V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配,
同时 free_cache 需要支持可选的 clear_storage 参数。

## Modifications

- cache_controller.py: free_cache 新增 clear_storage 参数(默认 False),
  仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空
- common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时
  使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑
- gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除

## Usage or Command

# 启动服务(V1 CacheManager)
python -m fastdeploy.entrypoints.openai.api_server \
  --enable-v1-kvcache-manager \
  ...

* [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup

- Remove unused `select_blocks_for_backup` method from radix_tree.py
- Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py
- Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix)
- Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path

Remove redundant prefix_caching branch in else path; always call
recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request

## Motivation

修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。

## Modifications

1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作)
2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃
3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用
4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]`
5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音
6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入

## Usage or Command

单元测试:
```bash
source .venv/py310/bin/activate
cd baidu/FastDeploy
python -m pytest tests/cache_manager/v1/test_cache_manager.py -v
python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v
python -m pytest tests/engine/test_request.py -v
```

* [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0

## Motivation

当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误:
`-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`,
会返回并清空整个空闲块列表,而非返回空列表。

## Modifications

在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`,
避免触发 Python 负索引陷阱。

## Usage or Command

```bash
# 运行相关单元测试验证修复
python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s
```

* [KVCache][Test] add unit tests for cache_manager v1 modules

## Motivation

补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。

## Modifications

新增/补充以下测试文件,全部 326 个用例通过:

- tests/cache_manager/v1/test_block_pool.py(新建)
  覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool
- tests/cache_manager/v1/test_metadata.py(新建)
  覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler
- tests/cache_manager/v1/test_cache_utils.py(补充)
  新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法
- tests/cache_manager/v1/test_radix_tree.py(补充)
  新增 TestCompleteSwapToDevice 专项测试类(6 个用例)
- tests/cache_manager/v1/test_cache_manager.py(补充)
  新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata
- tests/cache_manager/v1/test_transfer_manager.py(补充)
  新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event

## Usage or Command

```bash
# 运行所有新增单测
source .venv/py310/bin/activate
python -m pytest tests/cache_manager/v1/test_block_pool.py \
  tests/cache_manager/v1/test_metadata.py \
  tests/cache_manager/v1/test_cache_utils.py \
  tests/cache_manager/v1/test_radix_tree.py \
  tests/cache_manager/v1/test_cache_manager.py \
  tests/cache_manager/v1/test_transfer_manager.py -v
# 期望结果:326 passed
```

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2026-04-21 14:39:00 +08:00

1613 lines
62 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import json
import time
import traceback
from dataclasses import asdict, dataclass, fields
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional
from typing import TypeVar as TypingTypeVar
from typing import Union
if TYPE_CHECKING:
from fastdeploy.cache_manager.v1.metadata import MatchResult
import numpy as np
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing_extensions import TypeVar
from fastdeploy import envs
from fastdeploy.cache_manager.v1.metadata import CacheLevel, CacheSwapMetadata
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.protocol import (
AnyResponseFormat,
DeltaMessage,
StructuralTagResponseFormat,
ToolCall,
)
from fastdeploy.logger.request_logger import (
RequestLogLevel,
log_request,
log_request_error,
)
from fastdeploy.worker.output import (
LogprobsLists,
PromptLogprobs,
SampleLogprobs,
SpeculateMetrics,
)
class RequestStatus(Enum):
WAITING = 0
RUNNING = 1
PREEMPTED = 2
FINISHED = 3
ABORT = 4
class RequestType(Enum):
PREFILL = 0
DECODE = 1
PREEMPTED = 2
EXTEND = 3
ABORT = 4
@dataclass
class ImagePosition:
offset: int = 0
length: int = 0
T = TypingTypeVar("T")
@dataclass
class Request:
def __init__(
self,
request_id: Optional[str],
prompt: Optional[Union[str, list[str], list[list[int]], list[int]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
prompt_token_ids_len: Optional[int] = None,
messages: Optional[list[Any]] = None,
tools: Optional[list[Dict]] = None,
system: Optional[Union[str, list[str]]] = None,
history: Optional[list[list[str]]] = None,
eos_token_ids: Optional[list[int]] = None,
sampling_params: Optional[SamplingParams] = None,
pooling_params: Optional[PoolingParams] = None,
multimodal_inputs: Optional[dict] = None,
multimodal_data: Optional[dict] = None,
disable_chat_template: bool = False,
disaggregate_info: Optional[dict] = None,
draft_token_ids: Optional[list[int]] = None,
guided_json: Optional[Any] = None,
guided_regex: Optional[Any] = None,
guided_choice: Optional[Any] = None,
guided_grammar: Optional[Any] = None,
structural_tag: Optional[Any] = None,
guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = None,
reasoning_max_tokens: Optional[int] = None,
trace_carrier: Optional[Dict[str, Any]] = None,
dp_rank: Optional[int] = None,
chat_template: Optional[str] = None,
image_start: int = 0,
video_start: int = 0,
audio_start: int = 0,
image_end: int = 0,
video_end: int = 0,
audio_end: int = 0,
prefill_start_index: int = 0,
prefill_end_index: int = 0,
num_computed_tokens: int = 0,
# for internal adapter
ic_req_data: Optional[dict] = (None,),
metrics: Optional[RequestMetrics] = None,
# from ChatCompletionRequest or CompletionRequest
user: Optional[str] = None,
metadata: Optional[dict] = None,
completion_token_ids: Optional[list[int]] = None,
chat_template_kwargs: Optional[dict] = None,
prompt_tokens: Optional[str] = None,
add_generation_prompt: Optional[bool] = None,
response_format: Optional[AnyResponseFormat] = None,
mm_hashes: Optional[list] = None,
suffix: Optional[dict] = None,
top_logprobs: Optional[int] = None,
# from PoolingRequest
add_special_tokens: Optional[bool] = False,
zmq_worker_pid: Optional[int] = None,
# block hasher for dynamic hash computation
block_hasher: Optional[callable] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_token_ids_len = prompt_token_ids_len
self.messages = messages
self.system = system
self.sampling_params = sampling_params
self.pooling_params = pooling_params
self.history = history
self.tools = tools
# model specific token ids: end of sentence token ids
self.eos_token_ids = eos_token_ids
self.disable_chat_template = disable_chat_template
self.disaggregate_info = disaggregate_info
# prefix caching related
self.num_cached_tokens = 0
self.num_cached_blocks = 0
self._prompt_hashes: list[str] = []
self._block_hasher = block_hasher
self._match_result: Optional[MatchResult] = None
self.cache_swap_metadata: list[CacheSwapMetadata] = []
self.cache_evict_metadata: list[CacheSwapMetadata] = []
# speculative method in disaggregate-mode
self.draft_token_ids = draft_token_ids
# guided decoding related
self.guided_json = guided_json
self.guided_regex = guided_regex
self.guided_choice = guided_choice
self.guided_grammar = guided_grammar
self.structural_tag = structural_tag
self.guided_json_object = guided_json_object
# Multi-modal related
self.multimodal_inputs = multimodal_inputs
self.multimodal_data = multimodal_data
self.multimodal_img_boundaries = None
self.enable_thinking = enable_thinking
self.reasoning_max_tokens = reasoning_max_tokens
self.trace_carrier = trace_carrier
self.chat_template = chat_template
# token num
self.block_tables = []
self.output_token_ids = []
self.num_computed_tokens = num_computed_tokens
self.prefill_start_index = prefill_start_index
self.prefill_end_index = prefill_end_index
self.image_start = image_start
self.video_start = video_start
self.audio_start = audio_start
self.image_end = image_end
self.video_end = video_end
self.audio_end = audio_end
# status
self.status = RequestStatus.WAITING
self.task_type = RequestType.PREFILL
self.has_been_preempted_before = False
self.idx = None
self.need_prefill_tokens = self.prompt_token_ids_len
self.audio_output_token_ids = []
# extend block tables
self.use_extend_tables = False
self.extend_block_tables = []
# dp
self.dp_rank = dp_rank
self.ic_req_data = ic_req_data
self.async_process_futures = []
self.error_message = None
self.error_code = None
if metrics is None:
self.metrics = RequestMetrics()
else:
self.metrics = metrics
# from ChatCompletionRequest or CompletionRequest
self.user = user
self.metadata = metadata
self.completion_token_ids = completion_token_ids
self.chat_template_kwargs = chat_template_kwargs
self.prompt_tokens = prompt_tokens
self.add_generation_prompt = add_generation_prompt
self.response_format = response_format
self.mm_hashes = mm_hashes
self.suffix = suffix
self.top_logprobs = top_logprobs
# from PoolingRequest
self.add_special_tokens = add_special_tokens
self.zmq_worker_pid = zmq_worker_pid
@property
def prompt_hashes(self) -> list[str]:
"""
Dynamically get prompt_hashes, automatically computing new block hashes.
When accessing this property, it checks if there are new complete blocks
that need hash computation, and if so, computes and appends them.
"""
if self._block_hasher is not None:
new_hashes = self._block_hasher(self)
if new_hashes:
self._prompt_hashes.extend(new_hashes)
return self._prompt_hashes
@property
def match_result(self) -> Optional[MatchResult]:
return self._match_result
def set_block_hasher(self, block_hasher: callable):
"""Set the block hasher for dynamic hash computation."""
self._block_hasher = block_hasher
def pop_cache_swap_metadata(self) -> list[CacheSwapMetadata]:
result = self.cache_swap_metadata
self.cache_swap_metadata = []
return result
def pop_cache_evict_metadata(self) -> list[CacheSwapMetadata]:
result = self.cache_evict_metadata
self.cache_evict_metadata = []
return result
@classmethod
def _process_guided_json(cls, r: T):
guided_json_object = None
if hasattr(r, "response_format") and r.response_format is not None:
if r.response_format.type == "json_object":
guided_json_object = True
elif r.response_format.type == "json_schema":
json_schema = r.response_format.json_schema.json_schema
assert json_schema is not None, "response_format.json_schema can not be None"
if isinstance(json_schema, (BaseModel, type(BaseModel))):
r.guided_json = json_schema.model_json_schema()
else:
r.guided_json = json_schema
elif r.response_format.type == "structural_tag":
structural_tag = r.response_format
assert structural_tag is not None and isinstance(structural_tag, StructuralTagResponseFormat)
r.structural_tag = json.dumps(structural_tag.model_dump(by_alias=True))
return guided_json_object
@classmethod
def from_generic_request(
cls,
req: T,
request_id: Optional[str] = None,
prompt: Optional[Union[str, list[int]]] = None,
pooling_params: Optional[PoolingParams] = None,
):
if request_id is not None:
setattr(req, "request_id", request_id)
if pooling_params is None:
sampling_params = SamplingParams.from_generic_request(req)
else:
sampling_params = SamplingParams()
guided_json_object = cls._process_guided_json(req)
metrics = RequestMetrics()
request = cls(
request_id=getattr(req, "request_id", None),
prompt_token_ids=getattr(req, "prompt_token_ids", None),
prompt=prompt,
sampling_params=sampling_params,
pooling_params=pooling_params,
metrics=metrics,
guided_json_object=guided_json_object,
disaggregate_info=getattr(req, "disaggregate_info", None),
guided_json=getattr(req, "guided_json", None),
guided_regex=getattr(req, "guided_regex", None),
guided_choice=getattr(req, "guided_choice", None),
guided_grammar=getattr(req, "guided_grammar", None),
user=getattr(req, "user", None),
response_format=(
getattr(req, "response_format", None).model_dump()
if (hasattr(getattr(req, "response_format", None), "model_dump"))
else None
),
mm_hashes=getattr(req, "mm_hashes", None),
add_special_tokens=getattr(req, "add_special_tokens", False),
)
if hasattr(req, "messages"):
if hasattr(req, "prompt_token_ids") and not req.prompt_token_ids:
# If disable_chat_template is set, then the first message in messages will be used as the prompt.
assert len(req.messages) > 0, "messages can not be an empty list, unless prompt_token_ids is passed"
if req.disable_chat_template:
request.prompt = req.messages[0]["content"]
request.messages = []
request.messages = getattr(req, "messages", None)
request.tools = (
[tool.model_dump() for tool in getattr(req, "tools", [])] if getattr(req, "tools", None) else None
)
request.reasoning_max_tokens = getattr(req, "reasoning_max_tokens", None)
request.disable_chat_template = getattr(req, "disable_chat_template", None)
request.top_logprobs = getattr(req, "top_logprobs", None)
request.structural_tag = getattr(req, "structural_tag", None)
request.chat_template = getattr(req, "chat_template", None)
request.ic_req_data = getattr(req, "ic_req_data", None)
request.metadata = getattr(req, "metadata", None)
request.completion_token_ids = getattr(req, "completion_token_ids", None)
request.chat_template_kwargs = getattr(req, "chat_template_kwargs", None)
if getattr(req, "suffix", None):
request.suffix = getattr(req, "suffix", None)
for key, value in req.suffix.items():
setattr(request, key, value)
if getattr(req, "metadata", None):
assert (
"raw_request" not in req.metadata
), "The parameter `raw_request` is not supported now, please use completion api instead."
for key, value in req.metadata.items():
setattr(request, key, value)
log_request(RequestLogLevel.STAGES, message="The parameter metadata is obsolete.")
return request
@classmethod
def from_dict(cls, d: dict):
log_request(RequestLogLevel.FULL, message="{request}", request=d)
sampling_params: SamplingParams = None
pooling_params: PoolingParams = None
metrics: RequestMetrics = None
if "pooling_params" in d and d["pooling_params"] is not None:
pooling_params = PoolingParams.from_dict(d["pooling_params"])
else:
sampling_params = SamplingParams.from_dict(d)
logprobs = d.get("logprobs", None)
if logprobs is not None:
if logprobs is True:
sampling_params.logprobs = d.get("top_logprobs", None)
elif logprobs is False:
sampling_params.logprobs = None
if "metrics" in d and d["metrics"] is not None:
metrics = RequestMetrics.from_dict(d["metrics"])
else:
metrics = RequestMetrics.from_dict(d)
if (
isinstance(d.get("multimodal_inputs"), dict)
and isinstance(d["multimodal_inputs"].get("mm_positions"), list)
and len(d["multimodal_inputs"]["mm_positions"]) > 0
):
# if mm_positions is not of type ImagePosition, convert to ImagePosition
try:
for i, mm_pos in enumerate(d["multimodal_inputs"]["mm_positions"]):
d["multimodal_inputs"]["mm_positions"][i] = (
ImagePosition(**mm_pos) if not isinstance(mm_pos, ImagePosition) else mm_pos
)
except Exception as e:
log_request_error(
message="request[{request_id}] Convert mm_positions to ImagePosition error: {error}, {traceback}",
request_id=d.get("request_id"),
error=str(e),
traceback=traceback.format_exc(),
)
return cls(
request_id=d["request_id"],
prompt=d.get("prompt"),
prompt_token_ids=d.get("prompt_token_ids"),
prompt_token_ids_len=d.get("prompt_token_ids_len"),
messages=d.get("messages"),
system=d.get("system"),
history=d.get("history"),
tools=d.get("tools"),
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_ids=d.get("eos_token_ids"),
multimodal_inputs=d.get("multimodal_inputs"),
multimodal_data=d.get("multimodal_data"),
disable_chat_template=d.get("disable_chat_template"),
disaggregate_info=d.get("disaggregate_info"),
draft_token_ids=d.get("draft_token_ids"),
guided_json=d.get("guided_json", None),
guided_regex=d.get("guided_regex", None),
guided_choice=d.get("guided_choice", None),
guided_grammar=d.get("guided_grammar", None),
structural_tag=d.get("structural_tag", None),
guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", None),
reasoning_max_tokens=d.get("reasoning_max_tokens", None),
trace_carrier=d.get("trace_carrier", {}),
chat_template=d.get("chat_template", None),
num_computed_tokens=d.get("num_computed_tokens", 0),
prefill_start_index=d.get("prefill_start_index", 0),
prefill_end_index=d.get("prefill_end_index", 0),
image_start=d.get("image_start", 0),
video_start=d.get("video_start", 0),
audio_start=d.get("audio_start", 0),
image_end=d.get("image_end", 0),
video_end=d.get("video_end", 0),
audio_end=d.get("audio_end", 0),
dp_rank=d.get("dp_rank", None),
ic_req_data=d.get("ic_req_data", None),
metrics=metrics,
)
@property
def num_total_tokens(self):
"""
Total tokens of the request, include prompt tokens and generated tokens.
"""
return self.prompt_token_ids_len + len(self.output_token_ids)
def __getstate__(self):
"""
Custom getstate method for pickle support.
Handles unpicklable attributes by filtering them from __dict__.
"""
# Attributes that cannot or need not be pickled for cross-process transfer.
# _block_hasher: closure/callable, not picklable.
# _match_result: contains BlockNode tree with parent<->children circular
# references, which causes RecursionError during pickling.
# async_process_futures: asyncio futures, not picklable.
_SKIP_KEYS = {"_block_hasher", "_match_result"}
filtered_dict = {}
for key, value in self.__dict__.items():
if key in _SKIP_KEYS:
continue
elif key == "async_process_futures":
filtered_dict[key] = []
else:
filtered_dict[key] = value
return filtered_dict
def __setstate__(self, state):
self.__dict__.update(state)
# Restore fields that were excluded from pickling with safe defaults.
if "_block_hasher" not in self.__dict__:
self._block_hasher = None
if "_match_result" not in self.__dict__:
self._match_result = None
def __eq__(self, other):
"""
EQ operator.
"""
if not isinstance(other, Request):
return False
return self.request_id == other.request_id
def to_dict(self) -> dict:
"""convert Request into a serializable dict"""
data = {
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": self.prompt_token_ids,
"prompt_token_ids_len": self.prompt_token_ids_len,
"messages": self.messages,
"system": self.system,
"history": self.history,
"tools": self.tools,
"eos_token_ids": self.eos_token_ids,
"multimodal_data": self.multimodal_data,
"disable_chat_template": self.disable_chat_template,
"disaggregate_info": self.disaggregate_info,
"draft_token_ids": self.draft_token_ids,
"enable_thinking": self.enable_thinking,
"reasoning_max_tokens": self.reasoning_max_tokens,
"trace_carrier": self.trace_carrier,
"chat_template": self.chat_template,
"num_computed_tokens": self.num_computed_tokens,
"prefill_start_index": self.prefill_start_index,
"prefill_end_index": self.prefill_end_index,
"image_start": self.image_start,
"video_start": self.video_start,
"audio_start": self.audio_start,
"image_end": self.image_end,
"video_end": self.video_end,
"audio_end": self.audio_end,
"ic_req_data": self.ic_req_data,
}
if isinstance(self.multimodal_inputs, dict):
# Optimize multimodal data transfer during PD separation:
# - V1 mode (ENABLE_V1_KVCACHE_SCHEDULER=1): position_ids, mm_positions and mm_hashes needed for decode nodes
# - V0 mode (ENABLE_V1_KVCACHE_SCHEDULER=0): Full field set required for compatibility
# This filtering significantly reduces serialized data size for large numpy arrays
allowed_keys = {"position_ids", "mm_positions", "mm_hashes"}
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
allowed_keys.update(["input_ids", "token_type_ids", "images", "image_type_ids", "grid_thw"])
data["multimodal_inputs"] = {
key: value for key, value in self.multimodal_inputs.items() if key in allowed_keys
}
add_params = [
"guided_json",
"guided_regex",
"guided_choice",
"guided_grammar",
"structural_tag",
"guided_json_object",
]
for param in add_params:
if getattr(self, param, None) is not None:
data[param] = getattr(self, param)
if self.sampling_params is not None:
data.update(asdict(self.sampling_params))
data.update(asdict(self.metrics))
return data
def get(self, key: str, default_value=None):
if hasattr(self, key):
return getattr(self, key)
elif hasattr(self.sampling_params, key):
return getattr(self.sampling_params, key)
else:
return default_value
def set(self, key, value):
if hasattr(self.sampling_params, key):
setattr(self.sampling_params, key, value)
else:
setattr(self, key, value)
def __repr__(self) -> str:
"""Sanitized repr without private or None fields."""
try:
if not envs.FD_DEBUG:
return f"Request(request_id={self.request_id})"
else:
attrs_snapshot = dict(vars(self))
non_none_fields = [
f"{attr}={value!r}"
for attr, value in attrs_snapshot.items()
if value is not None and not attr.startswith("_")
]
return f"Request({', '.join(non_none_fields)})"
except Exception as e:
return f"<Request repr failed: {e}>"
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
elif hasattr(self.sampling_params, key):
return getattr(self.sampling_params, key)
else:
raise KeyError(key) from None
def __setitem__(self, key, value):
if hasattr(self.sampling_params, key):
setattr(self.sampling_params, key, value)
else:
setattr(self, key, value)
def __delitem__(self, key):
try:
if hasattr(self.sampling_params, key):
delattr(self.sampling_params, key)
else:
delattr(self, key)
except AttributeError:
raise KeyError(key) from None
def __contains__(self, key: str) -> bool:
if hasattr(self.sampling_params, key):
return True
return hasattr(self, key)
class BatchRequest:
def __init__(self):
self.requests: list[Request] = []
self.cache_swap_metadata: Optional[CacheSwapMetadata] = None
self.cache_evict_metadata: Optional[CacheSwapMetadata] = None
def add_request(self, request):
if hasattr(request, "cache_swap_metadata") and request.cache_swap_metadata:
self.append_swap_metadata(request.pop_cache_swap_metadata())
request.cache_swap_metadata = []
if hasattr(request, "cache_evict_metadata") and request.cache_evict_metadata:
self.append_evict_metadata(request.pop_cache_evict_metadata())
request.cache_evict_metadata = []
self.requests.append(request)
def append_swap_metadata(self, metadata: List[CacheSwapMetadata]):
for meta in metadata:
if self.cache_swap_metadata:
self.cache_swap_metadata.src_block_ids.extend(meta.src_block_ids)
self.cache_swap_metadata.dst_block_ids.extend(meta.dst_block_ids)
self.cache_swap_metadata.hash_values.extend(meta.hash_values)
else:
self.cache_swap_metadata = CacheSwapMetadata(
src_block_ids=meta.src_block_ids,
dst_block_ids=meta.dst_block_ids,
src_type=CacheLevel.HOST,
dst_type=CacheLevel.DEVICE,
hash_values=meta.hash_values,
)
def append_evict_metadata(self, metadata: List[CacheSwapMetadata]):
for meta in metadata:
if self.cache_evict_metadata:
self.cache_evict_metadata.src_block_ids.extend(meta.src_block_ids)
self.cache_evict_metadata.dst_block_ids.extend(meta.dst_block_ids)
self.cache_evict_metadata.hash_values.extend(meta.hash_values)
else:
self.cache_evict_metadata = CacheSwapMetadata(
src_block_ids=meta.src_block_ids,
dst_block_ids=meta.dst_block_ids,
src_type=CacheLevel.DEVICE,
dst_type=CacheLevel.HOST,
hash_values=meta.hash_values,
)
def __repr__(self):
requests_repr = repr(self.requests)
return f"BatchRequest(requests={requests_repr}, swap_metadata={self.cache_swap_metadata}, evict_metadata={self.cache_evict_metadata})"
def __getstate__(self):
state = self.__dict__.copy()
state["requests"] = [req.__getstate__() if hasattr(req, "__getstate__") else req for req in state["requests"]]
return state
def __setstate__(self, state):
self.__dict__.update(state)
restored_requests = []
for req_data in self.requests:
if isinstance(req_data, dict):
req = Request.__new__(Request)
req.__dict__.update(req_data)
restored_requests.append(req)
else:
restored_requests.append(req_data)
self.requests = restored_requests
def __iter__(self):
for req in self.requests:
yield req
def __getitem__(self, index):
return self.requests[index]
def __len__(self):
return len(self.requests)
def append(self, batch_request: "BatchRequest"):
self.requests.extend(batch_request.requests)
if batch_request.cache_swap_metadata:
self.append_swap_metadata([batch_request.cache_swap_metadata])
if batch_request.cache_evict_metadata:
self.append_evict_metadata([batch_request.cache_evict_metadata])
def extend(self, batch_requests: list["BatchRequest"]):
for br in batch_requests:
self.append(br)
@classmethod
def from_tasks(cls, tasks: list) -> tuple["BatchRequest", list, int]:
"""Classify tasks from the engine worker queue into inference requests and control requests.
Args:
tasks: List of (payload, real_bsz) tuples from task_queue.get_tasks().
payload is one of: BatchRequest, List[Request], or [ControlRequest].
Returns:
(batch_request, control_reqs, max_occupied_batch_index)
- batch_request: merged BatchRequest containing all inference requests
- control_reqs: list of ControlRequest objects
- max_occupied_batch_index: real_bsz of the last inference task batch
"""
batch_request = cls()
control_reqs = []
max_occupied_batch_index = 0
for payload, bsz in tasks:
if len(payload) > 0 and isinstance(payload[0], ControlRequest):
control_reqs.append(payload[0])
else:
max_occupied_batch_index = int(bsz)
if isinstance(payload, cls):
batch_request.append(payload)
else:
for req in payload:
batch_request.add_request(req)
return batch_request, control_reqs, max_occupied_batch_index
class ControlRequest:
"""A generic control request that supports method and args for control operations.
This request type is used for system-level control operations rather than
typical inference requests. It enables dynamic control of engine behavior,
resource management, and system configuration via a flexible method-args interface.
"""
def __init__(
self,
request_id: str,
method: str,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""
Args:
request_id: Unique identifier for the control request.
method: The control method to execute (e.g., "reset_scheduler", "get_metrics").
args: Optional arguments for the control method.
"""
self.request_id = request_id
self.method = method
self.args = args or {}
@classmethod
def from_dict(cls, d: dict):
"""Create ControlRequest instance from dictionary."""
return cls(request_id=d["request_id"], method=d["method"], args=d.get("args", {}))
def to_dict(self) -> dict:
"""Convert ControlRequest into a serializable dict."""
return {"request_id": self.request_id, "method": self.method, "args": self.args}
def __repr__(self) -> str:
"""Provide a clean representation of the control request."""
try:
if not envs.FD_DEBUG:
return f"ControlRequest(request_id={self.request_id}, method={self.method})"
else:
return (
f"ControlRequest("
f"request_id={self.request_id}, "
f"method={self.method}, "
f"args={self.args}"
f")"
)
except Exception as e:
return f"<ControlRequest repr failed: {e}>"
def get_method(self) -> str:
"""Get the control method name."""
return self.method
def get_args(self) -> Dict[str, Any]:
"""Get the control method arguments."""
return self.args.copy()
@staticmethod
def is_control_request(d: dict) -> bool:
"""
Check if a dictionary represents a valid ControlRequest.
Args:
d: Dictionary to check
Returns:
bool: True if the dictionary contains the required fields for a ControlRequest
"""
# Check if all required fields are present and have correct types
if not isinstance(d, dict):
return False
# Check field types
if "request_id" not in d or not isinstance(d.get("request_id"), str):
return False
if "method" not in d or not isinstance(d.get("method"), str):
return False
# Args is optional, but if present should be a dict
if "args" in d and not isinstance(d["args"], dict):
return False
return True
class ControlResponse:
"""
Response for control operations
"""
def __init__(
self,
request_id: str,
error_code: int = 200,
error_message: Optional[str] = None,
result: Optional[dict] = None,
finished: bool = True,
) -> None:
self.request_id = request_id
self.finished = finished
self.error_message = error_message
self.result = result
self.error_code = error_code
def to_dict(self) -> dict:
"""Convert ControlResponse into a serializable dict."""
return {
"request_id": self.request_id,
"finished": self.finished,
"error_code": self.error_code,
"error_message": self.error_message,
"result": self.result,
}
@classmethod
def from_dict(cls, d: dict):
"""Create ControlResponse instance from dictionary."""
return cls(
request_id=d["request_id"],
finished=d.get("finished", True),
error_code=d.get("error_code", 200),
error_message=d.get("error_message"),
result=d.get("result"),
)
def to_api_json_response(self) -> JSONResponse:
"""Convert ControlResponse into a JSONResponse."""
status = "success" if self.error_code == 200 else "error"
content = {
"request_id": self.request_id,
"status": status,
"error_message": self.error_message,
"result": self.result,
}
return JSONResponse(status_code=self.error_code, content=content)
def __repr__(self) -> str:
"""Provide a clean representation of the control response."""
return (
f"ControlResponse("
f"request_id={self.request_id}, "
f"finished={self.finished}, "
f"error_code={self.error_code}, "
f"error_message={self.error_message}, "
f"result={self.result}"
f")"
)
@dataclass(slots=True)
class CompletionOutput:
"""The output data of one completion output of a request.
Args:
index: The index of the output in the request.
text: The generated output text.
token_ids: The token IDs of the generated output text.
"""
index: int
send_idx: int
token_ids: list[Any]
decode_type: int = 0
logprob: Optional[float] = None
top_logprobs: Optional[LogprobsLists] = None
draft_top_logprobs: Optional[LogprobsLists] = None
logprobs: Optional[SampleLogprobs] = None
draft_token_ids: list[int] = None
text: Optional[str] = None
reasoning_content: Optional[str] = None
reasoning_token_num: Optional[int] = 0
tool_calls: Optional[ToolCall] = None
speculate_metrics: Optional[SpeculateMetrics] = None
completion_tokens: Optional[str] = None
delta_message: Optional[DeltaMessage] = None
multipart: Optional[list[Any]] = None
num_image_tokens: Optional[int] = None
def to_dict(self):
"""
convert CompletionOutput to a serialized dict
"""
return {
"index": self.index,
"send_idx": self.send_idx,
"token_ids": self.token_ids,
"decode_type": self.decode_type,
"logprob": self.logprob,
"top_logprobs": self.top_logprobs,
"draft_top_logprobs": self.draft_top_logprobs,
"logprobs": self.logprobs,
"draft_token_ids": self.draft_token_ids,
"text": self.text,
"reasoning_content": self.reasoning_content,
"reasoning_token_num": self.reasoning_token_num,
}
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> CompletionOutput:
"""Create instance from dict arguments"""
return cls(
**{
field.name: (req_dict[field.name] if field.name in req_dict else field.default)
for field in fields(cls)
}
)
def __repr__(self) -> str:
return (
f"CompletionOutput(index={self.index}, "
f"send_idx={self.send_idx}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"decode_type={self.decode_type}, "
f"draft_token_ids={self.draft_token_ids}, "
f"reasoning_content={self.reasoning_content!r}, "
f"reasoning_token_num={self.reasoning_token_num}, "
f"logprobs={self.logprobs}, "
f"top_logprobs={self.top_logprobs}, "
f"draft_top_logprobs={self.draft_top_logprobs}, "
)
def get(self, key: str, default_value=None):
if hasattr(self, key):
return getattr(self, key)
else:
return default_value
def set(self, key: str, value):
if hasattr(self, key):
setattr(self, key, value)
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(key) from None
def __setitem__(self, key, value):
if hasattr(self, key):
setattr(self, key, value)
@dataclass(slots=True)
class RequestMetrics:
"""Metrics associated with a request.
Attributes:
arrival_time: The time when the request arrived.
preprocess_start_time: The time when the preprocess started.
preprocess_end_time: The time when the preprocess ended.
scheduler_recv_req_time: The time when the scheduler received the request.
engine_get_req_time: The time when the engine got the request.
ask_decode_resource_start_time: The time when the engine asks for decode resource.
ask_decode_resource_finish_time: The time when the engine has asked for decode resource.
inference_start_time: The time when engine adds request to the running queue in resource manager.
wait_for_sending_cache_time: The time when the engine waited for sending cache.
send_request_output_to_decode_time: The time when the engine sent request_output to decode.
decode_recv_req_time: The time when the decode received the request.
decode_preallocate_req_time: The time when the decode has preallocated resource for the request.
decode_recv_first_token_time: The time when the decode received the first token.
decode_inference_start_time: The time when the decode sent the request to worker.
decode_recv_second_token_time: The time when the decode received the second token.
first_token_time: The cost time between engine_recv_first_token_time and inference_start_time
time_in_queue: The time the request spent in the queue.
model_forward_time: The time spent in the model forward pass when this
request was in the batch.
model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
request_start_time: Time to accept the request
"""
arrival_time: Optional[float] = None # api server receives request
preprocess_start_time: Optional[float] = None # preprocess start time in api server
preprocess_end_time: Optional[float] = None # preprocess end time in api server
scheduler_recv_req_time: Optional[float] = None # scheduler receives request and add to scheduler
engine_get_req_time: Optional[float] = None # engine gets request from scheduler
ask_decode_resource_start_time: Optional[float] = None # engine asks decode resource (only valid for prefill)
ask_decode_resource_finish_time: Optional[float] = None # engine has got decode resource (only valid for prefill)
add_req_to_resource_manager_time: Optional[float] = None # engine adds request to resource manager
inference_start_time: Optional[float] = None # requests are added into the engine work queue
engine_recv_latest_token_time: Optional[float] = None # receive the latest token from worker
engine_recv_first_token_time: Optional[float] = None # receive first token from worker
wait_for_sending_cache_time: Optional[float] = None # wait for sending cache (only valid for prefill)
send_request_output_to_decode_time: Optional[float] = (
None # send request_output to worker (only valid for prefill)
)
decode_recv_req_time: Optional[float] = None # decode receive request from prefill (only valid for decode)
decode_preallocate_req_time: Optional[float] = (
None # decode has preallocatee resource for req (only valid for decode)
)
decode_recv_first_token_time: Optional[float] = (
None # decode receive request_output with first token from prefill (only valid for decode)
)
decode_inference_start_time: Optional[float] = (
None # decode adds request to the engine work queue (only valid for decode)
)
decode_recv_second_token_time: Optional[float] = (
None # decode receives the second token from worker (only valid for decode)
)
first_token_time: Optional[float] = None
time_in_queue: Optional[float] = None
preprocess_cost_time: Optional[float] = None
model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None
request_start_time: Optional[float] = None
llm_engine_recv_req_timestamp: Optional[float] = None
llm_engine_send_req_to_engine_timestamp: Optional[float] = None
llm_engine_send_req_to_decoder_engine_timestamp: Optional[float] = None
llm_engine_recv_latest_token_timestamp: Optional[float] = None
llm_engine_recv_token_timestamp: Optional[float] = None
speculate_metrics: Optional[SpeculateMetrics] = None
# cache related
gpu_cache_token_num: Optional[int] = 0
cpu_cache_token_num: Optional[int] = 0
storage_cache_token_num: Optional[int] = 0
cpu_cache_prepare_time: Optional[float] = None
storage_cache_prepare_time: Optional[float] = None
preempted_count: int = 0
def __post_init__(self):
if self.arrival_time is None:
self.arrival_time = time.time()
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> RequestMetrics:
"""Create instance from dict arguments"""
return cls(
**{
field.name: (req_dict[field.name] if field.name in req_dict else field.default)
for field in fields(cls)
}
)
def to_dict(self):
"""
Convert the RequestMetrics object to a dictionary.
"""
return {k: v for k, v in asdict(self).items()}
def record_recv_first_token(self):
cur_time = time.time()
self.record_recv_token(cur_time)
self.engine_recv_first_token_time = cur_time
def record_recv_token(self, cur_time: float = None):
cur_time = time.time() if cur_time is None else cur_time
self.engine_recv_latest_token_time = cur_time
self.llm_engine_recv_latest_token_timestamp = cur_time
self.model_execute_time = cur_time - self.arrival_time
if self.inference_start_time:
self.model_forward_time = cur_time - self.inference_start_time
def record_decode_recv_second_token(self):
cur_time = time.time()
self.record_recv_token(cur_time)
self.decode_recv_second_token_time = cur_time
def get_inference_start_time(self, is_decode: bool):
if is_decode:
return self.decode_inference_start_time
else:
return self.inference_start_time
def cal_cost_time(self):
"""Calculates various timing metrics based on the recorded times"""
if self.engine_recv_first_token_time and self.inference_start_time:
self.first_token_time = self.engine_recv_first_token_time - self.inference_start_time
if self.inference_start_time and self.preprocess_end_time:
self.time_in_queue = self.inference_start_time - self.preprocess_end_time
if self.preprocess_end_time and self.preprocess_start_time:
self.preprocess_cost_time = self.preprocess_end_time - self.preprocess_start_time
self.request_start_time = self.arrival_time
# for compatibility with old metrics
self.llm_engine_recv_req_timestamp = self.engine_get_req_time
self.llm_engine_send_req_to_engine_timestamp = self.inference_start_time
self.llm_engine_recv_token_timestamp = self.engine_recv_first_token_time
def get(self, key: str, default_value=None):
if hasattr(self, key):
return getattr(self, key)
else:
return default_value
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(key) from None
def __setitem__(self, key, value):
setattr(self, key, value)
# Set engine time for decoder-node
def update_decoder_start_time(self):
self.llm_engine_send_req_to_decoder_engine_timestamp = self.decode_inference_start_time
class RequestOutput:
"""The output data of a completion request to the LLM.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
For encoder/decoder models, this is the
decoder input prompt.
prompt_token_ids: The token IDs of the prompt.
For encoder/decoder models, this is the
decoder input prompt token ids.
prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request.
None if decoder-only.
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
num_input_image_tokens: The number of input image tokens.
num_input_video_tokens: The number of input video tokens.
"""
def __init__(
self,
request_id: str,
prompt: Optional[str] = None,
prompt_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[PromptLogprobs] = None,
output_type: Optional[int] = 3,
outputs: CompletionOutput = None,
finished: bool = False,
metrics: Optional[RequestMetrics] = None,
num_cached_tokens: Optional[int] = 0,
num_input_image_tokens: Optional[int] = 0,
num_input_video_tokens: Optional[int] = 0,
error_code: Optional[int] = 200,
error_msg: Optional[str] = None,
# for internal adapter
ic_req_data: Optional[dict] = None,
prompt_token_ids_len: Optional[int] = 0,
trace_carrier: dict = dict(),
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_logprobs = prompt_logprobs
self.output_type = output_type
self.outputs = outputs
self.finished = finished
self.metrics = metrics
self.num_cached_tokens = num_cached_tokens
self.num_input_image_tokens = num_input_image_tokens
self.num_input_video_tokens = num_input_video_tokens
self.error_code = error_code
self.error_msg = error_msg
self.ic_req_data = ic_req_data
self.prompt_token_ids_len = prompt_token_ids_len
self.trace_carrier = trace_carrier
if prompt_token_ids is None:
self.prompt_token_ids = []
elif isinstance(self.prompt_token_ids, np.ndarray):
self.prompt_token_ids = self.prompt_token_ids.tolist()
if self.outputs and self.outputs.tool_calls:
self.accumulate_tool_calls: Optional[list[ToolCall]] = [self.outputs.tool_calls]
else:
self.accumulate_tool_calls = None
def add(self, next_output: RequestOutput) -> None:
"""Merge RequestOutput into this one"""
if next_output.prompt is not None:
self.prompt = next_output.prompt
if next_output.prompt_token_ids is not None:
self.prompt_token_ids = next_output.prompt_token_ids
self.finished |= next_output.finished
self.outputs.index = next_output.outputs.index
self.outputs.token_ids.extend(next_output.outputs.token_ids)
if next_output.metrics.model_forward_time is not None:
self.metrics.model_forward_time = next_output.metrics.model_forward_time
if next_output.metrics.model_execute_time is not None:
self.metrics.model_execute_time = next_output.metrics.model_execute_time
if next_output.metrics.engine_recv_latest_token_time is not None:
self.metrics.engine_recv_latest_token_time = next_output.metrics.engine_recv_latest_token_time
if next_output.outputs.top_logprobs is not None:
self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids)
self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs)
self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks)
if next_output.outputs.draft_top_logprobs is not None:
self.outputs.draft_top_logprobs.logprob_token_ids.extend(
next_output.outputs.draft_top_logprobs.logprob_token_ids
)
self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs)
self.outputs.draft_top_logprobs.sampled_token_ranks.extend(
next_output.outputs.draft_top_logprobs.sampled_token_ranks
)
if next_output.metrics.speculate_metrics is not None:
self.outputs.speculate_metrics = next_output.metrics.speculate_metrics
def accumulate(self, next_output: RequestOutput) -> None:
"""Accumulate RequestOutput"""
if self.outputs.text is None:
self.outputs.text = next_output.outputs.text
elif next_output.outputs.text:
self.outputs.text += next_output.outputs.text
if self.outputs.reasoning_content is None:
self.outputs.reasoning_content = next_output.outputs.reasoning_content
elif next_output.outputs.reasoning_content:
self.outputs.reasoning_content += next_output.outputs.reasoning_content
if self.outputs.completion_tokens is None:
self.outputs.completion_tokens = next_output.outputs.completion_tokens
elif next_output.outputs.completion_tokens:
self.outputs.completion_tokens += next_output.outputs.completion_tokens
if next_output.outputs.tool_calls:
if self.accumulate_tool_calls is None:
self.accumulate_tool_calls = []
self.accumulate_tool_calls.append(next_output.outputs.tool_calls)
self.add(next_output)
def __repr__(self) -> str:
return (
f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"output_type={self.output_type}, "
f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"num_cached_tokens={self.num_cached_tokens}, "
f"num_input_image_tokens={self.num_input_image_tokens}, "
f"num_input_video_tokens={self.num_input_video_tokens}, "
f"metrics={self.metrics}, "
f"error_code={self.error_code}, "
f"error_msg={self.error_msg},"
f"trace_carrier={self.trace_carrier}"
)
@classmethod
def from_dict(cls, d: dict):
"""Create instance from dict arguments"""
if "outputs" in d and isinstance(d["outputs"], dict):
completion_output = CompletionOutput.from_dict(d.pop("outputs"))
else:
d.pop("outputs", None)
completion_output = None
if "metrics" in d and isinstance(d["metrics"], dict):
metrics = RequestMetrics.from_dict(d.pop("metrics"))
else:
d.pop("metrics", None)
metrics = None
trace_carrier = d.pop("trace_carrier", {})
return RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier)
def to_dict(self):
"""convert RequestOutput into a serializable dict"""
return {
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": self.prompt_token_ids,
"prompt_logprobs": self.prompt_logprobs,
"output_type": self.output_type,
"outputs": None if self.outputs is None else self.outputs.to_dict(),
"metrics": None if self.metrics is None else self.metrics.to_dict(),
"finished": self.finished,
"num_cached_tokens": self.num_cached_tokens,
"num_input_image_tokens": self.num_input_image_tokens,
"num_input_video_tokens": self.num_input_video_tokens,
"error_code": self.error_code,
"error_msg": self.error_msg,
"ic_req_data": self.ic_req_data,
"prompt_token_ids_len": self.prompt_token_ids_len,
"trace_carrier": self.trace_carrier,
}
def get(self, key: str, default_value=None):
if hasattr(self, key):
return getattr(self, key)
elif hasattr(self.outputs, key):
return getattr(self.outputs, key)
elif hasattr(self.metrics, key):
return getattr(self.metrics, key)
else:
return default_value
def set(self, key: str, value):
if hasattr(self.outputs, key):
setattr(self.outputs, key, value)
elif hasattr(self.metrics, key):
setattr(self.metrics, key, value)
else:
setattr(self, key, value)
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
elif hasattr(self.outputs, key):
return getattr(self.outputs, key)
elif hasattr(self.metrics, key):
return getattr(self.metrics, key)
else:
raise KeyError(key) from None
def __setitem__(self, key, value):
if hasattr(self.outputs, key):
setattr(self.outputs, key, value)
elif hasattr(self.metrics, key):
setattr(self.metrics, key, value)
else:
setattr(self, key, value)
def __delitem__(self, key):
if hasattr(self, key):
delattr(self, key)
elif hasattr(self.outputs, key):
delattr(self.outputs, key)
elif hasattr(self.metrics, key):
delattr(self.metrics, key)
else:
raise KeyError(key)
def __contains__(self, key: str) -> bool:
if hasattr(self, key):
return True
elif hasattr(self.outputs, key):
return True
elif hasattr(self.metrics, key):
return True
else:
return False
@dataclass
class PoolingOutput:
"""The output data of one pooling output of a request.
Args:
data: The extracted hidden states.
"""
data: list[Any]
def __repr__(self) -> str:
return f"PoolingOutput(data={self.data})"
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and bool((self.data == other.data).all())
def to_dict(self):
return {"data": self.data}
_O = TypeVar("_O", default=PoolingOutput)
@dataclass
class PoolingRequestOutput(Generic[_O]):
"""
The output data of a pooling request to the LLM.
Args:
request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the pooling is completed.
"""
request_id: str
outputs: _O
prompt_token_ids: list[int]
finished: bool
metrics: Optional[RequestMetrics] = (None,)
error_code: Optional[int] = (200,)
error_msg: Optional[str] = (None,)
def __repr__(self):
return (
f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"error_code={self.error_code}, "
f"error_msg={self.error_msg})"
)
def to_dict(self):
return {
"request_id": self.request_id,
"outputs": None if self.outputs is None else self.outputs.to_dict(),
"prompt_token_ids": self.prompt_token_ids,
"finished": self.finished,
"metrics": None if self.metrics is None else self.metrics.to_dict(),
"error_code": self.error_code,
"error_msg": self.error_msg,
}
@classmethod
def from_dict(cls, req_dict: dict):
"""Create instance from dict arguments"""
outputs = PoolingOutput(req_dict["outputs"]["data"])
init_args = {
field.name: (outputs if field.name == "outputs" else req_dict.get(field.name, field.default))
for field in fields(cls)
}
return cls(**init_args)
@dataclass
class EmbeddingOutput:
"""The output data of one embedding output of a request.
Args:
embedding: The embedding vector, which is a list of floats.
Its length depends on the hidden dimension of the model.
"""
embedding: list[float]
@staticmethod
def from_base(pooling_output: PoolingOutput):
pooled_data = pooling_output.data
# if pooled_data.ndim != 1:
# raise ValueError("pooled_data should be a 1-D embedding vector")
if isinstance(pooled_data, list):
return EmbeddingOutput(pooled_data)
return EmbeddingOutput(pooled_data.tolist())
@property
def hidden_size(self) -> int:
return len(self.embedding)
def __repr__(self) -> str:
return f"EmbeddingOutput(hidden_size={self.hidden_size})"
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return EmbeddingRequestOutput(
request_id=request_output.request_id,
outputs=EmbeddingOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
@dataclass
class ClassificationOutput:
"""The output data of one classification output of a request.
Args:
probs: The probability vector, which is a list of floats.
Its length depends on the number of classes.
"""
probs: list[float]
@staticmethod
def from_base(pooling_output: PoolingOutput):
# pooling_output shape: (num_classes)
pooled_data = pooling_output.data
if pooled_data.ndim != 1:
raise ValueError("pooled_data should be a 1-D probability vector")
return ClassificationOutput(pooled_data.tolist())
@property
def num_classes(self) -> int:
return len(self.probs)
def __repr__(self) -> str:
return f"ClassificationOutput(num_classes={self.num_classes})"
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return ClassificationRequestOutput(
request_id=request_output.request_id,
outputs=ClassificationOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
@dataclass
class ScoringOutput:
"""The output data of one scoring output of a request.
Args:
score: The similarity score, which is a scalar value.
"""
score: float
@staticmethod
def from_base(pooling_output: PoolingOutput):
# pooling_output shape:
# classify task: (num_classes) num_classes == 1
# embed task: a scalar value
pooled_data = pooling_output.data.squeeze()
if pooled_data.ndim != 0:
raise ValueError("pooled_data should be a scalar score")
return ScoringOutput(pooled_data.item())
def __repr__(self) -> str:
return f"ScoringOutput(score={self.score})"
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return ScoringRequestOutput(
request_id=request_output.request_id,
outputs=ScoringOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
@dataclass
class RewardOutput:
"""The output data of one reward output of a request.
Args:
reward: The score, which is a list of floats.
Its length depends on the hidden dimension of the model.
"""
score: list[float]
@staticmethod
def from_base(pooling_output: PoolingOutput):
pooled_data = pooling_output.data
# if pooled_data.ndim != 1:
# raise ValueError("pooled_data should be a 1-D embedding vector")
if isinstance(pooled_data, list):
return RewardOutput(pooled_data)
return RewardOutput(pooled_data.tolist())
@property
def hidden_size(self) -> int:
return len(self.score)
def __repr__(self) -> str:
return f"RewardOutput(hidden_size={self.hidden_size})"
class RewardRequestOutput(PoolingRequestOutput[RewardOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return RewardRequestOutput(
request_id=request_output.request_id,
outputs=RewardOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)