[Feature] support v1 update/clear api for RL (#6761)

* [Feature] support v1 update/clear api for RL

* [fix] fix execute_model and add sleep/wakeup api

* [fix] fix mtp and key_prefix

* [chore] move _update_key_prefix to resume method

* [fix] make the interface safe to call multiple times

* [fix] fix some tiny bugs

* [chore] make small changes against pr review

* [docs] add docs for weight update

* [test] add some tests and update docs

* [style] fix code style check

* [test] fix ci

* [fix] fix stale control responses when control method timed out

* [chore] remove unused code

* [chore] fix code style

* [chore] optimize tags and key_prefix

* [test] fix ci

* [chore] fix code style

* [test] fix ci

* [fix] fix ep control

* [fix] fix ep control for engine cache queue
This commit is contained in:
Yonghua Li
2026-03-25 19:18:46 +08:00
committed by GitHub
parent 48cfb608aa
commit a7f52c300d
26 changed files with 1857 additions and 392 deletions
+308
View File
@@ -0,0 +1,308 @@
[简体中文](../zh/features/weight_update.md)
# Weight Clear and Update
FastDeploy supports dynamic weight clear and update for RL and RLHF rollout services. This capability is primarily intended to address the following two requirements:
- release GPU memory when the rollout engine is idle;
- refresh inference weights after the trainer produces a new checkpoint, without restarting the whole service.
This page describes the weight-control interfaces currently supported by FastDeploy, the semantics of each interface, and their typical usage in RLHF training.
## Prerequisites
In RLHF scenarios, FastDeploy mainly provides this capability through the online serving mode. Dynamic weight loading must be enabled when starting the service:
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/model \
--dynamic-load-weight \
--load_strategy ipc_snapshot
```
`--dynamic-load-weight` enables dynamic weight control, and `--load_strategy` specifies the concrete weight update mechanism. The currently supported update modes are listed below:
| Mode | `load_strategy` | Typical use | Notes |
| --- | --- | --- | --- |
| CUDA IPC | `ipc` | Training and inference processes on the same node share live tensors | Update source comes from IPC metadata produced by the training side. |
| IPC snapshot | `ipc_snapshot` | Rollout reloads a snapshot file produced by training | Used by current RL rollout examples. |
| RDMA / rsync | `rsync` | Trainer publishes a new version and rollout fetches it remotely | `POST /v1/update_weights` is the explicit API for this mode. |
## API Overview
### Compatibility APIs
In FastDeploy <= 2.5, the following simplified APIs are provided for compatibility with the legacy RL control flow.
| API | Method | Meaning | Availability |
| --- | --- | --- | --- |
| `/clear_load_weight` | `GET` | Clear or offload currently loaded weights | Requires `dynamic_load_weight=True` |
| `/update_model_weight` | `GET` | Reload weights after a clear/offload operation | Requires `dynamic_load_weight=True` |
### V1 control APIs
In FastDeploy >= 2.6, the underlying control-signal communication path is optimized and V1 control APIs are introduced. Compared with the legacy APIs, the V1 APIs provide a more stable execution path, clearer semantics, and more flexible control:
| API | Method | Request params | Semantics |
| --- | --- | --- | --- |
| `/v1/pause` | `POST` | none | Pause request generation, abort running and inflight requests, reset scheduler state, and pause cache transfer if enabled. |
| `/v1/resume` | `POST` | none | Resume request generation and cache transfer. |
| `/v1/is_paused` | `GET` | none | Return `{"is_paused": bool}`. |
| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | Offload selected GPU memory objects. Supported tags are `weight` and `kv_cache`. If omitted, both are used. |
| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | Reload previously offloaded weights and/or KV cache. On success, the engine resumes automatically. |
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. |
### Compatibility Notes
The optimized communication path also applies to the legacy APIs. By setting `FD_ENABLE_V1_UPDATE_WEIGHTS=1`, the legacy APIs can be switched to the new control path while keeping the original API form.
- `FD_ENABLE_V1_UPDATE_WEIGHTS=0`: use the legacy shared-memory-based control path.
- `FD_ENABLE_V1_UPDATE_WEIGHTS=1`: `/clear_load_weight` is effectively handled through `/v1/sleep`, and `/update_model_weight` is effectively handled through `/v1/wakeup`. The corresponding pause/resume actions are handled internally by `sleep` and `wakeup`.
**Note**: regardless of whether V1 is enabled, the legacy APIs are not the recommended standard interface for RLHF scenarios and may be gradually deprecated in future releases. The `/v1/*` control APIs are recommended.
## Interface Semantics
### `/v1/pause`
`/v1/pause` is the safe boundary before changing model state.
It does the following:
- stops new request generation;
- aborts running and inflight requests;
- resets scheduler state;
- pauses cache transfer when multi-level cache or KV cache storage is enabled.
When a clear boundary is required between one rollout round and the next training stage, this API should be called first.
### `/v1/sleep`
`/v1/sleep` offloads selected runtime state from GPU memory.
Supported tags:
- `weight`: clear model weights from device memory; if enabled, communication groups and DeepEP buffers may also be released.
- `kv_cache`: clear KV cache; MTP cache is also cleared when speculative decoding uses MTP.
If the `tags` parameter is omitted, FastDeploy defaults to:
```bash
/v1/sleep?tags=weight,kv_cache
```
In the current implementation, `sleep` automatically performs a `pause` first. New integrations should not rely on this implicit behavior.
### `/v1/wakeup`
`/v1/wakeup` restores the state offloaded by `/v1/sleep`.
Depending on tags and configuration, FastDeploy may:
- restart communication groups;
- recreate DeepEP buffers;
- reload model weights from the configured source;
- rebuild KV cache;
- recapture CUDA Graph.
After `wakeup` succeeds, FastDeploy automatically calls `resume`.
### `/v1/update_weights`
`/v1/update_weights` refreshes model parameters directly, without unloading the GPU memory occupied by model weights.
Current request fields:
- `version`: optional string. Used to choose a target checkpoint version.
- `rsync_config`: optional dictionary. Must contain `etcd_server` when provided.
Important semantics:
- the engine must already be paused, otherwise the request fails;
- the update is executed on workers only;
- this API is meant for explicit weight refresh, especially the `rsync` path;
- it does not implicitly call `resume`.
Recommended sequence:
1. `POST /v1/pause`
2. `POST /v1/update_weights`
3. `POST /v1/resume`
If GPU memory also needs to be reclaimed between rollout rounds, the `sleep` / `wakeup` workflow is more appropriate.
## Example Requests
### Basic APIs
Pause the engine:
```bash
curl -X POST http://127.0.0.1:8000/v1/pause
```
Resume the engine:
```bash
curl -X POST http://127.0.0.1:8000/v1/resume
```
### Sleep / Wakeup APIs
**Offload weights and KV cache**
```bash
# Offload both weights and KV cache
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache"
# Offload only weights
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight"
# Omit parameter, defaults to both
curl -X POST "http://127.0.0.1:8000/v1/sleep"
```
**Restore weights and KV cache**
```bash
# Restore both weights and KV cache
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache"
# Restore only weights
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight"
# Omit parameter, defaults to both
curl -X POST "http://127.0.0.1:8000/v1/wakeup"
```
**Note**: When `use_cudagraph=True`, KV cache must be restored before weights. This means `/v1/wakeup` with the `kv_cache` tag must be called before calling `/v1/wakeup` with the `weight` tag. If weights are restored without KV cache, an error will be raised. It is recommended to keep the `tags` parameter consistent between `/v1/sleep` and `/v1/wakeup`.
### Update Weights API
Refresh to a new remotely published version:
```bash
curl -X POST http://127.0.0.1:8000/v1/update_weights \
-H "Content-Type: application/json" \
-d '{
"version": "global_step_1200",
"rsync_config": {
"etcd_server": "127.0.0.1:2379"
}
}'
```
## RLHF Usage
### Recommended Rollout Service Setup
In RLHF scenarios, FastDeploy rollout services are typically configured as follows:
- `dynamic_load_weight=True`
- `load_strategy=ipc_snapshot` for local snapshot-based refresh;
- or `load_strategy=rsync` for versioned remote refresh.
The rollout utilities in the repository already follow this pattern. A typical example is:
```python
from fastdeploy.rl.rollout_config import RolloutModelConfig
from fastdeploy.rl.rollout_model import RolloutModel
rollout_config = RolloutModelConfig(
model_name_or_path=model_path,
tensor_parallel_size=ranks,
dynamic_load_weight=True,
load_strategy="ipc_snapshot",
)
rollout_model = RolloutModel(rollout_config)
```
### Training-Side Integration Support
In addition to serving endpoints, FastDeploy provides the following training-side integration capabilities for RLHF:
- `RolloutModel.state_dict()`: exposes the rollout-side inference parameters.
- `RolloutModel.get_name_mappings_to_training()`: exposes the mapping from inference parameter names to training parameter names.
These interfaces can be used to align training checkpoints with rollout-side parameter layouts, especially when inference-side and training-side parameter names are not fully identical.
### Common RLHF workflows
The following examples assume the service endpoint is `http://127.0.0.1:8000`.
**Workflow 1: clear and restore**
This workflow is suitable when the rollout service stays resident, but GPU memory should be released before training and restored afterward. The recommended sequence is `(pause) -> sleep -> wakeup -> (resume)`, where the steps in parentheses are optional.
```bash
# Optional: explicitly pause the engine to establish a clear transition boundary
curl -X POST http://127.0.0.1:8000/v1/pause
# Offload both weights and KV cache
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache"
# Restore both weights and KV cache after training completes
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache"
# Optional: explicitly resume if required by the integration
curl -X POST http://127.0.0.1:8000/v1/resume
```
**Workflow 2: in-place refresh to a new checkpoint**
This workflow is suitable when the service remains resident and only needs to switch to a new checkpoint version. The recommended sequence is `pause -> update_weights -> resume`.
```bash
# Pause the engine first
curl -X POST http://127.0.0.1:8000/v1/pause
# Refresh to a new checkpoint version in place
curl -X POST http://127.0.0.1:8000/v1/update_weights \
-H "Content-Type: application/json" \
-d '{
"version": "global_step_1200",
"rsync_config": {
"etcd_server": "127.0.0.1:2379"
}
}'
# Resume the service after the update completes
curl -X POST http://127.0.0.1:8000/v1/resume
```
**Workflow 3: legacy compatibility APIs**
Legacy RL clients can continue to use the compatibility flow `clear_load_weight -> update_model_weight`.
```bash
# Clear or offload the current weights
curl -X GET http://127.0.0.1:8000/clear_load_weight
# Reload weights after the trainer updates the checkpoint
curl -X GET http://127.0.0.1:8000/update_model_weight
```
For new integrations, the `/v1/*` APIs are recommended because their control path is more explicit and easier to trace.
## Other Related Configuration
### Communication Group Clear and Rebuild
FastDeploy provides `--shutdown-comm-group-if-worker-idle` and `--no-shutdown-comm-group-if-worker-idle` to explicitly control whether communication groups should also be torn down when weights are offloaded.
Keeping communication groups alive generally improves the stability of weight clearing and reloading. The tradeoff is that more GPU memory remains allocated after weight offload, and the execution time of `sleep` / `wakeup` may also increase.
By default:
- in EP scenarios, communication groups are kept;
- in non-EP scenarios, communication groups are torn down.
### CPU Cache Clear and Rebuild
After `--swap-space` is enabled, the following environment variable can be used to control whether CPU-side cache should also be cleared when `/v1/sleep` is executed, in order to reduce memory pressure during training.
By default, FastDeploy does not actively clear CPU cache. To clear it together with `sleep`, set:
```bash
export FD_ENABLE_SWAP_SPACE_CLEARING=1
```
+307
View File
@@ -0,0 +1,307 @@
[English](../../features/weight_update.md)
# 权重清除与更新
FastDeploy 支持面向 RL / RLHF Rollout 服务的动态权重清除、显存卸载和权重更新,主要用于解决以下两类问题:
- Rollout 引擎空闲时释放 GPU 显存;
- Trainer 产出新 checkpoint 后,推理服务在不重启进程的情况下切换到新权重。
本文档介绍 FastDeploy 当前支持的权重控制接口、各接口的语义,以及它们在 RLHF 训练中的典型用法。
## 前置条件
在 RLHF 场景下,FastDeploy 主要通过在线服务模式提供该能力。启动服务时,需要开启动态权重加载:
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/model \
--dynamic-load-weight \
--load-strategy ipc_snapshot
```
`--dynamic-load-weight` 用于开启动态权重控制能力,`--load-strategy` 用于指定具体的权重更新方式。当前支持的更新模式如下:
| 模式 | `load_strategy` | 典型场景 | 说明 |
| --- | --- | --- | --- |
| CUDA IPC | `ipc` | 训练进程与推理进程同机,直接共享实时张量 | 更新来源是训练侧产出的 IPC 元信息。 |
| IPC 快照 | `ipc_snapshot` | Rollout 从训练侧产出的权重快照文件重载 | 当前仓库里的 RL rollout 示例主要使用该模式。 |
| RDMA / rsync | `rsync` | Trainer 发布新版本,Rollout 远端拉取 | `POST /v1/update_weights` 是这一模式的显式接口。 |
## 接口说明
### 旧版接口
在 FastDeploy <= 2.5 版本中,主要提供以下简化接口,保留给旧版 RL 控制流使用。
| 接口 | 方法 | 含义 | 可用条件 |
| --- | --- | --- | --- |
| `/clear_load_weight` | `GET` | 清除或卸载当前已加载权重 | 需要 `dynamic_load_weight=True` |
| `/update_model_weight` | `GET` | 在清除/卸载后重新加载权重 | 需要 `dynamic_load_weight=True` |
### V1 新版接口
在 FastDeploy >= 2.6 版本中,底层控制信号通信链路经过优化,并引入了 V1 控制接口。相较于旧版接口,V1 接口在通信与执行链路上更稳定,语义更清晰,同时提供了更灵活的控制方式,包括以下接口:
| 接口 | 方法 | 请求参数 | 语义 |
| --- | --- | --- | --- |
| `/v1/pause` | `POST` | 无 | 暂停请求生成,中断 running/inflight 请求,重置调度器,并在开启 cache transfer 时暂停 cache transfer。 |
| `/v1/resume` | `POST` | 无 | 恢复请求生成和 cache transfer。 |
| `/v1/is_paused` | `GET` | 无 | 返回 `{"is_paused": bool}`。 |
| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | 卸载指定 GPU 内存对象。支持 `weight``kv_cache`;不传时默认同时处理两者。 |
| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | 重新加载之前被卸载的权重和/或 KV Cache。成功后会自动 `resume`。 |
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 |
### 兼容性说明
底层通信链路的优化同样适用于旧版接口。通过设置环境变量 `FD_ENABLE_V1_UPDATE_WEIGHTS=1`,可以将旧版接口切换到新的控制链路,在保留兼容接口形式的同时,获得更明确的执行路径和更好的可观测性。
- `FD_ENABLE_V1_UPDATE_WEIGHTS=0`:走旧版基于共享内存的控制链路。
- `FD_ENABLE_V1_UPDATE_WEIGHTS=1``/clear_load_weight` 底层等价于执行 `/v1/sleep``/update_model_weight` 底层等价于执行 `/v1/wakeup`。对应的 pause/resume 动作分别由 `sleep``wakeup` 内部处理。
**注意**:无论是否设置 V1 环境变量,旧版接口都不是 RLHF 场景下推荐的标准使用方式,后续版本中也可能逐步废弃。建议优先使用 `/v1/*` 控制接口。
## 各接口语义
### `/v1/pause`
`/v1/pause` 是变更模型状态前的安全边界。
它会执行以下动作:
- 停止新请求生成;
- 中断当前 running 和 inflight 请求;
- 重置调度器状态;
- 在启用多级缓存或 KV Cache 存储时暂停 cache 传输。
如果需要在每一轮 rollout 与下一轮训练之间建立清晰的切换边界,建议先调用该接口。
### `/v1/sleep`
`/v1/sleep` 用于从 GPU 显存中卸载指定的运行时状态。
当前支持的 `tags`
- `weight`:清除设备上的模型权重;如果开启了相关配置,还可能一并释放通信组和 DeepEP buffer。
- `kv_cache`:清除 KV Cache;如果投机解码采用 MTP,还会同步清理 MTP cache。
如果不传 `tags` 参数,FastDeploy 默认等价于:
```bash
/v1/sleep?tags=weight,kv_cache
```
当前实现中,`sleep` 会自动先执行一次 `pause`。但新的接入方不应长期依赖这一隐式行为。
### `/v1/wakeup`
`/v1/wakeup` 用于恢复通过 `/v1/sleep` 卸载的状态。
根据 `tags` 和运行配置,FastDeploy 可能执行:
- 重建通信组;
- 重建 DeepEP buffer
- 从当前配置的数据源重新加载模型权重;
- 重建 KV Cache
- 重新捕获 CUDA Graph。
`wakeup` 成功后,FastDeploy 会自动调用一次 `resume`
### `/v1/update_weights`
`/v1/update_weights` 用于在不卸载模型权重显存占用的情况下,直接刷新模型参数。
当前支持的请求字段:
- `version`:可选字符串,用于指定目标 checkpoint 版本。
- `rsync_config`:可选字典;如果传入,必须包含 `etcd_server`
关键语义:
- 调用前引擎必须已经处于暂停状态,否则请求会失败;
- 实际更新动作只在 worker 侧执行;
- 该接口主要用于显式权重刷新,尤其是 `rsync` 路径;
- 它不会自动执行 `resume`
推荐调用顺序:
1. `POST /v1/pause`
2. `POST /v1/update_weights`
3. `POST /v1/resume`
如果除更新权重外,还希望在 rollout 轮次之间回收 GPU 显存,则更适合使用 `sleep` / `wakeup` 组合。
## 请求示例
### 基础接口
暂停引擎:
```bash
curl -X POST http://127.0.0.1:8000/v1/pause
```
恢复引擎:
```bash
curl -X POST http://127.0.0.1:8000/v1/resume
```
### Sleep / Wakeup 接口
**卸载权重和 KV Cache**
```bash
# 同时卸载权重和 KV Cache
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache"
# 只卸载权重
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight"
# 不传参数,默认同时卸载两者
curl -X POST "http://127.0.0.1:8000/v1/sleep"
```
**恢复权重和 KV Cache**
```bash
# 恢复权重和 KV Cache
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache"
# 只恢复权重
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight"
# 不传参数,默认同时恢复两者
curl -X POST "http://127.0.0.1:8000/v1/wakeup"
```
**注意**:当 `use_cudagraph=True` 时,必须先恢复 KV Cache 再恢复权重。这意味着调用 `/v1/wakeup` 时如果只包含 `weight` tag 而不包含 `kv_cache`,会报错。建议保持 sleep 和 wakeup 的 tags 参数一致。
### Update Weights 接口
切换到远端发布的新版本权重:
```bash
curl -X POST http://127.0.0.1:8000/v1/update_weights \
-H "Content-Type: application/json" \
-d '{
"version": "global_step_1200",
"rsync_config": {
"etcd_server": "127.0.0.1:2379"
}
}'
```
## 在 RLHF 中如何使用
### 推荐的 Rollout 服务配置
在 RLHF 场景下,FastDeploy 的 Rollout 服务通常采用以下配置:
- `dynamic_load_weight=True`
- `load_strategy=ipc_snapshot`,适合本地快照式刷新;
-`load_strategy=rsync`,适合远端版本化刷新。
仓库中的 RL rollout 工具已经按该方式接入。典型写法如下:
```python
from fastdeploy.rl.rollout_config import RolloutModelConfig
from fastdeploy.rl.rollout_model import RolloutModel
rollout_config = RolloutModelConfig(
model_name_or_path=model_path,
tensor_parallel_size=ranks,
dynamic_load_weight=True,
load_strategy="ipc_snapshot",
)
rollout_model = RolloutModel(rollout_config)
```
### FastDeploy 对训练侧的支持
除服务接口外,FastDeploy 还提供以下两类 RLHF 训练侧对接能力:
- `RolloutModel.state_dict()`:暴露 rollout 侧的推理参数;
- `RolloutModel.get_name_mappings_to_training()`:暴露推理参数名到训练参数名的映射关系。
这两个接口可用于训练侧 checkpoint 与 rollout 侧权重布局对齐,尤其适用于推理态和训练态参数命名不完全一致的场景。
### RLHF 常见工作流
以下给出 RLHF 场景下几种常见调用方式。示例中假设服务地址为 `http://127.0.0.1:8000`
**工作流 1:显存卸载与恢复**
适用于 rollout 服务常驻,但需要在训练阶段前后释放并恢复显存的场景。推荐流程为 `(pause) -> sleep -> wakeup -> (resume)`,其中括号内步骤为可选。
```bash
# 可选:先显式暂停引擎,建立清晰的切换边界
curl -X POST http://127.0.0.1:8000/v1/pause
# 卸载权重和 KV Cache
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache"
# 训练完成后恢复权重和 KV Cache
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache"
# 可选:如业务侧需要显式恢复,可手动调用
curl -X POST http://127.0.0.1:8000/v1/resume
```
**工作流 2:原地刷新到新 checkpoint**
适用于服务常驻、仅需要切换到新版本权重的场景。推荐流程为 `pause -> update_weights -> resume`
```bash
# 先暂停引擎
curl -X POST http://127.0.0.1:8000/v1/pause
# 原地刷新到新版本权重
curl -X POST http://127.0.0.1:8000/v1/update_weights \
-H "Content-Type: application/json" \
-d '{
"version": "global_step_1200",
"rsync_config": {
"etcd_server": "127.0.0.1:2379"
}
}'
# 更新完成后恢复服务
curl -X POST http://127.0.0.1:8000/v1/resume
```
**工作流 3:兼容旧版接口**
旧版 RL 客户端仍可继续使用兼容接口,流程为 `clear_load_weight -> update_model_weight`
```bash
# 清除或卸载当前权重
curl -X GET http://127.0.0.1:8000/clear_load_weight
# 训练侧完成 checkpoint 更新后,重新加载权重
curl -X GET http://127.0.0.1:8000/update_model_weight
```
对于新的接入方,建议优先使用 `/v1/*` 接口,因为其控制链路更显式,日志排查和故障定位也更直接。
## 其他相关配置
### 通信组的销毁与重建
FastDeploy 支持通过 `--shutdown-comm-group-if-worker-idle``--no-shutdown-comm-group-if-worker-idle`,显式控制在卸载权重时是否同时销毁通信组。
保留通信组通常有助于提升权重清除和重新加载过程中的稳定性;相应地,代价是卸载权重后仍会保留更多显存占用,同时 `sleep` / `wakeup` 的执行时间也可能更长。
默认情况下:
- 在 EP 场景下,默认不销毁通信组;
- 在非 EP 场景下,默认销毁通信组。
### CPU 缓存的清除与重建
启用 `--swap-space` 后,可以通过以下环境变量控制在执行 `/v1/sleep` 时,是否同步清理 CPU 侧缓存,以降低训练阶段的内存压力。
默认情况下,FastDeploy 不会主动清理 CPU Cache。如需在 `sleep` 时一并清理,可设置:
```bash
export FD_ENABLE_SWAP_SPACE_CLEARING=1
```
+1
View File
@@ -32,6 +32,7 @@ class CacheStatus(Enum):
CPU = 3
GPU2STORAGE = 4
STORAGE2GPU = 5
CTRL = -1
class BlockNode:
+231 -152
View File
@@ -15,6 +15,7 @@
"""
import argparse
import asyncio
import concurrent.futures
import gc
import json
@@ -35,7 +36,6 @@ from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTa
from fastdeploy.cache_manager.ops import (
cuda_host_alloc,
cuda_host_free,
memory_allocated,
set_data_ipc,
set_device,
share_external_data_,
@@ -49,7 +49,9 @@ from fastdeploy.cache_manager.transfer_factory import (
MooncakeStore,
)
from fastdeploy.config import CacheConfig, SpeculativeConfig
from fastdeploy.engine.request import ControlRequest, ControlResponse
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.platforms import current_platform
from fastdeploy.utils import console_logger, get_logger
@@ -96,7 +98,6 @@ def parse_args():
help="engine worker queue port",
)
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--ipc_suffix", type=str, default=None, help="engine pid")
parser.add_argument(
"--protocol",
type=str,
@@ -184,15 +185,17 @@ class CacheTransferManager:
self.key_prefix = ""
# extract other arg values
self.model_path = args.model_path
self.model_id = os.path.basename(args.model_path.rstrip("/"))
self.n_ranks = args.mp_num
self.rank = args.rank
self.device = args.device_id
self.num_layers = args.num_layers
self.ipc_suffix = args.ipc_suffix
self.create_cache_tensor = args.create_cache_tensor
self.local_data_parallel_id = args.local_data_parallel_id
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
self.cache_queue_port = args.cache_queue_port
paddle.set_default_dtype(args.default_dtype)
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
@@ -200,8 +203,10 @@ class CacheTransferManager:
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self.control_task_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.ctrl_output_queue = None
address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue(
@@ -209,7 +214,7 @@ class CacheTransferManager:
is_server=False,
num_client=args.mp_num,
client_id=self.rank,
local_data_parallel_id=args.local_data_parallel_id,
local_data_parallel_id=0,
)
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
@@ -231,11 +236,12 @@ class CacheTransferManager:
self.num_cpu_blocks = args.num_cpu_blocks
self._init_gpu_cache(args)
self._init_gpu_cache()
if self.num_cpu_blocks > 0:
self._init_cpu_cache(args)
self._init_cpu_cache()
if self.storage_backend_type is not None:
self._init_storage(args)
self._init_control()
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
@@ -287,7 +293,9 @@ class CacheTransferManager:
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
self.is_paused = False # transfer manager state
self.is_sleeping = False
self.inflight = 0 # number of inflight transfer tasks
self.inflight_tasks = {}
cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_transfer_inited_signal = IPCSignal(
@@ -299,6 +307,15 @@ class CacheTransferManager:
)
self.cache_transfer_inited_signal.value[self.rank] = 1
def _init_control(self):
dp_rank = self.local_data_parallel_id
tp_rank = self.rank
tp_size = self.n_ranks
cq_port = self.cache_queue_port
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_rank}_{cq_port}"
self.ctrl_output_queue = FMQ().queue(name, "producer")
logger.info(f"Init control output queue: {name} (producer)")
def _init_storage(self, args):
try:
# TODO: support cache scale for other backend
@@ -354,13 +371,19 @@ class CacheTransferManager:
raise ValueError(f"Invalid write policy: {args.write_policy}")
self.write_policy = args.write_policy
version_file_path = os.path.join(args.model_path, "version.yaml")
if os.path.exists(version_file_path):
self.key_prefix = get_key_prefix_from_version(version_file_path)
logger.info(f"The key_prefix of cache storage is {self.key_prefix}")
self._update_key_prefix()
logger.info("Initialize cache storage successfully")
def _update_key_prefix(self):
# use key_prefix to distinguish cache for different version of weight in rl
version_file_path = os.path.join(self.model_path, "version.yaml")
if os.path.exists(version_file_path):
self.key_prefix = get_key_prefix_from_version(version_file_path)
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
else:
logger.error(f"version.yaml not found at {version_file_path}")
def _init_storage_buffer(self, args):
"""
Initialize pinned memory buffer that can hold the cache for a longest request
@@ -409,20 +432,20 @@ class CacheTransferManager:
self.storage_value_scale_write_buffer = write_buffer + scale_buffer_total_bytes // 2
self.storage_backend.register_buffer(write_buffer, scale_buffer_total_bytes)
def _init_gpu_cache(self, args):
def _init_gpu_cache(self):
if not args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners or messagers to create kv cache.")
if not self.create_cache_tensor:
logger.info("Waiting for runners or messagers to create kv cache.")
while self.cache_ready_signal.value[self.rank] != 1:
time.sleep(0.1)
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
logger.info("OK! Stop waiting.")
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
cache_type = "uint8"
else:
cache_type = args.cache_dtype
cache_type = self.cache_dtype
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
logger.info("Initializing kv cache for all layers.")
set_device(self.device)
for i in range(self.num_layers + self.num_extra_layers):
# NOTE: num_extra_layer_gpu_blocks is usually equal to num_gpu_blocks
@@ -445,14 +468,12 @@ class CacheTransferManager:
self.value_cache_shape[2],
self.value_cache_shape[3],
]
if args.create_cache_tensor:
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
)
if self.create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(key_cache, key_name)
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
key_cache_scales = paddle.full(
shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]],
fill_value=0,
@@ -463,7 +484,7 @@ class CacheTransferManager:
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(val_cache, val_name)
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
value_cache_scales = paddle.full(
shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]],
fill_value=0,
@@ -471,13 +492,11 @@ class CacheTransferManager:
)
set_data_ipc(value_cache_scales, value_cache_scales_name)
else:
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
)
logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
key_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
key_cache_scales = share_external_data_(
key_cache_scales,
@@ -487,7 +506,7 @@ class CacheTransferManager:
)
if self.value_cache_shape:
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
value_cache_scales = share_external_data_(
value_cache_scales,
@@ -498,48 +517,65 @@ class CacheTransferManager:
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
self.gpu_cache_kvs[key_cache_scales_name] = key_cache_scales
self.gpu_cache_scales_k_tensors.append(self.gpu_cache_kvs[key_cache_scales_name])
if args.value_cache_shape:
if self.value_cache_shape:
self.gpu_cache_kvs[val_name] = val_cache
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
self.gpu_cache_kvs[value_cache_scales_name] = value_cache_scales
self.gpu_cache_scales_v_tensors.append(self.gpu_cache_kvs[value_cache_scales_name])
if args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
if self.create_cache_tensor:
self.cache_ready_signal.value[self.rank] = 1
while np.sum(self.cache_ready_signal.value) != self.n_ranks:
time.sleep(0.1)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}")
logger.info("GPU KV cache is initialized")
def _init_cpu_cache(self, args):
def _clear_gpu_cache(self):
if self.create_cache_tensor:
logger.debug("Waiting for gpu runner to unlink cuda ipc")
while self.cache_ready_signal.value[self.rank] != 0:
time.sleep(0.1)
logger.debug("Stop waiting! gpu runner has unlinked cuda ipc")
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
if hasattr(self, "gpu_cache_scales_k_tensors"):
self.gpu_cache_scales_k_tensors.clear()
if hasattr(self, "gpu_cache_scales_v_tensors"):
self.gpu_cache_scales_v_tensors.clear()
paddle.device.cuda.empty_cache()
else:
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
logger.debug("Successfully unlinked gpu caches cuda ipc")
self.cache_ready_signal.value[self.rank] = 0
while np.sum(self.cache_ready_signal.value) != 0:
time.sleep(0.1)
logger.info("All ranks cleared gpu caches")
def _init_cpu_cache(self):
if self.num_cpu_blocks == 0:
return
paddle.set_device("cpu")
key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3]
if args.value_cache_shape:
if self.value_cache_shape:
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
else:
value_cache_size = 0
cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype)
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size
if args.cache_dtype == "block_wise_fp8":
key_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * key_cache_size
value_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * value_cache_size
logger.info("Initializing swap space (cpu cache) for all layers.")
if self.cache_dtype == "block_wise_fp8":
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
scales_key_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
scales_value_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
)
if args.num_cpu_blocks == 0:
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
self.swap_space_ready_signal.value[self.rank] = 1
return
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.")
paddle.set_device("cpu")
scales_key_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
scales_value_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
self.k_dst_ptrs = []
self.v_dst_ptrs = []
self.k_scales_ptrs = []
@@ -550,22 +586,43 @@ class CacheTransferManager:
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}"
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
f"..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
)
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(scales_key_need_to_allocate_bytes)
self.k_scales_ptrs.append(self.cpu_cache_kvs[key_cache_scales_name])
if value_need_to_allocate_bytes > 0:
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(scales_value_need_to_allocate_bytes)
self.v_scales_ptrs.append(self.cpu_cache_kvs[value_cache_scales_name])
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
logger.info("Swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1
while np.sum(self.swap_space_ready_signal.value) != self.n_ranks:
time.sleep(0.1)
logger.info("All ranks init cpu caches")
def _clear_cpu_cache(self):
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
cuda_host_free(ptrs)
self.cpu_cache_kvs.clear()
self.k_dst_ptrs.clear()
self.v_dst_ptrs.clear()
if hasattr(self, "k_scales_ptrs"):
self.k_scales_ptrs.clear()
if hasattr(self, "v_scales_ptrs"):
self.v_scales_ptrs.clear()
gc.collect()
self.swap_space_ready_signal.value[self.rank] = 0
while np.sum(self.swap_space_ready_signal.value) != 0:
time.sleep(0.1)
logger.info("All ranks cleared cpu caches")
def _run_read_storage(
self,
task_id: str,
@@ -1023,6 +1080,79 @@ class CacheTransferManager:
logger.debug(f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
logger.info(f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
def _handle_pause(self):
if self.is_paused:
logger.info("💡 Cache transfer manager is already paused, no need to pause again!")
else:
self.pause()
logger.info("✅ Successfully paused transfer")
return True
def _handle_resume(self):
if not self.is_paused:
logger.info("💡 Cache transfer manager is not paused, no need to resume!")
else:
self.resume()
if self.storage_backend_type is not None:
self._update_key_prefix()
logger.info("✅ Successfully resumed transfer")
return True
def _handle_sleep(self):
if self.is_sleeping:
logger.info("💡 Cache transfer manager is already sleeping, no need to sleep again!")
else:
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._clear_cpu_cache()
self._clear_gpu_cache()
self.is_sleeping = True
logger.info("✅ Successfully fell asleep (offloaded caches)")
return True
def _handle_wakeup(self):
if not self.is_sleeping:
logger.info("💡 Cache transfer manager is not sleeping, no need to wakeup!")
else:
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache()
self._init_gpu_cache()
self.is_sleeping = False
logger.info("✅ Successfully wakeup (reload caches)")
return True
def control_task(self, task: ControlRequest):
method = task.get_method()
tags = task.args.get("tags", {})
logger.info(f"Received control task: {method}, tags: {tags}")
handlers = {
"pause": self._handle_pause,
"resume": self._handle_resume,
"sleep": self._handle_sleep,
"wakeup": self._handle_wakeup,
}
handler = handlers.get(method)
error_code = 200
error_message = "Success"
if handler:
try:
handler()
except Exception as e:
error_code = 500
error_message = f"Failed to execute {method}: {str(e)}"
logger.error(f"Error in control_task: {traceback.format_exc()}")
else:
error_code = 400
error_message = f"Unknown control method: {method}"
logger.warning(error_message)
self.cache_task_queue.barrier.wait()
resp = ControlResponse(task.request_id, error_code, error_message)
asyncio.run(self.ctrl_output_queue.put(resp))
logger.info(f"Put response into output queue {self.ctrl_output_queue.name}: {resp}")
def check_work_status(self, time_interval_threashold=envs.FD_CACHE_PROC_EXIT_TIMEOUT):
"""
Check the health of the model server by checking whether all workers are alive.
@@ -1042,8 +1172,13 @@ class CacheTransferManager:
return fn(*args)
finally:
self.inflight -= 1
logger.debug(f"submit_task: {fn.__name__} finished, args: {args}, current inflight: {self.inflight}")
self.inflight += 1
thread_pool.submit(inflight_task, task_fn, *args)
logger.debug(
f"submit_task: {task_fn.__name__} submitted to thread pool, args: {args}, current inflight: {self.inflight}"
)
def do_data_transfer(self):
"""
@@ -1054,6 +1189,7 @@ class CacheTransferManager:
max_errors = (
envs.FD_CACHE_PROC_ERROR_COUNT
) # After this many consecutive errors, check if the worker process exists.
is_paused = False
while True:
try:
@@ -1065,18 +1201,18 @@ class CacheTransferManager:
self.cache_task_queue.barrier0.reset()
# Ensure all ranks synchronically do one of the following things:
# (1) If rank#0 is paused, wait for a short time and check out rank#0 status again;
# (2) otherwise, all ranks are allowed to pull tasks from cache task queue
# (1) If rank#0 is paused, wait for inflight tasks to finish first, then only process control tasks;
# (2) otherwise, all ranks are allowed to pull all tasks from cache task queue
if self.cache_task_is_paused_signal.value[0] == 1:
# wait for inflight tasks to finish first
while self.inflight != 0:
time.sleep(0.1)
# mark the current rank as not having inflight tasks
self.cache_task_inflight_signal.value[self.rank] = 0
time.sleep(1)
continue
is_paused = True
else:
self.cache_task_inflight_signal.value[self.rank] = 1
is_paused = False
if self.rank == 0:
if not self.cache_task_queue.empty():
@@ -1087,12 +1223,16 @@ class CacheTransferManager:
self.cache_task_queue.barrier1.reset()
if self.cache_task_broadcast_signal.value[0] == 1:
self.inflight += 1
data, read_finish = self.cache_task_queue.get_transfer_task()
logger.debug(f"do_data_transfer: {data}")
if read_finish:
self.cache_task_broadcast_signal.value[0] = 0
event_type, event_args = data[0], data[1:]
# control task is the only task allowed to execute when loop is paused
if is_paused and event_type.value != CacheStatus.CTRL.value:
continue
if event_type.value == CacheStatus.SWAP2CPU.value:
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
self.submit_task(
@@ -1129,6 +1269,13 @@ class CacheTransferManager:
self.write_back_storage_task,
write_storage_task,
)
elif event_type.value == CacheStatus.CTRL.value:
control_task = event_args[0]
self.control_task_thread_pool.submit(
self.control_task,
control_task,
)
else:
if self.n_ranks > 1:
self.cache_task_queue.barrier2.wait()
@@ -1291,114 +1438,46 @@ class CacheTransferManager:
# TODO XPU support RL
if unset_data_ipc is None:
return
logger.info("[RL] Launch a thread to clear/restore kv cache when model weights are cleared/updated.")
logger.info(
"check_cache_status: Launch a thread to clear/restore kv cache when model weights are cleared/updated."
)
while True:
# handle cache clearing/restoring
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
try:
# wait for inflight transfer tasks to finish and pause transfer manager
# pause transfer
self.pause()
# clear cpu caches
logger.info("[RL] start clearing caches")
logger.debug("[RL] start clearing cpu caches")
# clear caches
logger.info("check_cache_status: start clearing caches")
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
paddle.set_device("cpu")
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
cuda_host_free(ptrs)
self.cpu_cache_kvs.clear()
self.k_dst_ptrs.clear()
self.v_dst_ptrs.clear()
if self.cache_dtype == "block_wise_fp8":
self.k_scales_ptrs.clear()
self.v_scales_ptrs.clear()
gc.collect()
logger.debug("[RL] successfully cleared cpu caches")
# reset swap_space_ready_signal
self.swap_space_ready_signal.value[self.rank] = 0
while np.sum(self.swap_space_ready_signal.value) != 0:
time.sleep(0.1)
logger.debug("[RL] all ranks cleared cpu caches")
else:
logger.debug("[RL] skip clearing cpu caches")
# clear gpu caches
logger.debug("[RL] start clearing gpu caches")
if args.create_cache_tensor:
logger.info("[RL] waiting for gpu runner to unlink cuda ipc")
while self.cache_ready_signal.value[self.rank] != 0:
time.sleep(0.1)
logger.info("[RL] stop waiting! gpu runner has unlinked cuda ipc")
paddle.set_device(f"gpu:{self.device}")
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
if self.cache_dtype == "block_wise_fp8":
self.gpu_cache_scales_k_tensors.clear()
self.gpu_cache_scales_v_tensors.clear()
paddle.device.cuda.empty_cache()
logger.debug("[RL] successfully cleared gpu caches")
else:
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
logger.debug("[RL] successfully unlinked gpu caches cuda ipc")
self.cache_ready_signal.value[self.rank] = 0
while np.sum(self.cache_ready_signal.value) != 0:
time.sleep(0.1)
logger.info("[RL] all ranks cleared caches!")
# reset kv_cache_status_signal
self._clear_cpu_cache()
self._clear_gpu_cache()
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
self._log_memory("after clearing caches")
self._log_memory("check_cache_status: after clearing caches")
except Exception as e:
logger.error(f"[RL] failed to clear caches: {e}")
logger.error(f"check_cache_status: failed to clear caches: {e}")
elif self.kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
assert args.splitwise_role == "mixed", "Only mixed mode supports updating cache."
try:
# restore cpu cache
logger.info("[RL] start restoring caches")
logger.debug("[RL] start restoring cpu caches")
logger.info("check_cache_status: start restoring caches")
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache(args)
logger.debug("[RL] successfully restored cpu caches")
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
time.sleep(0.1)
logger.debug("[RL] all ranks restored cpu caches")
else:
logger.debug("[RL] skip restoring cpu caches")
# restore gpu cache and set cache_ready_signal
logger.debug("[RL] start restoring gpu caches")
self._init_gpu_cache(args)
logger.debug("[RL] successfully restored gpu caches")
self._init_cpu_cache()
self._init_gpu_cache()
# update key prefix for kv cache backend
if self.storage_backend_type is not None:
# use key_prefix to distinguish cache for different version of weight in rl
version_file_path = os.path.join(args.model_path, "version.yaml")
assert os.path.exists(version_file_path), f"version.yaml not found at {version_file_path}"
self.key_prefix = get_key_prefix_from_version(version_file_path)
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
# wait for all ranks caches to be ready
while np.sum(self.cache_ready_signal.value) != args.mp_num:
time.sleep(0.1)
logger.info("[RL] all ranks restored caches!")
self._update_key_prefix()
# resume transfer
self.resume()
# set kv_cache_status_signal
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
self._log_memory("after restoring caches")
self._log_memory("check_cache_status: after restoring caches")
except Exception as e:
logger.error(f"[RL] failed to restore caches: {e}")
logger.error(f"check_cache_status: failed to restore caches: {e}")
time.sleep(0.1)
@@ -1407,11 +1486,11 @@ class CacheTransferManager:
self.cache_task_queue.pause_barrier.wait()
if self.rank == 0:
self.cache_task_queue.pause_barrier.reset()
logger.info("[RL] 🟠 wait for inflight transfer tasks to finish")
logger.info("pause: 🟠 wait for inflight transfer tasks to finish")
self.is_paused = True
while np.sum(self.cache_task_inflight_signal.value) != 0:
time.sleep(0.1)
logger.info("[RL] 🔴 pause transfer manager and stop do transfer tasks")
logger.info("pause: 🔴 pause transfer manager and stop do transfer tasks")
def resume(self):
if self.n_ranks > 1:
@@ -1421,7 +1500,7 @@ class CacheTransferManager:
self.is_paused = False
while np.sum(self.cache_task_inflight_signal.value) != self.n_ranks:
time.sleep(0.1)
logger.info("[RL] 🟢 resume transfer manager and start to do transfer tasks")
logger.info("resume: 🟢 resume transfer manager and start to do transfer tasks")
def _log_memory(self, context: str):
"""Log current GPU memory usage."""
@@ -1430,8 +1509,8 @@ class CacheTransferManager:
curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3)
curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3)
logger.warning(
f"GPU memory usage {context}:"
logger.info(
f"{context}: "
f"max_allocated: {max_alloc:.2f}GB "
f"max_reserved: {max_reserved:.2f}GB "
f"current_allocated: {curr_alloc:.2f}GB "
@@ -217,7 +217,7 @@ class PrefixCacheManager:
is_server=False,
num_client=tensor_parallel_size,
client_id=0,
local_data_parallel_id=self.local_data_parallel_id,
local_data_parallel_id=0,
)
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
@@ -293,7 +293,7 @@ class PrefixCacheManager:
else:
storage_arg_str = " "
if self.cache_config.swap_space or self.cache_config.kvcache_storage_backend:
if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend:
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
@@ -314,7 +314,6 @@ class PrefixCacheManager:
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --ipc_suffix {ipc_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
@@ -353,9 +352,8 @@ class PrefixCacheManager:
# Start additional threads
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
threading.Thread(target=self.recv_data_transfer_result, daemon=True).start()
if cache_config.enable_prefix_caching:
if cache_config.enable_prefix_caching and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
threading.Thread(target=self.clear_prefix_cache, daemon=True).start()
all_cache_processes = cache_messager_processes + cache_manager_processes
+3
View File
@@ -687,6 +687,9 @@ class ParallelConfig:
if self.shutdown_comm_group_if_worker_idle is None:
self.shutdown_comm_group_if_worker_idle = not self.use_ep
if self.shutdown_comm_group_if_worker_idle and envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
raise RuntimeError("shutdown_comm_group_if_worker_idle cannot be True when FD_ENABLE_V1_UPDATE_WEIGHTS=1")
# pd_disaggregation
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
+251 -48
View File
@@ -39,6 +39,8 @@ import zmq
from tqdm import tqdm
import fastdeploy.metrics.trace as tracing
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import FDConfig
from fastdeploy.engine.register_manager import RegisterManager
from fastdeploy.engine.request import (
ControlRequest,
@@ -84,7 +86,7 @@ class EngineService:
Base class containing common engine functionality
"""
def __init__(self, cfg, start_queue=True, use_async_llm=False):
def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False):
"""
Initializes the LLMEngine with the provided configuration.
@@ -104,14 +106,23 @@ class EngineService:
self.is_paused = False # pause request generation
self._pause_cond = threading.Condition()
self._ctrl_worker_output_queues = []
self._ctrl_output_queues = {}
self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict)
tp_size = cfg.parallel_config.tensor_parallel_size
dp_index = cfg.parallel_config.local_data_parallel_id
for rank in range(tp_size):
for tp_rank in range(tp_size):
# create worker control response queue
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
# create cache control response queue
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}"
self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
self.scheduler = cfg.scheduler_config.scheduler()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
@@ -1296,7 +1307,7 @@ class EngineService:
worker_pid = self.request_worker_map.pop(request_id, None)
try:
self.llm_logger.info(f"START run control method {request_id}: {method}")
self.llm_logger.info(f"Start to run control method {method}: {request_id}")
handler_name = f"_control_{method}"
handler = getattr(self, handler_name, None)
@@ -1308,13 +1319,13 @@ class EngineService:
return
result = handler(control_req)
self.llm_logger.info(f"SUCCESS run control method {method}.")
self.llm_logger.info(f"Successfully run control method {method}: {request_id} {result}")
succ_result = ControlResponse(request_id, 200, "Success", result)
data = [[succ_result]] if envs.ZMQ_SEND_BATCH_DATA else [succ_result]
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
except Exception as e:
error_msg = f"Failed run control method {method}: {str(e)}"
error_msg = f"Failed to run control method {method}: {request_id} {str(e)}"
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
@@ -1338,12 +1349,15 @@ class EngineService:
if self.cfg.scheduler_config.name != "local":
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
self.llm_logger.info("Start to pause request generation.")
with self._pause_cond:
if self.is_paused:
self.llm_logger.info("Pause Request Generation: already paused.")
self.llm_logger.info("Engine is already paused, no need to pause again.")
return
self.is_paused = True
self.llm_logger.info("Start Abort Running Requests")
self.llm_logger.info("Abort running requests.")
self.resource_manager.log_status()
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
@@ -1354,7 +1368,7 @@ class EngineService:
if count >= timeout * 1000:
break
if count >= timeout * 1000:
error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged"
error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
running_reqs = self.resource_manager.preempted_all()
@@ -1369,12 +1383,22 @@ class EngineService:
# abort inflight requests to user
inflight_requests = self.scheduler.get_inflight_requests()
self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests")
self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).")
for req in inflight_requests:
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
self._send_error_response(req.request_id, "Request is aborted since engine is paused.")
self.scheduler.reset()
# pause cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to pause cache transfer.")
pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause")
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
# Wait for cache_transfer responses
asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"]))
self.llm_logger.info("Successfully paused cache transfer.")
self.resource_manager.cache_manager.reset()
self.llm_logger.info("Successfully paused request generation.")
return None
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
@@ -1386,14 +1410,24 @@ class EngineService:
Args:
control_request: Control request object containing resume operation information
"""
self.llm_logger.info("START Resume Request Generation")
self.llm_logger.info("Start to resume request generation.")
with self._pause_cond:
if not self.is_paused:
self.llm_logger.info("Resume Request Generation: not paused.")
self.llm_logger.info("Engine is not paused, no need to resume.")
return None
self.is_paused = False
self._pause_cond.notify_all()
self.llm_logger.info("END Resume Request Generation")
# resume cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to resume cache transfer.")
resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume")
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request))
# Wait for cache_transfer responses
asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"]))
self.llm_logger.info("Successfully resumed cache transfer.")
self.llm_logger.info("Successfully resumed request generation.")
return None
def _control_is_paused(self, control_request: ControlRequest) -> bool:
@@ -1447,49 +1481,218 @@ class EngineService:
return responses
async def _wait_all_control_responses(self, request_id: str, timeout: int):
"""Wait for control responses from all workers with a global timeout.
This method concurrently waits for responses from all control workers
and enforces an overall timeout to avoid leaking pending tasks.
def _parse_tags(self, control_request: ControlRequest):
"""
timeout_ms = timeout * 1000
# Create one get() coroutine per worker output queue
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
try:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=timeout,
Parse tags from control request.
"""
allowed_tags = ["weight", "kv_cache"]
tags = control_request.args.get("tags", None)
if tags is None:
tags = ",".join(allowed_tags)
control_request.args["tags"] = tags
self.llm_logger.info(
f"Detected empty tags of request {control_request.request_id}, defaulting to tags: {tags}"
)
except asyncio.TimeoutError:
# Keep the error message consistent with previous behavior
raise Exception("Worker Update Weights Timeouted after 600s")
elif isinstance(tags, list):
tags = ",".join(tags)
for tag in tags.split(","):
if tag not in allowed_tags:
raise ValueError(f"Unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}")
return tags
def _control_sleep(self, control_request: ControlRequest):
"""
Offload gpu memory occupation for certain parts, e.g. weight, cache.
Args:
control_request: Control request object containing parameters for offloading memory
tags: list of tags to offload, supported values: ["weight", "cache"]
TODO: support different level of offloading, to provide options for release memory forever
or merely offloading to cpu memory for now.
"""
# Args check
tags = self._parse_tags(control_request)
control_request.args["tags"] = tags
# Make sure llm engine is paused.
self.llm_logger.warning(
"Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. "
"Please explicitly request to /pause the engine before /sleep."
)
self._control_pause(None)
# Determine which executors are needed for the sleep command
executors = set()
if "weight" in tags:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
if self.cfg.cache_config.enable_prefix_caching:
self.resource_manager.cache_manager.reset()
# Dispatch sleep request to executors
self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}")
self._dispatch_control_request(control_request, executors)
return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors))
def _control_wakeup(self, control_request: ControlRequest):
"""
Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache.
Args:
control_request: Control request object containing parameters for reloading memory
tags: list of tags to reload, supported values: ["weight", "kv_cache"]
"""
# Args check
tags = self._parse_tags(control_request)
control_request.args["tags"] = tags
# Determine which executors are needed for the wakeup command
executors = set()
if "weight" in tags:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
# Dispatch wakeup request to executors
self.llm_logger.info(f"Dispatch wakeup request to executors: {list(executors)}")
self._dispatch_control_request(control_request, executors)
result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 300, executors=executors))
# Resume the engine after wakeup
self._control_resume(None)
return result
def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]):
"""
Dispatch control requests to workers, cache managers or engine itself.
Args:
control_request: ControlRequest
executors: List
"""
if "worker" in executors:
self.engine_worker_queue.put_tasks(([control_request], 1))
if "cache_transfer" in executors:
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request))
return
async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None):
"""Wait for matching control responses from the selected executor queues.
This helper selects the control-response queues that belong to the requested
executors, then waits for all of them concurrently. Each queue gets a local
waiter that keeps reading until it sees the target request ID and stashes stale
responses into that queue's mailbox.
Args:
request_id: The control request ID that all returned responses must match.
timeout: Global timeout budget in seconds for the full multi-queue wait.
executors: Executor groups to wait for, for example `["worker"]` or
`["worker", "cache_transfer"]`. If `None`, waits for all control
response queues.
Returns:
A list of `response.result` values collected from all matched
`ControlResponse` objects. If no queue is selected, returns `None`.
Raises:
Exception: If the overall wait times out, or if any queue reports a non-200
control response or fails while waiting.
"""
def select_control_queues(executors: List[str] = None):
"""Select control response queues by executors."""
if executors is None:
return self._ctrl_output_queues
else:
queues = {}
for k, v in self._ctrl_output_queues.items():
if "w2e" in k and "worker" in executors:
queues[k] = v
elif "c2e" in k and "cache_transfer" in executors:
queues[k] = v
return queues
async def wait_one(queue_name: str, queue):
"""Wait until one queue returns a response for the current request_id."""
mailbox = self._ctrl_response_mailboxes[queue_name]
# Reuse a previously stashed response for this request before touching FMQ again.
cached_response = mailbox.pop(request_id, None)
if cached_response is not None:
self.llm_logger.info(f"Returning cached control response from {queue_name}.")
return cached_response
while True:
msg = await queue.get()
# Return if the response matches the control request
response: ControlResponse = msg.payload
if response.request_id == request_id:
self.llm_logger.info(f"Returning new control response from {queue_name}.")
return response
# Stash late responses from other control requests so they do not consume the
# current request's only read chance on this queue.
mailbox[response.request_id] = response
self.llm_logger.info(
f"Stashed old control response from {queue_name}. "
f"Expected request {request_id}, got request {response.request_id}"
)
# Select only the control response queues that belong to the requested executors.
queues = select_control_queues(executors)
if not queues:
self.llm_logger.info(f"No queues to wait for, executors: {executors}")
return
self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}")
# Each queue gets its own waiter, which will stash stale responses until it finds the
# target request ID for this control request.
tasks = {name: asyncio.create_task(wait_one(name, queue)) for name, queue in queues.items()}
done, pending = await asyncio.wait(tasks.values(), timeout=timeout)
if pending:
pending_names = [name for name, task in tasks.items() if task in pending]
done_names = [name for name, task in tasks.items() if task in done]
self.llm_logger.error(
f"Control request {request_id} execution timeout. "
f"Pending queues: {pending_names}, completed queues: {done_names}."
)
# Stop unfinished queue waiters so they do not outlive the control request.
for task in pending:
task.cancel()
await asyncio.gather(*pending, return_exceptions=True)
raise Exception(f"Control request {request_id} timed out after {timeout}s")
# Collect the results from all completed queues.
responses = []
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
if isinstance(msg, Exception):
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
raise Exception(f"Call Worker error: {repr(msg)}")
if msg is None:
# Preserve original semantics when no message is received
raise Exception("Worker Update Weights Timeouted after 600s")
response: ControlResponse = msg.payload
if response.request_id != request_id:
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
continue
for name, task in tasks.items():
try:
response = task.result()
except Exception as e:
self.llm_logger.error(f"Waiting for control response from {name} failed: {repr(e)}")
raise
if response.error_code != 200:
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
raise Exception(f"Call Worker error: {response.error_message}")
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
raise Exception(f"Error response from {name}: {response.error_message}")
responses.append(response.result)
return responses
def _call_worker(self, control_request: ControlRequest, timeout: int):
request_id = control_request.request_id
self.engine_worker_queue.put_tasks(([control_request], 1))
# Use a single asyncio.run() to concurrently wait for all worker responses.
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"]))
def _send_error_response(self, request_id, error_msg, error_code: int = 500, worker_pid=None):
self.llm_logger.error(
+24 -4
View File
@@ -595,11 +595,14 @@ class EngineClient:
return True, ""
async def run_control_method(self, request: ControlRequest):
api_server_logger.info(f"Start Run Control Method: {request}")
api_server_logger.info(f"Received control request: {request}")
req_dict = request.to_dict()
if envs.ZMQ_SEND_BATCH_DATA:
req_dict["zmq_worker_pid"] = self.worker_pid
self.zmq_client.send_json(req_dict)
if not self.enable_mm and not envs.ENABLE_V1_DATA_PROCESSOR:
self.zmq_client.send_json(req_dict)
else:
self.zmq_client.send_pyobj(req_dict)
request_id = request.request_id
dealer, response_queue = await self.connection_manager.get_connection(request_id)
if not envs.ZMQ_SEND_BATCH_DATA:
@@ -608,12 +611,29 @@ class EngineClient:
# todo: support user specified timeout. default 600s is enough for most control cases
response = await asyncio.wait_for(response_queue.get(), timeout=600)
response = ControlResponse.from_dict(response[0])
api_server_logger.info(f"End Run Control Method: {response}")
api_server_logger.info(f"Return control response: {response}")
return response
except asyncio.TimeoutError:
error_response = ControlResponse(request_id, 500, "Timeout waiting for control method response")
api_server_logger.error(f"Error Run Control Method: {error_response}")
api_server_logger.error(f"Control request timed out: {error_response}")
return error_response
except Exception as e:
import traceback
api_server_logger.error(f"Unknown error in control method: {str(e)}\n{traceback.format_exc()}")
error_response = ControlResponse(request_id, 500, str(e))
return error_response
def run_control_method_sync(self, request: ControlRequest, event_loop):
"""
Support running control methods by a synchronous caller.
NOTE: Since asyncio.Queue operations must occur in the same event loop,
this method bridges synchronous and asynchronous execution by running
the async run_control_method in the specified event loop.
"""
future = asyncio.run_coroutine_threadsafe(self.run_control_method(request), event_loop)
return future.result()
def is_workers_alive(self):
"""
+55 -4
View File
@@ -284,6 +284,7 @@ async def lifespan(app: FastAPI):
app.state.completion_handler = completion_handler
app.state.embedding_handler = embedding_handler
app.state.reward_handler = reward_handler
app.state.event_loop = asyncio.get_running_loop()
if llm_engine is not None and not isinstance(llm_engine, AsyncLLM):
llm_engine.engine.data_processor = engine_client.data_processor
@@ -406,6 +407,44 @@ async def is_paused(request: Request) -> Response:
return control_response.to_api_json_response()
@app.post("/v1/sleep")
async def sleep(request: Request) -> Response:
request_id = f"control-{uuid.uuid4()}"
# Support both JSON body and query parameter
if await request.body():
request_data = await request.json()
else:
# Extract query params
request_data = dict(request.query_params)
try:
control_request = ControlRequest(request_id, "sleep", request_data)
except TypeError as e:
return JSONResponse(status_code=400, content={"error": "Invalid parameter type", "message": str(e)})
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()
@app.post("/v1/wakeup")
async def wakeup(request: Request) -> Response:
request_id = f"control-{uuid.uuid4()}"
# Support both JSON body and query parameter
if await request.body():
request_data = await request.json()
else:
# Extract query params
request_data = dict(request.query_params)
try:
control_request = ControlRequest(request_id, "wakeup", request_data)
except TypeError as e:
return JSONResponse(status_code=400, content={"error": "Invalid parameter type", "message": str(e)})
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()
@app.post("/v1/update_weights")
async def update_weights(request: Request) -> Response:
request_id = f"control-{uuid.uuid4()}"
@@ -606,8 +645,14 @@ def update_model_weight(request: Request) -> Response:
update model weight
"""
if app.state.dynamic_load_weight:
status_code, msg = app.state.engine_client.update_model_weight()
return JSONResponse(content=msg, status_code=status_code)
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
request_id = f"control-{uuid.uuid4()}"
control_request = ControlRequest(request_id, "wakeup")
control_response = app.state.engine_client.run_control_method_sync(control_request, app.state.event_loop)
return control_response.to_api_json_response()
else:
status_code, msg = app.state.engine_client.update_model_weight()
return JSONResponse(content=msg, status_code=status_code)
else:
return JSONResponse(content={"error": "Dynamic Load Weight Disabled."}, status_code=404)
@@ -619,8 +664,14 @@ def clear_load_weight(request: Request) -> Response:
clear model weight
"""
if app.state.dynamic_load_weight:
status_code, msg = app.state.engine_client.clear_load_weight()
return JSONResponse(content=msg, status_code=status_code)
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
request_id = f"control-{uuid.uuid4()}"
control_request = ControlRequest(request_id, "sleep")
control_response = app.state.engine_client.run_control_method_sync(control_request, app.state.event_loop)
return control_response.to_api_json_response()
else:
status_code, msg = app.state.engine_client.clear_load_weight()
return JSONResponse(content=msg, status_code=status_code)
else:
return JSONResponse(content={"error": "Dynamic Load Weight Disabled."}, status_code=404)
@@ -122,12 +122,21 @@ class ChatResponseProcessor:
else:
self._audio_buffer[req_id] = [token_ids]
else:
yield self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
)
if self._is_async_processor:
response = await self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
)
else:
response = self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
include_stop_str_in_output=include_stop_str_in_output,
request=request,
)
yield response
elif stream:
decode_type = request_output["outputs"].get("decode_type", 0)
token_ids = request_output["outputs"]["token_ids"]
+6 -1
View File
@@ -184,9 +184,14 @@ class DealerConnectionManager:
self.request_num[request_id] -= 1
if self.request_num[request_id] == 0:
self._update_load(conn_index, -1)
else:
api_server_logger.warning(
f"request_id {request_id} not in request_map, available keys: {list(self.request_map.keys())}"
)
except Exception as e:
api_server_logger.error(f"Listener error: {str(e)}")
api_server_logger.error(f"Listener error: {str(e)}\n{traceback.format_exc()}")
break
api_server_logger.info(f"Listener loop ended for conn_index {conn_index}")
def _update_load(self, conn_index, delta):
"""Update connection load and maintain the heap"""
+5
View File
@@ -249,6 +249,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_MOE_PROB_IN_ADVANCE": lambda: bool(int(os.getenv("FD_MOE_PROB_IN_ADVANCE", "0"))),
# Whether to use batch send data in zmq
"ZMQ_SEND_BATCH_DATA": lambda: int(os.getenv("ZMQ_SEND_BATCH_DATA", "1")),
# Whether to enable v1 weight updating, which utilizes ZMQ/EngineWorkerQueue/EngineCacheQueue/FMQs
# to pass control requests and responses.
# When v1 is enabled, the legacy /clear_load_weight and /update_model_weight
# will adopt this new communication pattern.
"FD_ENABLE_V1_UPDATE_WEIGHTS": lambda: bool(int(os.getenv("FD_ENABLE_V1_UPDATE_WEIGHTS", "0"))),
}
@@ -58,6 +58,9 @@ class EngineCacheQueue:
client_id: Unique identifier for client instances
local_data_parallel_size: data parallel size
local_data_parallel_id: local data parallel id
TODO(liyonghua): Remove multi-DP initialization. Each DP will have its own cache queue.
"""
self.address: Tuple[str, int] = address
self.authkey: bytes = authkey
@@ -87,6 +90,7 @@ class EngineCacheQueue:
]
# Initialize barriers
self.barrier = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
self.barrier0_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
self.barrier1_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
self.barrier2_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
@@ -142,6 +146,7 @@ class EngineCacheQueue:
callable=lambda idx: self.transfer_task_done_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register("get_barrier", callable=lambda idx: self.barrier[idx])
QueueManager.register("get_barrier0", callable=lambda idx: self.barrier0_init[idx])
QueueManager.register("get_barrier1", callable=lambda idx: self.barrier1_init[idx])
QueueManager.register("get_barrier2", callable=lambda idx: self.barrier2_init[idx])
@@ -191,6 +196,7 @@ class EngineCacheQueue:
QueueManager.register("get_cache_sync_value")
QueueManager.register("get_transfer_task_lock")
QueueManager.register("get_transfer_task_done_lock")
QueueManager.register("get_barrier")
QueueManager.register("get_barrier0")
QueueManager.register("get_barrier1")
QueueManager.register("get_barrier2")
@@ -215,6 +221,7 @@ class EngineCacheQueue:
self.task_done_lock = self.manager.get_transfer_task_done_lock(self.local_data_parallel_id)
# Get barrier proxies
self.barrier = self.manager.get_barrier(self.local_data_parallel_id)
self.barrier0 = self.manager.get_barrier0(self.local_data_parallel_id)
self.barrier1 = self.manager.get_barrier1(self.local_data_parallel_id)
self.barrier2 = self.manager.get_barrier2(self.local_data_parallel_id)
@@ -264,7 +271,12 @@ class EngineCacheQueue:
def put_transfer_task(self, item):
"""
put swap task
Enqueue a cache transfer task (cpu/gpu swap task, read/write storage task)
or a control task (cache clearing/restoring).
The queue is shared by multiple clients. A task can be enqueued only after
the previous task has been read by all clients.
`task_sync_value` is used as a bitmask to track per-client read status.
"""
self.task_lock.acquire()
if 0 < self.task_sync_value.get() < self.total_num:
@@ -279,7 +291,11 @@ class EngineCacheQueue:
def get_transfer_task(self):
"""
get swap task
Get the current cache transfer task (cpu/gpu swap task, read/write storage task)
or control signal (cache clearing/restoring) from cache task queue.
Each client reads the same task once. The task is removed from the queue
only after all clients have read it, tracked by `task_sync_value`.
"""
data = None
read_finish = False
@@ -283,6 +283,9 @@ class ZmqServerBase(ABC):
has_result_handle = False
with self.mutex:
if req_id not in self.req_dict:
llm_logger.warning(
f"req_id '{req_id}' not in req_dict, caching response. Available req_ids: {list(self.req_dict.keys())}"
)
self.cached_results[req_id].append(data)
else:
has_result_handle = True
+57
View File
@@ -201,6 +201,39 @@ class DynamicWeightManager:
# step5: recapture cuda_graph
# step6: update weight status signal
def restart_communication_group(self):
if not self.first_load:
start_time = time.perf_counter()
paddle.distributed.restart_process_group()
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
logger.info(f"finish restarting communication groups! time cost: {time.perf_counter()-start_time:.3f}s")
def recreate_deepep_buffer(self):
if not self.first_load:
start_time = time.perf_counter()
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
DeepEPBufferManager.recreate_buffer()
# ep barrier
paddle.distributed.barrier(self.parallel_config.ep_group)
logger.info(f"finish recreating deepep buffer! time cost: {time.perf_counter()-start_time:.3f}s")
def reload_model_weights(self):
if not self.first_load:
start_time = time.perf_counter()
strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot,
"ipc": self._update_ipc,
}
if handler := strategy_handlers.get(self.load_config.load_strategy):
handler()
else:
raise ValueError(f"Unsupported strategy: {self.load_config.load_strategy}")
logger.info(f"finish reload model weights! time cost: {time.perf_counter()-start_time:.3f}s")
def _update_ipc_snapshot(self):
"""Update using IPC snapshot strategy for elastic recovery.
@@ -329,6 +362,30 @@ class DynamicWeightManager:
paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
def clear_deepep_buffer(self):
start_time = time.perf_counter()
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
DeepEPBufferManager.clear_buffer()
logger.info(f"finish clearing deepep buffer! time cost: {time.perf_counter()-start_time:.3f}s")
def clear_model_weight(self):
start_time = time.perf_counter()
for model in self.model_list:
for param in model.state_dict().values():
param._clear_data()
logger.info(f"finish clearing model weight! time cost: {time.perf_counter()-start_time:.3f}s")
def clear_communication_group(self):
start_time = time.perf_counter()
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier(self.parallel_config.tp_group)
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
logger.info(f"finish clearing communication groups! time cost: {time.perf_counter()-start_time:.3f}s")
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
"""Update model parameters from given state dictionary."""
if len(state_dict) == 0:
+15 -3
View File
@@ -702,12 +702,15 @@ def singleton(cls):
return get_instance
def print_gpu_memory_use(gpu_id: int, title: str) -> None:
def print_gpu_memory_use(title: str, gpu_id: int, device_id: int | None = None) -> None:
"""Print memory usage"""
import pynvml
if device_id is None:
device_id = gpu_id
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
@@ -724,7 +727,7 @@ def print_gpu_memory_use(gpu_id: int, title: str) -> None:
f"\n\tPaddle max memory Reserved(GiB): {paddle_max_reserved / 1024.0 / 1024.0 / 1024.0}",
f"\n\tPaddle max memory Allocated(GiB): {paddle_max_allocated / 1024.0 / 1024.0 / 1024.0}",
f"\n\tPaddle memory Reserved(GiB): {paddle_reserved / 1024.0 / 1024.0 / 1024.0}",
f"\n\tPaddle memory Allocated(GiB): {paddle_allocated / 1024.0 / 1024.0 / 1024.0}",
f"\n\tPaddle memory Allocated(GiB): {paddle_allocated / 1024.0 / 1024.0 / 1024.0}\n",
)
@@ -1326,3 +1329,12 @@ else:
register_op = do_nothing
register_custom_python_op = register_op
def all_gather_values(value: int | float | bool, group: paddle.distributed.communication.group.Group) -> list:
_type = type(value)
_local = paddle.to_tensor([value], dtype="float32")
_global = [paddle.zeros_like(_local) for _ in range(group.world_size)]
paddle.distributed.all_gather(_global, _local, group)
_results = [_type(t.item()) for t in _global]
return _results
+85
View File
@@ -54,6 +54,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.platforms import current_platform
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.utils import print_gpu_memory_use
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
if current_platform.is_iluvatar():
@@ -142,6 +143,9 @@ class GPUModelRunner(ModelRunnerBase):
self.cache_kvs_map: dict = {}
self.exist_prefill_flag = False
self.is_kvcache_sleeping = False
self.is_weight_sleeping = False
if self.speculative_decoding:
self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory()
self.output_token_num_event = paddle.device.cuda.Event()
@@ -288,6 +292,10 @@ class GPUModelRunner(ModelRunnerBase):
"""
return self.exist_prefill_flag
@property
def is_sleeping(self):
return self.is_weight_sleeping or self.is_kvcache_sleeping
def exist_decode(self):
"""
check whether decode stage exist
@@ -2673,6 +2681,83 @@ class GPUModelRunner(ModelRunnerBase):
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
def sleep(self, tags):
logger.info(f">>> start offloading memory, tags: {tags}")
start_time = time.perf_counter()
# Clear weights, deepep_buffer, cudagraph, etc.
if "weight" in tags.split(","):
if self.is_weight_sleeping:
logger.info("GPU model runner's weight is already sleeping, no need to sleep again!")
return
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
if self.fd_config.parallel_config.enable_expert_parallel:
self.dynamic_weight_manager.clear_deepep_buffer()
self.dynamic_weight_manager.clear_model_weight()
if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
self.dynamic_weight_manager.clear_communication_group()
self.is_weight_sleeping = True
# Clear KV cache
if "kv_cache" in tags.split(","):
if self.is_kvcache_sleeping:
logger.info("GPU model runner's kv cache is already sleeping, no need to sleep again!")
return
if self.spec_method == SpecMethod.MTP:
self.proposer.clear_mtp_cache()
self.clear_cache()
self.is_kvcache_sleeping = True
paddle.device.cuda.empty_cache()
logger.info(f"<<< finish offloading memory! time cost: {time.perf_counter()-start_time:.3f}s")
print_gpu_memory_use(f"After offloading memory [{tags}]", self.local_rank, self.device_id)
def wakeup(self, tags):
if tags == "weight" and self.use_cudagraph and self.is_kvcache_sleeping:
raise RuntimeError(
"Waking up [weight] alone is not supported when CUDA Graph is enabled, "
"as recapturing the graph requires the KV cache to be rebuilt first. "
"Please wake up [kv_cache] first."
)
logger.info(f">>> start reloading memory, tags: {tags}")
start_time = time.perf_counter()
# Reset share_inputs to restore tensor shapes and values
if self.spec_method == SpecMethod.MTP:
self.proposer.model_inputs.reset_model_inputs()
self.share_inputs.reset_share_inputs()
# Reinitialize KV cache
if "kv_cache" in tags.split(","):
if not self.is_kvcache_sleeping:
logger.info("GPU model runner's kv cache is not sleeping, no need to wakeup!")
return
if self.spec_method == SpecMethod.MTP:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
self.initialize_kv_cache()
self.is_kvcache_sleeping = False
# Reload weights, deepep_buffer, cudagraph, etc.
if "weight" in tags.split(","):
if not self.is_weight_sleeping:
logger.info("GPU model runner's weight is not sleeping, no need to wakeup!")
return
if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
self.dynamic_weight_manager.restart_communication_group()
if self.fd_config.parallel_config.enable_expert_parallel:
self.dynamic_weight_manager.recreate_deepep_buffer()
self.dynamic_weight_manager.reload_model_weights()
if self.use_cudagraph:
self.capture_model()
self.is_weight_sleeping = False
logger.info(f"<<< finish reloading memory! time cost: {time.perf_counter()-start_time:.3f}s")
print_gpu_memory_use(f"After reloading memory [{tags}]", self.local_rank, self.device_id)
def padding_cudagraph_inputs(self) -> None:
"""
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
+8
View File
@@ -196,6 +196,14 @@ class GpuWorker(WorkerBase):
"""update weights in place"""
return self.model_runner.update_weights(version, rsync_config)
def sleep(self, **kwargs) -> None:
"""Offload memory from GPU"""
return self.model_runner.sleep(**kwargs)
def wakeup(self, **kwargs) -> None:
"""Reload memory into GPU"""
return self.model_runner.wakeup(**kwargs)
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
+43 -21
View File
@@ -69,7 +69,7 @@ from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.model_executor.utils import v1_loader_support
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import get_logger, optional_type
from fastdeploy.utils import all_gather_values, get_logger, optional_type
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("worker_process", "worker_process.log")
@@ -172,6 +172,7 @@ class PaddleDisWorkerProc:
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
self.enable_overlap_schedule = self.scheduler_config.enable_overlap_schedule
self.cached_control_reqs = []
def init_control(self):
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
@@ -482,7 +483,7 @@ class PaddleDisWorkerProc:
# run eplb
self._run_eplb(tp_rank)
if self.fd_config.load_config.dynamic_load_weight:
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
@@ -504,7 +505,7 @@ class PaddleDisWorkerProc:
# Synchronize the signal set by tp_rank0 visiable to other workers
self._tp_barrier_wait() if tp_size > 1 else None
if self.fd_config.load_config.dynamic_load_weight:
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
if self.ranks > 1:
paddle.distributed.barrier()
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
@@ -581,22 +582,34 @@ class PaddleDisWorkerProc:
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
if self.parallel_config.use_ep:
self.cached_control_reqs.append(control_req)
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
else:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
# Count prefill requests in current batch
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
num_scheduled_requests = len(req_dicts)
scheduled_request_ids = [req.request_id for req in req_dicts]
logger.info(
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
f"max_occupied_batch_index: {max_occupied_batch_index}, "
f"num_scheduled_requests: {num_scheduled_requests}, "
f"scheduled_request_ids: {scheduled_request_ids}"
)
if len(req_dicts) > 0:
# Count prefill requests in current batch
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
num_scheduled_requests = len(req_dicts)
scheduled_request_ids = [req.request_id for req in req_dicts]
logger.info(
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
f"max_occupied_batch_index: {max_occupied_batch_index}, "
f"num_scheduled_requests: {num_scheduled_requests}, "
f"scheduled_request_ids: {scheduled_request_ids}"
)
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
# Let the ep group run control method synchronically
if self.parallel_config.use_ep:
pendings = all_gather_values(len(self.cached_control_reqs), self.parallel_config.ep_group)
if all([p > 0 for p in pendings]):
logger.info(f"Rank: {self.local_rank} Detected all ep ranks have pending control tasks.")
self.run_control_method(self.cached_control_reqs.pop(0))
if (
not self.parallel_config.use_ep
@@ -607,6 +620,15 @@ class PaddleDisWorkerProc:
time.sleep(0.001)
continue
# Check if worker is paused (V1 update weights flow)
if (
self.fd_config.load_config.dynamic_load_weight
and hasattr(self.worker.model_runner, "is_sleeping")
and self.worker.model_runner.is_sleeping
):
self._tp_barrier_wait() if tp_size > 1 else None
continue
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
start_execute_time = time.time()
@@ -737,14 +759,14 @@ class PaddleDisWorkerProc:
self.loaded_model_signal.value[0] = 1
def run_control_method(self, control_request: ControlRequest) -> None:
logger.info(f"Start run control request: {control_request}")
logger.info(f"Rank: {self.local_rank} Start to run control request: {control_request}")
request_id = control_request.request_id
method = control_request.method
kwargs = control_request.args
handler = getattr(self.worker, method, None)
if handler is None or not callable(handler):
error_msg = f"Rank-{self.local_rank}: Unknown control method {method}"
error_msg = f"Rank: {self.local_rank} Unknown control method {method}"
error_result = ControlResponse(request_id, 400, error_msg)
asyncio.run(self._ctrl_output.put(error_result))
return
@@ -753,11 +775,11 @@ class PaddleDisWorkerProc:
result = handler(**kwargs)
succ_result = ControlResponse(request_id, 200, "Success", result)
logger.info(
f"Rank-{self.local_rank} Success run control request: {control_request}, response: {succ_result}"
f"Rank: {self.local_rank} Successfully run control request: {control_request}, response: {succ_result}"
)
asyncio.run(self._ctrl_output.put(succ_result, shm_threshold=100 * 1024 * 1024))
except Exception as e:
error_msg = f"Rank-{self.local_rank} Failed run control method {method}: {str(e)}"
error_msg = f"Rank: {self.local_rank} Failed to run control method {method}: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
asyncio.run(self._ctrl_output.put(error_result))
@@ -106,7 +106,11 @@ class TestCacheTransferManager(unittest.TestCase):
# --------------------------
# mock IPCSignal
# --------------------------
patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=MagicMock())
class DummyIPCSignal:
def __init__(self, name, array, dtype, suffix, create=False):
self.value = array
patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=DummyIPCSignal)
patcher2.start()
self.addCleanup(patcher2.stop)
@@ -122,8 +126,8 @@ class TestCacheTransferManager(unittest.TestCase):
# --------------------------
self._orig_init_cpu_cache = CacheTransferManager._init_cpu_cache
self._orig_init_gpu_cache = CacheTransferManager._init_gpu_cache
patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None)
patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None)
patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None)
patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None)
patcher3.start()
patcher4.start()
self.addCleanup(patcher3.stop)
@@ -193,8 +197,8 @@ class TestCacheTransferManager(unittest.TestCase):
kvcache_storage_backend = "unknown"
with (
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
patch("fastdeploy.cache_manager.cache_transfer_manager.console_logger") as mock_console,
):
with self.assertRaises(NotImplementedError):
@@ -209,8 +213,8 @@ class TestCacheTransferManager(unittest.TestCase):
kvcache_storage_backend = "file"
with (
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
):
with self.assertRaises(ValueError):
CacheTransferManager(LocalArgs())
@@ -221,8 +225,8 @@ class TestCacheTransferManager(unittest.TestCase):
version_path = os.path.join(tmpdir, "version.yaml")
with open(version_path, "w", encoding="utf-8") as handle:
handle.write("version: RL-STEP03-20250101-uuid\n")
args.model_path = tmpdir
args.kvcache_storage_backend = None
self.manager.model_path = tmpdir
self.manager.kvcache_storage_backend = None
self.manager._init_storage(args)
self.assertEqual(self.manager.key_prefix, "RL-STEP03")
@@ -465,23 +469,21 @@ class TestCacheTransferManager(unittest.TestCase):
def __init__(self):
self.value = [0]
args = Args()
args.num_cpu_blocks = 0
self.manager.num_cpu_blocks = 0
self.manager.swap_space_ready_signal = DummySignal()
self._orig_init_cpu_cache(self.manager, args)
self._orig_init_cpu_cache(self.manager)
self.assertEqual(self.manager.swap_space_ready_signal.value[0], 1)
self.assertEqual(self.manager.swap_space_ready_signal.value[0], 0)
def test_init_cpu_cache_allocates_block_wise_fp8(self):
class DummySignal:
def __init__(self):
self.value = [0]
args = Args()
args.num_cpu_blocks = 2
args.cache_dtype = "block_wise_fp8"
args.value_cache_shape = "2,1,1,1"
self.manager.num_cpu_blocks = 2
self.manager.cache_dtype = "block_wise_fp8"
self.manager.has_cache_scale = True
self.manager.swap_space_ready_signal = DummySignal()
self.manager.value_cache_shape = [2, 1, 1, 1]
@@ -492,7 +494,7 @@ class TestCacheTransferManager(unittest.TestCase):
),
patch("fastdeploy.cache_manager.cache_transfer_manager.paddle.set_device"),
):
self._orig_init_cpu_cache(self.manager, args)
self._orig_init_cpu_cache(self.manager)
self.assertEqual(self.manager.swap_space_ready_signal.value[0], 1)
@@ -510,8 +512,8 @@ class TestCacheTransferManager(unittest.TestCase):
value_cache_shape = "2,1,1,1"
with (
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
patch("fastdeploy.cache_manager.cache_transfer_manager.MooncakeStore"),
patch.object(CacheTransferManager, "_init_storage_buffer"),
):
@@ -527,9 +529,8 @@ class TestCacheTransferManager(unittest.TestCase):
with (
patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"),
patch("fastdeploy.cache_manager.cache_transfer_manager.set_data_ipc") as mock_set_ipc,
patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0),
):
self._orig_init_gpu_cache(manager, LocalArgs())
self._orig_init_gpu_cache(manager)
self.assertEqual(mock_set_ipc.call_count, 4)
self.assertIn("key_caches_0_rank0.device0", manager.gpu_cache_kvs)
@@ -548,8 +549,8 @@ class TestCacheTransferManager(unittest.TestCase):
value_cache_shape = "2,1,1,1"
with (
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
patch("fastdeploy.cache_manager.cache_transfer_manager.MooncakeStore"),
patch.object(CacheTransferManager, "_init_storage_buffer"),
):
@@ -568,9 +569,8 @@ class TestCacheTransferManager(unittest.TestCase):
with (
patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"),
patch("fastdeploy.cache_manager.cache_transfer_manager.share_external_data_", side_effect=fake_share),
patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0),
):
self._orig_init_gpu_cache(manager, LocalArgs())
self._orig_init_gpu_cache(manager)
self.assertIn("key_cache_scales_0_rank0.device0", manager.gpu_cache_kvs)
@@ -583,8 +583,8 @@ class TestCacheTransferManager(unittest.TestCase):
value_cache_shape = "1,1,1,1"
with (
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
):
manager = CacheTransferManager(LocalArgs())
@@ -607,9 +607,8 @@ class TestCacheTransferManager(unittest.TestCase):
patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep),
patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"),
patch("fastdeploy.cache_manager.cache_transfer_manager.share_external_data_", side_effect=fake_share),
patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0),
):
self._orig_init_gpu_cache(manager, LocalArgs())
self._orig_init_gpu_cache(manager)
self.assertIn("key_caches_0_rank0.device0", manager.gpu_cache_kvs)
@@ -1160,6 +1159,14 @@ class TestCacheTransferManager(unittest.TestCase):
self.manager.inflight = 1
self.manager.cache_task_is_paused_signal = DummySignal(0)
self.manager.cache_task_inflight_signal = DummySignal(1)
self.manager.cache_task_broadcast_signal = DummySignal(0)
self.manager.cache_task_queue.empty.return_value = False
self.manager.cache_task_queue.get_transfer_task.return_value = (
(cache_transfer_manager.CacheStatus.CTRL, MagicMock()),
True,
)
self.manager.control_task_thread_pool = MagicMock()
self.manager.control_task_thread_pool.submit.side_effect = SystemExit
call_count = {"count": 0}
@@ -1168,7 +1175,6 @@ class TestCacheTransferManager(unittest.TestCase):
if call_count["count"] == 1:
self.manager.inflight = 0
return None
raise SystemExit
with patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep):
with self.assertRaises(SystemExit):
@@ -1273,6 +1279,7 @@ class TestCacheTransferManager(unittest.TestCase):
args = Args()
args.splitwise_role = "mixed"
args.create_cache_tensor = True
self.manager.create_cache_tensor = True
self.manager.kv_cache_status_signal = DummySignal(cache_transfer_manager.KVCacheStatus.CLEARING)
self.manager.cache_ready_signal = DummySignal(0)
self.manager.swap_space_ready_signal = DummySignal(0)
@@ -1293,7 +1300,10 @@ class TestCacheTransferManager(unittest.TestCase):
patch("paddle.device.cuda.empty_cache") as mock_empty,
patch("paddle.set_device"),
patch.object(self.manager, "_log_memory"),
patch("time.sleep", side_effect=maybe_stop_cleared_with_tensor),
patch(
"fastdeploy.cache_manager.cache_transfer_manager.time.sleep",
side_effect=maybe_stop_cleared_with_tensor,
),
):
with self.assertRaises(StopIteration):
self.manager.check_cache_status(args)
@@ -1427,6 +1437,7 @@ class TestCacheTransferManager(unittest.TestCase):
args = Args()
args.splitwise_role = "mixed"
args.mp_num = 2
self.manager.n_ranks = 2
self.manager.kv_cache_status_signal = DummySignal([cache_transfer_manager.KVCacheStatus.UPDATING])
self.manager.cache_ready_signal = DummySignal([0, 1])
self.manager.swap_space_ready_signal = DummySignal([0, 1])
@@ -1455,6 +1466,7 @@ class TestCacheTransferManager(unittest.TestCase):
patch.object(self.manager, "_init_cpu_cache"),
patch.object(self.manager, "_init_gpu_cache"),
patch.object(self.manager, "resume"),
patch.object(self.manager, "_update_key_prefix"),
patch("fastdeploy.cache_manager.cache_transfer_manager.envs.FD_ENABLE_SWAP_SPACE_CLEARING", True),
patch.object(self.manager, "_log_memory"),
patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep),
@@ -1465,7 +1477,7 @@ class TestCacheTransferManager(unittest.TestCase):
self.assertEqual(self.manager.kv_cache_status_signal.value[0], cache_transfer_manager.KVCacheStatus.NORMAL)
def test_log_memory_records_gpu_stats(self):
with patch.object(cache_transfer_manager.logger, "warning") as mock_warning:
with patch.object(cache_transfer_manager.logger, "info") as mock_info:
with (
patch("paddle.device.cuda.max_memory_allocated", return_value=1024**3),
patch("paddle.device.cuda.max_memory_reserved", return_value=2 * 1024**3),
@@ -1474,7 +1486,7 @@ class TestCacheTransferManager(unittest.TestCase):
):
self.manager._log_memory("test")
mock_warning.assert_called_once()
mock_info.assert_called_once()
def test_pause_and_resume_wait_for_signals(self):
class DummySignal:
@@ -1503,7 +1515,7 @@ class TestCacheTransferManager(unittest.TestCase):
self.assertFalse(self.manager.is_paused)
def test_submit_task_decrements_inflight(self):
def test_submit_task_decrements_inflight_on_task_error(self):
class DummyPool:
def submit(self, fn, *args):
try:
@@ -1514,10 +1526,26 @@ class TestCacheTransferManager(unittest.TestCase):
def raise_task():
raise RuntimeError("boom")
self.manager.inflight = 1
self.manager.inflight = 0
self.manager.submit_task(DummyPool(), raise_task)
self.assertEqual(self.manager.inflight, 0)
def test_submit_task_decrements_inflight_on_success(self):
class DummyPool:
def submit(self, fn, *args):
return fn(*args)
task_called = {"value": False}
def ok_task():
task_called["value"] = True
self.manager.inflight = 0
self.manager.submit_task(DummyPool(), ok_task)
self.assertTrue(task_called["value"])
self.assertEqual(self.manager.inflight, 0)
def test_main_invokes_manager(self):
cache_transfer_manager.args = Args()
with patch("fastdeploy.cache_manager.cache_transfer_manager.CacheTransferManager") as mock_manager:
@@ -427,7 +427,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
self.assertEqual(manager.get_required_block_num(8, 4), 2)
def test_launch_cache_manager_initializes_processes(self):
manager = _create_manager()
manager = _create_manager(num_cpu_blocks=1)
manager.cache_config.enable_hierarchical_cache = False
with (
@@ -637,7 +637,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
self.assertIsNone(processes)
def test_launch_cache_manager_formats_value_cache_shape(self):
manager = _create_manager()
manager = _create_manager(num_cpu_blocks=1)
captured = {}
+116 -65
View File
@@ -11,7 +11,9 @@ HOST="0.0.0.0"
PORT="${FD_API_PORT}" # 这里需要配合启动脚本那个URL PORT
BASE_URL="http://$HOST:$PORT"
TOTAL_ROUNDS=6
V0_ROUNDS=3
V1_ROUNDS=3
TOTAL_ROUNDS=$((V0_ROUNDS + V1_ROUNDS))
CHAT_REQUESTS_PER_ROUND=3
export CUDA_VISIBLE_DEVICES=0,1
MAX_MEMORY_MB=10240 # 10GB
@@ -48,7 +50,7 @@ assert_success() {
fi
}
# curl_get_status(url, options...) → returns via global variables http_code and response_body
# curl_get_status(url, options) → returns via global variables http_code and response_body
curl_get_status() {
local result
result=$(curl -s -w "%{http_code}" "$@")
@@ -56,6 +58,23 @@ curl_get_status() {
response_body="${result%???}"
}
post_json_and_assert() {
local url="$1"
local payload="${2:-}"
if [ -n "$payload" ]; then
curl_get_status -X POST "$url" -H "Content-Type: application/json" -d "$payload"
else
curl_get_status -X POST "$url"
fi
assert_eq "$http_code" "200" "$url failed with HTTP $http_code, body: $response_body"
}
get_and_assert() {
local url="$1"
curl_get_status "$url"
assert_eq "$http_code" "200" "$url failed with HTTP $http_code, body: $response_body"
}
# ====================================================
# Get visible GPU IDs from CUDA_VISIBLE_DEVICES
# ====================================================
@@ -79,104 +98,69 @@ check_gpu_memory() {
local gpu_ids
gpu_ids=($(get_visible_gpu_ids))
echo "========== GPU Memory Check =========="
echo "CUDA_VISIBLE_DEVICES = $CUDA_VISIBLE_DEVICES"
echo "MAX_MEMORY_MB = $MAX_MEMORY_MB"
echo "======================================"
echo "----------------------------------------"
echo " GPU Memory Check (MAX: ${MAX_MEMORY_MB}MB)"
echo "----------------------------------------"
if [ ${#gpu_ids[@]} -eq 0 ]; then
echo "Assertion failed: No valid GPU IDs in CUDA_VISIBLE_DEVICES='$CUDA_VISIBLE_DEVICES'" >&2
echo "ERROR: No valid GPU IDs in CUDA_VISIBLE_DEVICES='$CUDA_VISIBLE_DEVICES'" >&2
exit 1
fi
for gpu_id in "${gpu_ids[@]}"; do
echo
echo "---- GPU $gpu_id ----"
# Query summary
local summary
summary=$(nvidia-smi -i "$gpu_id" \
--query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu \
--format=csv,noheader,nounits) || {
echo "Failed to query GPU $gpu_id summary" >&2
echo "ERROR: Failed to query GPU $gpu_id" >&2
exit 1
}
# Parse fields
IFS=',' read -r idx name mem_total mem_used mem_free util <<< "$summary"
local used_ratio=$(( mem_used * 100 / mem_total ))
echo "GPU $idx: $name"
echo "Total Memory : ${mem_total} MB"
echo "Used Memory : ${mem_used} MB"
echo "Free Memory : ${mem_free} MB"
echo "GPU Util : ${util} %"
# Print GPU info (single line)
printf " GPU %s: %-35s Total:%5sMB Used:%5sMB (%s%%)\n" \
"$idx" "$name" "$mem_total" "$mem_used" "$util"
# --- Hard assertions ---
# Hard assertion
assert_true "$(( mem_used <= MAX_MEMORY_MB ))" \
"GPU $gpu_id memory.used ${mem_used} MB > MAX_MEMORY_MB ${MAX_MEMORY_MB} MB"
# --- Soft safety check: usage ratio ---
local used_ratio
used_ratio=$(( mem_used * 100 / mem_total ))
echo "Used Ratio : ${used_ratio} %"
# Usage ratio check
if [ "$used_ratio" -gt 90 ]; then
echo "Assertion failed: GPU $gpu_id memory usage > 90% (${used_ratio}%)" >&2
echo "ERROR: GPU $gpu_id memory usage > 90% (${used_ratio}%)" >&2
exit 1
fi
# --- Process-level attribution ---
echo "Processes on GPU $gpu_id:"
# Process info (compact format)
local proc_info
proc_info=$(nvidia-smi -i "$gpu_id" \
--query-compute-apps=pid,process_name,used_memory \
--format=csv,noheader,nounits)
if [ -z "$proc_info" ]; then
echo " (No active compute processes)"
else
if [ -n "$proc_info" ]; then
echo "$proc_info" | while IFS=',' read -r pid pname pmem; do
echo " PID=$pid NAME=$pname MEM=${pmem}MB"
printf " └─ PID=%-8s %-30s MEM=%4sMB\n" \
"$pid" "$pname" "$pmem"
done
fi
echo "GPU $gpu_id memory check PASSED"
done
echo "========== GPU Memory Check DONE =========="
echo "----------------------------------------"
}
# ====================================================
for round in $(seq 1 $TOTAL_ROUNDS); do
echo "=== Round $round / $TOTAL_ROUNDS ==="
# Step 1: Clear loaded weights
echo "[Step 1] Clearing load weight..."
curl_get_status -i "$BASE_URL/clear_load_weight"
assert_eq "$http_code" "200" "/clear_load_weight failed with HTTP $http_code"
sleep 10
# Step 2: Check GPU memory usage
echo "[Step 2] Checking GPU memory..."
check_gpu_memory
# Step 3: Update model weights
echo "[Step 3] Updating model weight..."
curl_get_status -i "$BASE_URL/update_model_weight"
assert_eq "$http_code" "200" "/update_model_weight failed with HTTP $http_code"
# Step 4: Send chat completion requests
echo "[Step 4] Sending $CHAT_REQUESTS_PER_ROUND chat completions..."
send_chat_requests() {
for i in $(seq 1 $CHAT_REQUESTS_PER_ROUND); do
echo " Request $i / $CHAT_REQUESTS_PER_ROUND"
# Send request and capture response
printf " └─ Sending chat request %d/%d...\n" "$i" "$CHAT_REQUESTS_PER_ROUND"
response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{"messages": [{"role": "user", "content": "Hello!"}]}')
# Extract the 'content' field from the response
content=$(echo "$response" | \
grep -o '"content":"[^"]*"' | \
head -1 | \
@@ -184,26 +168,93 @@ for round in $(seq 1 $TOTAL_ROUNDS); do
sed 's/"$//')
if [ -z "$content" ]; then
# Fallback: try extracting content using sed more robustly
content=$(echo "$response" | \
sed -n 's/.*"content":"\([^"]*\)".*/\1/p' | \
head -1)
fi
# Check if content is empty or null
if [ -z "$content" ] || [ "$content" = "null" ]; then
echo "Failed: Empty or null 'content' in response" >&2
echo "Raw response:" >&2
echo " ERROR: Empty or null 'content' in response" >&2
echo " Raw response:" >&2
echo "$response" >&2
exit 1
fi
echo "Received non-empty response"
echo -e "\n---\n"
printf " └─ Response received: %s\n" "$content"
done
}
echo "Round $round completed."
echo "==================================\n"
run_v0_round() {
local round="$1"
echo "========================================"
printf "Round %d/%d (V0)\n" "$round" "$TOTAL_ROUNDS"
echo "========================================"
echo ""
printf "[Step 1] %-30s " "Clearing load weight via v0..."
get_and_assert "$BASE_URL/clear_load_weight"
echo "[OK]"
sleep 10
printf "[Step 2] %-30s " "Checking GPU memory..."
echo ""
check_gpu_memory
echo ""
printf "[Step 3] %-30s " "Updating model weight via v0..."
get_and_assert "$BASE_URL/update_model_weight"
echo "[OK]"
echo "[Step 4] Sending $CHAT_REQUESTS_PER_ROUND chat completions"
send_chat_requests
echo ""
}
run_v1_round() {
local round="$1"
local sleep_payload='{"tags":"weight,kv_cache"}'
local wakeup_payload='{"tags":"weight,kv_cache"}'
echo "========================================"
printf "Round %d/%d (V1)\n" "$round" "$TOTAL_ROUNDS"
echo "========================================"
echo ""
printf "[Step 1] %-30s " "Pausing engine via v1..."
post_json_and_assert "$BASE_URL/v1/pause" ""
echo "[OK]"
printf "[Step 2] %-30s " "Sleeping via v1..."
post_json_and_assert "$BASE_URL/v1/sleep" "$sleep_payload"
echo "[OK]"
sleep 10
printf "[Step 3] %-30s " "Checking GPU memory..."
echo ""
check_gpu_memory
echo ""
printf "[Step 4] %-30s " "Waking up via v1..."
post_json_and_assert "$BASE_URL/v1/wakeup" "$wakeup_payload"
echo "[OK]"
printf "[Step 5] %-30s " "Resuming engine via v1..."
post_json_and_assert "$BASE_URL/v1/resume" ""
echo "[OK]"
echo "[Step 6] Sending $CHAT_REQUESTS_PER_ROUND chat completions"
send_chat_requests
echo ""
}
for round in $(seq 1 $V0_ROUNDS); do
run_v0_round "$round"
done
echo "All $TOTAL_ROUNDS rounds completed successfully."
for round in $(seq 1 $V1_ROUNDS); do
run_v1_round "$round"
done
echo "========================================"
printf "All %d rounds completed successfully.\n" "$TOTAL_ROUNDS"
echo "========================================"
+1
View File
@@ -11,6 +11,7 @@ FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633))
# List of ports to clean before and after tests
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT]
+136 -32
View File
@@ -20,7 +20,7 @@ import threading
import time
import types
import unittest
from unittest.mock import ANY, MagicMock, Mock, patch
from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch
import numpy as np
import paddle
@@ -1091,6 +1091,19 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
self.assertEqual(result, {"ok": True})
self._detach_finalizer(eng)
def test_control_update_weights_updates_cfg_version(self):
eng = self._make_mixed_engine()
eng.is_paused = True
eng._pause_cond = threading.Condition()
eng.cfg.model_config.version = "old-version"
eng._call_worker = Mock(return_value=[{"version": "new-version"}, {"ok": True}])
result = eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))
self.assertEqual(result, [{"version": "new-version"}, {"ok": True}])
self.assertEqual(eng.cfg.model_config.version, "new-version")
self._detach_finalizer(eng)
def test_control_pause_and_resume_paths(self):
eng = self._make_mixed_engine()
eng.is_paused = False
@@ -1155,12 +1168,50 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
async def get(self, timeout=None):
return Mock(payload=ControlResponse(request_id="req", result={"ok": True}, error_code=200))
eng._ctrl_worker_output_queues = [DummyQueue()]
eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": DummyQueue()}
result = eng._call_worker(ControlRequest(request_id="req", method="noop"), timeout=1)
self.assertEqual(result, [{"ok": True}])
eng.engine_worker_queue.put_tasks.assert_called_once()
self._detach_finalizer(eng)
def test_control_sleep_defaults_tags_and_dispatches_cache_transfer(self):
cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=4)
eng = self._make_engine(cfg)
eng.cfg.cache_config.num_cpu_blocks = 1
eng.engine_worker_queue = Mock()
eng.cache_task_queue = Mock()
eng.resource_manager.cache_manager.reset = Mock()
eng._control_pause = Mock()
eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}])
result = eng._control_sleep(ControlRequest(request_id="sleep", method="sleep", args={}))
self.assertEqual(result, [{"ok": True}])
eng._control_pause.assert_called_once_with(None)
eng.resource_manager.cache_manager.reset.assert_called_once()
eng.engine_worker_queue.put_tasks.assert_called_once()
eng.cache_task_queue.put_transfer_task.assert_called_once()
sleep_req = eng.engine_worker_queue.put_tasks.call_args.args[0][0][0]
self.assertEqual(sleep_req.args["tags"], "weight,kv_cache")
self._detach_finalizer(eng)
def test_control_wakeup_resumes_after_wait(self):
cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=4)
eng = self._make_engine(cfg)
eng.cfg.cache_config.num_cpu_blocks = 1
eng.engine_worker_queue = Mock()
eng.cache_task_queue = Mock()
eng._control_resume = Mock()
eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}])
result = eng._control_wakeup(ControlRequest(request_id="wakeup", method="wakeup", args={"tags": "kv_cache"}))
self.assertEqual(result, [{"ok": True}])
eng.engine_worker_queue.put_tasks.assert_called_once()
eng.cache_task_queue.put_transfer_task.assert_called_once()
eng._control_resume.assert_called_once_with(None)
self._detach_finalizer(eng)
def test_control_update_weights_requires_pause(self):
eng = self._make_mixed_engine()
eng.is_paused = False
@@ -1940,68 +1991,121 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
mock_logger.warning.assert_called()
self._detach_finalizer(eng)
def test_wait_all_control_responses_success(self):
def test_wait_for_control_responses_success(self):
eng = self._make_mixed_engine()
eng._ctrl_worker_output_queues = [
self._make_ctrl_queue("q0", Mock(request_id="req", error_code=200, result={"ok": True})),
self._make_ctrl_queue("q1", Mock(request_id="req", error_code=200, result={"ok": True})),
]
eng._ctrl_output_queues = {
"ctrl_w2e_rank0_6778": self._make_ctrl_queue(
"q0", Mock(request_id="req", error_code=200, result={"ok": True})
),
"ctrl_w2e_rank1_6778": self._make_ctrl_queue(
"q1", Mock(request_id="req", error_code=200, result={"ok": True})
),
}
results = asyncio.run(eng._wait_all_control_responses("req", timeout=1))
results = asyncio.run(eng._wait_for_control_responses("req", timeout=1))
self.assertEqual(results, [{"ok": True}, {"ok": True}])
self._detach_finalizer(eng)
def test_wait_all_control_responses_ignores_mismatch(self):
def test_wait_for_control_responses_filters_executors(self):
eng = self._make_mixed_engine()
eng._ctrl_worker_output_queues = [
self._make_ctrl_queue("q0", Mock(request_id="old", error_code=200, result={"ok": False})),
self._make_ctrl_queue("q1", Mock(request_id="req", error_code=200, result={"ok": True})),
]
eng._ctrl_output_queues = {
"ctrl_w2e_rank0_6778": self._make_ctrl_queue(
"worker", Mock(request_id="req", error_code=200, result={"worker": True})
),
"ctrl_c2e_rank0_6779": self._make_ctrl_queue(
"cache", Mock(request_id="req", error_code=200, result={"cache": True})
),
}
results = asyncio.run(eng._wait_all_control_responses("req", timeout=1))
self.assertEqual(results, [{"ok": True}])
worker_results = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["worker"]))
cache_results = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["cache_transfer"]))
self.assertEqual(worker_results, [{"worker": True}])
self.assertEqual(cache_results, [{"cache": True}])
self._detach_finalizer(eng)
def test_wait_all_control_responses_error_paths(self):
def test_wait_for_control_responses_ignores_mismatch(self):
eng = self._make_mixed_engine()
eng._ctrl_worker_output_queues = [
self._make_ctrl_queue("q0", Exception("boom"), payload_wrapped=False),
]
class DummyQueue:
def __init__(self, name, payloads):
self.name = name
self.payloads = list(payloads)
async def get(self, timeout=None):
return Mock(payload=self.payloads.pop(0))
eng._ctrl_output_queues = {
"ctrl_w2e_rank0_6778": DummyQueue(
"q0",
[
Mock(request_id="old", error_code=200, result={"ok": False}),
Mock(request_id="req", error_code=200, result={"ok": "from-q0"}),
],
),
"ctrl_w2e_rank1_6778": self._make_ctrl_queue(
"q1", Mock(request_id="req", error_code=200, result={"ok": True})
),
}
results = asyncio.run(eng._wait_for_control_responses("req", timeout=1))
self.assertEqual(results, [{"ok": "from-q0"}, {"ok": True}])
self.assertEqual(
eng._ctrl_response_mailboxes["ctrl_w2e_rank0_6778"]["old"].result,
{"ok": False},
)
self._detach_finalizer(eng)
def test_wait_for_control_responses_error_paths(self):
eng = self._make_mixed_engine()
eng._ctrl_output_queues = {
"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", Exception("boom"), payload_wrapped=False)
}
with self.assertRaises(Exception):
asyncio.run(eng._wait_all_control_responses("req", timeout=1))
asyncio.run(eng._wait_for_control_responses("req", timeout=1))
self._detach_finalizer(eng)
def test_wait_all_control_responses_none_message(self):
def test_wait_for_control_responses_none_message(self):
eng = self._make_mixed_engine()
eng._ctrl_worker_output_queues = [self._make_ctrl_queue("q0", None, payload_wrapped=False)]
eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)}
with self.assertRaises(Exception):
asyncio.run(eng._wait_all_control_responses("req", timeout=1))
asyncio.run(eng._wait_for_control_responses("req", timeout=1))
self._detach_finalizer(eng)
def test_wait_all_control_responses_error_code(self):
def test_wait_for_control_responses_error_code(self):
eng = self._make_mixed_engine()
eng._ctrl_worker_output_queues = [
self._make_ctrl_queue("q0", ControlResponse(request_id="req", error_code=500, error_message="bad")),
]
eng._ctrl_output_queues = {
"ctrl_w2e_rank0_6778": self._make_ctrl_queue(
"q0", ControlResponse(request_id="req", error_code=500, error_message="bad")
)
}
with self.assertRaises(Exception):
asyncio.run(eng._wait_all_control_responses("req", timeout=1))
asyncio.run(eng._wait_for_control_responses("req", timeout=1))
self._detach_finalizer(eng)
def test_wait_all_control_responses_timeout(self):
def test_wait_for_control_responses_timeout(self):
eng = self._make_mixed_engine()
eng._ctrl_worker_output_queues = [self._make_ctrl_queue("q0", None, payload_wrapped=False)]
eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)}
with patch("fastdeploy.engine.common_engine.asyncio.wait_for", side_effect=asyncio.TimeoutError):
with self.assertRaises(Exception):
asyncio.run(eng._wait_all_control_responses("req", timeout=1))
asyncio.run(eng._wait_for_control_responses("req", timeout=1))
self._detach_finalizer(eng)
def test_wait_for_control_responses_without_matching_queues(self):
eng = self._make_mixed_engine()
eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)}
result = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["cache_transfer"]))
self.assertIsNone(result)
self._detach_finalizer(eng)
def test_insert_tasks_prefill_error_and_success(self):
@@ -3254,7 +3358,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
# Lines 1299-1300: try block start + info logging
info_msgs = [str(c) for c in mock_logger.info.call_args_list]
self.assertTrue(any("START run control method" in m for m in info_msgs))
self.assertTrue(any("Start to run control method" in m for m in info_msgs))
# worker_pid should be popped from the map
self.assertNotIn("ctrl-log", eng.request_worker_map)
self._detach_finalizer(eng)
+92 -3
View File
@@ -84,6 +84,7 @@ def _reload_api_server(args):
fake_envs_mod.EXPORTER_OTLP_ENDPOINT = ""
fake_envs_mod.EXPORTER_OTLP_HEADERS = ""
fake_envs_mod.FD_SUPPORT_MAX_CONNECTIONS = 1024
fake_envs_mod.FD_ENABLE_V1_UPDATE_WEIGHTS = 0
fake_envs_mod.environment_variables = _FakeEnvVars()
# Save original sys.argv and replace with minimal valid args to avoid parse errors
@@ -98,9 +99,8 @@ def _reload_api_server(args):
patch.dict("sys.modules", {"fastdeploy.envs": fake_envs_mod}),
patch("fastdeploy.envs", fake_envs_mod),
):
from fastdeploy.entrypoints.openai import api_server as api_server_mod
return importlib.reload(api_server_mod)
sys.modules.pop("fastdeploy.entrypoints.openai.api_server", None)
return importlib.import_module("fastdeploy.entrypoints.openai.api_server")
finally:
sys.argv = original_argv
@@ -536,6 +536,95 @@ async def test_reward_embedding_and_weights():
assert api_server.clear_load_weight(MagicMock()).status_code == 404
@pytest.mark.asyncio
async def test_sleep_wakeup_and_v1_clear_load_weight_routes():
args = _build_args(dynamic_load_weight=True)
api_server = _reload_api_server(args)
api_server.app.state.dynamic_load_weight = True
api_server.app.state.event_loop = asyncio.get_running_loop()
mock_control_response = MagicMock()
mock_control_response.to_api_json_response.return_value = api_server.JSONResponse(
content={"ok": True}, status_code=200
)
api_server.app.state.engine_client = MagicMock()
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response)
api_server.app.state.engine_client.run_control_method_sync.return_value = mock_control_response
sleep_req = MagicMock()
sleep_req.body = AsyncMock(return_value=b'{"tags":"weight"}')
sleep_req.json = AsyncMock(return_value={"tags": "weight"})
sleep_resp = await api_server.sleep(sleep_req)
assert sleep_resp.status_code == 200
sleep_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[0].args[0]
assert sleep_control_request.method == "sleep"
assert sleep_control_request.args == {"tags": "weight"}
wakeup_req = MagicMock()
wakeup_req.body = AsyncMock(return_value=b'{"tags":"weight,kv_cache"}')
wakeup_req.json = AsyncMock(return_value={"tags": "weight,kv_cache"})
wakeup_resp = await api_server.wakeup(wakeup_req)
assert wakeup_resp.status_code == 200
wakeup_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[1].args[0]
assert wakeup_control_request.method == "wakeup"
assert wakeup_control_request.args == {"tags": "weight,kv_cache"}
passthrough_req = MagicMock()
passthrough_req.body = AsyncMock(return_value=b'{"tags":["weight"]}')
passthrough_req.json = AsyncMock(return_value={"tags": ["weight"]})
passthrough_resp = await api_server.sleep(passthrough_req)
assert passthrough_resp.status_code == 200
passthrough_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[2].args[0]
assert passthrough_control_request.args == {"tags": ["weight"]}
with patch.object(api_server.envs, "FD_ENABLE_V1_UPDATE_WEIGHTS", True):
clear_resp = api_server.clear_load_weight(MagicMock())
assert clear_resp.status_code == 200
sync_control_request = api_server.app.state.engine_client.run_control_method_sync.call_args.args[0]
assert sync_control_request.method == "sleep"
api_server.app.state.engine_client.run_control_method_sync.reset_mock()
with patch.object(api_server.envs, "FD_ENABLE_V1_UPDATE_WEIGHTS", True):
update_resp = api_server.update_model_weight(MagicMock())
assert update_resp.status_code == 200
sync_control_request = api_server.app.state.engine_client.run_control_method_sync.call_args.args[0]
assert sync_control_request.method == "wakeup"
@pytest.mark.asyncio
async def test_update_weights_route_validation():
args = _build_args(dynamic_load_weight=True)
api_server = _reload_api_server(args)
mock_control_response = MagicMock()
mock_control_response.to_api_json_response.return_value = api_server.JSONResponse(
content={"ok": True}, status_code=200
)
api_server.app.state.engine_client = MagicMock()
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response)
valid_req = MagicMock()
valid_req.body = AsyncMock(return_value=b'{"version":"v2","rsync_config":{"etcd_server":"127.0.0.1"}}')
valid_req.json = AsyncMock(return_value={"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}})
valid_resp = await api_server.update_weights(valid_req)
assert valid_resp.status_code == 200
control_request = api_server.app.state.engine_client.run_control_method.await_args.args[0]
assert control_request.method == "update_weights"
assert control_request.args == {"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}}
invalid_version_req = MagicMock()
invalid_version_req.body = AsyncMock(return_value=b'{"version":1}')
invalid_version_req.json = AsyncMock(return_value={"version": 1})
invalid_version_resp = await api_server.update_weights(invalid_version_req)
assert invalid_version_resp.status_code == 400
invalid_rsync_req = MagicMock()
invalid_rsync_req.body = AsyncMock(return_value=b'{"rsync_config":{"user":"u"}}')
invalid_rsync_req.json = AsyncMock(return_value={"rsync_config": {"user": "u"}})
invalid_rsync_resp = await api_server.update_weights(invalid_rsync_req)
assert invalid_rsync_resp.status_code == 400
@pytest.mark.asyncio
async def test_expert_and_stats_routes():
args = _build_args()
@@ -142,34 +142,34 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
def capture_and_replay(self, input_tensor1, forward_meta1):
""" """
# Trigger Capture
print_gpu_memory_use(0, "before capture")
print_gpu_memory_use("before capture", 0)
output1 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
print_gpu_memory_use(0, "after capture")
print_gpu_memory_use("after capture", 0)
# Replay
output1 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
assert (output1 == self.output_correct).all()
# Destroy
print_gpu_memory_use(0, "before destroy")
print_gpu_memory_use("before destroy", 0)
self.test_model1.clear_grpah_opt_backend()
print_gpu_memory_use(0, "after destroy")
print_gpu_memory_use("after destroy", 0)
def recapture_and_replay(self, input_tensor1, forward_meta1):
""" """
# Trigger Capture
print_gpu_memory_use(0, "before recapture")
print_gpu_memory_use("before recapture", 0)
output2 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
print_gpu_memory_use(0, "after recapture")
print_gpu_memory_use("after recapture", 0)
# Replay
output2 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
assert (output2 == self.output_correct).all()
# Destroy
print_gpu_memory_use(0, "before destroy")
print_gpu_memory_use("before destroy", 0)
self.test_model1.clear_grpah_opt_backend()
print_gpu_memory_use(0, "after destroy")
print_gpu_memory_use("after destroy", 0)
if __name__ == "__main__":