mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[Feature] support v1 update/clear api for RL (#6761)
* [Feature] support v1 update/clear api for RL * [fix] fix execute_model and add sleep/wakeup api * [fix] fix mtp and key_prefix * [chore] move _update_key_prefix to resume method * [fix] make the interface safe to call multiple times * [fix] fix some tiny bugs * [chore] make small changes against pr review * [docs] add docs for weight update * [test] add some tests and update docs * [style] fix code style check * [test] fix ci * [fix] fix stale control responses when control method timed out * [chore] remove unused code * [chore] fix code style * [chore] optimize tags and key_prefix * [test] fix ci * [chore] fix code style * [test] fix ci * [fix] fix ep control * [fix] fix ep control for engine cache queue
This commit is contained in:
@@ -0,0 +1,308 @@
|
||||
[简体中文](../zh/features/weight_update.md)
|
||||
|
||||
# Weight Clear and Update
|
||||
|
||||
FastDeploy supports dynamic weight clear and update for RL and RLHF rollout services. This capability is primarily intended to address the following two requirements:
|
||||
|
||||
- release GPU memory when the rollout engine is idle;
|
||||
- refresh inference weights after the trainer produces a new checkpoint, without restarting the whole service.
|
||||
|
||||
This page describes the weight-control interfaces currently supported by FastDeploy, the semantics of each interface, and their typical usage in RLHF training.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
In RLHF scenarios, FastDeploy mainly provides this capability through the online serving mode. Dynamic weight loading must be enabled when starting the service:
|
||||
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model /path/to/model \
|
||||
--dynamic-load-weight \
|
||||
--load_strategy ipc_snapshot
|
||||
```
|
||||
|
||||
`--dynamic-load-weight` enables dynamic weight control, and `--load_strategy` specifies the concrete weight update mechanism. The currently supported update modes are listed below:
|
||||
|
||||
| Mode | `load_strategy` | Typical use | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| CUDA IPC | `ipc` | Training and inference processes on the same node share live tensors | Update source comes from IPC metadata produced by the training side. |
|
||||
| IPC snapshot | `ipc_snapshot` | Rollout reloads a snapshot file produced by training | Used by current RL rollout examples. |
|
||||
| RDMA / rsync | `rsync` | Trainer publishes a new version and rollout fetches it remotely | `POST /v1/update_weights` is the explicit API for this mode. |
|
||||
|
||||
## API Overview
|
||||
|
||||
### Compatibility APIs
|
||||
|
||||
In FastDeploy <= 2.5, the following simplified APIs are provided for compatibility with the legacy RL control flow.
|
||||
|
||||
| API | Method | Meaning | Availability |
|
||||
| --- | --- | --- | --- |
|
||||
| `/clear_load_weight` | `GET` | Clear or offload currently loaded weights | Requires `dynamic_load_weight=True` |
|
||||
| `/update_model_weight` | `GET` | Reload weights after a clear/offload operation | Requires `dynamic_load_weight=True` |
|
||||
|
||||
### V1 control APIs
|
||||
|
||||
In FastDeploy >= 2.6, the underlying control-signal communication path is optimized and V1 control APIs are introduced. Compared with the legacy APIs, the V1 APIs provide a more stable execution path, clearer semantics, and more flexible control:
|
||||
|
||||
| API | Method | Request params | Semantics |
|
||||
| --- | --- | --- | --- |
|
||||
| `/v1/pause` | `POST` | none | Pause request generation, abort running and inflight requests, reset scheduler state, and pause cache transfer if enabled. |
|
||||
| `/v1/resume` | `POST` | none | Resume request generation and cache transfer. |
|
||||
| `/v1/is_paused` | `GET` | none | Return `{"is_paused": bool}`. |
|
||||
| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | Offload selected GPU memory objects. Supported tags are `weight` and `kv_cache`. If omitted, both are used. |
|
||||
| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | Reload previously offloaded weights and/or KV cache. On success, the engine resumes automatically. |
|
||||
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. |
|
||||
|
||||
### Compatibility Notes
|
||||
|
||||
The optimized communication path also applies to the legacy APIs. By setting `FD_ENABLE_V1_UPDATE_WEIGHTS=1`, the legacy APIs can be switched to the new control path while keeping the original API form.
|
||||
|
||||
- `FD_ENABLE_V1_UPDATE_WEIGHTS=0`: use the legacy shared-memory-based control path.
|
||||
- `FD_ENABLE_V1_UPDATE_WEIGHTS=1`: `/clear_load_weight` is effectively handled through `/v1/sleep`, and `/update_model_weight` is effectively handled through `/v1/wakeup`. The corresponding pause/resume actions are handled internally by `sleep` and `wakeup`.
|
||||
|
||||
**Note**: regardless of whether V1 is enabled, the legacy APIs are not the recommended standard interface for RLHF scenarios and may be gradually deprecated in future releases. The `/v1/*` control APIs are recommended.
|
||||
|
||||
## Interface Semantics
|
||||
|
||||
### `/v1/pause`
|
||||
|
||||
`/v1/pause` is the safe boundary before changing model state.
|
||||
|
||||
It does the following:
|
||||
|
||||
- stops new request generation;
|
||||
- aborts running and inflight requests;
|
||||
- resets scheduler state;
|
||||
- pauses cache transfer when multi-level cache or KV cache storage is enabled.
|
||||
|
||||
When a clear boundary is required between one rollout round and the next training stage, this API should be called first.
|
||||
|
||||
### `/v1/sleep`
|
||||
|
||||
`/v1/sleep` offloads selected runtime state from GPU memory.
|
||||
|
||||
Supported tags:
|
||||
|
||||
- `weight`: clear model weights from device memory; if enabled, communication groups and DeepEP buffers may also be released.
|
||||
- `kv_cache`: clear KV cache; MTP cache is also cleared when speculative decoding uses MTP.
|
||||
|
||||
If the `tags` parameter is omitted, FastDeploy defaults to:
|
||||
|
||||
```bash
|
||||
/v1/sleep?tags=weight,kv_cache
|
||||
```
|
||||
|
||||
In the current implementation, `sleep` automatically performs a `pause` first. New integrations should not rely on this implicit behavior.
|
||||
|
||||
### `/v1/wakeup`
|
||||
|
||||
`/v1/wakeup` restores the state offloaded by `/v1/sleep`.
|
||||
|
||||
Depending on tags and configuration, FastDeploy may:
|
||||
|
||||
- restart communication groups;
|
||||
- recreate DeepEP buffers;
|
||||
- reload model weights from the configured source;
|
||||
- rebuild KV cache;
|
||||
- recapture CUDA Graph.
|
||||
|
||||
After `wakeup` succeeds, FastDeploy automatically calls `resume`.
|
||||
|
||||
### `/v1/update_weights`
|
||||
|
||||
`/v1/update_weights` refreshes model parameters directly, without unloading the GPU memory occupied by model weights.
|
||||
|
||||
Current request fields:
|
||||
|
||||
- `version`: optional string. Used to choose a target checkpoint version.
|
||||
- `rsync_config`: optional dictionary. Must contain `etcd_server` when provided.
|
||||
|
||||
Important semantics:
|
||||
|
||||
- the engine must already be paused, otherwise the request fails;
|
||||
- the update is executed on workers only;
|
||||
- this API is meant for explicit weight refresh, especially the `rsync` path;
|
||||
- it does not implicitly call `resume`.
|
||||
|
||||
Recommended sequence:
|
||||
|
||||
1. `POST /v1/pause`
|
||||
2. `POST /v1/update_weights`
|
||||
3. `POST /v1/resume`
|
||||
|
||||
If GPU memory also needs to be reclaimed between rollout rounds, the `sleep` / `wakeup` workflow is more appropriate.
|
||||
|
||||
## Example Requests
|
||||
|
||||
### Basic APIs
|
||||
|
||||
Pause the engine:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8000/v1/pause
|
||||
```
|
||||
|
||||
Resume the engine:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8000/v1/resume
|
||||
```
|
||||
|
||||
### Sleep / Wakeup APIs
|
||||
|
||||
**Offload weights and KV cache**
|
||||
|
||||
```bash
|
||||
# Offload both weights and KV cache
|
||||
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache"
|
||||
|
||||
# Offload only weights
|
||||
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight"
|
||||
|
||||
# Omit parameter, defaults to both
|
||||
curl -X POST "http://127.0.0.1:8000/v1/sleep"
|
||||
```
|
||||
|
||||
**Restore weights and KV cache**
|
||||
|
||||
```bash
|
||||
# Restore both weights and KV cache
|
||||
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache"
|
||||
|
||||
# Restore only weights
|
||||
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight"
|
||||
|
||||
# Omit parameter, defaults to both
|
||||
curl -X POST "http://127.0.0.1:8000/v1/wakeup"
|
||||
```
|
||||
|
||||
**Note**: When `use_cudagraph=True`, KV cache must be restored before weights. This means `/v1/wakeup` with the `kv_cache` tag must be called before calling `/v1/wakeup` with the `weight` tag. If weights are restored without KV cache, an error will be raised. It is recommended to keep the `tags` parameter consistent between `/v1/sleep` and `/v1/wakeup`.
|
||||
|
||||
### Update Weights API
|
||||
|
||||
Refresh to a new remotely published version:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8000/v1/update_weights \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"version": "global_step_1200",
|
||||
"rsync_config": {
|
||||
"etcd_server": "127.0.0.1:2379"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## RLHF Usage
|
||||
|
||||
### Recommended Rollout Service Setup
|
||||
|
||||
In RLHF scenarios, FastDeploy rollout services are typically configured as follows:
|
||||
|
||||
- `dynamic_load_weight=True`
|
||||
- `load_strategy=ipc_snapshot` for local snapshot-based refresh;
|
||||
- or `load_strategy=rsync` for versioned remote refresh.
|
||||
|
||||
The rollout utilities in the repository already follow this pattern. A typical example is:
|
||||
|
||||
```python
|
||||
from fastdeploy.rl.rollout_config import RolloutModelConfig
|
||||
from fastdeploy.rl.rollout_model import RolloutModel
|
||||
|
||||
rollout_config = RolloutModelConfig(
|
||||
model_name_or_path=model_path,
|
||||
tensor_parallel_size=ranks,
|
||||
dynamic_load_weight=True,
|
||||
load_strategy="ipc_snapshot",
|
||||
)
|
||||
rollout_model = RolloutModel(rollout_config)
|
||||
```
|
||||
|
||||
### Training-Side Integration Support
|
||||
|
||||
In addition to serving endpoints, FastDeploy provides the following training-side integration capabilities for RLHF:
|
||||
|
||||
- `RolloutModel.state_dict()`: exposes the rollout-side inference parameters.
|
||||
- `RolloutModel.get_name_mappings_to_training()`: exposes the mapping from inference parameter names to training parameter names.
|
||||
|
||||
These interfaces can be used to align training checkpoints with rollout-side parameter layouts, especially when inference-side and training-side parameter names are not fully identical.
|
||||
|
||||
### Common RLHF workflows
|
||||
|
||||
The following examples assume the service endpoint is `http://127.0.0.1:8000`.
|
||||
|
||||
**Workflow 1: clear and restore**
|
||||
|
||||
This workflow is suitable when the rollout service stays resident, but GPU memory should be released before training and restored afterward. The recommended sequence is `(pause) -> sleep -> wakeup -> (resume)`, where the steps in parentheses are optional.
|
||||
|
||||
```bash
|
||||
# Optional: explicitly pause the engine to establish a clear transition boundary
|
||||
curl -X POST http://127.0.0.1:8000/v1/pause
|
||||
|
||||
# Offload both weights and KV cache
|
||||
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache"
|
||||
|
||||
# Restore both weights and KV cache after training completes
|
||||
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache"
|
||||
|
||||
# Optional: explicitly resume if required by the integration
|
||||
curl -X POST http://127.0.0.1:8000/v1/resume
|
||||
```
|
||||
|
||||
**Workflow 2: in-place refresh to a new checkpoint**
|
||||
|
||||
This workflow is suitable when the service remains resident and only needs to switch to a new checkpoint version. The recommended sequence is `pause -> update_weights -> resume`.
|
||||
|
||||
```bash
|
||||
# Pause the engine first
|
||||
curl -X POST http://127.0.0.1:8000/v1/pause
|
||||
|
||||
# Refresh to a new checkpoint version in place
|
||||
curl -X POST http://127.0.0.1:8000/v1/update_weights \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"version": "global_step_1200",
|
||||
"rsync_config": {
|
||||
"etcd_server": "127.0.0.1:2379"
|
||||
}
|
||||
}'
|
||||
|
||||
# Resume the service after the update completes
|
||||
curl -X POST http://127.0.0.1:8000/v1/resume
|
||||
```
|
||||
|
||||
**Workflow 3: legacy compatibility APIs**
|
||||
|
||||
Legacy RL clients can continue to use the compatibility flow `clear_load_weight -> update_model_weight`.
|
||||
|
||||
```bash
|
||||
# Clear or offload the current weights
|
||||
curl -X GET http://127.0.0.1:8000/clear_load_weight
|
||||
|
||||
# Reload weights after the trainer updates the checkpoint
|
||||
curl -X GET http://127.0.0.1:8000/update_model_weight
|
||||
```
|
||||
|
||||
For new integrations, the `/v1/*` APIs are recommended because their control path is more explicit and easier to trace.
|
||||
|
||||
## Other Related Configuration
|
||||
|
||||
### Communication Group Clear and Rebuild
|
||||
|
||||
FastDeploy provides `--shutdown-comm-group-if-worker-idle` and `--no-shutdown-comm-group-if-worker-idle` to explicitly control whether communication groups should also be torn down when weights are offloaded.
|
||||
|
||||
Keeping communication groups alive generally improves the stability of weight clearing and reloading. The tradeoff is that more GPU memory remains allocated after weight offload, and the execution time of `sleep` / `wakeup` may also increase.
|
||||
|
||||
By default:
|
||||
|
||||
- in EP scenarios, communication groups are kept;
|
||||
- in non-EP scenarios, communication groups are torn down.
|
||||
|
||||
### CPU Cache Clear and Rebuild
|
||||
|
||||
After `--swap-space` is enabled, the following environment variable can be used to control whether CPU-side cache should also be cleared when `/v1/sleep` is executed, in order to reduce memory pressure during training.
|
||||
|
||||
By default, FastDeploy does not actively clear CPU cache. To clear it together with `sleep`, set:
|
||||
|
||||
```bash
|
||||
export FD_ENABLE_SWAP_SPACE_CLEARING=1
|
||||
```
|
||||
@@ -0,0 +1,307 @@
|
||||
[English](../../features/weight_update.md)
|
||||
|
||||
# 权重清除与更新
|
||||
|
||||
FastDeploy 支持面向 RL / RLHF Rollout 服务的动态权重清除、显存卸载和权重更新,主要用于解决以下两类问题:
|
||||
|
||||
- Rollout 引擎空闲时释放 GPU 显存;
|
||||
- Trainer 产出新 checkpoint 后,推理服务在不重启进程的情况下切换到新权重。
|
||||
|
||||
本文档介绍 FastDeploy 当前支持的权重控制接口、各接口的语义,以及它们在 RLHF 训练中的典型用法。
|
||||
|
||||
## 前置条件
|
||||
|
||||
在 RLHF 场景下,FastDeploy 主要通过在线服务模式提供该能力。启动服务时,需要开启动态权重加载:
|
||||
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model /path/to/model \
|
||||
--dynamic-load-weight \
|
||||
--load-strategy ipc_snapshot
|
||||
```
|
||||
|
||||
`--dynamic-load-weight` 用于开启动态权重控制能力,`--load-strategy` 用于指定具体的权重更新方式。当前支持的更新模式如下:
|
||||
|
||||
| 模式 | `load_strategy` | 典型场景 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| CUDA IPC | `ipc` | 训练进程与推理进程同机,直接共享实时张量 | 更新来源是训练侧产出的 IPC 元信息。 |
|
||||
| IPC 快照 | `ipc_snapshot` | Rollout 从训练侧产出的权重快照文件重载 | 当前仓库里的 RL rollout 示例主要使用该模式。 |
|
||||
| RDMA / rsync | `rsync` | Trainer 发布新版本,Rollout 远端拉取 | `POST /v1/update_weights` 是这一模式的显式接口。 |
|
||||
|
||||
## 接口说明
|
||||
|
||||
### 旧版接口
|
||||
|
||||
在 FastDeploy <= 2.5 版本中,主要提供以下简化接口,保留给旧版 RL 控制流使用。
|
||||
|
||||
| 接口 | 方法 | 含义 | 可用条件 |
|
||||
| --- | --- | --- | --- |
|
||||
| `/clear_load_weight` | `GET` | 清除或卸载当前已加载权重 | 需要 `dynamic_load_weight=True` |
|
||||
| `/update_model_weight` | `GET` | 在清除/卸载后重新加载权重 | 需要 `dynamic_load_weight=True` |
|
||||
|
||||
### V1 新版接口
|
||||
|
||||
在 FastDeploy >= 2.6 版本中,底层控制信号通信链路经过优化,并引入了 V1 控制接口。相较于旧版接口,V1 接口在通信与执行链路上更稳定,语义更清晰,同时提供了更灵活的控制方式,包括以下接口:
|
||||
|
||||
| 接口 | 方法 | 请求参数 | 语义 |
|
||||
| --- | --- | --- | --- |
|
||||
| `/v1/pause` | `POST` | 无 | 暂停请求生成,中断 running/inflight 请求,重置调度器,并在开启 cache transfer 时暂停 cache transfer。 |
|
||||
| `/v1/resume` | `POST` | 无 | 恢复请求生成和 cache transfer。 |
|
||||
| `/v1/is_paused` | `GET` | 无 | 返回 `{"is_paused": bool}`。 |
|
||||
| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | 卸载指定 GPU 内存对象。支持 `weight` 与 `kv_cache`;不传时默认同时处理两者。 |
|
||||
| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | 重新加载之前被卸载的权重和/或 KV Cache。成功后会自动 `resume`。 |
|
||||
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 |
|
||||
|
||||
### 兼容性说明
|
||||
|
||||
底层通信链路的优化同样适用于旧版接口。通过设置环境变量 `FD_ENABLE_V1_UPDATE_WEIGHTS=1`,可以将旧版接口切换到新的控制链路,在保留兼容接口形式的同时,获得更明确的执行路径和更好的可观测性。
|
||||
- `FD_ENABLE_V1_UPDATE_WEIGHTS=0`:走旧版基于共享内存的控制链路。
|
||||
- `FD_ENABLE_V1_UPDATE_WEIGHTS=1`:`/clear_load_weight` 底层等价于执行 `/v1/sleep`,`/update_model_weight` 底层等价于执行 `/v1/wakeup`。对应的 pause/resume 动作分别由 `sleep` 和 `wakeup` 内部处理。
|
||||
|
||||
**注意**:无论是否设置 V1 环境变量,旧版接口都不是 RLHF 场景下推荐的标准使用方式,后续版本中也可能逐步废弃。建议优先使用 `/v1/*` 控制接口。
|
||||
|
||||
## 各接口语义
|
||||
|
||||
### `/v1/pause`
|
||||
|
||||
`/v1/pause` 是变更模型状态前的安全边界。
|
||||
|
||||
它会执行以下动作:
|
||||
|
||||
- 停止新请求生成;
|
||||
- 中断当前 running 和 inflight 请求;
|
||||
- 重置调度器状态;
|
||||
- 在启用多级缓存或 KV Cache 存储时暂停 cache 传输。
|
||||
|
||||
如果需要在每一轮 rollout 与下一轮训练之间建立清晰的切换边界,建议先调用该接口。
|
||||
|
||||
### `/v1/sleep`
|
||||
|
||||
`/v1/sleep` 用于从 GPU 显存中卸载指定的运行时状态。
|
||||
|
||||
当前支持的 `tags`:
|
||||
|
||||
- `weight`:清除设备上的模型权重;如果开启了相关配置,还可能一并释放通信组和 DeepEP buffer。
|
||||
- `kv_cache`:清除 KV Cache;如果投机解码采用 MTP,还会同步清理 MTP cache。
|
||||
|
||||
如果不传 `tags` 参数,FastDeploy 默认等价于:
|
||||
|
||||
```bash
|
||||
/v1/sleep?tags=weight,kv_cache
|
||||
```
|
||||
|
||||
当前实现中,`sleep` 会自动先执行一次 `pause`。但新的接入方不应长期依赖这一隐式行为。
|
||||
|
||||
### `/v1/wakeup`
|
||||
|
||||
`/v1/wakeup` 用于恢复通过 `/v1/sleep` 卸载的状态。
|
||||
|
||||
根据 `tags` 和运行配置,FastDeploy 可能执行:
|
||||
|
||||
- 重建通信组;
|
||||
- 重建 DeepEP buffer;
|
||||
- 从当前配置的数据源重新加载模型权重;
|
||||
- 重建 KV Cache;
|
||||
- 重新捕获 CUDA Graph。
|
||||
|
||||
`wakeup` 成功后,FastDeploy 会自动调用一次 `resume`。
|
||||
|
||||
### `/v1/update_weights`
|
||||
|
||||
`/v1/update_weights` 用于在不卸载模型权重显存占用的情况下,直接刷新模型参数。
|
||||
|
||||
当前支持的请求字段:
|
||||
|
||||
- `version`:可选字符串,用于指定目标 checkpoint 版本。
|
||||
- `rsync_config`:可选字典;如果传入,必须包含 `etcd_server`。
|
||||
|
||||
关键语义:
|
||||
|
||||
- 调用前引擎必须已经处于暂停状态,否则请求会失败;
|
||||
- 实际更新动作只在 worker 侧执行;
|
||||
- 该接口主要用于显式权重刷新,尤其是 `rsync` 路径;
|
||||
- 它不会自动执行 `resume`。
|
||||
|
||||
推荐调用顺序:
|
||||
|
||||
1. `POST /v1/pause`
|
||||
2. `POST /v1/update_weights`
|
||||
3. `POST /v1/resume`
|
||||
|
||||
如果除更新权重外,还希望在 rollout 轮次之间回收 GPU 显存,则更适合使用 `sleep` / `wakeup` 组合。
|
||||
|
||||
## 请求示例
|
||||
|
||||
### 基础接口
|
||||
|
||||
暂停引擎:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8000/v1/pause
|
||||
```
|
||||
|
||||
恢复引擎:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8000/v1/resume
|
||||
```
|
||||
|
||||
### Sleep / Wakeup 接口
|
||||
|
||||
**卸载权重和 KV Cache**
|
||||
|
||||
```bash
|
||||
# 同时卸载权重和 KV Cache
|
||||
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache"
|
||||
|
||||
# 只卸载权重
|
||||
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight"
|
||||
|
||||
# 不传参数,默认同时卸载两者
|
||||
curl -X POST "http://127.0.0.1:8000/v1/sleep"
|
||||
```
|
||||
|
||||
**恢复权重和 KV Cache**
|
||||
|
||||
```bash
|
||||
# 恢复权重和 KV Cache
|
||||
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache"
|
||||
|
||||
# 只恢复权重
|
||||
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight"
|
||||
|
||||
# 不传参数,默认同时恢复两者
|
||||
curl -X POST "http://127.0.0.1:8000/v1/wakeup"
|
||||
```
|
||||
|
||||
**注意**:当 `use_cudagraph=True` 时,必须先恢复 KV Cache 再恢复权重。这意味着调用 `/v1/wakeup` 时如果只包含 `weight` tag 而不包含 `kv_cache`,会报错。建议保持 sleep 和 wakeup 的 tags 参数一致。
|
||||
|
||||
### Update Weights 接口
|
||||
|
||||
切换到远端发布的新版本权重:
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8000/v1/update_weights \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"version": "global_step_1200",
|
||||
"rsync_config": {
|
||||
"etcd_server": "127.0.0.1:2379"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## 在 RLHF 中如何使用
|
||||
|
||||
### 推荐的 Rollout 服务配置
|
||||
|
||||
在 RLHF 场景下,FastDeploy 的 Rollout 服务通常采用以下配置:
|
||||
|
||||
- `dynamic_load_weight=True`
|
||||
- `load_strategy=ipc_snapshot`,适合本地快照式刷新;
|
||||
- 或 `load_strategy=rsync`,适合远端版本化刷新。
|
||||
|
||||
仓库中的 RL rollout 工具已经按该方式接入。典型写法如下:
|
||||
|
||||
```python
|
||||
from fastdeploy.rl.rollout_config import RolloutModelConfig
|
||||
from fastdeploy.rl.rollout_model import RolloutModel
|
||||
|
||||
rollout_config = RolloutModelConfig(
|
||||
model_name_or_path=model_path,
|
||||
tensor_parallel_size=ranks,
|
||||
dynamic_load_weight=True,
|
||||
load_strategy="ipc_snapshot",
|
||||
)
|
||||
rollout_model = RolloutModel(rollout_config)
|
||||
```
|
||||
|
||||
### FastDeploy 对训练侧的支持
|
||||
|
||||
除服务接口外,FastDeploy 还提供以下两类 RLHF 训练侧对接能力:
|
||||
|
||||
- `RolloutModel.state_dict()`:暴露 rollout 侧的推理参数;
|
||||
- `RolloutModel.get_name_mappings_to_training()`:暴露推理参数名到训练参数名的映射关系。
|
||||
|
||||
这两个接口可用于训练侧 checkpoint 与 rollout 侧权重布局对齐,尤其适用于推理态和训练态参数命名不完全一致的场景。
|
||||
|
||||
### RLHF 常见工作流
|
||||
|
||||
以下给出 RLHF 场景下几种常见调用方式。示例中假设服务地址为 `http://127.0.0.1:8000`。
|
||||
|
||||
**工作流 1:显存卸载与恢复**
|
||||
|
||||
适用于 rollout 服务常驻,但需要在训练阶段前后释放并恢复显存的场景。推荐流程为 `(pause) -> sleep -> wakeup -> (resume)`,其中括号内步骤为可选。
|
||||
|
||||
```bash
|
||||
# 可选:先显式暂停引擎,建立清晰的切换边界
|
||||
curl -X POST http://127.0.0.1:8000/v1/pause
|
||||
|
||||
# 卸载权重和 KV Cache
|
||||
curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache"
|
||||
|
||||
# 训练完成后恢复权重和 KV Cache
|
||||
curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache"
|
||||
|
||||
# 可选:如业务侧需要显式恢复,可手动调用
|
||||
curl -X POST http://127.0.0.1:8000/v1/resume
|
||||
```
|
||||
|
||||
**工作流 2:原地刷新到新 checkpoint**
|
||||
|
||||
适用于服务常驻、仅需要切换到新版本权重的场景。推荐流程为 `pause -> update_weights -> resume`。
|
||||
|
||||
```bash
|
||||
# 先暂停引擎
|
||||
curl -X POST http://127.0.0.1:8000/v1/pause
|
||||
|
||||
# 原地刷新到新版本权重
|
||||
curl -X POST http://127.0.0.1:8000/v1/update_weights \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"version": "global_step_1200",
|
||||
"rsync_config": {
|
||||
"etcd_server": "127.0.0.1:2379"
|
||||
}
|
||||
}'
|
||||
|
||||
# 更新完成后恢复服务
|
||||
curl -X POST http://127.0.0.1:8000/v1/resume
|
||||
```
|
||||
|
||||
**工作流 3:兼容旧版接口**
|
||||
|
||||
旧版 RL 客户端仍可继续使用兼容接口,流程为 `clear_load_weight -> update_model_weight`。
|
||||
|
||||
```bash
|
||||
# 清除或卸载当前权重
|
||||
curl -X GET http://127.0.0.1:8000/clear_load_weight
|
||||
|
||||
# 训练侧完成 checkpoint 更新后,重新加载权重
|
||||
curl -X GET http://127.0.0.1:8000/update_model_weight
|
||||
```
|
||||
|
||||
对于新的接入方,建议优先使用 `/v1/*` 接口,因为其控制链路更显式,日志排查和故障定位也更直接。
|
||||
|
||||
## 其他相关配置
|
||||
|
||||
### 通信组的销毁与重建
|
||||
|
||||
FastDeploy 支持通过 `--shutdown-comm-group-if-worker-idle` 和 `--no-shutdown-comm-group-if-worker-idle`,显式控制在卸载权重时是否同时销毁通信组。
|
||||
|
||||
保留通信组通常有助于提升权重清除和重新加载过程中的稳定性;相应地,代价是卸载权重后仍会保留更多显存占用,同时 `sleep` / `wakeup` 的执行时间也可能更长。
|
||||
|
||||
默认情况下:
|
||||
|
||||
- 在 EP 场景下,默认不销毁通信组;
|
||||
- 在非 EP 场景下,默认销毁通信组。
|
||||
|
||||
### CPU 缓存的清除与重建
|
||||
|
||||
启用 `--swap-space` 后,可以通过以下环境变量控制在执行 `/v1/sleep` 时,是否同步清理 CPU 侧缓存,以降低训练阶段的内存压力。
|
||||
|
||||
默认情况下,FastDeploy 不会主动清理 CPU Cache。如需在 `sleep` 时一并清理,可设置:
|
||||
|
||||
```bash
|
||||
export FD_ENABLE_SWAP_SPACE_CLEARING=1
|
||||
```
|
||||
@@ -32,6 +32,7 @@ class CacheStatus(Enum):
|
||||
CPU = 3
|
||||
GPU2STORAGE = 4
|
||||
STORAGE2GPU = 5
|
||||
CTRL = -1
|
||||
|
||||
|
||||
class BlockNode:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import gc
|
||||
import json
|
||||
@@ -35,7 +36,6 @@ from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTa
|
||||
from fastdeploy.cache_manager.ops import (
|
||||
cuda_host_alloc,
|
||||
cuda_host_free,
|
||||
memory_allocated,
|
||||
set_data_ipc,
|
||||
set_device,
|
||||
share_external_data_,
|
||||
@@ -49,7 +49,9 @@ from fastdeploy.cache_manager.transfer_factory import (
|
||||
MooncakeStore,
|
||||
)
|
||||
from fastdeploy.config import CacheConfig, SpeculativeConfig
|
||||
from fastdeploy.engine.request import ControlRequest, ControlResponse
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.inter_communicator.fmq import FMQ
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import console_logger, get_logger
|
||||
|
||||
@@ -96,7 +98,6 @@ def parse_args():
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
||||
parser.add_argument("--ipc_suffix", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
@@ -184,15 +185,17 @@ class CacheTransferManager:
|
||||
self.key_prefix = ""
|
||||
|
||||
# extract other arg values
|
||||
self.model_path = args.model_path
|
||||
self.model_id = os.path.basename(args.model_path.rstrip("/"))
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = args.rank
|
||||
self.device = args.device_id
|
||||
self.num_layers = args.num_layers
|
||||
self.ipc_suffix = args.ipc_suffix
|
||||
self.create_cache_tensor = args.create_cache_tensor
|
||||
self.local_data_parallel_id = args.local_data_parallel_id
|
||||
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
||||
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
self.cache_queue_port = args.cache_queue_port
|
||||
paddle.set_default_dtype(args.default_dtype)
|
||||
|
||||
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
@@ -200,8 +203,10 @@ class CacheTransferManager:
|
||||
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
self.control_task_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
||||
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
||||
self.ctrl_output_queue = None
|
||||
|
||||
address = (args.pod_ip, args.cache_queue_port)
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
@@ -209,7 +214,7 @@ class CacheTransferManager:
|
||||
is_server=False,
|
||||
num_client=args.mp_num,
|
||||
client_id=self.rank,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
local_data_parallel_id=0,
|
||||
)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
@@ -231,11 +236,12 @@ class CacheTransferManager:
|
||||
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
self._init_gpu_cache(args)
|
||||
self._init_gpu_cache()
|
||||
if self.num_cpu_blocks > 0:
|
||||
self._init_cpu_cache(args)
|
||||
self._init_cpu_cache()
|
||||
if self.storage_backend_type is not None:
|
||||
self._init_storage(args)
|
||||
self._init_control()
|
||||
|
||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.cache_task_broadcast_signal = IPCSignal(
|
||||
@@ -287,7 +293,9 @@ class CacheTransferManager:
|
||||
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
|
||||
|
||||
self.is_paused = False # transfer manager state
|
||||
self.is_sleeping = False
|
||||
self.inflight = 0 # number of inflight transfer tasks
|
||||
self.inflight_tasks = {}
|
||||
|
||||
cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.cache_transfer_inited_signal = IPCSignal(
|
||||
@@ -299,6 +307,15 @@ class CacheTransferManager:
|
||||
)
|
||||
self.cache_transfer_inited_signal.value[self.rank] = 1
|
||||
|
||||
def _init_control(self):
|
||||
dp_rank = self.local_data_parallel_id
|
||||
tp_rank = self.rank
|
||||
tp_size = self.n_ranks
|
||||
cq_port = self.cache_queue_port
|
||||
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_rank}_{cq_port}"
|
||||
self.ctrl_output_queue = FMQ().queue(name, "producer")
|
||||
logger.info(f"Init control output queue: {name} (producer)")
|
||||
|
||||
def _init_storage(self, args):
|
||||
try:
|
||||
# TODO: support cache scale for other backend
|
||||
@@ -354,13 +371,19 @@ class CacheTransferManager:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
self.write_policy = args.write_policy
|
||||
|
||||
version_file_path = os.path.join(args.model_path, "version.yaml")
|
||||
if os.path.exists(version_file_path):
|
||||
self.key_prefix = get_key_prefix_from_version(version_file_path)
|
||||
logger.info(f"The key_prefix of cache storage is {self.key_prefix}")
|
||||
self._update_key_prefix()
|
||||
|
||||
logger.info("Initialize cache storage successfully")
|
||||
|
||||
def _update_key_prefix(self):
|
||||
# use key_prefix to distinguish cache for different version of weight in rl
|
||||
version_file_path = os.path.join(self.model_path, "version.yaml")
|
||||
if os.path.exists(version_file_path):
|
||||
self.key_prefix = get_key_prefix_from_version(version_file_path)
|
||||
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
|
||||
else:
|
||||
logger.error(f"version.yaml not found at {version_file_path}")
|
||||
|
||||
def _init_storage_buffer(self, args):
|
||||
"""
|
||||
Initialize pinned memory buffer that can hold the cache for a longest request
|
||||
@@ -409,20 +432,20 @@ class CacheTransferManager:
|
||||
self.storage_value_scale_write_buffer = write_buffer + scale_buffer_total_bytes // 2
|
||||
self.storage_backend.register_buffer(write_buffer, scale_buffer_total_bytes)
|
||||
|
||||
def _init_gpu_cache(self, args):
|
||||
def _init_gpu_cache(self):
|
||||
|
||||
if not args.create_cache_tensor:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners or messagers to create kv cache.")
|
||||
if not self.create_cache_tensor:
|
||||
logger.info("Waiting for runners or messagers to create kv cache.")
|
||||
while self.cache_ready_signal.value[self.rank] != 1:
|
||||
time.sleep(0.1)
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
|
||||
logger.info("OK! Stop waiting.")
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
cache_type = "uint8"
|
||||
else:
|
||||
cache_type = args.cache_dtype
|
||||
cache_type = self.cache_dtype
|
||||
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
logger.info("Initializing kv cache for all layers.")
|
||||
set_device(self.device)
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
# NOTE: num_extra_layer_gpu_blocks is usually equal to num_gpu_blocks
|
||||
@@ -445,14 +468,12 @@ class CacheTransferManager:
|
||||
self.value_cache_shape[2],
|
||||
self.value_cache_shape[3],
|
||||
]
|
||||
if args.create_cache_tensor:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
|
||||
)
|
||||
if self.create_cache_tensor:
|
||||
logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
|
||||
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(key_cache, key_name)
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
key_cache_scales = paddle.full(
|
||||
shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]],
|
||||
fill_value=0,
|
||||
@@ -463,7 +484,7 @@ class CacheTransferManager:
|
||||
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(val_cache, val_name)
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
value_cache_scales = paddle.full(
|
||||
shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]],
|
||||
fill_value=0,
|
||||
@@ -471,13 +492,11 @@ class CacheTransferManager:
|
||||
)
|
||||
set_data_ipc(value_cache_scales, value_cache_scales_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
|
||||
)
|
||||
logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
val_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
key_cache_scales = share_external_data_(
|
||||
key_cache_scales,
|
||||
@@ -487,7 +506,7 @@ class CacheTransferManager:
|
||||
)
|
||||
if self.value_cache_shape:
|
||||
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
value_cache_scales = share_external_data_(
|
||||
value_cache_scales,
|
||||
@@ -498,48 +517,65 @@ class CacheTransferManager:
|
||||
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.gpu_cache_kvs[key_cache_scales_name] = key_cache_scales
|
||||
self.gpu_cache_scales_k_tensors.append(self.gpu_cache_kvs[key_cache_scales_name])
|
||||
if args.value_cache_shape:
|
||||
if self.value_cache_shape:
|
||||
self.gpu_cache_kvs[val_name] = val_cache
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.gpu_cache_kvs[value_cache_scales_name] = value_cache_scales
|
||||
self.gpu_cache_scales_v_tensors.append(self.gpu_cache_kvs[value_cache_scales_name])
|
||||
|
||||
if args.create_cache_tensor:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
|
||||
if self.create_cache_tensor:
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
while np.sum(self.cache_ready_signal.value) != self.n_ranks:
|
||||
time.sleep(0.1)
|
||||
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}")
|
||||
logger.info("GPU KV cache is initialized")
|
||||
|
||||
def _init_cpu_cache(self, args):
|
||||
def _clear_gpu_cache(self):
|
||||
if self.create_cache_tensor:
|
||||
logger.debug("Waiting for gpu runner to unlink cuda ipc")
|
||||
while self.cache_ready_signal.value[self.rank] != 0:
|
||||
time.sleep(0.1)
|
||||
logger.debug("Stop waiting! gpu runner has unlinked cuda ipc")
|
||||
self.gpu_cache_kvs.clear()
|
||||
self.gpu_cache_k_tensors.clear()
|
||||
self.gpu_cache_v_tensors.clear()
|
||||
if hasattr(self, "gpu_cache_scales_k_tensors"):
|
||||
self.gpu_cache_scales_k_tensors.clear()
|
||||
if hasattr(self, "gpu_cache_scales_v_tensors"):
|
||||
self.gpu_cache_scales_v_tensors.clear()
|
||||
paddle.device.cuda.empty_cache()
|
||||
else:
|
||||
for name, tensor in self.gpu_cache_kvs.items():
|
||||
unset_data_ipc(tensor, name, True, False)
|
||||
logger.debug("Successfully unlinked gpu caches cuda ipc")
|
||||
self.cache_ready_signal.value[self.rank] = 0
|
||||
|
||||
while np.sum(self.cache_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("All ranks cleared gpu caches")
|
||||
|
||||
def _init_cpu_cache(self):
|
||||
if self.num_cpu_blocks == 0:
|
||||
return
|
||||
paddle.set_device("cpu")
|
||||
key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3]
|
||||
if args.value_cache_shape:
|
||||
if self.value_cache_shape:
|
||||
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
|
||||
else:
|
||||
value_cache_size = 0
|
||||
cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype)
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
key_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * value_cache_size
|
||||
logger.info("Initializing swap space (cpu cache) for all layers.")
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
|
||||
scales_key_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
scales_value_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
if args.num_cpu_blocks == 0:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
return
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.")
|
||||
paddle.set_device("cpu")
|
||||
scales_key_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
scales_value_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
self.k_scales_ptrs = []
|
||||
@@ -550,22 +586,43 @@ class CacheTransferManager:
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
|
||||
value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}"
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
f"..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
|
||||
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(scales_key_need_to_allocate_bytes)
|
||||
self.k_scales_ptrs.append(self.cpu_cache_kvs[key_cache_scales_name])
|
||||
if value_need_to_allocate_bytes > 0:
|
||||
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(scales_value_need_to_allocate_bytes)
|
||||
self.v_scales_ptrs.append(self.cpu_cache_kvs[value_cache_scales_name])
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
logger.info("Swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
while np.sum(self.swap_space_ready_signal.value) != self.n_ranks:
|
||||
time.sleep(0.1)
|
||||
logger.info("All ranks init cpu caches")
|
||||
|
||||
def _clear_cpu_cache(self):
|
||||
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
||||
cuda_host_free(ptrs)
|
||||
self.cpu_cache_kvs.clear()
|
||||
self.k_dst_ptrs.clear()
|
||||
self.v_dst_ptrs.clear()
|
||||
if hasattr(self, "k_scales_ptrs"):
|
||||
self.k_scales_ptrs.clear()
|
||||
if hasattr(self, "v_scales_ptrs"):
|
||||
self.v_scales_ptrs.clear()
|
||||
gc.collect()
|
||||
self.swap_space_ready_signal.value[self.rank] = 0
|
||||
|
||||
while np.sum(self.swap_space_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("All ranks cleared cpu caches")
|
||||
|
||||
def _run_read_storage(
|
||||
self,
|
||||
task_id: str,
|
||||
@@ -1023,6 +1080,79 @@ class CacheTransferManager:
|
||||
logger.debug(f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
|
||||
logger.info(f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
|
||||
|
||||
def _handle_pause(self):
|
||||
if self.is_paused:
|
||||
logger.info("💡 Cache transfer manager is already paused, no need to pause again!")
|
||||
else:
|
||||
self.pause()
|
||||
logger.info("✅ Successfully paused transfer")
|
||||
return True
|
||||
|
||||
def _handle_resume(self):
|
||||
if not self.is_paused:
|
||||
logger.info("💡 Cache transfer manager is not paused, no need to resume!")
|
||||
else:
|
||||
self.resume()
|
||||
if self.storage_backend_type is not None:
|
||||
self._update_key_prefix()
|
||||
logger.info("✅ Successfully resumed transfer")
|
||||
return True
|
||||
|
||||
def _handle_sleep(self):
|
||||
if self.is_sleeping:
|
||||
logger.info("💡 Cache transfer manager is already sleeping, no need to sleep again!")
|
||||
else:
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._clear_cpu_cache()
|
||||
self._clear_gpu_cache()
|
||||
self.is_sleeping = True
|
||||
logger.info("✅ Successfully fell asleep (offloaded caches)")
|
||||
return True
|
||||
|
||||
def _handle_wakeup(self):
|
||||
if not self.is_sleeping:
|
||||
logger.info("💡 Cache transfer manager is not sleeping, no need to wakeup!")
|
||||
else:
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._init_cpu_cache()
|
||||
self._init_gpu_cache()
|
||||
self.is_sleeping = False
|
||||
logger.info("✅ Successfully wakeup (reload caches)")
|
||||
return True
|
||||
|
||||
def control_task(self, task: ControlRequest):
|
||||
method = task.get_method()
|
||||
tags = task.args.get("tags", {})
|
||||
logger.info(f"Received control task: {method}, tags: {tags}")
|
||||
|
||||
handlers = {
|
||||
"pause": self._handle_pause,
|
||||
"resume": self._handle_resume,
|
||||
"sleep": self._handle_sleep,
|
||||
"wakeup": self._handle_wakeup,
|
||||
}
|
||||
|
||||
handler = handlers.get(method)
|
||||
error_code = 200
|
||||
error_message = "Success"
|
||||
|
||||
if handler:
|
||||
try:
|
||||
handler()
|
||||
except Exception as e:
|
||||
error_code = 500
|
||||
error_message = f"Failed to execute {method}: {str(e)}"
|
||||
logger.error(f"Error in control_task: {traceback.format_exc()}")
|
||||
else:
|
||||
error_code = 400
|
||||
error_message = f"Unknown control method: {method}"
|
||||
logger.warning(error_message)
|
||||
|
||||
self.cache_task_queue.barrier.wait()
|
||||
resp = ControlResponse(task.request_id, error_code, error_message)
|
||||
asyncio.run(self.ctrl_output_queue.put(resp))
|
||||
logger.info(f"Put response into output queue {self.ctrl_output_queue.name}: {resp}")
|
||||
|
||||
def check_work_status(self, time_interval_threashold=envs.FD_CACHE_PROC_EXIT_TIMEOUT):
|
||||
"""
|
||||
Check the health of the model server by checking whether all workers are alive.
|
||||
@@ -1042,8 +1172,13 @@ class CacheTransferManager:
|
||||
return fn(*args)
|
||||
finally:
|
||||
self.inflight -= 1
|
||||
logger.debug(f"submit_task: {fn.__name__} finished, args: {args}, current inflight: {self.inflight}")
|
||||
|
||||
self.inflight += 1
|
||||
thread_pool.submit(inflight_task, task_fn, *args)
|
||||
logger.debug(
|
||||
f"submit_task: {task_fn.__name__} submitted to thread pool, args: {args}, current inflight: {self.inflight}"
|
||||
)
|
||||
|
||||
def do_data_transfer(self):
|
||||
"""
|
||||
@@ -1054,6 +1189,7 @@ class CacheTransferManager:
|
||||
max_errors = (
|
||||
envs.FD_CACHE_PROC_ERROR_COUNT
|
||||
) # After this many consecutive errors, check if the worker process exists.
|
||||
is_paused = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -1065,18 +1201,18 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.barrier0.reset()
|
||||
|
||||
# Ensure all ranks synchronically do one of the following things:
|
||||
# (1) If rank#0 is paused, wait for a short time and check out rank#0 status again;
|
||||
# (2) otherwise, all ranks are allowed to pull tasks from cache task queue
|
||||
# (1) If rank#0 is paused, wait for inflight tasks to finish first, then only process control tasks;
|
||||
# (2) otherwise, all ranks are allowed to pull all tasks from cache task queue
|
||||
if self.cache_task_is_paused_signal.value[0] == 1:
|
||||
# wait for inflight tasks to finish first
|
||||
while self.inflight != 0:
|
||||
time.sleep(0.1)
|
||||
# mark the current rank as not having inflight tasks
|
||||
self.cache_task_inflight_signal.value[self.rank] = 0
|
||||
time.sleep(1)
|
||||
continue
|
||||
is_paused = True
|
||||
else:
|
||||
self.cache_task_inflight_signal.value[self.rank] = 1
|
||||
is_paused = False
|
||||
|
||||
if self.rank == 0:
|
||||
if not self.cache_task_queue.empty():
|
||||
@@ -1087,12 +1223,16 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.barrier1.reset()
|
||||
|
||||
if self.cache_task_broadcast_signal.value[0] == 1:
|
||||
self.inflight += 1
|
||||
data, read_finish = self.cache_task_queue.get_transfer_task()
|
||||
logger.debug(f"do_data_transfer: {data}")
|
||||
if read_finish:
|
||||
self.cache_task_broadcast_signal.value[0] = 0
|
||||
event_type, event_args = data[0], data[1:]
|
||||
|
||||
# control task is the only task allowed to execute when loop is paused
|
||||
if is_paused and event_type.value != CacheStatus.CTRL.value:
|
||||
continue
|
||||
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||
self.submit_task(
|
||||
@@ -1129,6 +1269,13 @@ class CacheTransferManager:
|
||||
self.write_back_storage_task,
|
||||
write_storage_task,
|
||||
)
|
||||
elif event_type.value == CacheStatus.CTRL.value:
|
||||
control_task = event_args[0]
|
||||
self.control_task_thread_pool.submit(
|
||||
self.control_task,
|
||||
control_task,
|
||||
)
|
||||
|
||||
else:
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.barrier2.wait()
|
||||
@@ -1291,114 +1438,46 @@ class CacheTransferManager:
|
||||
# TODO XPU support RL
|
||||
if unset_data_ipc is None:
|
||||
return
|
||||
logger.info("[RL] Launch a thread to clear/restore kv cache when model weights are cleared/updated.")
|
||||
logger.info(
|
||||
"check_cache_status: Launch a thread to clear/restore kv cache when model weights are cleared/updated."
|
||||
)
|
||||
while True:
|
||||
# handle cache clearing/restoring
|
||||
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
||||
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
|
||||
try:
|
||||
# wait for inflight transfer tasks to finish and pause transfer manager
|
||||
# pause transfer
|
||||
self.pause()
|
||||
|
||||
# clear cpu caches
|
||||
logger.info("[RL] start clearing caches")
|
||||
logger.debug("[RL] start clearing cpu caches")
|
||||
# clear caches
|
||||
logger.info("check_cache_status: start clearing caches")
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
paddle.set_device("cpu")
|
||||
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
||||
cuda_host_free(ptrs)
|
||||
self.cpu_cache_kvs.clear()
|
||||
self.k_dst_ptrs.clear()
|
||||
self.v_dst_ptrs.clear()
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.k_scales_ptrs.clear()
|
||||
self.v_scales_ptrs.clear()
|
||||
gc.collect()
|
||||
logger.debug("[RL] successfully cleared cpu caches")
|
||||
# reset swap_space_ready_signal
|
||||
self.swap_space_ready_signal.value[self.rank] = 0
|
||||
while np.sum(self.swap_space_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.debug("[RL] all ranks cleared cpu caches")
|
||||
else:
|
||||
logger.debug("[RL] skip clearing cpu caches")
|
||||
|
||||
# clear gpu caches
|
||||
logger.debug("[RL] start clearing gpu caches")
|
||||
if args.create_cache_tensor:
|
||||
logger.info("[RL] waiting for gpu runner to unlink cuda ipc")
|
||||
while self.cache_ready_signal.value[self.rank] != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] stop waiting! gpu runner has unlinked cuda ipc")
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
self.gpu_cache_kvs.clear()
|
||||
self.gpu_cache_k_tensors.clear()
|
||||
self.gpu_cache_v_tensors.clear()
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.gpu_cache_scales_k_tensors.clear()
|
||||
self.gpu_cache_scales_v_tensors.clear()
|
||||
paddle.device.cuda.empty_cache()
|
||||
logger.debug("[RL] successfully cleared gpu caches")
|
||||
else:
|
||||
for name, tensor in self.gpu_cache_kvs.items():
|
||||
unset_data_ipc(tensor, name, True, False)
|
||||
logger.debug("[RL] successfully unlinked gpu caches cuda ipc")
|
||||
self.cache_ready_signal.value[self.rank] = 0
|
||||
|
||||
while np.sum(self.cache_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] all ranks cleared caches!")
|
||||
|
||||
# reset kv_cache_status_signal
|
||||
self._clear_cpu_cache()
|
||||
self._clear_gpu_cache()
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
|
||||
|
||||
self._log_memory("after clearing caches")
|
||||
|
||||
self._log_memory("check_cache_status: after clearing caches")
|
||||
except Exception as e:
|
||||
logger.error(f"[RL] failed to clear caches: {e}")
|
||||
logger.error(f"check_cache_status: failed to clear caches: {e}")
|
||||
|
||||
elif self.kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
|
||||
assert args.splitwise_role == "mixed", "Only mixed mode supports updating cache."
|
||||
try:
|
||||
# restore cpu cache
|
||||
logger.info("[RL] start restoring caches")
|
||||
logger.debug("[RL] start restoring cpu caches")
|
||||
logger.info("check_cache_status: start restoring caches")
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._init_cpu_cache(args)
|
||||
logger.debug("[RL] successfully restored cpu caches")
|
||||
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
logger.debug("[RL] all ranks restored cpu caches")
|
||||
else:
|
||||
logger.debug("[RL] skip restoring cpu caches")
|
||||
|
||||
# restore gpu cache and set cache_ready_signal
|
||||
logger.debug("[RL] start restoring gpu caches")
|
||||
self._init_gpu_cache(args)
|
||||
logger.debug("[RL] successfully restored gpu caches")
|
||||
self._init_cpu_cache()
|
||||
self._init_gpu_cache()
|
||||
|
||||
# update key prefix for kv cache backend
|
||||
if self.storage_backend_type is not None:
|
||||
# use key_prefix to distinguish cache for different version of weight in rl
|
||||
version_file_path = os.path.join(args.model_path, "version.yaml")
|
||||
assert os.path.exists(version_file_path), f"version.yaml not found at {version_file_path}"
|
||||
self.key_prefix = get_key_prefix_from_version(version_file_path)
|
||||
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
|
||||
|
||||
# wait for all ranks caches to be ready
|
||||
while np.sum(self.cache_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] all ranks restored caches!")
|
||||
self._update_key_prefix()
|
||||
|
||||
# resume transfer
|
||||
self.resume()
|
||||
|
||||
# set kv_cache_status_signal
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
||||
|
||||
self._log_memory("after restoring caches")
|
||||
self._log_memory("check_cache_status: after restoring caches")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[RL] failed to restore caches: {e}")
|
||||
logger.error(f"check_cache_status: failed to restore caches: {e}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
@@ -1407,11 +1486,11 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.pause_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.pause_barrier.reset()
|
||||
logger.info("[RL] 🟠 wait for inflight transfer tasks to finish")
|
||||
logger.info("pause: 🟠 wait for inflight transfer tasks to finish")
|
||||
self.is_paused = True
|
||||
while np.sum(self.cache_task_inflight_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] 🔴 pause transfer manager and stop do transfer tasks")
|
||||
logger.info("pause: 🔴 pause transfer manager and stop do transfer tasks")
|
||||
|
||||
def resume(self):
|
||||
if self.n_ranks > 1:
|
||||
@@ -1421,7 +1500,7 @@ class CacheTransferManager:
|
||||
self.is_paused = False
|
||||
while np.sum(self.cache_task_inflight_signal.value) != self.n_ranks:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] 🟢 resume transfer manager and start to do transfer tasks")
|
||||
logger.info("resume: 🟢 resume transfer manager and start to do transfer tasks")
|
||||
|
||||
def _log_memory(self, context: str):
|
||||
"""Log current GPU memory usage."""
|
||||
@@ -1430,8 +1509,8 @@ class CacheTransferManager:
|
||||
curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3)
|
||||
curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3)
|
||||
|
||||
logger.warning(
|
||||
f"GPU memory usage {context}:"
|
||||
logger.info(
|
||||
f"{context}: "
|
||||
f"max_allocated: {max_alloc:.2f}GB "
|
||||
f"max_reserved: {max_reserved:.2f}GB "
|
||||
f"current_allocated: {curr_alloc:.2f}GB "
|
||||
|
||||
@@ -217,7 +217,7 @@ class PrefixCacheManager:
|
||||
is_server=False,
|
||||
num_client=tensor_parallel_size,
|
||||
client_id=0,
|
||||
local_data_parallel_id=self.local_data_parallel_id,
|
||||
local_data_parallel_id=0,
|
||||
)
|
||||
|
||||
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
|
||||
@@ -293,7 +293,7 @@ class PrefixCacheManager:
|
||||
else:
|
||||
storage_arg_str = " "
|
||||
|
||||
if self.cache_config.swap_space or self.cache_config.kvcache_storage_backend:
|
||||
if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend:
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth "
|
||||
@@ -314,7 +314,6 @@ class PrefixCacheManager:
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --ipc_suffix {ipc_suffix}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
|
||||
@@ -353,9 +352,8 @@ class PrefixCacheManager:
|
||||
|
||||
# Start additional threads
|
||||
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
threading.Thread(target=self.recv_data_transfer_result, daemon=True).start()
|
||||
if cache_config.enable_prefix_caching:
|
||||
if cache_config.enable_prefix_caching and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||
threading.Thread(target=self.clear_prefix_cache, daemon=True).start()
|
||||
|
||||
all_cache_processes = cache_messager_processes + cache_manager_processes
|
||||
|
||||
@@ -687,6 +687,9 @@ class ParallelConfig:
|
||||
if self.shutdown_comm_group_if_worker_idle is None:
|
||||
self.shutdown_comm_group_if_worker_idle = not self.use_ep
|
||||
|
||||
if self.shutdown_comm_group_if_worker_idle and envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||
raise RuntimeError("shutdown_comm_group_if_worker_idle cannot be True when FD_ENABLE_V1_UPDATE_WEIGHTS=1")
|
||||
|
||||
# pd_disaggregation
|
||||
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||
use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
|
||||
|
||||
@@ -39,6 +39,8 @@ import zmq
|
||||
from tqdm import tqdm
|
||||
|
||||
import fastdeploy.metrics.trace as tracing
|
||||
from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.register_manager import RegisterManager
|
||||
from fastdeploy.engine.request import (
|
||||
ControlRequest,
|
||||
@@ -84,7 +86,7 @@ class EngineService:
|
||||
Base class containing common engine functionality
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, start_queue=True, use_async_llm=False):
|
||||
def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False):
|
||||
"""
|
||||
Initializes the LLMEngine with the provided configuration.
|
||||
|
||||
@@ -104,14 +106,23 @@ class EngineService:
|
||||
self.is_paused = False # pause request generation
|
||||
self._pause_cond = threading.Condition()
|
||||
|
||||
self._ctrl_worker_output_queues = []
|
||||
self._ctrl_output_queues = {}
|
||||
self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict)
|
||||
tp_size = cfg.parallel_config.tensor_parallel_size
|
||||
dp_index = cfg.parallel_config.local_data_parallel_id
|
||||
for rank in range(tp_size):
|
||||
for tp_rank in range(tp_size):
|
||||
# create worker control response queue
|
||||
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
|
||||
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
|
||||
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
|
||||
name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}"
|
||||
self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)")
|
||||
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
|
||||
|
||||
# create cache control response queue
|
||||
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
||||
engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port
|
||||
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}"
|
||||
self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)")
|
||||
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
@@ -1296,7 +1307,7 @@ class EngineService:
|
||||
worker_pid = self.request_worker_map.pop(request_id, None)
|
||||
|
||||
try:
|
||||
self.llm_logger.info(f"START run control method {request_id}: {method}")
|
||||
self.llm_logger.info(f"Start to run control method {method}: {request_id}")
|
||||
|
||||
handler_name = f"_control_{method}"
|
||||
handler = getattr(self, handler_name, None)
|
||||
@@ -1308,13 +1319,13 @@ class EngineService:
|
||||
return
|
||||
|
||||
result = handler(control_req)
|
||||
self.llm_logger.info(f"SUCCESS run control method {method}.")
|
||||
self.llm_logger.info(f"Successfully run control method {method}: {request_id} {result}")
|
||||
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||
data = [[succ_result]] if envs.ZMQ_SEND_BATCH_DATA else [succ_result]
|
||||
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed run control method {method}: {str(e)}"
|
||||
error_msg = f"Failed to run control method {method}: {request_id} {str(e)}"
|
||||
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
error_result = ControlResponse(request_id, 500, error_msg)
|
||||
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
|
||||
@@ -1338,12 +1349,15 @@ class EngineService:
|
||||
if self.cfg.scheduler_config.name != "local":
|
||||
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
|
||||
|
||||
self.llm_logger.info("Start to pause request generation.")
|
||||
|
||||
with self._pause_cond:
|
||||
if self.is_paused:
|
||||
self.llm_logger.info("Pause Request Generation: already paused.")
|
||||
self.llm_logger.info("Engine is already paused, no need to pause again.")
|
||||
return
|
||||
self.is_paused = True
|
||||
|
||||
self.llm_logger.info("Start Abort Running Requests")
|
||||
self.llm_logger.info("Abort running requests.")
|
||||
|
||||
self.resource_manager.log_status()
|
||||
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
|
||||
@@ -1354,7 +1368,7 @@ class EngineService:
|
||||
if count >= timeout * 1000:
|
||||
break
|
||||
if count >= timeout * 1000:
|
||||
error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged"
|
||||
error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!"
|
||||
self.llm_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
running_reqs = self.resource_manager.preempted_all()
|
||||
@@ -1369,12 +1383,22 @@ class EngineService:
|
||||
|
||||
# abort inflight requests to user
|
||||
inflight_requests = self.scheduler.get_inflight_requests()
|
||||
self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests")
|
||||
self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).")
|
||||
for req in inflight_requests:
|
||||
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
|
||||
self._send_error_response(req.request_id, "Request is aborted since engine is paused.")
|
||||
self.scheduler.reset()
|
||||
|
||||
# pause cache transfer
|
||||
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
||||
self.llm_logger.info("Start to pause cache transfer.")
|
||||
pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause")
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
|
||||
# Wait for cache_transfer responses
|
||||
asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"]))
|
||||
self.llm_logger.info("Successfully paused cache transfer.")
|
||||
|
||||
self.resource_manager.cache_manager.reset()
|
||||
self.llm_logger.info("Successfully paused request generation.")
|
||||
return None
|
||||
|
||||
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
|
||||
@@ -1386,14 +1410,24 @@ class EngineService:
|
||||
Args:
|
||||
control_request: Control request object containing resume operation information
|
||||
"""
|
||||
self.llm_logger.info("START Resume Request Generation")
|
||||
self.llm_logger.info("Start to resume request generation.")
|
||||
with self._pause_cond:
|
||||
if not self.is_paused:
|
||||
self.llm_logger.info("Resume Request Generation: not paused.")
|
||||
self.llm_logger.info("Engine is not paused, no need to resume.")
|
||||
return None
|
||||
self.is_paused = False
|
||||
self._pause_cond.notify_all()
|
||||
self.llm_logger.info("END Resume Request Generation")
|
||||
|
||||
# resume cache transfer
|
||||
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
||||
self.llm_logger.info("Start to resume cache transfer.")
|
||||
resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume")
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request))
|
||||
# Wait for cache_transfer responses
|
||||
asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"]))
|
||||
self.llm_logger.info("Successfully resumed cache transfer.")
|
||||
|
||||
self.llm_logger.info("Successfully resumed request generation.")
|
||||
return None
|
||||
|
||||
def _control_is_paused(self, control_request: ControlRequest) -> bool:
|
||||
@@ -1447,49 +1481,218 @@ class EngineService:
|
||||
|
||||
return responses
|
||||
|
||||
async def _wait_all_control_responses(self, request_id: str, timeout: int):
|
||||
"""Wait for control responses from all workers with a global timeout.
|
||||
|
||||
This method concurrently waits for responses from all control workers
|
||||
and enforces an overall timeout to avoid leaking pending tasks.
|
||||
def _parse_tags(self, control_request: ControlRequest):
|
||||
"""
|
||||
timeout_ms = timeout * 1000
|
||||
# Create one get() coroutine per worker output queue
|
||||
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=timeout,
|
||||
Parse tags from control request.
|
||||
"""
|
||||
allowed_tags = ["weight", "kv_cache"]
|
||||
tags = control_request.args.get("tags", None)
|
||||
if tags is None:
|
||||
tags = ",".join(allowed_tags)
|
||||
control_request.args["tags"] = tags
|
||||
self.llm_logger.info(
|
||||
f"Detected empty tags of request {control_request.request_id}, defaulting to tags: {tags}"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Keep the error message consistent with previous behavior
|
||||
raise Exception("Worker Update Weights Timeouted after 600s")
|
||||
elif isinstance(tags, list):
|
||||
tags = ",".join(tags)
|
||||
|
||||
for tag in tags.split(","):
|
||||
if tag not in allowed_tags:
|
||||
raise ValueError(f"Unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}")
|
||||
|
||||
return tags
|
||||
|
||||
def _control_sleep(self, control_request: ControlRequest):
|
||||
"""
|
||||
Offload gpu memory occupation for certain parts, e.g. weight, cache.
|
||||
|
||||
Args:
|
||||
control_request: Control request object containing parameters for offloading memory
|
||||
tags: list of tags to offload, supported values: ["weight", "cache"]
|
||||
|
||||
TODO: support different level of offloading, to provide options for release memory forever
|
||||
or merely offloading to cpu memory for now.
|
||||
"""
|
||||
# Args check
|
||||
tags = self._parse_tags(control_request)
|
||||
control_request.args["tags"] = tags
|
||||
|
||||
# Make sure llm engine is paused.
|
||||
self.llm_logger.warning(
|
||||
"Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. "
|
||||
"Please explicitly request to /pause the engine before /sleep."
|
||||
)
|
||||
self._control_pause(None)
|
||||
|
||||
# Determine which executors are needed for the sleep command
|
||||
executors = set()
|
||||
if "weight" in tags:
|
||||
executors.add("worker")
|
||||
if "kv_cache" in tags:
|
||||
executors.add("worker")
|
||||
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
||||
executors.add("cache_transfer")
|
||||
if self.cfg.cache_config.enable_prefix_caching:
|
||||
self.resource_manager.cache_manager.reset()
|
||||
|
||||
# Dispatch sleep request to executors
|
||||
self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}")
|
||||
self._dispatch_control_request(control_request, executors)
|
||||
return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors))
|
||||
|
||||
def _control_wakeup(self, control_request: ControlRequest):
|
||||
"""
|
||||
Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache.
|
||||
|
||||
Args:
|
||||
control_request: Control request object containing parameters for reloading memory
|
||||
tags: list of tags to reload, supported values: ["weight", "kv_cache"]
|
||||
"""
|
||||
# Args check
|
||||
tags = self._parse_tags(control_request)
|
||||
control_request.args["tags"] = tags
|
||||
|
||||
# Determine which executors are needed for the wakeup command
|
||||
executors = set()
|
||||
if "weight" in tags:
|
||||
executors.add("worker")
|
||||
if "kv_cache" in tags:
|
||||
executors.add("worker")
|
||||
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
||||
executors.add("cache_transfer")
|
||||
|
||||
# Dispatch wakeup request to executors
|
||||
self.llm_logger.info(f"Dispatch wakeup request to executors: {list(executors)}")
|
||||
self._dispatch_control_request(control_request, executors)
|
||||
result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 300, executors=executors))
|
||||
|
||||
# Resume the engine after wakeup
|
||||
self._control_resume(None)
|
||||
|
||||
return result
|
||||
|
||||
def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]):
|
||||
"""
|
||||
Dispatch control requests to workers, cache managers or engine itself.
|
||||
|
||||
Args:
|
||||
control_request: ControlRequest
|
||||
executors: List
|
||||
"""
|
||||
if "worker" in executors:
|
||||
self.engine_worker_queue.put_tasks(([control_request], 1))
|
||||
if "cache_transfer" in executors:
|
||||
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request))
|
||||
return
|
||||
|
||||
async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None):
|
||||
"""Wait for matching control responses from the selected executor queues.
|
||||
|
||||
This helper selects the control-response queues that belong to the requested
|
||||
executors, then waits for all of them concurrently. Each queue gets a local
|
||||
waiter that keeps reading until it sees the target request ID and stashes stale
|
||||
responses into that queue's mailbox.
|
||||
|
||||
Args:
|
||||
request_id: The control request ID that all returned responses must match.
|
||||
timeout: Global timeout budget in seconds for the full multi-queue wait.
|
||||
executors: Executor groups to wait for, for example `["worker"]` or
|
||||
`["worker", "cache_transfer"]`. If `None`, waits for all control
|
||||
response queues.
|
||||
|
||||
Returns:
|
||||
A list of `response.result` values collected from all matched
|
||||
`ControlResponse` objects. If no queue is selected, returns `None`.
|
||||
|
||||
Raises:
|
||||
Exception: If the overall wait times out, or if any queue reports a non-200
|
||||
control response or fails while waiting.
|
||||
"""
|
||||
|
||||
def select_control_queues(executors: List[str] = None):
|
||||
"""Select control response queues by executors."""
|
||||
if executors is None:
|
||||
return self._ctrl_output_queues
|
||||
else:
|
||||
queues = {}
|
||||
for k, v in self._ctrl_output_queues.items():
|
||||
if "w2e" in k and "worker" in executors:
|
||||
queues[k] = v
|
||||
elif "c2e" in k and "cache_transfer" in executors:
|
||||
queues[k] = v
|
||||
return queues
|
||||
|
||||
async def wait_one(queue_name: str, queue):
|
||||
"""Wait until one queue returns a response for the current request_id."""
|
||||
mailbox = self._ctrl_response_mailboxes[queue_name]
|
||||
# Reuse a previously stashed response for this request before touching FMQ again.
|
||||
cached_response = mailbox.pop(request_id, None)
|
||||
if cached_response is not None:
|
||||
self.llm_logger.info(f"Returning cached control response from {queue_name}.")
|
||||
return cached_response
|
||||
|
||||
while True:
|
||||
msg = await queue.get()
|
||||
|
||||
# Return if the response matches the control request
|
||||
response: ControlResponse = msg.payload
|
||||
if response.request_id == request_id:
|
||||
self.llm_logger.info(f"Returning new control response from {queue_name}.")
|
||||
return response
|
||||
|
||||
# Stash late responses from other control requests so they do not consume the
|
||||
# current request's only read chance on this queue.
|
||||
mailbox[response.request_id] = response
|
||||
self.llm_logger.info(
|
||||
f"Stashed old control response from {queue_name}. "
|
||||
f"Expected request {request_id}, got request {response.request_id}"
|
||||
)
|
||||
|
||||
# Select only the control response queues that belong to the requested executors.
|
||||
queues = select_control_queues(executors)
|
||||
if not queues:
|
||||
self.llm_logger.info(f"No queues to wait for, executors: {executors}")
|
||||
return
|
||||
self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}")
|
||||
|
||||
# Each queue gets its own waiter, which will stash stale responses until it finds the
|
||||
# target request ID for this control request.
|
||||
tasks = {name: asyncio.create_task(wait_one(name, queue)) for name, queue in queues.items()}
|
||||
done, pending = await asyncio.wait(tasks.values(), timeout=timeout)
|
||||
if pending:
|
||||
pending_names = [name for name, task in tasks.items() if task in pending]
|
||||
done_names = [name for name, task in tasks.items() if task in done]
|
||||
self.llm_logger.error(
|
||||
f"Control request {request_id} execution timeout. "
|
||||
f"Pending queues: {pending_names}, completed queues: {done_names}."
|
||||
)
|
||||
# Stop unfinished queue waiters so they do not outlive the control request.
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
raise Exception(f"Control request {request_id} timed out after {timeout}s")
|
||||
|
||||
# Collect the results from all completed queues.
|
||||
responses = []
|
||||
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
|
||||
if isinstance(msg, Exception):
|
||||
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
|
||||
raise Exception(f"Call Worker error: {repr(msg)}")
|
||||
if msg is None:
|
||||
# Preserve original semantics when no message is received
|
||||
raise Exception("Worker Update Weights Timeouted after 600s")
|
||||
response: ControlResponse = msg.payload
|
||||
if response.request_id != request_id:
|
||||
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
|
||||
continue
|
||||
for name, task in tasks.items():
|
||||
try:
|
||||
response = task.result()
|
||||
except Exception as e:
|
||||
self.llm_logger.error(f"Waiting for control response from {name} failed: {repr(e)}")
|
||||
raise
|
||||
|
||||
if response.error_code != 200:
|
||||
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
|
||||
raise Exception(f"Call Worker error: {response.error_message}")
|
||||
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
|
||||
raise Exception(f"Error response from {name}: {response.error_message}")
|
||||
responses.append(response.result)
|
||||
|
||||
return responses
|
||||
|
||||
def _call_worker(self, control_request: ControlRequest, timeout: int):
|
||||
request_id = control_request.request_id
|
||||
self.engine_worker_queue.put_tasks(([control_request], 1))
|
||||
# Use a single asyncio.run() to concurrently wait for all worker responses.
|
||||
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
|
||||
return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"]))
|
||||
|
||||
def _send_error_response(self, request_id, error_msg, error_code: int = 500, worker_pid=None):
|
||||
self.llm_logger.error(
|
||||
|
||||
@@ -595,11 +595,14 @@ class EngineClient:
|
||||
return True, ""
|
||||
|
||||
async def run_control_method(self, request: ControlRequest):
|
||||
api_server_logger.info(f"Start Run Control Method: {request}")
|
||||
api_server_logger.info(f"Received control request: {request}")
|
||||
req_dict = request.to_dict()
|
||||
if envs.ZMQ_SEND_BATCH_DATA:
|
||||
req_dict["zmq_worker_pid"] = self.worker_pid
|
||||
self.zmq_client.send_json(req_dict)
|
||||
if not self.enable_mm and not envs.ENABLE_V1_DATA_PROCESSOR:
|
||||
self.zmq_client.send_json(req_dict)
|
||||
else:
|
||||
self.zmq_client.send_pyobj(req_dict)
|
||||
request_id = request.request_id
|
||||
dealer, response_queue = await self.connection_manager.get_connection(request_id)
|
||||
if not envs.ZMQ_SEND_BATCH_DATA:
|
||||
@@ -608,12 +611,29 @@ class EngineClient:
|
||||
# todo: support user specified timeout. default 600s is enough for most control cases
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=600)
|
||||
response = ControlResponse.from_dict(response[0])
|
||||
api_server_logger.info(f"End Run Control Method: {response}")
|
||||
api_server_logger.info(f"Return control response: {response}")
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
error_response = ControlResponse(request_id, 500, "Timeout waiting for control method response")
|
||||
api_server_logger.error(f"Error Run Control Method: {error_response}")
|
||||
api_server_logger.error(f"Control request timed out: {error_response}")
|
||||
return error_response
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
api_server_logger.error(f"Unknown error in control method: {str(e)}\n{traceback.format_exc()}")
|
||||
error_response = ControlResponse(request_id, 500, str(e))
|
||||
return error_response
|
||||
|
||||
def run_control_method_sync(self, request: ControlRequest, event_loop):
|
||||
"""
|
||||
Support running control methods by a synchronous caller.
|
||||
|
||||
NOTE: Since asyncio.Queue operations must occur in the same event loop,
|
||||
this method bridges synchronous and asynchronous execution by running
|
||||
the async run_control_method in the specified event loop.
|
||||
"""
|
||||
future = asyncio.run_coroutine_threadsafe(self.run_control_method(request), event_loop)
|
||||
return future.result()
|
||||
|
||||
def is_workers_alive(self):
|
||||
"""
|
||||
|
||||
@@ -284,6 +284,7 @@ async def lifespan(app: FastAPI):
|
||||
app.state.completion_handler = completion_handler
|
||||
app.state.embedding_handler = embedding_handler
|
||||
app.state.reward_handler = reward_handler
|
||||
app.state.event_loop = asyncio.get_running_loop()
|
||||
|
||||
if llm_engine is not None and not isinstance(llm_engine, AsyncLLM):
|
||||
llm_engine.engine.data_processor = engine_client.data_processor
|
||||
@@ -406,6 +407,44 @@ async def is_paused(request: Request) -> Response:
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.post("/v1/sleep")
|
||||
async def sleep(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
# Support both JSON body and query parameter
|
||||
if await request.body():
|
||||
request_data = await request.json()
|
||||
else:
|
||||
# Extract query params
|
||||
request_data = dict(request.query_params)
|
||||
|
||||
try:
|
||||
control_request = ControlRequest(request_id, "sleep", request_data)
|
||||
except TypeError as e:
|
||||
return JSONResponse(status_code=400, content={"error": "Invalid parameter type", "message": str(e)})
|
||||
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.post("/v1/wakeup")
|
||||
async def wakeup(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
# Support both JSON body and query parameter
|
||||
if await request.body():
|
||||
request_data = await request.json()
|
||||
else:
|
||||
# Extract query params
|
||||
request_data = dict(request.query_params)
|
||||
|
||||
try:
|
||||
control_request = ControlRequest(request_id, "wakeup", request_data)
|
||||
except TypeError as e:
|
||||
return JSONResponse(status_code=400, content={"error": "Invalid parameter type", "message": str(e)})
|
||||
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.post("/v1/update_weights")
|
||||
async def update_weights(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
@@ -606,8 +645,14 @@ def update_model_weight(request: Request) -> Response:
|
||||
update model weight
|
||||
"""
|
||||
if app.state.dynamic_load_weight:
|
||||
status_code, msg = app.state.engine_client.update_model_weight()
|
||||
return JSONResponse(content=msg, status_code=status_code)
|
||||
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
control_request = ControlRequest(request_id, "wakeup")
|
||||
control_response = app.state.engine_client.run_control_method_sync(control_request, app.state.event_loop)
|
||||
return control_response.to_api_json_response()
|
||||
else:
|
||||
status_code, msg = app.state.engine_client.update_model_weight()
|
||||
return JSONResponse(content=msg, status_code=status_code)
|
||||
else:
|
||||
return JSONResponse(content={"error": "Dynamic Load Weight Disabled."}, status_code=404)
|
||||
|
||||
@@ -619,8 +664,14 @@ def clear_load_weight(request: Request) -> Response:
|
||||
clear model weight
|
||||
"""
|
||||
if app.state.dynamic_load_weight:
|
||||
status_code, msg = app.state.engine_client.clear_load_weight()
|
||||
return JSONResponse(content=msg, status_code=status_code)
|
||||
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
control_request = ControlRequest(request_id, "sleep")
|
||||
control_response = app.state.engine_client.run_control_method_sync(control_request, app.state.event_loop)
|
||||
return control_response.to_api_json_response()
|
||||
else:
|
||||
status_code, msg = app.state.engine_client.clear_load_weight()
|
||||
return JSONResponse(content=msg, status_code=status_code)
|
||||
else:
|
||||
return JSONResponse(content={"error": "Dynamic Load Weight Disabled."}, status_code=404)
|
||||
|
||||
|
||||
@@ -122,12 +122,21 @@ class ChatResponseProcessor:
|
||||
else:
|
||||
self._audio_buffer[req_id] = [token_ids]
|
||||
else:
|
||||
yield self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
request=request,
|
||||
)
|
||||
if self._is_async_processor:
|
||||
response = await self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
request=request,
|
||||
)
|
||||
else:
|
||||
response = self.data_processor.process_response_dict(
|
||||
response_dict=request_output,
|
||||
stream=stream,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
request=request,
|
||||
)
|
||||
yield response
|
||||
elif stream:
|
||||
decode_type = request_output["outputs"].get("decode_type", 0)
|
||||
token_ids = request_output["outputs"]["token_ids"]
|
||||
|
||||
@@ -184,9 +184,14 @@ class DealerConnectionManager:
|
||||
self.request_num[request_id] -= 1
|
||||
if self.request_num[request_id] == 0:
|
||||
self._update_load(conn_index, -1)
|
||||
else:
|
||||
api_server_logger.warning(
|
||||
f"request_id {request_id} not in request_map, available keys: {list(self.request_map.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Listener error: {str(e)}")
|
||||
api_server_logger.error(f"Listener error: {str(e)}\n{traceback.format_exc()}")
|
||||
break
|
||||
api_server_logger.info(f"Listener loop ended for conn_index {conn_index}")
|
||||
|
||||
def _update_load(self, conn_index, delta):
|
||||
"""Update connection load and maintain the heap"""
|
||||
|
||||
@@ -249,6 +249,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_MOE_PROB_IN_ADVANCE": lambda: bool(int(os.getenv("FD_MOE_PROB_IN_ADVANCE", "0"))),
|
||||
# Whether to use batch send data in zmq
|
||||
"ZMQ_SEND_BATCH_DATA": lambda: int(os.getenv("ZMQ_SEND_BATCH_DATA", "1")),
|
||||
# Whether to enable v1 weight updating, which utilizes ZMQ/EngineWorkerQueue/EngineCacheQueue/FMQs
|
||||
# to pass control requests and responses.
|
||||
# When v1 is enabled, the legacy /clear_load_weight and /update_model_weight
|
||||
# will adopt this new communication pattern.
|
||||
"FD_ENABLE_V1_UPDATE_WEIGHTS": lambda: bool(int(os.getenv("FD_ENABLE_V1_UPDATE_WEIGHTS", "0"))),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -58,6 +58,9 @@ class EngineCacheQueue:
|
||||
client_id: Unique identifier for client instances
|
||||
local_data_parallel_size: data parallel size
|
||||
local_data_parallel_id: local data parallel id
|
||||
|
||||
TODO(liyonghua): Remove multi-DP initialization. Each DP will have its own cache queue.
|
||||
|
||||
"""
|
||||
self.address: Tuple[str, int] = address
|
||||
self.authkey: bytes = authkey
|
||||
@@ -87,6 +90,7 @@ class EngineCacheQueue:
|
||||
]
|
||||
|
||||
# Initialize barriers
|
||||
self.barrier = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
|
||||
self.barrier0_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
|
||||
self.barrier1_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
|
||||
self.barrier2_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
|
||||
@@ -142,6 +146,7 @@ class EngineCacheQueue:
|
||||
callable=lambda idx: self.transfer_task_done_lock_init[idx],
|
||||
proxytype=AcquirerProxy,
|
||||
)
|
||||
QueueManager.register("get_barrier", callable=lambda idx: self.barrier[idx])
|
||||
QueueManager.register("get_barrier0", callable=lambda idx: self.barrier0_init[idx])
|
||||
QueueManager.register("get_barrier1", callable=lambda idx: self.barrier1_init[idx])
|
||||
QueueManager.register("get_barrier2", callable=lambda idx: self.barrier2_init[idx])
|
||||
@@ -191,6 +196,7 @@ class EngineCacheQueue:
|
||||
QueueManager.register("get_cache_sync_value")
|
||||
QueueManager.register("get_transfer_task_lock")
|
||||
QueueManager.register("get_transfer_task_done_lock")
|
||||
QueueManager.register("get_barrier")
|
||||
QueueManager.register("get_barrier0")
|
||||
QueueManager.register("get_barrier1")
|
||||
QueueManager.register("get_barrier2")
|
||||
@@ -215,6 +221,7 @@ class EngineCacheQueue:
|
||||
self.task_done_lock = self.manager.get_transfer_task_done_lock(self.local_data_parallel_id)
|
||||
|
||||
# Get barrier proxies
|
||||
self.barrier = self.manager.get_barrier(self.local_data_parallel_id)
|
||||
self.barrier0 = self.manager.get_barrier0(self.local_data_parallel_id)
|
||||
self.barrier1 = self.manager.get_barrier1(self.local_data_parallel_id)
|
||||
self.barrier2 = self.manager.get_barrier2(self.local_data_parallel_id)
|
||||
@@ -264,7 +271,12 @@ class EngineCacheQueue:
|
||||
|
||||
def put_transfer_task(self, item):
|
||||
"""
|
||||
put swap task
|
||||
Enqueue a cache transfer task (cpu/gpu swap task, read/write storage task)
|
||||
or a control task (cache clearing/restoring).
|
||||
|
||||
The queue is shared by multiple clients. A task can be enqueued only after
|
||||
the previous task has been read by all clients.
|
||||
`task_sync_value` is used as a bitmask to track per-client read status.
|
||||
"""
|
||||
self.task_lock.acquire()
|
||||
if 0 < self.task_sync_value.get() < self.total_num:
|
||||
@@ -279,7 +291,11 @@ class EngineCacheQueue:
|
||||
|
||||
def get_transfer_task(self):
|
||||
"""
|
||||
get swap task
|
||||
Get the current cache transfer task (cpu/gpu swap task, read/write storage task)
|
||||
or control signal (cache clearing/restoring) from cache task queue.
|
||||
|
||||
Each client reads the same task once. The task is removed from the queue
|
||||
only after all clients have read it, tracked by `task_sync_value`.
|
||||
"""
|
||||
data = None
|
||||
read_finish = False
|
||||
|
||||
@@ -283,6 +283,9 @@ class ZmqServerBase(ABC):
|
||||
has_result_handle = False
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
llm_logger.warning(
|
||||
f"req_id '{req_id}' not in req_dict, caching response. Available req_ids: {list(self.req_dict.keys())}"
|
||||
)
|
||||
self.cached_results[req_id].append(data)
|
||||
else:
|
||||
has_result_handle = True
|
||||
|
||||
@@ -201,6 +201,39 @@ class DynamicWeightManager:
|
||||
# step5: recapture cuda_graph
|
||||
# step6: update weight status signal
|
||||
|
||||
def restart_communication_group(self):
|
||||
if not self.first_load:
|
||||
start_time = time.perf_counter()
|
||||
paddle.distributed.restart_process_group()
|
||||
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
|
||||
logger.info(f"finish restarting communication groups! time cost: {time.perf_counter()-start_time:.3f}s")
|
||||
|
||||
def recreate_deepep_buffer(self):
|
||||
if not self.first_load:
|
||||
start_time = time.perf_counter()
|
||||
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
|
||||
|
||||
DeepEPBufferManager.recreate_buffer()
|
||||
# ep barrier
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
logger.info(f"finish recreating deepep buffer! time cost: {time.perf_counter()-start_time:.3f}s")
|
||||
|
||||
def reload_model_weights(self):
|
||||
if not self.first_load:
|
||||
start_time = time.perf_counter()
|
||||
strategy_handlers = {
|
||||
"ipc_snapshot": self._update_ipc_snapshot,
|
||||
"ipc": self._update_ipc,
|
||||
}
|
||||
|
||||
if handler := strategy_handlers.get(self.load_config.load_strategy):
|
||||
handler()
|
||||
else:
|
||||
raise ValueError(f"Unsupported strategy: {self.load_config.load_strategy}")
|
||||
logger.info(f"finish reload model weights! time cost: {time.perf_counter()-start_time:.3f}s")
|
||||
|
||||
def _update_ipc_snapshot(self):
|
||||
"""Update using IPC snapshot strategy for elastic recovery.
|
||||
|
||||
@@ -329,6 +362,30 @@ class DynamicWeightManager:
|
||||
paddle.distributed.shutdown_process_group()
|
||||
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
|
||||
|
||||
def clear_deepep_buffer(self):
|
||||
start_time = time.perf_counter()
|
||||
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
|
||||
|
||||
DeepEPBufferManager.clear_buffer()
|
||||
logger.info(f"finish clearing deepep buffer! time cost: {time.perf_counter()-start_time:.3f}s")
|
||||
|
||||
def clear_model_weight(self):
|
||||
start_time = time.perf_counter()
|
||||
for model in self.model_list:
|
||||
for param in model.state_dict().values():
|
||||
param._clear_data()
|
||||
logger.info(f"finish clearing model weight! time cost: {time.perf_counter()-start_time:.3f}s")
|
||||
|
||||
def clear_communication_group(self):
|
||||
start_time = time.perf_counter()
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
|
||||
logger.info(f"finish clearing communication groups! time cost: {time.perf_counter()-start_time:.3f}s")
|
||||
|
||||
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
|
||||
"""Update model parameters from given state dictionary."""
|
||||
if len(state_dict) == 0:
|
||||
|
||||
+15
-3
@@ -702,12 +702,15 @@ def singleton(cls):
|
||||
return get_instance
|
||||
|
||||
|
||||
def print_gpu_memory_use(gpu_id: int, title: str) -> None:
|
||||
def print_gpu_memory_use(title: str, gpu_id: int, device_id: int | None = None) -> None:
|
||||
"""Print memory usage"""
|
||||
import pynvml
|
||||
|
||||
if device_id is None:
|
||||
device_id = gpu_id
|
||||
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
@@ -724,7 +727,7 @@ def print_gpu_memory_use(gpu_id: int, title: str) -> None:
|
||||
f"\n\tPaddle max memory Reserved(GiB): {paddle_max_reserved / 1024.0 / 1024.0 / 1024.0}",
|
||||
f"\n\tPaddle max memory Allocated(GiB): {paddle_max_allocated / 1024.0 / 1024.0 / 1024.0}",
|
||||
f"\n\tPaddle memory Reserved(GiB): {paddle_reserved / 1024.0 / 1024.0 / 1024.0}",
|
||||
f"\n\tPaddle memory Allocated(GiB): {paddle_allocated / 1024.0 / 1024.0 / 1024.0}",
|
||||
f"\n\tPaddle memory Allocated(GiB): {paddle_allocated / 1024.0 / 1024.0 / 1024.0}\n",
|
||||
)
|
||||
|
||||
|
||||
@@ -1326,3 +1329,12 @@ else:
|
||||
register_op = do_nothing
|
||||
|
||||
register_custom_python_op = register_op
|
||||
|
||||
|
||||
def all_gather_values(value: int | float | bool, group: paddle.distributed.communication.group.Group) -> list:
|
||||
_type = type(value)
|
||||
_local = paddle.to_tensor([value], dtype="float32")
|
||||
_global = [paddle.zeros_like(_local) for _ in range(group.world_size)]
|
||||
paddle.distributed.all_gather(_global, _local, group)
|
||||
_results = [_type(t.item()) for t in _global]
|
||||
return _results
|
||||
|
||||
@@ -54,6 +54,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
from fastdeploy.utils import print_gpu_memory_use
|
||||
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
|
||||
|
||||
if current_platform.is_iluvatar():
|
||||
@@ -142,6 +143,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.cache_kvs_map: dict = {}
|
||||
self.exist_prefill_flag = False
|
||||
|
||||
self.is_kvcache_sleeping = False
|
||||
self.is_weight_sleeping = False
|
||||
|
||||
if self.speculative_decoding:
|
||||
self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory()
|
||||
self.output_token_num_event = paddle.device.cuda.Event()
|
||||
@@ -288,6 +292,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
return self.exist_prefill_flag
|
||||
|
||||
@property
|
||||
def is_sleeping(self):
|
||||
return self.is_weight_sleeping or self.is_kvcache_sleeping
|
||||
|
||||
def exist_decode(self):
|
||||
"""
|
||||
check whether decode stage exist
|
||||
@@ -2673,6 +2681,83 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
|
||||
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
|
||||
|
||||
def sleep(self, tags):
|
||||
|
||||
logger.info(f">>> start offloading memory, tags: {tags}")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Clear weights, deepep_buffer, cudagraph, etc.
|
||||
if "weight" in tags.split(","):
|
||||
if self.is_weight_sleeping:
|
||||
logger.info("GPU model runner's weight is already sleeping, no need to sleep again!")
|
||||
return
|
||||
if self.use_cudagraph:
|
||||
self.model.clear_grpah_opt_backend()
|
||||
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||
self.dynamic_weight_manager.clear_deepep_buffer()
|
||||
self.dynamic_weight_manager.clear_model_weight()
|
||||
if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
|
||||
self.dynamic_weight_manager.clear_communication_group()
|
||||
self.is_weight_sleeping = True
|
||||
|
||||
# Clear KV cache
|
||||
if "kv_cache" in tags.split(","):
|
||||
if self.is_kvcache_sleeping:
|
||||
logger.info("GPU model runner's kv cache is already sleeping, no need to sleep again!")
|
||||
return
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.clear_mtp_cache()
|
||||
self.clear_cache()
|
||||
self.is_kvcache_sleeping = True
|
||||
|
||||
paddle.device.cuda.empty_cache()
|
||||
logger.info(f"<<< finish offloading memory! time cost: {time.perf_counter()-start_time:.3f}s")
|
||||
print_gpu_memory_use(f"After offloading memory [{tags}]", self.local_rank, self.device_id)
|
||||
|
||||
def wakeup(self, tags):
|
||||
|
||||
if tags == "weight" and self.use_cudagraph and self.is_kvcache_sleeping:
|
||||
raise RuntimeError(
|
||||
"Waking up [weight] alone is not supported when CUDA Graph is enabled, "
|
||||
"as recapturing the graph requires the KV cache to be rebuilt first. "
|
||||
"Please wake up [kv_cache] first."
|
||||
)
|
||||
|
||||
logger.info(f">>> start reloading memory, tags: {tags}")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Reset share_inputs to restore tensor shapes and values
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.model_inputs.reset_model_inputs()
|
||||
self.share_inputs.reset_share_inputs()
|
||||
|
||||
# Reinitialize KV cache
|
||||
if "kv_cache" in tags.split(","):
|
||||
if not self.is_kvcache_sleeping:
|
||||
logger.info("GPU model runner's kv cache is not sleeping, no need to wakeup!")
|
||||
return
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
|
||||
self.initialize_kv_cache()
|
||||
self.is_kvcache_sleeping = False
|
||||
|
||||
# Reload weights, deepep_buffer, cudagraph, etc.
|
||||
if "weight" in tags.split(","):
|
||||
if not self.is_weight_sleeping:
|
||||
logger.info("GPU model runner's weight is not sleeping, no need to wakeup!")
|
||||
return
|
||||
if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
|
||||
self.dynamic_weight_manager.restart_communication_group()
|
||||
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||
self.dynamic_weight_manager.recreate_deepep_buffer()
|
||||
self.dynamic_weight_manager.reload_model_weights()
|
||||
if self.use_cudagraph:
|
||||
self.capture_model()
|
||||
self.is_weight_sleeping = False
|
||||
|
||||
logger.info(f"<<< finish reloading memory! time cost: {time.perf_counter()-start_time:.3f}s")
|
||||
print_gpu_memory_use(f"After reloading memory [{tags}]", self.local_rank, self.device_id)
|
||||
|
||||
def padding_cudagraph_inputs(self) -> None:
|
||||
"""
|
||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||
|
||||
@@ -196,6 +196,14 @@ class GpuWorker(WorkerBase):
|
||||
"""update weights in place"""
|
||||
return self.model_runner.update_weights(version, rsync_config)
|
||||
|
||||
def sleep(self, **kwargs) -> None:
|
||||
"""Offload memory from GPU"""
|
||||
return self.model_runner.sleep(**kwargs)
|
||||
|
||||
def wakeup(self, **kwargs) -> None:
|
||||
"""Reload memory into GPU"""
|
||||
return self.model_runner.wakeup(**kwargs)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_forward_batch: Optional[List[Request]] = None,
|
||||
|
||||
@@ -69,7 +69,7 @@ from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
||||
from fastdeploy.model_executor.utils import v1_loader_support
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.utils import get_logger, optional_type
|
||||
from fastdeploy.utils import all_gather_values, get_logger, optional_type
|
||||
from fastdeploy.worker.worker_base import WorkerBase
|
||||
|
||||
logger = get_logger("worker_process", "worker_process.log")
|
||||
@@ -172,6 +172,7 @@ class PaddleDisWorkerProc:
|
||||
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
self.enable_overlap_schedule = self.scheduler_config.enable_overlap_schedule
|
||||
self.cached_control_reqs = []
|
||||
|
||||
def init_control(self):
|
||||
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
|
||||
@@ -482,7 +483,7 @@ class PaddleDisWorkerProc:
|
||||
# run eplb
|
||||
self._run_eplb(tp_rank)
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.ranks > 1:
|
||||
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
|
||||
@@ -504,7 +505,7 @@ class PaddleDisWorkerProc:
|
||||
# Synchronize the signal set by tp_rank0 visiable to other workers
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
||||
if self.ranks > 1:
|
||||
paddle.distributed.barrier()
|
||||
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
|
||||
@@ -581,22 +582,34 @@ class PaddleDisWorkerProc:
|
||||
if len(control_reqs) > 0:
|
||||
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
|
||||
for control_req in control_reqs:
|
||||
self.run_control_method(control_req)
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
if self.parallel_config.use_ep:
|
||||
self.cached_control_reqs.append(control_req)
|
||||
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
|
||||
else:
|
||||
self.run_control_method(control_req)
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
|
||||
# Count prefill requests in current batch
|
||||
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
|
||||
num_scheduled_requests = len(req_dicts)
|
||||
scheduled_request_ids = [req.request_id for req in req_dicts]
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
|
||||
f"max_occupied_batch_index: {max_occupied_batch_index}, "
|
||||
f"num_scheduled_requests: {num_scheduled_requests}, "
|
||||
f"scheduled_request_ids: {scheduled_request_ids}"
|
||||
)
|
||||
if len(req_dicts) > 0:
|
||||
# Count prefill requests in current batch
|
||||
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
|
||||
num_scheduled_requests = len(req_dicts)
|
||||
scheduled_request_ids = [req.request_id for req in req_dicts]
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
|
||||
f"max_occupied_batch_index: {max_occupied_batch_index}, "
|
||||
f"num_scheduled_requests: {num_scheduled_requests}, "
|
||||
f"scheduled_request_ids: {scheduled_request_ids}"
|
||||
)
|
||||
|
||||
# Process prefill inputs
|
||||
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
|
||||
# Process prefill inputs
|
||||
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
|
||||
|
||||
# Let the ep group run control method synchronically
|
||||
if self.parallel_config.use_ep:
|
||||
pendings = all_gather_values(len(self.cached_control_reqs), self.parallel_config.ep_group)
|
||||
if all([p > 0 for p in pendings]):
|
||||
logger.info(f"Rank: {self.local_rank} Detected all ep ranks have pending control tasks.")
|
||||
self.run_control_method(self.cached_control_reqs.pop(0))
|
||||
|
||||
if (
|
||||
not self.parallel_config.use_ep
|
||||
@@ -607,6 +620,15 @@ class PaddleDisWorkerProc:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
# Check if worker is paused (V1 update weights flow)
|
||||
if (
|
||||
self.fd_config.load_config.dynamic_load_weight
|
||||
and hasattr(self.worker.model_runner, "is_sleeping")
|
||||
and self.worker.model_runner.is_sleeping
|
||||
):
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
continue
|
||||
|
||||
# Execute model to generate token. The generated token will be written to the buffer.
|
||||
# These generated tokens can be obtained through get_output op.
|
||||
start_execute_time = time.time()
|
||||
@@ -737,14 +759,14 @@ class PaddleDisWorkerProc:
|
||||
self.loaded_model_signal.value[0] = 1
|
||||
|
||||
def run_control_method(self, control_request: ControlRequest) -> None:
|
||||
logger.info(f"Start run control request: {control_request}")
|
||||
logger.info(f"Rank: {self.local_rank} Start to run control request: {control_request}")
|
||||
request_id = control_request.request_id
|
||||
method = control_request.method
|
||||
kwargs = control_request.args
|
||||
|
||||
handler = getattr(self.worker, method, None)
|
||||
if handler is None or not callable(handler):
|
||||
error_msg = f"Rank-{self.local_rank}: Unknown control method {method}"
|
||||
error_msg = f"Rank: {self.local_rank} Unknown control method {method}"
|
||||
error_result = ControlResponse(request_id, 400, error_msg)
|
||||
asyncio.run(self._ctrl_output.put(error_result))
|
||||
return
|
||||
@@ -753,11 +775,11 @@ class PaddleDisWorkerProc:
|
||||
result = handler(**kwargs)
|
||||
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||
logger.info(
|
||||
f"Rank-{self.local_rank} Success run control request: {control_request}, response: {succ_result}"
|
||||
f"Rank: {self.local_rank} Successfully run control request: {control_request}, response: {succ_result}"
|
||||
)
|
||||
asyncio.run(self._ctrl_output.put(succ_result, shm_threshold=100 * 1024 * 1024))
|
||||
except Exception as e:
|
||||
error_msg = f"Rank-{self.local_rank} Failed run control method {method}: {str(e)}"
|
||||
error_msg = f"Rank: {self.local_rank} Failed to run control method {method}: {str(e)}"
|
||||
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
error_result = ControlResponse(request_id, 500, error_msg)
|
||||
asyncio.run(self._ctrl_output.put(error_result))
|
||||
|
||||
@@ -106,7 +106,11 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
# --------------------------
|
||||
# mock IPCSignal
|
||||
# --------------------------
|
||||
patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=MagicMock())
|
||||
class DummyIPCSignal:
|
||||
def __init__(self, name, array, dtype, suffix, create=False):
|
||||
self.value = array
|
||||
|
||||
patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=DummyIPCSignal)
|
||||
patcher2.start()
|
||||
self.addCleanup(patcher2.stop)
|
||||
|
||||
@@ -122,8 +126,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
# --------------------------
|
||||
self._orig_init_cpu_cache = CacheTransferManager._init_cpu_cache
|
||||
self._orig_init_gpu_cache = CacheTransferManager._init_gpu_cache
|
||||
patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None)
|
||||
patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None)
|
||||
patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None)
|
||||
patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None)
|
||||
patcher3.start()
|
||||
patcher4.start()
|
||||
self.addCleanup(patcher3.stop)
|
||||
@@ -193,8 +197,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
kvcache_storage_backend = "unknown"
|
||||
|
||||
with (
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.console_logger") as mock_console,
|
||||
):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
@@ -209,8 +213,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
kvcache_storage_backend = "file"
|
||||
|
||||
with (
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
|
||||
):
|
||||
with self.assertRaises(ValueError):
|
||||
CacheTransferManager(LocalArgs())
|
||||
@@ -221,8 +225,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
version_path = os.path.join(tmpdir, "version.yaml")
|
||||
with open(version_path, "w", encoding="utf-8") as handle:
|
||||
handle.write("version: RL-STEP03-20250101-uuid\n")
|
||||
args.model_path = tmpdir
|
||||
args.kvcache_storage_backend = None
|
||||
self.manager.model_path = tmpdir
|
||||
self.manager.kvcache_storage_backend = None
|
||||
self.manager._init_storage(args)
|
||||
|
||||
self.assertEqual(self.manager.key_prefix, "RL-STEP03")
|
||||
@@ -465,23 +469,21 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
def __init__(self):
|
||||
self.value = [0]
|
||||
|
||||
args = Args()
|
||||
args.num_cpu_blocks = 0
|
||||
self.manager.num_cpu_blocks = 0
|
||||
self.manager.swap_space_ready_signal = DummySignal()
|
||||
|
||||
self._orig_init_cpu_cache(self.manager, args)
|
||||
self._orig_init_cpu_cache(self.manager)
|
||||
|
||||
self.assertEqual(self.manager.swap_space_ready_signal.value[0], 1)
|
||||
self.assertEqual(self.manager.swap_space_ready_signal.value[0], 0)
|
||||
|
||||
def test_init_cpu_cache_allocates_block_wise_fp8(self):
|
||||
class DummySignal:
|
||||
def __init__(self):
|
||||
self.value = [0]
|
||||
|
||||
args = Args()
|
||||
args.num_cpu_blocks = 2
|
||||
args.cache_dtype = "block_wise_fp8"
|
||||
args.value_cache_shape = "2,1,1,1"
|
||||
self.manager.num_cpu_blocks = 2
|
||||
self.manager.cache_dtype = "block_wise_fp8"
|
||||
self.manager.has_cache_scale = True
|
||||
self.manager.swap_space_ready_signal = DummySignal()
|
||||
self.manager.value_cache_shape = [2, 1, 1, 1]
|
||||
|
||||
@@ -492,7 +494,7 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.paddle.set_device"),
|
||||
):
|
||||
self._orig_init_cpu_cache(self.manager, args)
|
||||
self._orig_init_cpu_cache(self.manager)
|
||||
|
||||
self.assertEqual(self.manager.swap_space_ready_signal.value[0], 1)
|
||||
|
||||
@@ -510,8 +512,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
value_cache_shape = "2,1,1,1"
|
||||
|
||||
with (
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.MooncakeStore"),
|
||||
patch.object(CacheTransferManager, "_init_storage_buffer"),
|
||||
):
|
||||
@@ -527,9 +529,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
with (
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.set_data_ipc") as mock_set_ipc,
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0),
|
||||
):
|
||||
self._orig_init_gpu_cache(manager, LocalArgs())
|
||||
self._orig_init_gpu_cache(manager)
|
||||
|
||||
self.assertEqual(mock_set_ipc.call_count, 4)
|
||||
self.assertIn("key_caches_0_rank0.device0", manager.gpu_cache_kvs)
|
||||
@@ -548,8 +549,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
value_cache_shape = "2,1,1,1"
|
||||
|
||||
with (
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.MooncakeStore"),
|
||||
patch.object(CacheTransferManager, "_init_storage_buffer"),
|
||||
):
|
||||
@@ -568,9 +569,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
with (
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.share_external_data_", side_effect=fake_share),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0),
|
||||
):
|
||||
self._orig_init_gpu_cache(manager, LocalArgs())
|
||||
self._orig_init_gpu_cache(manager)
|
||||
|
||||
self.assertIn("key_cache_scales_0_rank0.device0", manager.gpu_cache_kvs)
|
||||
|
||||
@@ -583,8 +583,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
value_cache_shape = "1,1,1,1"
|
||||
|
||||
with (
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None),
|
||||
patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None),
|
||||
patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None),
|
||||
):
|
||||
manager = CacheTransferManager(LocalArgs())
|
||||
|
||||
@@ -607,9 +607,8 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.share_external_data_", side_effect=fake_share),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0),
|
||||
):
|
||||
self._orig_init_gpu_cache(manager, LocalArgs())
|
||||
self._orig_init_gpu_cache(manager)
|
||||
|
||||
self.assertIn("key_caches_0_rank0.device0", manager.gpu_cache_kvs)
|
||||
|
||||
@@ -1160,6 +1159,14 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
self.manager.inflight = 1
|
||||
self.manager.cache_task_is_paused_signal = DummySignal(0)
|
||||
self.manager.cache_task_inflight_signal = DummySignal(1)
|
||||
self.manager.cache_task_broadcast_signal = DummySignal(0)
|
||||
self.manager.cache_task_queue.empty.return_value = False
|
||||
self.manager.cache_task_queue.get_transfer_task.return_value = (
|
||||
(cache_transfer_manager.CacheStatus.CTRL, MagicMock()),
|
||||
True,
|
||||
)
|
||||
self.manager.control_task_thread_pool = MagicMock()
|
||||
self.manager.control_task_thread_pool.submit.side_effect = SystemExit
|
||||
|
||||
call_count = {"count": 0}
|
||||
|
||||
@@ -1168,7 +1175,6 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
if call_count["count"] == 1:
|
||||
self.manager.inflight = 0
|
||||
return None
|
||||
raise SystemExit
|
||||
|
||||
with patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep):
|
||||
with self.assertRaises(SystemExit):
|
||||
@@ -1273,6 +1279,7 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
args = Args()
|
||||
args.splitwise_role = "mixed"
|
||||
args.create_cache_tensor = True
|
||||
self.manager.create_cache_tensor = True
|
||||
self.manager.kv_cache_status_signal = DummySignal(cache_transfer_manager.KVCacheStatus.CLEARING)
|
||||
self.manager.cache_ready_signal = DummySignal(0)
|
||||
self.manager.swap_space_ready_signal = DummySignal(0)
|
||||
@@ -1293,7 +1300,10 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
patch("paddle.device.cuda.empty_cache") as mock_empty,
|
||||
patch("paddle.set_device"),
|
||||
patch.object(self.manager, "_log_memory"),
|
||||
patch("time.sleep", side_effect=maybe_stop_cleared_with_tensor),
|
||||
patch(
|
||||
"fastdeploy.cache_manager.cache_transfer_manager.time.sleep",
|
||||
side_effect=maybe_stop_cleared_with_tensor,
|
||||
),
|
||||
):
|
||||
with self.assertRaises(StopIteration):
|
||||
self.manager.check_cache_status(args)
|
||||
@@ -1427,6 +1437,7 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
args = Args()
|
||||
args.splitwise_role = "mixed"
|
||||
args.mp_num = 2
|
||||
self.manager.n_ranks = 2
|
||||
self.manager.kv_cache_status_signal = DummySignal([cache_transfer_manager.KVCacheStatus.UPDATING])
|
||||
self.manager.cache_ready_signal = DummySignal([0, 1])
|
||||
self.manager.swap_space_ready_signal = DummySignal([0, 1])
|
||||
@@ -1455,6 +1466,7 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
patch.object(self.manager, "_init_cpu_cache"),
|
||||
patch.object(self.manager, "_init_gpu_cache"),
|
||||
patch.object(self.manager, "resume"),
|
||||
patch.object(self.manager, "_update_key_prefix"),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.envs.FD_ENABLE_SWAP_SPACE_CLEARING", True),
|
||||
patch.object(self.manager, "_log_memory"),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep),
|
||||
@@ -1465,7 +1477,7 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
self.assertEqual(self.manager.kv_cache_status_signal.value[0], cache_transfer_manager.KVCacheStatus.NORMAL)
|
||||
|
||||
def test_log_memory_records_gpu_stats(self):
|
||||
with patch.object(cache_transfer_manager.logger, "warning") as mock_warning:
|
||||
with patch.object(cache_transfer_manager.logger, "info") as mock_info:
|
||||
with (
|
||||
patch("paddle.device.cuda.max_memory_allocated", return_value=1024**3),
|
||||
patch("paddle.device.cuda.max_memory_reserved", return_value=2 * 1024**3),
|
||||
@@ -1474,7 +1486,7 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
):
|
||||
self.manager._log_memory("test")
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
mock_info.assert_called_once()
|
||||
|
||||
def test_pause_and_resume_wait_for_signals(self):
|
||||
class DummySignal:
|
||||
@@ -1503,7 +1515,7 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
|
||||
self.assertFalse(self.manager.is_paused)
|
||||
|
||||
def test_submit_task_decrements_inflight(self):
|
||||
def test_submit_task_decrements_inflight_on_task_error(self):
|
||||
class DummyPool:
|
||||
def submit(self, fn, *args):
|
||||
try:
|
||||
@@ -1514,10 +1526,26 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
def raise_task():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
self.manager.inflight = 1
|
||||
self.manager.inflight = 0
|
||||
self.manager.submit_task(DummyPool(), raise_task)
|
||||
self.assertEqual(self.manager.inflight, 0)
|
||||
|
||||
def test_submit_task_decrements_inflight_on_success(self):
|
||||
class DummyPool:
|
||||
def submit(self, fn, *args):
|
||||
return fn(*args)
|
||||
|
||||
task_called = {"value": False}
|
||||
|
||||
def ok_task():
|
||||
task_called["value"] = True
|
||||
|
||||
self.manager.inflight = 0
|
||||
self.manager.submit_task(DummyPool(), ok_task)
|
||||
|
||||
self.assertTrue(task_called["value"])
|
||||
self.assertEqual(self.manager.inflight, 0)
|
||||
|
||||
def test_main_invokes_manager(self):
|
||||
cache_transfer_manager.args = Args()
|
||||
with patch("fastdeploy.cache_manager.cache_transfer_manager.CacheTransferManager") as mock_manager:
|
||||
|
||||
@@ -427,7 +427,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
self.assertEqual(manager.get_required_block_num(8, 4), 2)
|
||||
|
||||
def test_launch_cache_manager_initializes_processes(self):
|
||||
manager = _create_manager()
|
||||
manager = _create_manager(num_cpu_blocks=1)
|
||||
manager.cache_config.enable_hierarchical_cache = False
|
||||
|
||||
with (
|
||||
@@ -637,7 +637,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
self.assertIsNone(processes)
|
||||
|
||||
def test_launch_cache_manager_formats_value_cache_shape(self):
|
||||
manager = _create_manager()
|
||||
manager = _create_manager(num_cpu_blocks=1)
|
||||
|
||||
captured = {}
|
||||
|
||||
|
||||
+116
-65
@@ -11,7 +11,9 @@ HOST="0.0.0.0"
|
||||
PORT="${FD_API_PORT}" # 这里需要配合启动脚本那个URL PORT
|
||||
BASE_URL="http://$HOST:$PORT"
|
||||
|
||||
TOTAL_ROUNDS=6
|
||||
V0_ROUNDS=3
|
||||
V1_ROUNDS=3
|
||||
TOTAL_ROUNDS=$((V0_ROUNDS + V1_ROUNDS))
|
||||
CHAT_REQUESTS_PER_ROUND=3
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
MAX_MEMORY_MB=10240 # 10GB
|
||||
@@ -48,7 +50,7 @@ assert_success() {
|
||||
fi
|
||||
}
|
||||
|
||||
# curl_get_status(url, options...) → returns via global variables http_code and response_body
|
||||
# curl_get_status(url, options) → returns via global variables http_code and response_body
|
||||
curl_get_status() {
|
||||
local result
|
||||
result=$(curl -s -w "%{http_code}" "$@")
|
||||
@@ -56,6 +58,23 @@ curl_get_status() {
|
||||
response_body="${result%???}"
|
||||
}
|
||||
|
||||
post_json_and_assert() {
|
||||
local url="$1"
|
||||
local payload="${2:-}"
|
||||
if [ -n "$payload" ]; then
|
||||
curl_get_status -X POST "$url" -H "Content-Type: application/json" -d "$payload"
|
||||
else
|
||||
curl_get_status -X POST "$url"
|
||||
fi
|
||||
assert_eq "$http_code" "200" "$url failed with HTTP $http_code, body: $response_body"
|
||||
}
|
||||
|
||||
get_and_assert() {
|
||||
local url="$1"
|
||||
curl_get_status "$url"
|
||||
assert_eq "$http_code" "200" "$url failed with HTTP $http_code, body: $response_body"
|
||||
}
|
||||
|
||||
# ====================================================
|
||||
# Get visible GPU IDs from CUDA_VISIBLE_DEVICES
|
||||
# ====================================================
|
||||
@@ -79,104 +98,69 @@ check_gpu_memory() {
|
||||
local gpu_ids
|
||||
gpu_ids=($(get_visible_gpu_ids))
|
||||
|
||||
echo "========== GPU Memory Check =========="
|
||||
echo "CUDA_VISIBLE_DEVICES = $CUDA_VISIBLE_DEVICES"
|
||||
echo "MAX_MEMORY_MB = $MAX_MEMORY_MB"
|
||||
echo "======================================"
|
||||
echo "----------------------------------------"
|
||||
echo " GPU Memory Check (MAX: ${MAX_MEMORY_MB}MB)"
|
||||
echo "----------------------------------------"
|
||||
|
||||
if [ ${#gpu_ids[@]} -eq 0 ]; then
|
||||
echo "Assertion failed: No valid GPU IDs in CUDA_VISIBLE_DEVICES='$CUDA_VISIBLE_DEVICES'" >&2
|
||||
echo "ERROR: No valid GPU IDs in CUDA_VISIBLE_DEVICES='$CUDA_VISIBLE_DEVICES'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for gpu_id in "${gpu_ids[@]}"; do
|
||||
echo
|
||||
echo "---- GPU $gpu_id ----"
|
||||
|
||||
# Query summary
|
||||
local summary
|
||||
summary=$(nvidia-smi -i "$gpu_id" \
|
||||
--query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu \
|
||||
--format=csv,noheader,nounits) || {
|
||||
echo "Failed to query GPU $gpu_id summary" >&2
|
||||
echo "ERROR: Failed to query GPU $gpu_id" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse fields
|
||||
IFS=',' read -r idx name mem_total mem_used mem_free util <<< "$summary"
|
||||
local used_ratio=$(( mem_used * 100 / mem_total ))
|
||||
|
||||
echo "GPU $idx: $name"
|
||||
echo "Total Memory : ${mem_total} MB"
|
||||
echo "Used Memory : ${mem_used} MB"
|
||||
echo "Free Memory : ${mem_free} MB"
|
||||
echo "GPU Util : ${util} %"
|
||||
# Print GPU info (single line)
|
||||
printf " GPU %s: %-35s Total:%5sMB Used:%5sMB (%s%%)\n" \
|
||||
"$idx" "$name" "$mem_total" "$mem_used" "$util"
|
||||
|
||||
# --- Hard assertions ---
|
||||
# Hard assertion
|
||||
assert_true "$(( mem_used <= MAX_MEMORY_MB ))" \
|
||||
"GPU $gpu_id memory.used ${mem_used} MB > MAX_MEMORY_MB ${MAX_MEMORY_MB} MB"
|
||||
|
||||
# --- Soft safety check: usage ratio ---
|
||||
local used_ratio
|
||||
used_ratio=$(( mem_used * 100 / mem_total ))
|
||||
|
||||
echo "Used Ratio : ${used_ratio} %"
|
||||
|
||||
# Usage ratio check
|
||||
if [ "$used_ratio" -gt 90 ]; then
|
||||
echo "Assertion failed: GPU $gpu_id memory usage > 90% (${used_ratio}%)" >&2
|
||||
echo "ERROR: GPU $gpu_id memory usage > 90% (${used_ratio}%)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# --- Process-level attribution ---
|
||||
echo "Processes on GPU $gpu_id:"
|
||||
# Process info (compact format)
|
||||
local proc_info
|
||||
proc_info=$(nvidia-smi -i "$gpu_id" \
|
||||
--query-compute-apps=pid,process_name,used_memory \
|
||||
--format=csv,noheader,nounits)
|
||||
|
||||
if [ -z "$proc_info" ]; then
|
||||
echo " (No active compute processes)"
|
||||
else
|
||||
if [ -n "$proc_info" ]; then
|
||||
echo "$proc_info" | while IFS=',' read -r pid pname pmem; do
|
||||
echo " PID=$pid NAME=$pname MEM=${pmem}MB"
|
||||
printf " └─ PID=%-8s %-30s MEM=%4sMB\n" \
|
||||
"$pid" "$pname" "$pmem"
|
||||
done
|
||||
fi
|
||||
|
||||
echo "GPU $gpu_id memory check PASSED"
|
||||
done
|
||||
|
||||
echo "========== GPU Memory Check DONE =========="
|
||||
echo "----------------------------------------"
|
||||
}
|
||||
|
||||
# ====================================================
|
||||
|
||||
for round in $(seq 1 $TOTAL_ROUNDS); do
|
||||
echo "=== Round $round / $TOTAL_ROUNDS ==="
|
||||
|
||||
# Step 1: Clear loaded weights
|
||||
echo "[Step 1] Clearing load weight..."
|
||||
curl_get_status -i "$BASE_URL/clear_load_weight"
|
||||
assert_eq "$http_code" "200" "/clear_load_weight failed with HTTP $http_code"
|
||||
sleep 10
|
||||
|
||||
# Step 2: Check GPU memory usage
|
||||
echo "[Step 2] Checking GPU memory..."
|
||||
check_gpu_memory
|
||||
|
||||
# Step 3: Update model weights
|
||||
echo "[Step 3] Updating model weight..."
|
||||
curl_get_status -i "$BASE_URL/update_model_weight"
|
||||
assert_eq "$http_code" "200" "/update_model_weight failed with HTTP $http_code"
|
||||
|
||||
# Step 4: Send chat completion requests
|
||||
echo "[Step 4] Sending $CHAT_REQUESTS_PER_ROUND chat completions..."
|
||||
send_chat_requests() {
|
||||
for i in $(seq 1 $CHAT_REQUESTS_PER_ROUND); do
|
||||
echo " Request $i / $CHAT_REQUESTS_PER_ROUND"
|
||||
# Send request and capture response
|
||||
printf " └─ Sending chat request %d/%d...\n" "$i" "$CHAT_REQUESTS_PER_ROUND"
|
||||
response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"messages": [{"role": "user", "content": "Hello!"}]}')
|
||||
|
||||
# Extract the 'content' field from the response
|
||||
content=$(echo "$response" | \
|
||||
grep -o '"content":"[^"]*"' | \
|
||||
head -1 | \
|
||||
@@ -184,26 +168,93 @@ for round in $(seq 1 $TOTAL_ROUNDS); do
|
||||
sed 's/"$//')
|
||||
|
||||
if [ -z "$content" ]; then
|
||||
# Fallback: try extracting content using sed more robustly
|
||||
content=$(echo "$response" | \
|
||||
sed -n 's/.*"content":"\([^"]*\)".*/\1/p' | \
|
||||
head -1)
|
||||
fi
|
||||
|
||||
# Check if content is empty or null
|
||||
if [ -z "$content" ] || [ "$content" = "null" ]; then
|
||||
echo "Failed: Empty or null 'content' in response" >&2
|
||||
echo "Raw response:" >&2
|
||||
echo " ERROR: Empty or null 'content' in response" >&2
|
||||
echo " Raw response:" >&2
|
||||
echo "$response" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Received non-empty response"
|
||||
echo -e "\n---\n"
|
||||
printf " └─ Response received: %s\n" "$content"
|
||||
done
|
||||
}
|
||||
|
||||
echo "Round $round completed."
|
||||
echo "==================================\n"
|
||||
run_v0_round() {
|
||||
local round="$1"
|
||||
echo "========================================"
|
||||
printf "Round %d/%d (V0)\n" "$round" "$TOTAL_ROUNDS"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
printf "[Step 1] %-30s " "Clearing load weight via v0..."
|
||||
get_and_assert "$BASE_URL/clear_load_weight"
|
||||
echo "[OK]"
|
||||
sleep 10
|
||||
|
||||
printf "[Step 2] %-30s " "Checking GPU memory..."
|
||||
echo ""
|
||||
check_gpu_memory
|
||||
echo ""
|
||||
|
||||
printf "[Step 3] %-30s " "Updating model weight via v0..."
|
||||
get_and_assert "$BASE_URL/update_model_weight"
|
||||
echo "[OK]"
|
||||
|
||||
echo "[Step 4] Sending $CHAT_REQUESTS_PER_ROUND chat completions"
|
||||
send_chat_requests
|
||||
echo ""
|
||||
}
|
||||
|
||||
run_v1_round() {
|
||||
local round="$1"
|
||||
local sleep_payload='{"tags":"weight,kv_cache"}'
|
||||
local wakeup_payload='{"tags":"weight,kv_cache"}'
|
||||
|
||||
echo "========================================"
|
||||
printf "Round %d/%d (V1)\n" "$round" "$TOTAL_ROUNDS"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
printf "[Step 1] %-30s " "Pausing engine via v1..."
|
||||
post_json_and_assert "$BASE_URL/v1/pause" ""
|
||||
echo "[OK]"
|
||||
|
||||
printf "[Step 2] %-30s " "Sleeping via v1..."
|
||||
post_json_and_assert "$BASE_URL/v1/sleep" "$sleep_payload"
|
||||
echo "[OK]"
|
||||
sleep 10
|
||||
|
||||
printf "[Step 3] %-30s " "Checking GPU memory..."
|
||||
echo ""
|
||||
check_gpu_memory
|
||||
echo ""
|
||||
|
||||
printf "[Step 4] %-30s " "Waking up via v1..."
|
||||
post_json_and_assert "$BASE_URL/v1/wakeup" "$wakeup_payload"
|
||||
echo "[OK]"
|
||||
|
||||
printf "[Step 5] %-30s " "Resuming engine via v1..."
|
||||
post_json_and_assert "$BASE_URL/v1/resume" ""
|
||||
echo "[OK]"
|
||||
|
||||
echo "[Step 6] Sending $CHAT_REQUESTS_PER_ROUND chat completions"
|
||||
send_chat_requests
|
||||
echo ""
|
||||
}
|
||||
|
||||
for round in $(seq 1 $V0_ROUNDS); do
|
||||
run_v0_round "$round"
|
||||
done
|
||||
|
||||
echo "All $TOTAL_ROUNDS rounds completed successfully."
|
||||
for round in $(seq 1 $V1_ROUNDS); do
|
||||
run_v1_round "$round"
|
||||
done
|
||||
|
||||
echo "========================================"
|
||||
printf "All %d rounds completed successfully.\n" "$TOTAL_ROUNDS"
|
||||
echo "========================================"
|
||||
|
||||
@@ -11,6 +11,7 @@ FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
|
||||
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
|
||||
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
|
||||
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633))
|
||||
|
||||
# List of ports to clean before and after tests
|
||||
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT]
|
||||
|
||||
@@ -20,7 +20,7 @@ import threading
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import ANY, MagicMock, Mock, patch
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -1091,6 +1091,19 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
self.assertEqual(result, {"ok": True})
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_update_weights_updates_cfg_version(self):
|
||||
eng = self._make_mixed_engine()
|
||||
eng.is_paused = True
|
||||
eng._pause_cond = threading.Condition()
|
||||
eng.cfg.model_config.version = "old-version"
|
||||
eng._call_worker = Mock(return_value=[{"version": "new-version"}, {"ok": True}])
|
||||
|
||||
result = eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))
|
||||
|
||||
self.assertEqual(result, [{"version": "new-version"}, {"ok": True}])
|
||||
self.assertEqual(eng.cfg.model_config.version, "new-version")
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_pause_and_resume_paths(self):
|
||||
eng = self._make_mixed_engine()
|
||||
eng.is_paused = False
|
||||
@@ -1155,12 +1168,50 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
async def get(self, timeout=None):
|
||||
return Mock(payload=ControlResponse(request_id="req", result={"ok": True}, error_code=200))
|
||||
|
||||
eng._ctrl_worker_output_queues = [DummyQueue()]
|
||||
eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": DummyQueue()}
|
||||
result = eng._call_worker(ControlRequest(request_id="req", method="noop"), timeout=1)
|
||||
self.assertEqual(result, [{"ok": True}])
|
||||
eng.engine_worker_queue.put_tasks.assert_called_once()
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_sleep_defaults_tags_and_dispatches_cache_transfer(self):
|
||||
cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=4)
|
||||
eng = self._make_engine(cfg)
|
||||
eng.cfg.cache_config.num_cpu_blocks = 1
|
||||
eng.engine_worker_queue = Mock()
|
||||
eng.cache_task_queue = Mock()
|
||||
eng.resource_manager.cache_manager.reset = Mock()
|
||||
eng._control_pause = Mock()
|
||||
eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}])
|
||||
|
||||
result = eng._control_sleep(ControlRequest(request_id="sleep", method="sleep", args={}))
|
||||
|
||||
self.assertEqual(result, [{"ok": True}])
|
||||
eng._control_pause.assert_called_once_with(None)
|
||||
eng.resource_manager.cache_manager.reset.assert_called_once()
|
||||
eng.engine_worker_queue.put_tasks.assert_called_once()
|
||||
eng.cache_task_queue.put_transfer_task.assert_called_once()
|
||||
sleep_req = eng.engine_worker_queue.put_tasks.call_args.args[0][0][0]
|
||||
self.assertEqual(sleep_req.args["tags"], "weight,kv_cache")
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_wakeup_resumes_after_wait(self):
|
||||
cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=4)
|
||||
eng = self._make_engine(cfg)
|
||||
eng.cfg.cache_config.num_cpu_blocks = 1
|
||||
eng.engine_worker_queue = Mock()
|
||||
eng.cache_task_queue = Mock()
|
||||
eng._control_resume = Mock()
|
||||
eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}])
|
||||
|
||||
result = eng._control_wakeup(ControlRequest(request_id="wakeup", method="wakeup", args={"tags": "kv_cache"}))
|
||||
|
||||
self.assertEqual(result, [{"ok": True}])
|
||||
eng.engine_worker_queue.put_tasks.assert_called_once()
|
||||
eng.cache_task_queue.put_transfer_task.assert_called_once()
|
||||
eng._control_resume.assert_called_once_with(None)
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_update_weights_requires_pause(self):
|
||||
eng = self._make_mixed_engine()
|
||||
eng.is_paused = False
|
||||
@@ -1940,68 +1991,121 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
mock_logger.warning.assert_called()
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_all_control_responses_success(self):
|
||||
def test_wait_for_control_responses_success(self):
|
||||
eng = self._make_mixed_engine()
|
||||
|
||||
eng._ctrl_worker_output_queues = [
|
||||
self._make_ctrl_queue("q0", Mock(request_id="req", error_code=200, result={"ok": True})),
|
||||
self._make_ctrl_queue("q1", Mock(request_id="req", error_code=200, result={"ok": True})),
|
||||
]
|
||||
eng._ctrl_output_queues = {
|
||||
"ctrl_w2e_rank0_6778": self._make_ctrl_queue(
|
||||
"q0", Mock(request_id="req", error_code=200, result={"ok": True})
|
||||
),
|
||||
"ctrl_w2e_rank1_6778": self._make_ctrl_queue(
|
||||
"q1", Mock(request_id="req", error_code=200, result={"ok": True})
|
||||
),
|
||||
}
|
||||
|
||||
results = asyncio.run(eng._wait_all_control_responses("req", timeout=1))
|
||||
results = asyncio.run(eng._wait_for_control_responses("req", timeout=1))
|
||||
self.assertEqual(results, [{"ok": True}, {"ok": True}])
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_all_control_responses_ignores_mismatch(self):
|
||||
def test_wait_for_control_responses_filters_executors(self):
|
||||
eng = self._make_mixed_engine()
|
||||
|
||||
eng._ctrl_worker_output_queues = [
|
||||
self._make_ctrl_queue("q0", Mock(request_id="old", error_code=200, result={"ok": False})),
|
||||
self._make_ctrl_queue("q1", Mock(request_id="req", error_code=200, result={"ok": True})),
|
||||
]
|
||||
eng._ctrl_output_queues = {
|
||||
"ctrl_w2e_rank0_6778": self._make_ctrl_queue(
|
||||
"worker", Mock(request_id="req", error_code=200, result={"worker": True})
|
||||
),
|
||||
"ctrl_c2e_rank0_6779": self._make_ctrl_queue(
|
||||
"cache", Mock(request_id="req", error_code=200, result={"cache": True})
|
||||
),
|
||||
}
|
||||
|
||||
results = asyncio.run(eng._wait_all_control_responses("req", timeout=1))
|
||||
self.assertEqual(results, [{"ok": True}])
|
||||
worker_results = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["worker"]))
|
||||
cache_results = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["cache_transfer"]))
|
||||
|
||||
self.assertEqual(worker_results, [{"worker": True}])
|
||||
self.assertEqual(cache_results, [{"cache": True}])
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_all_control_responses_error_paths(self):
|
||||
def test_wait_for_control_responses_ignores_mismatch(self):
|
||||
eng = self._make_mixed_engine()
|
||||
|
||||
eng._ctrl_worker_output_queues = [
|
||||
self._make_ctrl_queue("q0", Exception("boom"), payload_wrapped=False),
|
||||
]
|
||||
class DummyQueue:
|
||||
def __init__(self, name, payloads):
|
||||
self.name = name
|
||||
self.payloads = list(payloads)
|
||||
|
||||
async def get(self, timeout=None):
|
||||
return Mock(payload=self.payloads.pop(0))
|
||||
|
||||
eng._ctrl_output_queues = {
|
||||
"ctrl_w2e_rank0_6778": DummyQueue(
|
||||
"q0",
|
||||
[
|
||||
Mock(request_id="old", error_code=200, result={"ok": False}),
|
||||
Mock(request_id="req", error_code=200, result={"ok": "from-q0"}),
|
||||
],
|
||||
),
|
||||
"ctrl_w2e_rank1_6778": self._make_ctrl_queue(
|
||||
"q1", Mock(request_id="req", error_code=200, result={"ok": True})
|
||||
),
|
||||
}
|
||||
|
||||
results = asyncio.run(eng._wait_for_control_responses("req", timeout=1))
|
||||
self.assertEqual(results, [{"ok": "from-q0"}, {"ok": True}])
|
||||
self.assertEqual(
|
||||
eng._ctrl_response_mailboxes["ctrl_w2e_rank0_6778"]["old"].result,
|
||||
{"ok": False},
|
||||
)
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_for_control_responses_error_paths(self):
|
||||
eng = self._make_mixed_engine()
|
||||
|
||||
eng._ctrl_output_queues = {
|
||||
"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", Exception("boom"), payload_wrapped=False)
|
||||
}
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
asyncio.run(eng._wait_all_control_responses("req", timeout=1))
|
||||
asyncio.run(eng._wait_for_control_responses("req", timeout=1))
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_all_control_responses_none_message(self):
|
||||
def test_wait_for_control_responses_none_message(self):
|
||||
eng = self._make_mixed_engine()
|
||||
|
||||
eng._ctrl_worker_output_queues = [self._make_ctrl_queue("q0", None, payload_wrapped=False)]
|
||||
eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)}
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
asyncio.run(eng._wait_all_control_responses("req", timeout=1))
|
||||
asyncio.run(eng._wait_for_control_responses("req", timeout=1))
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_all_control_responses_error_code(self):
|
||||
def test_wait_for_control_responses_error_code(self):
|
||||
eng = self._make_mixed_engine()
|
||||
|
||||
eng._ctrl_worker_output_queues = [
|
||||
self._make_ctrl_queue("q0", ControlResponse(request_id="req", error_code=500, error_message="bad")),
|
||||
]
|
||||
eng._ctrl_output_queues = {
|
||||
"ctrl_w2e_rank0_6778": self._make_ctrl_queue(
|
||||
"q0", ControlResponse(request_id="req", error_code=500, error_message="bad")
|
||||
)
|
||||
}
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
asyncio.run(eng._wait_all_control_responses("req", timeout=1))
|
||||
asyncio.run(eng._wait_for_control_responses("req", timeout=1))
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_all_control_responses_timeout(self):
|
||||
def test_wait_for_control_responses_timeout(self):
|
||||
eng = self._make_mixed_engine()
|
||||
eng._ctrl_worker_output_queues = [self._make_ctrl_queue("q0", None, payload_wrapped=False)]
|
||||
eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)}
|
||||
|
||||
with patch("fastdeploy.engine.common_engine.asyncio.wait_for", side_effect=asyncio.TimeoutError):
|
||||
with self.assertRaises(Exception):
|
||||
asyncio.run(eng._wait_all_control_responses("req", timeout=1))
|
||||
asyncio.run(eng._wait_for_control_responses("req", timeout=1))
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_for_control_responses_without_matching_queues(self):
|
||||
eng = self._make_mixed_engine()
|
||||
eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)}
|
||||
|
||||
result = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["cache_transfer"]))
|
||||
self.assertIsNone(result)
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_insert_tasks_prefill_error_and_success(self):
|
||||
@@ -3254,7 +3358,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
|
||||
# Lines 1299-1300: try block start + info logging
|
||||
info_msgs = [str(c) for c in mock_logger.info.call_args_list]
|
||||
self.assertTrue(any("START run control method" in m for m in info_msgs))
|
||||
self.assertTrue(any("Start to run control method" in m for m in info_msgs))
|
||||
# worker_pid should be popped from the map
|
||||
self.assertNotIn("ctrl-log", eng.request_worker_map)
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
@@ -84,6 +84,7 @@ def _reload_api_server(args):
|
||||
fake_envs_mod.EXPORTER_OTLP_ENDPOINT = ""
|
||||
fake_envs_mod.EXPORTER_OTLP_HEADERS = ""
|
||||
fake_envs_mod.FD_SUPPORT_MAX_CONNECTIONS = 1024
|
||||
fake_envs_mod.FD_ENABLE_V1_UPDATE_WEIGHTS = 0
|
||||
fake_envs_mod.environment_variables = _FakeEnvVars()
|
||||
|
||||
# Save original sys.argv and replace with minimal valid args to avoid parse errors
|
||||
@@ -98,9 +99,8 @@ def _reload_api_server(args):
|
||||
patch.dict("sys.modules", {"fastdeploy.envs": fake_envs_mod}),
|
||||
patch("fastdeploy.envs", fake_envs_mod),
|
||||
):
|
||||
from fastdeploy.entrypoints.openai import api_server as api_server_mod
|
||||
|
||||
return importlib.reload(api_server_mod)
|
||||
sys.modules.pop("fastdeploy.entrypoints.openai.api_server", None)
|
||||
return importlib.import_module("fastdeploy.entrypoints.openai.api_server")
|
||||
finally:
|
||||
sys.argv = original_argv
|
||||
|
||||
@@ -536,6 +536,95 @@ async def test_reward_embedding_and_weights():
|
||||
assert api_server.clear_load_weight(MagicMock()).status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep_wakeup_and_v1_clear_load_weight_routes():
|
||||
args = _build_args(dynamic_load_weight=True)
|
||||
api_server = _reload_api_server(args)
|
||||
api_server.app.state.dynamic_load_weight = True
|
||||
api_server.app.state.event_loop = asyncio.get_running_loop()
|
||||
|
||||
mock_control_response = MagicMock()
|
||||
mock_control_response.to_api_json_response.return_value = api_server.JSONResponse(
|
||||
content={"ok": True}, status_code=200
|
||||
)
|
||||
|
||||
api_server.app.state.engine_client = MagicMock()
|
||||
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response)
|
||||
api_server.app.state.engine_client.run_control_method_sync.return_value = mock_control_response
|
||||
|
||||
sleep_req = MagicMock()
|
||||
sleep_req.body = AsyncMock(return_value=b'{"tags":"weight"}')
|
||||
sleep_req.json = AsyncMock(return_value={"tags": "weight"})
|
||||
sleep_resp = await api_server.sleep(sleep_req)
|
||||
assert sleep_resp.status_code == 200
|
||||
sleep_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[0].args[0]
|
||||
assert sleep_control_request.method == "sleep"
|
||||
assert sleep_control_request.args == {"tags": "weight"}
|
||||
|
||||
wakeup_req = MagicMock()
|
||||
wakeup_req.body = AsyncMock(return_value=b'{"tags":"weight,kv_cache"}')
|
||||
wakeup_req.json = AsyncMock(return_value={"tags": "weight,kv_cache"})
|
||||
wakeup_resp = await api_server.wakeup(wakeup_req)
|
||||
assert wakeup_resp.status_code == 200
|
||||
wakeup_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[1].args[0]
|
||||
assert wakeup_control_request.method == "wakeup"
|
||||
assert wakeup_control_request.args == {"tags": "weight,kv_cache"}
|
||||
|
||||
passthrough_req = MagicMock()
|
||||
passthrough_req.body = AsyncMock(return_value=b'{"tags":["weight"]}')
|
||||
passthrough_req.json = AsyncMock(return_value={"tags": ["weight"]})
|
||||
passthrough_resp = await api_server.sleep(passthrough_req)
|
||||
assert passthrough_resp.status_code == 200
|
||||
passthrough_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[2].args[0]
|
||||
assert passthrough_control_request.args == {"tags": ["weight"]}
|
||||
|
||||
with patch.object(api_server.envs, "FD_ENABLE_V1_UPDATE_WEIGHTS", True):
|
||||
clear_resp = api_server.clear_load_weight(MagicMock())
|
||||
assert clear_resp.status_code == 200
|
||||
sync_control_request = api_server.app.state.engine_client.run_control_method_sync.call_args.args[0]
|
||||
assert sync_control_request.method == "sleep"
|
||||
|
||||
api_server.app.state.engine_client.run_control_method_sync.reset_mock()
|
||||
with patch.object(api_server.envs, "FD_ENABLE_V1_UPDATE_WEIGHTS", True):
|
||||
update_resp = api_server.update_model_weight(MagicMock())
|
||||
assert update_resp.status_code == 200
|
||||
sync_control_request = api_server.app.state.engine_client.run_control_method_sync.call_args.args[0]
|
||||
assert sync_control_request.method == "wakeup"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_weights_route_validation():
|
||||
args = _build_args(dynamic_load_weight=True)
|
||||
api_server = _reload_api_server(args)
|
||||
mock_control_response = MagicMock()
|
||||
mock_control_response.to_api_json_response.return_value = api_server.JSONResponse(
|
||||
content={"ok": True}, status_code=200
|
||||
)
|
||||
api_server.app.state.engine_client = MagicMock()
|
||||
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response)
|
||||
|
||||
valid_req = MagicMock()
|
||||
valid_req.body = AsyncMock(return_value=b'{"version":"v2","rsync_config":{"etcd_server":"127.0.0.1"}}')
|
||||
valid_req.json = AsyncMock(return_value={"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}})
|
||||
valid_resp = await api_server.update_weights(valid_req)
|
||||
assert valid_resp.status_code == 200
|
||||
control_request = api_server.app.state.engine_client.run_control_method.await_args.args[0]
|
||||
assert control_request.method == "update_weights"
|
||||
assert control_request.args == {"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}}
|
||||
|
||||
invalid_version_req = MagicMock()
|
||||
invalid_version_req.body = AsyncMock(return_value=b'{"version":1}')
|
||||
invalid_version_req.json = AsyncMock(return_value={"version": 1})
|
||||
invalid_version_resp = await api_server.update_weights(invalid_version_req)
|
||||
assert invalid_version_resp.status_code == 400
|
||||
|
||||
invalid_rsync_req = MagicMock()
|
||||
invalid_rsync_req.body = AsyncMock(return_value=b'{"rsync_config":{"user":"u"}}')
|
||||
invalid_rsync_req.json = AsyncMock(return_value={"rsync_config": {"user": "u"}})
|
||||
invalid_rsync_resp = await api_server.update_weights(invalid_rsync_req)
|
||||
assert invalid_rsync_resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expert_and_stats_routes():
|
||||
args = _build_args()
|
||||
|
||||
@@ -142,34 +142,34 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
|
||||
def capture_and_replay(self, input_tensor1, forward_meta1):
|
||||
""" """
|
||||
# Trigger Capture
|
||||
print_gpu_memory_use(0, "before capture")
|
||||
print_gpu_memory_use("before capture", 0)
|
||||
output1 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
print_gpu_memory_use(0, "after capture")
|
||||
print_gpu_memory_use("after capture", 0)
|
||||
|
||||
# Replay
|
||||
output1 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
assert (output1 == self.output_correct).all()
|
||||
|
||||
# Destroy
|
||||
print_gpu_memory_use(0, "before destroy")
|
||||
print_gpu_memory_use("before destroy", 0)
|
||||
self.test_model1.clear_grpah_opt_backend()
|
||||
print_gpu_memory_use(0, "after destroy")
|
||||
print_gpu_memory_use("after destroy", 0)
|
||||
|
||||
def recapture_and_replay(self, input_tensor1, forward_meta1):
|
||||
""" """
|
||||
# Trigger Capture
|
||||
print_gpu_memory_use(0, "before recapture")
|
||||
print_gpu_memory_use("before recapture", 0)
|
||||
output2 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
print_gpu_memory_use(0, "after recapture")
|
||||
print_gpu_memory_use("after recapture", 0)
|
||||
|
||||
# Replay
|
||||
output2 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1)
|
||||
assert (output2 == self.output_correct).all()
|
||||
|
||||
# Destroy
|
||||
print_gpu_memory_use(0, "before destroy")
|
||||
print_gpu_memory_use("before destroy", 0)
|
||||
self.test_model1.clear_grpah_opt_backend()
|
||||
print_gpu_memory_use(0, "after destroy")
|
||||
print_gpu_memory_use("after destroy", 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user