From a7f52c300d071ca4b57b1af7fd96dfa3eeef0b3f Mon Sep 17 00:00:00 2001 From: Yonghua Li <39643373+liyonghua0910@users.noreply.github.com> Date: Wed, 25 Mar 2026 19:18:46 +0800 Subject: [PATCH] [Feature] support v1 update/clear api for RL (#6761) * [Feature] support v1 update/clear api for RL * [fix] fix execute_model and add sleep/wakeup api * [fix] fix mtp and key_prefix * [chore] move _update_key_prefix to resume method * [fix] make the interface safe to call multiple times * [fix] fix some tiny bugs * [chore] make small changes against pr review * [docs] add docs for weight update * [test] add some tests and update docs * [style] fix code style check * [test] fix ci * [fix] fix stale control responses when control method timed out * [chore] remove unused code * [chore] fix code style * [chore] optimize tags and key_prefix * [test] fix ci * [chore] fix code style * [test] fix ci * [fix] fix ep control * [fix] fix ep control for engine cache queue --- docs/features/weight_update.md | 308 ++++++++++++++ docs/zh/features/weight_update.md | 307 ++++++++++++++ fastdeploy/cache_manager/cache_data.py | 1 + .../cache_manager/cache_transfer_manager.py | 383 +++++++++++------- .../cache_manager/prefix_cache_manager.py | 8 +- fastdeploy/config.py | 3 + fastdeploy/engine/common_engine.py | 299 +++++++++++--- fastdeploy/entrypoints/engine_client.py | 28 +- fastdeploy/entrypoints/openai/api_server.py | 59 ++- .../entrypoints/openai/response_processors.py | 21 +- fastdeploy/entrypoints/openai/utils.py | 7 +- fastdeploy/envs.py | 5 + .../inter_communicator/engine_cache_queue.py | 20 +- fastdeploy/inter_communicator/zmq_server.py | 3 + fastdeploy/rl/dynamic_weight_manager.py | 57 +++ fastdeploy/utils.py | 18 +- fastdeploy/worker/gpu_model_runner.py | 85 ++++ fastdeploy/worker/gpu_worker.py | 8 + fastdeploy/worker/worker_process.py | 64 ++- .../test_cache_transfer_manager.py | 100 +++-- .../test_prefix_cache_manager.py | 4 +- tests/ce/stable_cases/run.sh | 181 ++++++--- tests/e2e/utils/serving_utils.py | 1 + tests/engine/test_common_engine.py | 168 ++++++-- tests/entrypoints/openai/test_api_server.py | 95 ++++- .../test_cuda_graph_recapture.py | 16 +- 26 files changed, 1857 insertions(+), 392 deletions(-) create mode 100644 docs/features/weight_update.md create mode 100644 docs/zh/features/weight_update.md diff --git a/docs/features/weight_update.md b/docs/features/weight_update.md new file mode 100644 index 0000000000..adb61704c8 --- /dev/null +++ b/docs/features/weight_update.md @@ -0,0 +1,308 @@ +[简体中文](../zh/features/weight_update.md) + +# Weight Clear and Update + +FastDeploy supports dynamic weight clear and update for RL and RLHF rollout services. This capability is primarily intended to address the following two requirements: + +- release GPU memory when the rollout engine is idle; +- refresh inference weights after the trainer produces a new checkpoint, without restarting the whole service. + +This page describes the weight-control interfaces currently supported by FastDeploy, the semantics of each interface, and their typical usage in RLHF training. + +## Prerequisites + +In RLHF scenarios, FastDeploy mainly provides this capability through the online serving mode. Dynamic weight loading must be enabled when starting the service: + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model /path/to/model \ + --dynamic-load-weight \ + --load_strategy ipc_snapshot +``` + +`--dynamic-load-weight` enables dynamic weight control, and `--load_strategy` specifies the concrete weight update mechanism. The currently supported update modes are listed below: + +| Mode | `load_strategy` | Typical use | Notes | +| --- | --- | --- | --- | +| CUDA IPC | `ipc` | Training and inference processes on the same node share live tensors | Update source comes from IPC metadata produced by the training side. | +| IPC snapshot | `ipc_snapshot` | Rollout reloads a snapshot file produced by training | Used by current RL rollout examples. | +| RDMA / rsync | `rsync` | Trainer publishes a new version and rollout fetches it remotely | `POST /v1/update_weights` is the explicit API for this mode. | + +## API Overview + +### Compatibility APIs + +In FastDeploy <= 2.5, the following simplified APIs are provided for compatibility with the legacy RL control flow. + +| API | Method | Meaning | Availability | +| --- | --- | --- | --- | +| `/clear_load_weight` | `GET` | Clear or offload currently loaded weights | Requires `dynamic_load_weight=True` | +| `/update_model_weight` | `GET` | Reload weights after a clear/offload operation | Requires `dynamic_load_weight=True` | + +### V1 control APIs + +In FastDeploy >= 2.6, the underlying control-signal communication path is optimized and V1 control APIs are introduced. Compared with the legacy APIs, the V1 APIs provide a more stable execution path, clearer semantics, and more flexible control: + +| API | Method | Request params | Semantics | +| --- | --- | --- | --- | +| `/v1/pause` | `POST` | none | Pause request generation, abort running and inflight requests, reset scheduler state, and pause cache transfer if enabled. | +| `/v1/resume` | `POST` | none | Resume request generation and cache transfer. | +| `/v1/is_paused` | `GET` | none | Return `{"is_paused": bool}`. | +| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | Offload selected GPU memory objects. Supported tags are `weight` and `kv_cache`. If omitted, both are used. | +| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | Reload previously offloaded weights and/or KV cache. On success, the engine resumes automatically. | +| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. | + +### Compatibility Notes + +The optimized communication path also applies to the legacy APIs. By setting `FD_ENABLE_V1_UPDATE_WEIGHTS=1`, the legacy APIs can be switched to the new control path while keeping the original API form. + +- `FD_ENABLE_V1_UPDATE_WEIGHTS=0`: use the legacy shared-memory-based control path. +- `FD_ENABLE_V1_UPDATE_WEIGHTS=1`: `/clear_load_weight` is effectively handled through `/v1/sleep`, and `/update_model_weight` is effectively handled through `/v1/wakeup`. The corresponding pause/resume actions are handled internally by `sleep` and `wakeup`. + +**Note**: regardless of whether V1 is enabled, the legacy APIs are not the recommended standard interface for RLHF scenarios and may be gradually deprecated in future releases. The `/v1/*` control APIs are recommended. + +## Interface Semantics + +### `/v1/pause` + +`/v1/pause` is the safe boundary before changing model state. + +It does the following: + +- stops new request generation; +- aborts running and inflight requests; +- resets scheduler state; +- pauses cache transfer when multi-level cache or KV cache storage is enabled. + +When a clear boundary is required between one rollout round and the next training stage, this API should be called first. + +### `/v1/sleep` + +`/v1/sleep` offloads selected runtime state from GPU memory. + +Supported tags: + +- `weight`: clear model weights from device memory; if enabled, communication groups and DeepEP buffers may also be released. +- `kv_cache`: clear KV cache; MTP cache is also cleared when speculative decoding uses MTP. + +If the `tags` parameter is omitted, FastDeploy defaults to: + +```bash +/v1/sleep?tags=weight,kv_cache +``` + +In the current implementation, `sleep` automatically performs a `pause` first. New integrations should not rely on this implicit behavior. + +### `/v1/wakeup` + +`/v1/wakeup` restores the state offloaded by `/v1/sleep`. + +Depending on tags and configuration, FastDeploy may: + +- restart communication groups; +- recreate DeepEP buffers; +- reload model weights from the configured source; +- rebuild KV cache; +- recapture CUDA Graph. + +After `wakeup` succeeds, FastDeploy automatically calls `resume`. + +### `/v1/update_weights` + +`/v1/update_weights` refreshes model parameters directly, without unloading the GPU memory occupied by model weights. + +Current request fields: + +- `version`: optional string. Used to choose a target checkpoint version. +- `rsync_config`: optional dictionary. Must contain `etcd_server` when provided. + +Important semantics: + +- the engine must already be paused, otherwise the request fails; +- the update is executed on workers only; +- this API is meant for explicit weight refresh, especially the `rsync` path; +- it does not implicitly call `resume`. + +Recommended sequence: + +1. `POST /v1/pause` +2. `POST /v1/update_weights` +3. `POST /v1/resume` + +If GPU memory also needs to be reclaimed between rollout rounds, the `sleep` / `wakeup` workflow is more appropriate. + +## Example Requests + +### Basic APIs + +Pause the engine: + +```bash +curl -X POST http://127.0.0.1:8000/v1/pause +``` + +Resume the engine: + +```bash +curl -X POST http://127.0.0.1:8000/v1/resume +``` + +### Sleep / Wakeup APIs + +**Offload weights and KV cache** + +```bash +# Offload both weights and KV cache +curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache" + +# Offload only weights +curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight" + +# Omit parameter, defaults to both +curl -X POST "http://127.0.0.1:8000/v1/sleep" +``` + +**Restore weights and KV cache** + +```bash +# Restore both weights and KV cache +curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache" + +# Restore only weights +curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight" + +# Omit parameter, defaults to both +curl -X POST "http://127.0.0.1:8000/v1/wakeup" +``` + +**Note**: When `use_cudagraph=True`, KV cache must be restored before weights. This means `/v1/wakeup` with the `kv_cache` tag must be called before calling `/v1/wakeup` with the `weight` tag. If weights are restored without KV cache, an error will be raised. It is recommended to keep the `tags` parameter consistent between `/v1/sleep` and `/v1/wakeup`. + +### Update Weights API + +Refresh to a new remotely published version: + +```bash +curl -X POST http://127.0.0.1:8000/v1/update_weights \ + -H "Content-Type: application/json" \ + -d '{ + "version": "global_step_1200", + "rsync_config": { + "etcd_server": "127.0.0.1:2379" + } + }' +``` + +## RLHF Usage + +### Recommended Rollout Service Setup + +In RLHF scenarios, FastDeploy rollout services are typically configured as follows: + +- `dynamic_load_weight=True` +- `load_strategy=ipc_snapshot` for local snapshot-based refresh; +- or `load_strategy=rsync` for versioned remote refresh. + +The rollout utilities in the repository already follow this pattern. A typical example is: + +```python +from fastdeploy.rl.rollout_config import RolloutModelConfig +from fastdeploy.rl.rollout_model import RolloutModel + +rollout_config = RolloutModelConfig( + model_name_or_path=model_path, + tensor_parallel_size=ranks, + dynamic_load_weight=True, + load_strategy="ipc_snapshot", +) +rollout_model = RolloutModel(rollout_config) +``` + +### Training-Side Integration Support + +In addition to serving endpoints, FastDeploy provides the following training-side integration capabilities for RLHF: + +- `RolloutModel.state_dict()`: exposes the rollout-side inference parameters. +- `RolloutModel.get_name_mappings_to_training()`: exposes the mapping from inference parameter names to training parameter names. + +These interfaces can be used to align training checkpoints with rollout-side parameter layouts, especially when inference-side and training-side parameter names are not fully identical. + +### Common RLHF workflows + +The following examples assume the service endpoint is `http://127.0.0.1:8000`. + +**Workflow 1: clear and restore** + +This workflow is suitable when the rollout service stays resident, but GPU memory should be released before training and restored afterward. The recommended sequence is `(pause) -> sleep -> wakeup -> (resume)`, where the steps in parentheses are optional. + +```bash +# Optional: explicitly pause the engine to establish a clear transition boundary +curl -X POST http://127.0.0.1:8000/v1/pause + +# Offload both weights and KV cache +curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache" + +# Restore both weights and KV cache after training completes +curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache" + +# Optional: explicitly resume if required by the integration +curl -X POST http://127.0.0.1:8000/v1/resume +``` + +**Workflow 2: in-place refresh to a new checkpoint** + +This workflow is suitable when the service remains resident and only needs to switch to a new checkpoint version. The recommended sequence is `pause -> update_weights -> resume`. + +```bash +# Pause the engine first +curl -X POST http://127.0.0.1:8000/v1/pause + +# Refresh to a new checkpoint version in place +curl -X POST http://127.0.0.1:8000/v1/update_weights \ + -H "Content-Type: application/json" \ + -d '{ + "version": "global_step_1200", + "rsync_config": { + "etcd_server": "127.0.0.1:2379" + } + }' + +# Resume the service after the update completes +curl -X POST http://127.0.0.1:8000/v1/resume +``` + +**Workflow 3: legacy compatibility APIs** + +Legacy RL clients can continue to use the compatibility flow `clear_load_weight -> update_model_weight`. + +```bash +# Clear or offload the current weights +curl -X GET http://127.0.0.1:8000/clear_load_weight + +# Reload weights after the trainer updates the checkpoint +curl -X GET http://127.0.0.1:8000/update_model_weight +``` + +For new integrations, the `/v1/*` APIs are recommended because their control path is more explicit and easier to trace. + +## Other Related Configuration + +### Communication Group Clear and Rebuild + +FastDeploy provides `--shutdown-comm-group-if-worker-idle` and `--no-shutdown-comm-group-if-worker-idle` to explicitly control whether communication groups should also be torn down when weights are offloaded. + +Keeping communication groups alive generally improves the stability of weight clearing and reloading. The tradeoff is that more GPU memory remains allocated after weight offload, and the execution time of `sleep` / `wakeup` may also increase. + +By default: + +- in EP scenarios, communication groups are kept; +- in non-EP scenarios, communication groups are torn down. + +### CPU Cache Clear and Rebuild + +After `--swap-space` is enabled, the following environment variable can be used to control whether CPU-side cache should also be cleared when `/v1/sleep` is executed, in order to reduce memory pressure during training. + +By default, FastDeploy does not actively clear CPU cache. To clear it together with `sleep`, set: + +```bash +export FD_ENABLE_SWAP_SPACE_CLEARING=1 +``` diff --git a/docs/zh/features/weight_update.md b/docs/zh/features/weight_update.md new file mode 100644 index 0000000000..1b34a29b9b --- /dev/null +++ b/docs/zh/features/weight_update.md @@ -0,0 +1,307 @@ +[English](../../features/weight_update.md) + +# 权重清除与更新 + +FastDeploy 支持面向 RL / RLHF Rollout 服务的动态权重清除、显存卸载和权重更新,主要用于解决以下两类问题: + +- Rollout 引擎空闲时释放 GPU 显存; +- Trainer 产出新 checkpoint 后,推理服务在不重启进程的情况下切换到新权重。 + +本文档介绍 FastDeploy 当前支持的权重控制接口、各接口的语义,以及它们在 RLHF 训练中的典型用法。 + +## 前置条件 + +在 RLHF 场景下,FastDeploy 主要通过在线服务模式提供该能力。启动服务时,需要开启动态权重加载: + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model /path/to/model \ + --dynamic-load-weight \ + --load-strategy ipc_snapshot +``` + +`--dynamic-load-weight` 用于开启动态权重控制能力,`--load-strategy` 用于指定具体的权重更新方式。当前支持的更新模式如下: + +| 模式 | `load_strategy` | 典型场景 | 说明 | +| --- | --- | --- | --- | +| CUDA IPC | `ipc` | 训练进程与推理进程同机,直接共享实时张量 | 更新来源是训练侧产出的 IPC 元信息。 | +| IPC 快照 | `ipc_snapshot` | Rollout 从训练侧产出的权重快照文件重载 | 当前仓库里的 RL rollout 示例主要使用该模式。 | +| RDMA / rsync | `rsync` | Trainer 发布新版本,Rollout 远端拉取 | `POST /v1/update_weights` 是这一模式的显式接口。 | + +## 接口说明 + +### 旧版接口 + +在 FastDeploy <= 2.5 版本中,主要提供以下简化接口,保留给旧版 RL 控制流使用。 + +| 接口 | 方法 | 含义 | 可用条件 | +| --- | --- | --- | --- | +| `/clear_load_weight` | `GET` | 清除或卸载当前已加载权重 | 需要 `dynamic_load_weight=True` | +| `/update_model_weight` | `GET` | 在清除/卸载后重新加载权重 | 需要 `dynamic_load_weight=True` | + +### V1 新版接口 + +在 FastDeploy >= 2.6 版本中,底层控制信号通信链路经过优化,并引入了 V1 控制接口。相较于旧版接口,V1 接口在通信与执行链路上更稳定,语义更清晰,同时提供了更灵活的控制方式,包括以下接口: + +| 接口 | 方法 | 请求参数 | 语义 | +| --- | --- | --- | --- | +| `/v1/pause` | `POST` | 无 | 暂停请求生成,中断 running/inflight 请求,重置调度器,并在开启 cache transfer 时暂停 cache transfer。 | +| `/v1/resume` | `POST` | 无 | 恢复请求生成和 cache transfer。 | +| `/v1/is_paused` | `GET` | 无 | 返回 `{"is_paused": bool}`。 | +| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | 卸载指定 GPU 内存对象。支持 `weight` 与 `kv_cache`;不传时默认同时处理两者。 | +| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | 重新加载之前被卸载的权重和/或 KV Cache。成功后会自动 `resume`。 | +| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 | + +### 兼容性说明 + +底层通信链路的优化同样适用于旧版接口。通过设置环境变量 `FD_ENABLE_V1_UPDATE_WEIGHTS=1`,可以将旧版接口切换到新的控制链路,在保留兼容接口形式的同时,获得更明确的执行路径和更好的可观测性。 +- `FD_ENABLE_V1_UPDATE_WEIGHTS=0`:走旧版基于共享内存的控制链路。 +- `FD_ENABLE_V1_UPDATE_WEIGHTS=1`:`/clear_load_weight` 底层等价于执行 `/v1/sleep`,`/update_model_weight` 底层等价于执行 `/v1/wakeup`。对应的 pause/resume 动作分别由 `sleep` 和 `wakeup` 内部处理。 + +**注意**:无论是否设置 V1 环境变量,旧版接口都不是 RLHF 场景下推荐的标准使用方式,后续版本中也可能逐步废弃。建议优先使用 `/v1/*` 控制接口。 + +## 各接口语义 + +### `/v1/pause` + +`/v1/pause` 是变更模型状态前的安全边界。 + +它会执行以下动作: + +- 停止新请求生成; +- 中断当前 running 和 inflight 请求; +- 重置调度器状态; +- 在启用多级缓存或 KV Cache 存储时暂停 cache 传输。 + +如果需要在每一轮 rollout 与下一轮训练之间建立清晰的切换边界,建议先调用该接口。 + +### `/v1/sleep` + +`/v1/sleep` 用于从 GPU 显存中卸载指定的运行时状态。 + +当前支持的 `tags`: + +- `weight`:清除设备上的模型权重;如果开启了相关配置,还可能一并释放通信组和 DeepEP buffer。 +- `kv_cache`:清除 KV Cache;如果投机解码采用 MTP,还会同步清理 MTP cache。 + +如果不传 `tags` 参数,FastDeploy 默认等价于: + +```bash +/v1/sleep?tags=weight,kv_cache +``` + +当前实现中,`sleep` 会自动先执行一次 `pause`。但新的接入方不应长期依赖这一隐式行为。 + +### `/v1/wakeup` + +`/v1/wakeup` 用于恢复通过 `/v1/sleep` 卸载的状态。 + +根据 `tags` 和运行配置,FastDeploy 可能执行: + +- 重建通信组; +- 重建 DeepEP buffer; +- 从当前配置的数据源重新加载模型权重; +- 重建 KV Cache; +- 重新捕获 CUDA Graph。 + +`wakeup` 成功后,FastDeploy 会自动调用一次 `resume`。 + +### `/v1/update_weights` + +`/v1/update_weights` 用于在不卸载模型权重显存占用的情况下,直接刷新模型参数。 + +当前支持的请求字段: + +- `version`:可选字符串,用于指定目标 checkpoint 版本。 +- `rsync_config`:可选字典;如果传入,必须包含 `etcd_server`。 + +关键语义: + +- 调用前引擎必须已经处于暂停状态,否则请求会失败; +- 实际更新动作只在 worker 侧执行; +- 该接口主要用于显式权重刷新,尤其是 `rsync` 路径; +- 它不会自动执行 `resume`。 + +推荐调用顺序: + +1. `POST /v1/pause` +2. `POST /v1/update_weights` +3. `POST /v1/resume` + +如果除更新权重外,还希望在 rollout 轮次之间回收 GPU 显存,则更适合使用 `sleep` / `wakeup` 组合。 + +## 请求示例 + +### 基础接口 + +暂停引擎: + +```bash +curl -X POST http://127.0.0.1:8000/v1/pause +``` + +恢复引擎: + +```bash +curl -X POST http://127.0.0.1:8000/v1/resume +``` + +### Sleep / Wakeup 接口 + +**卸载权重和 KV Cache** + +```bash +# 同时卸载权重和 KV Cache +curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache" + +# 只卸载权重 +curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight" + +# 不传参数,默认同时卸载两者 +curl -X POST "http://127.0.0.1:8000/v1/sleep" +``` + +**恢复权重和 KV Cache** + +```bash +# 恢复权重和 KV Cache +curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache" + +# 只恢复权重 +curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight" + +# 不传参数,默认同时恢复两者 +curl -X POST "http://127.0.0.1:8000/v1/wakeup" +``` + +**注意**:当 `use_cudagraph=True` 时,必须先恢复 KV Cache 再恢复权重。这意味着调用 `/v1/wakeup` 时如果只包含 `weight` tag 而不包含 `kv_cache`,会报错。建议保持 sleep 和 wakeup 的 tags 参数一致。 + +### Update Weights 接口 + +切换到远端发布的新版本权重: + +```bash +curl -X POST http://127.0.0.1:8000/v1/update_weights \ + -H "Content-Type: application/json" \ + -d '{ + "version": "global_step_1200", + "rsync_config": { + "etcd_server": "127.0.0.1:2379" + } + }' +``` + +## 在 RLHF 中如何使用 + +### 推荐的 Rollout 服务配置 + +在 RLHF 场景下,FastDeploy 的 Rollout 服务通常采用以下配置: + +- `dynamic_load_weight=True` +- `load_strategy=ipc_snapshot`,适合本地快照式刷新; +- 或 `load_strategy=rsync`,适合远端版本化刷新。 + +仓库中的 RL rollout 工具已经按该方式接入。典型写法如下: + +```python +from fastdeploy.rl.rollout_config import RolloutModelConfig +from fastdeploy.rl.rollout_model import RolloutModel + +rollout_config = RolloutModelConfig( + model_name_or_path=model_path, + tensor_parallel_size=ranks, + dynamic_load_weight=True, + load_strategy="ipc_snapshot", +) +rollout_model = RolloutModel(rollout_config) +``` + +### FastDeploy 对训练侧的支持 + +除服务接口外,FastDeploy 还提供以下两类 RLHF 训练侧对接能力: + +- `RolloutModel.state_dict()`:暴露 rollout 侧的推理参数; +- `RolloutModel.get_name_mappings_to_training()`:暴露推理参数名到训练参数名的映射关系。 + +这两个接口可用于训练侧 checkpoint 与 rollout 侧权重布局对齐,尤其适用于推理态和训练态参数命名不完全一致的场景。 + +### RLHF 常见工作流 + +以下给出 RLHF 场景下几种常见调用方式。示例中假设服务地址为 `http://127.0.0.1:8000`。 + +**工作流 1:显存卸载与恢复** + +适用于 rollout 服务常驻,但需要在训练阶段前后释放并恢复显存的场景。推荐流程为 `(pause) -> sleep -> wakeup -> (resume)`,其中括号内步骤为可选。 + +```bash +# 可选:先显式暂停引擎,建立清晰的切换边界 +curl -X POST http://127.0.0.1:8000/v1/pause + +# 卸载权重和 KV Cache +curl -X POST "http://127.0.0.1:8000/v1/sleep?tags=weight,kv_cache" + +# 训练完成后恢复权重和 KV Cache +curl -X POST "http://127.0.0.1:8000/v1/wakeup?tags=weight,kv_cache" + +# 可选:如业务侧需要显式恢复,可手动调用 +curl -X POST http://127.0.0.1:8000/v1/resume +``` + +**工作流 2:原地刷新到新 checkpoint** + +适用于服务常驻、仅需要切换到新版本权重的场景。推荐流程为 `pause -> update_weights -> resume`。 + +```bash +# 先暂停引擎 +curl -X POST http://127.0.0.1:8000/v1/pause + +# 原地刷新到新版本权重 +curl -X POST http://127.0.0.1:8000/v1/update_weights \ + -H "Content-Type: application/json" \ + -d '{ + "version": "global_step_1200", + "rsync_config": { + "etcd_server": "127.0.0.1:2379" + } + }' + +# 更新完成后恢复服务 +curl -X POST http://127.0.0.1:8000/v1/resume +``` + +**工作流 3:兼容旧版接口** + +旧版 RL 客户端仍可继续使用兼容接口,流程为 `clear_load_weight -> update_model_weight`。 + +```bash +# 清除或卸载当前权重 +curl -X GET http://127.0.0.1:8000/clear_load_weight + +# 训练侧完成 checkpoint 更新后,重新加载权重 +curl -X GET http://127.0.0.1:8000/update_model_weight +``` + +对于新的接入方,建议优先使用 `/v1/*` 接口,因为其控制链路更显式,日志排查和故障定位也更直接。 + +## 其他相关配置 + +### 通信组的销毁与重建 + +FastDeploy 支持通过 `--shutdown-comm-group-if-worker-idle` 和 `--no-shutdown-comm-group-if-worker-idle`,显式控制在卸载权重时是否同时销毁通信组。 + +保留通信组通常有助于提升权重清除和重新加载过程中的稳定性;相应地,代价是卸载权重后仍会保留更多显存占用,同时 `sleep` / `wakeup` 的执行时间也可能更长。 + +默认情况下: + +- 在 EP 场景下,默认不销毁通信组; +- 在非 EP 场景下,默认销毁通信组。 + +### CPU 缓存的清除与重建 + +启用 `--swap-space` 后,可以通过以下环境变量控制在执行 `/v1/sleep` 时,是否同步清理 CPU 侧缓存,以降低训练阶段的内存压力。 + +默认情况下,FastDeploy 不会主动清理 CPU Cache。如需在 `sleep` 时一并清理,可设置: + +```bash +export FD_ENABLE_SWAP_SPACE_CLEARING=1 +``` diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 84e7d804c3..82911eccfa 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -32,6 +32,7 @@ class CacheStatus(Enum): CPU = 3 GPU2STORAGE = 4 STORAGE2GPU = 5 + CTRL = -1 class BlockNode: diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 74386f909a..b264f03b75 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -15,6 +15,7 @@ """ import argparse +import asyncio import concurrent.futures import gc import json @@ -35,7 +36,6 @@ from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTa from fastdeploy.cache_manager.ops import ( cuda_host_alloc, cuda_host_free, - memory_allocated, set_data_ipc, set_device, share_external_data_, @@ -49,7 +49,9 @@ from fastdeploy.cache_manager.transfer_factory import ( MooncakeStore, ) from fastdeploy.config import CacheConfig, SpeculativeConfig +from fastdeploy.engine.request import ControlRequest, ControlResponse from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus +from fastdeploy.inter_communicator.fmq import FMQ from fastdeploy.platforms import current_platform from fastdeploy.utils import console_logger, get_logger @@ -96,7 +98,6 @@ def parse_args(): help="engine worker queue port", ) parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number") - parser.add_argument("--ipc_suffix", type=str, default=None, help="engine pid") parser.add_argument( "--protocol", type=str, @@ -184,15 +185,17 @@ class CacheTransferManager: self.key_prefix = "" # extract other arg values + self.model_path = args.model_path self.model_id = os.path.basename(args.model_path.rstrip("/")) self.n_ranks = args.mp_num self.rank = args.rank self.device = args.device_id self.num_layers = args.num_layers - self.ipc_suffix = args.ipc_suffix + self.create_cache_tensor = args.create_cache_tensor self.local_data_parallel_id = args.local_data_parallel_id self.num_extra_layers = self.speculative_config.num_extra_cache_layer self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) + self.cache_queue_port = args.cache_queue_port paddle.set_default_dtype(args.default_dtype) self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) @@ -200,8 +203,10 @@ class CacheTransferManager: self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) + self.control_task_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.transfer_task_queue = queue.Queue() # 用来接收传输任务 self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 + self.ctrl_output_queue = None address = (args.pod_ip, args.cache_queue_port) self.cache_task_queue = EngineCacheQueue( @@ -209,7 +214,7 @@ class CacheTransferManager: is_server=False, num_client=args.mp_num, client_id=self.rank, - local_data_parallel_id=args.local_data_parallel_id, + local_data_parallel_id=0, ) cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32) @@ -231,11 +236,12 @@ class CacheTransferManager: self.num_cpu_blocks = args.num_cpu_blocks - self._init_gpu_cache(args) + self._init_gpu_cache() if self.num_cpu_blocks > 0: - self._init_cpu_cache(args) + self._init_cpu_cache() if self.storage_backend_type is not None: self._init_storage(args) + self._init_control() cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) self.cache_task_broadcast_signal = IPCSignal( @@ -287,7 +293,9 @@ class CacheTransferManager: threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start() self.is_paused = False # transfer manager state + self.is_sleeping = False self.inflight = 0 # number of inflight transfer tasks + self.inflight_tasks = {} cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32) self.cache_transfer_inited_signal = IPCSignal( @@ -299,6 +307,15 @@ class CacheTransferManager: ) self.cache_transfer_inited_signal.value[self.rank] = 1 + def _init_control(self): + dp_rank = self.local_data_parallel_id + tp_rank = self.rank + tp_size = self.n_ranks + cq_port = self.cache_queue_port + name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_rank}_{cq_port}" + self.ctrl_output_queue = FMQ().queue(name, "producer") + logger.info(f"Init control output queue: {name} (producer)") + def _init_storage(self, args): try: # TODO: support cache scale for other backend @@ -354,13 +371,19 @@ class CacheTransferManager: raise ValueError(f"Invalid write policy: {args.write_policy}") self.write_policy = args.write_policy - version_file_path = os.path.join(args.model_path, "version.yaml") - if os.path.exists(version_file_path): - self.key_prefix = get_key_prefix_from_version(version_file_path) - logger.info(f"The key_prefix of cache storage is {self.key_prefix}") + self._update_key_prefix() logger.info("Initialize cache storage successfully") + def _update_key_prefix(self): + # use key_prefix to distinguish cache for different version of weight in rl + version_file_path = os.path.join(self.model_path, "version.yaml") + if os.path.exists(version_file_path): + self.key_prefix = get_key_prefix_from_version(version_file_path) + logger.info(f"Update key_prefix of cache storage to {self.key_prefix}") + else: + logger.error(f"version.yaml not found at {version_file_path}") + def _init_storage_buffer(self, args): """ Initialize pinned memory buffer that can hold the cache for a longest request @@ -409,20 +432,20 @@ class CacheTransferManager: self.storage_value_scale_write_buffer = write_buffer + scale_buffer_total_bytes // 2 self.storage_backend.register_buffer(write_buffer, scale_buffer_total_bytes) - def _init_gpu_cache(self, args): + def _init_gpu_cache(self): - if not args.create_cache_tensor: - logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners or messagers to create kv cache.") + if not self.create_cache_tensor: + logger.info("Waiting for runners or messagers to create kv cache.") while self.cache_ready_signal.value[self.rank] != 1: time.sleep(0.1) - logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.") + logger.info("OK! Stop waiting.") - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": cache_type = "uint8" else: - cache_type = args.cache_dtype + cache_type = self.cache_dtype - logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.") + logger.info("Initializing kv cache for all layers.") set_device(self.device) for i in range(self.num_layers + self.num_extra_layers): # NOTE: num_extra_layer_gpu_blocks is usually equal to num_gpu_blocks @@ -445,14 +468,12 @@ class CacheTransferManager: self.value_cache_shape[2], self.value_cache_shape[3], ] - if args.create_cache_tensor: - logger.info( - f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}" - ) + if self.create_cache_tensor: + logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_name) - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": key_cache_scales = paddle.full( shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]], fill_value=0, @@ -463,7 +484,7 @@ class CacheTransferManager: val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(val_cache, val_name) - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": value_cache_scales = paddle.full( shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]], fill_value=0, @@ -471,13 +492,11 @@ class CacheTransferManager: ) set_data_ipc(value_cache_scales, value_cache_scales_name) else: - logger.info( - f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}" - ) + logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) val_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True) - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) key_cache_scales = share_external_data_( key_cache_scales, @@ -487,7 +506,7 @@ class CacheTransferManager: ) if self.value_cache_shape: val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True) - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) value_cache_scales = share_external_data_( value_cache_scales, @@ -498,48 +517,65 @@ class CacheTransferManager: self.gpu_cache_kvs[key_name] = key_cache self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name]) - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": self.gpu_cache_kvs[key_cache_scales_name] = key_cache_scales self.gpu_cache_scales_k_tensors.append(self.gpu_cache_kvs[key_cache_scales_name]) - if args.value_cache_shape: + if self.value_cache_shape: self.gpu_cache_kvs[val_name] = val_cache self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name]) - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": self.gpu_cache_kvs[value_cache_scales_name] = value_cache_scales self.gpu_cache_scales_v_tensors.append(self.gpu_cache_kvs[value_cache_scales_name]) - if args.create_cache_tensor: - logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!") + if self.create_cache_tensor: self.cache_ready_signal.value[self.rank] = 1 + while np.sum(self.cache_ready_signal.value) != self.n_ranks: + time.sleep(0.1) - cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()]) - logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}") - logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}") - logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}") + logger.info("GPU KV cache is initialized") - def _init_cpu_cache(self, args): + def _clear_gpu_cache(self): + if self.create_cache_tensor: + logger.debug("Waiting for gpu runner to unlink cuda ipc") + while self.cache_ready_signal.value[self.rank] != 0: + time.sleep(0.1) + logger.debug("Stop waiting! gpu runner has unlinked cuda ipc") + self.gpu_cache_kvs.clear() + self.gpu_cache_k_tensors.clear() + self.gpu_cache_v_tensors.clear() + if hasattr(self, "gpu_cache_scales_k_tensors"): + self.gpu_cache_scales_k_tensors.clear() + if hasattr(self, "gpu_cache_scales_v_tensors"): + self.gpu_cache_scales_v_tensors.clear() + paddle.device.cuda.empty_cache() + else: + for name, tensor in self.gpu_cache_kvs.items(): + unset_data_ipc(tensor, name, True, False) + logger.debug("Successfully unlinked gpu caches cuda ipc") + self.cache_ready_signal.value[self.rank] = 0 + + while np.sum(self.cache_ready_signal.value) != 0: + time.sleep(0.1) + logger.info("All ranks cleared gpu caches") + + def _init_cpu_cache(self): + if self.num_cpu_blocks == 0: + return + paddle.set_device("cpu") key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3] - if args.value_cache_shape: + if self.value_cache_shape: value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3] else: value_cache_size = 0 cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype) - key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size - value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size - if args.cache_dtype == "block_wise_fp8": + key_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * key_cache_size + value_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * value_cache_size + logger.info("Initializing swap space (cpu cache) for all layers.") + if self.cache_dtype == "block_wise_fp8": cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2] - scales_key_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size - scales_value_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size - logger.info( - f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB" - ) - if args.num_cpu_blocks == 0: - logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.") - self.swap_space_ready_signal.value[self.rank] = 1 - return - logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.") - paddle.set_device("cpu") + scales_key_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size + scales_value_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size self.k_dst_ptrs = [] self.v_dst_ptrs = [] self.k_scales_ptrs = [] @@ -550,22 +586,43 @@ class CacheTransferManager: key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}" value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}" logger.info( - f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB" + f"..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB" ) self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes) self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name]) - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(scales_key_need_to_allocate_bytes) self.k_scales_ptrs.append(self.cpu_cache_kvs[key_cache_scales_name]) if value_need_to_allocate_bytes > 0: self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes) self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name]) - if args.cache_dtype == "block_wise_fp8": + if self.cache_dtype == "block_wise_fp8": self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(scales_value_need_to_allocate_bytes) self.v_scales_ptrs.append(self.cpu_cache_kvs[value_cache_scales_name]) - logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!") + logger.info("Swap space (cpu cache) is ready!") self.swap_space_ready_signal.value[self.rank] = 1 + while np.sum(self.swap_space_ready_signal.value) != self.n_ranks: + time.sleep(0.1) + logger.info("All ranks init cpu caches") + + def _clear_cpu_cache(self): + for ptrs in self.k_dst_ptrs + self.v_dst_ptrs: + cuda_host_free(ptrs) + self.cpu_cache_kvs.clear() + self.k_dst_ptrs.clear() + self.v_dst_ptrs.clear() + if hasattr(self, "k_scales_ptrs"): + self.k_scales_ptrs.clear() + if hasattr(self, "v_scales_ptrs"): + self.v_scales_ptrs.clear() + gc.collect() + self.swap_space_ready_signal.value[self.rank] = 0 + + while np.sum(self.swap_space_ready_signal.value) != 0: + time.sleep(0.1) + logger.info("All ranks cleared cpu caches") + def _run_read_storage( self, task_id: str, @@ -1023,6 +1080,79 @@ class CacheTransferManager: logger.debug(f"_do_swap_to_gpu_task: put_transfer_done_signal {result}") logger.info(f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}") + def _handle_pause(self): + if self.is_paused: + logger.info("💡 Cache transfer manager is already paused, no need to pause again!") + else: + self.pause() + logger.info("✅ Successfully paused transfer") + return True + + def _handle_resume(self): + if not self.is_paused: + logger.info("💡 Cache transfer manager is not paused, no need to resume!") + else: + self.resume() + if self.storage_backend_type is not None: + self._update_key_prefix() + logger.info("✅ Successfully resumed transfer") + return True + + def _handle_sleep(self): + if self.is_sleeping: + logger.info("💡 Cache transfer manager is already sleeping, no need to sleep again!") + else: + if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING: + self._clear_cpu_cache() + self._clear_gpu_cache() + self.is_sleeping = True + logger.info("✅ Successfully fell asleep (offloaded caches)") + return True + + def _handle_wakeup(self): + if not self.is_sleeping: + logger.info("💡 Cache transfer manager is not sleeping, no need to wakeup!") + else: + if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING: + self._init_cpu_cache() + self._init_gpu_cache() + self.is_sleeping = False + logger.info("✅ Successfully wakeup (reload caches)") + return True + + def control_task(self, task: ControlRequest): + method = task.get_method() + tags = task.args.get("tags", {}) + logger.info(f"Received control task: {method}, tags: {tags}") + + handlers = { + "pause": self._handle_pause, + "resume": self._handle_resume, + "sleep": self._handle_sleep, + "wakeup": self._handle_wakeup, + } + + handler = handlers.get(method) + error_code = 200 + error_message = "Success" + + if handler: + try: + handler() + except Exception as e: + error_code = 500 + error_message = f"Failed to execute {method}: {str(e)}" + logger.error(f"Error in control_task: {traceback.format_exc()}") + else: + error_code = 400 + error_message = f"Unknown control method: {method}" + logger.warning(error_message) + + self.cache_task_queue.barrier.wait() + resp = ControlResponse(task.request_id, error_code, error_message) + asyncio.run(self.ctrl_output_queue.put(resp)) + logger.info(f"Put response into output queue {self.ctrl_output_queue.name}: {resp}") + def check_work_status(self, time_interval_threashold=envs.FD_CACHE_PROC_EXIT_TIMEOUT): """ Check the health of the model server by checking whether all workers are alive. @@ -1042,8 +1172,13 @@ class CacheTransferManager: return fn(*args) finally: self.inflight -= 1 + logger.debug(f"submit_task: {fn.__name__} finished, args: {args}, current inflight: {self.inflight}") + self.inflight += 1 thread_pool.submit(inflight_task, task_fn, *args) + logger.debug( + f"submit_task: {task_fn.__name__} submitted to thread pool, args: {args}, current inflight: {self.inflight}" + ) def do_data_transfer(self): """ @@ -1054,6 +1189,7 @@ class CacheTransferManager: max_errors = ( envs.FD_CACHE_PROC_ERROR_COUNT ) # After this many consecutive errors, check if the worker process exists. + is_paused = False while True: try: @@ -1065,18 +1201,18 @@ class CacheTransferManager: self.cache_task_queue.barrier0.reset() # Ensure all ranks synchronically do one of the following things: - # (1) If rank#0 is paused, wait for a short time and check out rank#0 status again; - # (2) otherwise, all ranks are allowed to pull tasks from cache task queue + # (1) If rank#0 is paused, wait for inflight tasks to finish first, then only process control tasks; + # (2) otherwise, all ranks are allowed to pull all tasks from cache task queue if self.cache_task_is_paused_signal.value[0] == 1: # wait for inflight tasks to finish first while self.inflight != 0: time.sleep(0.1) # mark the current rank as not having inflight tasks self.cache_task_inflight_signal.value[self.rank] = 0 - time.sleep(1) - continue + is_paused = True else: self.cache_task_inflight_signal.value[self.rank] = 1 + is_paused = False if self.rank == 0: if not self.cache_task_queue.empty(): @@ -1087,12 +1223,16 @@ class CacheTransferManager: self.cache_task_queue.barrier1.reset() if self.cache_task_broadcast_signal.value[0] == 1: - self.inflight += 1 data, read_finish = self.cache_task_queue.get_transfer_task() logger.debug(f"do_data_transfer: {data}") if read_finish: self.cache_task_broadcast_signal.value[0] = 0 event_type, event_args = data[0], data[1:] + + # control task is the only task allowed to execute when loop is paused + if is_paused and event_type.value != CacheStatus.CTRL.value: + continue + if event_type.value == CacheStatus.SWAP2CPU.value: transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args self.submit_task( @@ -1129,6 +1269,13 @@ class CacheTransferManager: self.write_back_storage_task, write_storage_task, ) + elif event_type.value == CacheStatus.CTRL.value: + control_task = event_args[0] + self.control_task_thread_pool.submit( + self.control_task, + control_task, + ) + else: if self.n_ranks > 1: self.cache_task_queue.barrier2.wait() @@ -1291,114 +1438,46 @@ class CacheTransferManager: # TODO XPU support RL if unset_data_ipc is None: return - logger.info("[RL] Launch a thread to clear/restore kv cache when model weights are cleared/updated.") + logger.info( + "check_cache_status: Launch a thread to clear/restore kv cache when model weights are cleared/updated." + ) while True: # handle cache clearing/restoring if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING: assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache." try: - # wait for inflight transfer tasks to finish and pause transfer manager + # pause transfer self.pause() - # clear cpu caches - logger.info("[RL] start clearing caches") - logger.debug("[RL] start clearing cpu caches") + # clear caches + logger.info("check_cache_status: start clearing caches") if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING: - paddle.set_device("cpu") - for ptrs in self.k_dst_ptrs + self.v_dst_ptrs: - cuda_host_free(ptrs) - self.cpu_cache_kvs.clear() - self.k_dst_ptrs.clear() - self.v_dst_ptrs.clear() - if self.cache_dtype == "block_wise_fp8": - self.k_scales_ptrs.clear() - self.v_scales_ptrs.clear() - gc.collect() - logger.debug("[RL] successfully cleared cpu caches") - # reset swap_space_ready_signal - self.swap_space_ready_signal.value[self.rank] = 0 - while np.sum(self.swap_space_ready_signal.value) != 0: - time.sleep(0.1) - logger.debug("[RL] all ranks cleared cpu caches") - else: - logger.debug("[RL] skip clearing cpu caches") - - # clear gpu caches - logger.debug("[RL] start clearing gpu caches") - if args.create_cache_tensor: - logger.info("[RL] waiting for gpu runner to unlink cuda ipc") - while self.cache_ready_signal.value[self.rank] != 0: - time.sleep(0.1) - logger.info("[RL] stop waiting! gpu runner has unlinked cuda ipc") - paddle.set_device(f"gpu:{self.device}") - self.gpu_cache_kvs.clear() - self.gpu_cache_k_tensors.clear() - self.gpu_cache_v_tensors.clear() - if self.cache_dtype == "block_wise_fp8": - self.gpu_cache_scales_k_tensors.clear() - self.gpu_cache_scales_v_tensors.clear() - paddle.device.cuda.empty_cache() - logger.debug("[RL] successfully cleared gpu caches") - else: - for name, tensor in self.gpu_cache_kvs.items(): - unset_data_ipc(tensor, name, True, False) - logger.debug("[RL] successfully unlinked gpu caches cuda ipc") - self.cache_ready_signal.value[self.rank] = 0 - - while np.sum(self.cache_ready_signal.value) != 0: - time.sleep(0.1) - logger.info("[RL] all ranks cleared caches!") - - # reset kv_cache_status_signal + self._clear_cpu_cache() + self._clear_gpu_cache() self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED - - self._log_memory("after clearing caches") - + self._log_memory("check_cache_status: after clearing caches") except Exception as e: - logger.error(f"[RL] failed to clear caches: {e}") + logger.error(f"check_cache_status: failed to clear caches: {e}") elif self.kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING: assert args.splitwise_role == "mixed", "Only mixed mode supports updating cache." try: - # restore cpu cache - logger.info("[RL] start restoring caches") - logger.debug("[RL] start restoring cpu caches") + logger.info("check_cache_status: start restoring caches") if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING: - self._init_cpu_cache(args) - logger.debug("[RL] successfully restored cpu caches") - while np.sum(self.swap_space_ready_signal.value) != args.mp_num: - time.sleep(0.1) - logger.debug("[RL] all ranks restored cpu caches") - else: - logger.debug("[RL] skip restoring cpu caches") - - # restore gpu cache and set cache_ready_signal - logger.debug("[RL] start restoring gpu caches") - self._init_gpu_cache(args) - logger.debug("[RL] successfully restored gpu caches") + self._init_cpu_cache() + self._init_gpu_cache() + # update key prefix for kv cache backend if self.storage_backend_type is not None: - # use key_prefix to distinguish cache for different version of weight in rl - version_file_path = os.path.join(args.model_path, "version.yaml") - assert os.path.exists(version_file_path), f"version.yaml not found at {version_file_path}" - self.key_prefix = get_key_prefix_from_version(version_file_path) - logger.info(f"Update key_prefix of cache storage to {self.key_prefix}") - - # wait for all ranks caches to be ready - while np.sum(self.cache_ready_signal.value) != args.mp_num: - time.sleep(0.1) - logger.info("[RL] all ranks restored caches!") + self._update_key_prefix() # resume transfer self.resume() - - # set kv_cache_status_signal self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL - - self._log_memory("after restoring caches") + self._log_memory("check_cache_status: after restoring caches") except Exception as e: - logger.error(f"[RL] failed to restore caches: {e}") + logger.error(f"check_cache_status: failed to restore caches: {e}") time.sleep(0.1) @@ -1407,11 +1486,11 @@ class CacheTransferManager: self.cache_task_queue.pause_barrier.wait() if self.rank == 0: self.cache_task_queue.pause_barrier.reset() - logger.info("[RL] 🟠 wait for inflight transfer tasks to finish") + logger.info("pause: 🟠 wait for inflight transfer tasks to finish") self.is_paused = True while np.sum(self.cache_task_inflight_signal.value) != 0: time.sleep(0.1) - logger.info("[RL] 🔴 pause transfer manager and stop do transfer tasks") + logger.info("pause: 🔴 pause transfer manager and stop do transfer tasks") def resume(self): if self.n_ranks > 1: @@ -1421,7 +1500,7 @@ class CacheTransferManager: self.is_paused = False while np.sum(self.cache_task_inflight_signal.value) != self.n_ranks: time.sleep(0.1) - logger.info("[RL] 🟢 resume transfer manager and start to do transfer tasks") + logger.info("resume: 🟢 resume transfer manager and start to do transfer tasks") def _log_memory(self, context: str): """Log current GPU memory usage.""" @@ -1430,8 +1509,8 @@ class CacheTransferManager: curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3) curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3) - logger.warning( - f"GPU memory usage {context}:" + logger.info( + f"{context}: " f"max_allocated: {max_alloc:.2f}GB " f"max_reserved: {max_reserved:.2f}GB " f"current_allocated: {curr_alloc:.2f}GB " diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index fd3f11e14e..dd64d7fb71 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -217,7 +217,7 @@ class PrefixCacheManager: is_server=False, num_client=tensor_parallel_size, client_id=0, - local_data_parallel_id=self.local_data_parallel_id, + local_data_parallel_id=0, ) current_dir_path = os.path.split(os.path.abspath(__file__))[0] @@ -293,7 +293,7 @@ class PrefixCacheManager: else: storage_arg_str = " " - if self.cache_config.swap_space or self.cache_config.kvcache_storage_backend: + if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend: for i in range(tensor_parallel_size): launch_cmd = ( "FLAGS_allocator_strategy=auto_growth " @@ -314,7 +314,6 @@ class PrefixCacheManager: + f" --pod_ip {pod_ip}" + f" --engine_worker_queue_port {engine_worker_queue_port}" + f" --num_cpu_blocks {cache_config.num_cpu_blocks}" - + f" --ipc_suffix {ipc_suffix}" + f" --protocol {cache_config.cache_transfer_protocol}" + f" --local_data_parallel_id {self.local_data_parallel_id}" + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" @@ -353,9 +352,8 @@ class PrefixCacheManager: # Start additional threads if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0: - logger.info("Enable hierarchical cache.") threading.Thread(target=self.recv_data_transfer_result, daemon=True).start() - if cache_config.enable_prefix_caching: + if cache_config.enable_prefix_caching and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: threading.Thread(target=self.clear_prefix_cache, daemon=True).start() all_cache_processes = cache_messager_processes + cache_manager_processes diff --git a/fastdeploy/config.py b/fastdeploy/config.py index a26e694a0c..2d115ebbd4 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -687,6 +687,9 @@ class ParallelConfig: if self.shutdown_comm_group_if_worker_idle is None: self.shutdown_comm_group_if_worker_idle = not self.use_ep + if self.shutdown_comm_group_if_worker_idle and envs.FD_ENABLE_V1_UPDATE_WEIGHTS: + raise RuntimeError("shutdown_comm_group_if_worker_idle cannot be True when FD_ENABLE_V1_UPDATE_WEIGHTS=1") + # pd_disaggregation use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0)) use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0)) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index aa2cdb29aa..05bee6edbe 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -39,6 +39,8 @@ import zmq from tqdm import tqdm import fastdeploy.metrics.trace as tracing +from fastdeploy.cache_manager.cache_data import CacheStatus +from fastdeploy.config import FDConfig from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( ControlRequest, @@ -84,7 +86,7 @@ class EngineService: Base class containing common engine functionality """ - def __init__(self, cfg, start_queue=True, use_async_llm=False): + def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): """ Initializes the LLMEngine with the provided configuration. @@ -104,14 +106,23 @@ class EngineService: self.is_paused = False # pause request generation self._pause_cond = threading.Condition() - self._ctrl_worker_output_queues = [] + self._ctrl_output_queues = {} + self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict) tp_size = cfg.parallel_config.tensor_parallel_size dp_index = cfg.parallel_config.local_data_parallel_id - for rank in range(tp_size): + for tp_rank in range(tp_size): + # create worker control response queue engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port - name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}" - self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)") - self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer")) + name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}" + self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)") + self._ctrl_output_queues[name] = FMQ().queue(name, "consumer") + + # create cache control response queue + if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: + engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port + name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}" + self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)") + self._ctrl_output_queues[name] = FMQ().queue(name, "consumer") self.scheduler = cfg.scheduler_config.scheduler() self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" @@ -1296,7 +1307,7 @@ class EngineService: worker_pid = self.request_worker_map.pop(request_id, None) try: - self.llm_logger.info(f"START run control method {request_id}: {method}") + self.llm_logger.info(f"Start to run control method {method}: {request_id}") handler_name = f"_control_{method}" handler = getattr(self, handler_name, None) @@ -1308,13 +1319,13 @@ class EngineService: return result = handler(control_req) - self.llm_logger.info(f"SUCCESS run control method {method}.") + self.llm_logger.info(f"Successfully run control method {method}: {request_id} {result}") succ_result = ControlResponse(request_id, 200, "Success", result) data = [[succ_result]] if envs.ZMQ_SEND_BATCH_DATA else [succ_result] self.send_response_server.send_response(request_id, data, worker_pid=worker_pid) except Exception as e: - error_msg = f"Failed run control method {method}: {str(e)}" + error_msg = f"Failed to run control method {method}: {request_id} {str(e)}" self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}") error_result = ControlResponse(request_id, 500, error_msg) data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result] @@ -1338,12 +1349,15 @@ class EngineService: if self.cfg.scheduler_config.name != "local": raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}") + self.llm_logger.info("Start to pause request generation.") + with self._pause_cond: if self.is_paused: - self.llm_logger.info("Pause Request Generation: already paused.") + self.llm_logger.info("Engine is already paused, no need to pause again.") + return self.is_paused = True - self.llm_logger.info("Start Abort Running Requests") + self.llm_logger.info("Abort running requests.") self.resource_manager.log_status() # preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue @@ -1354,7 +1368,7 @@ class EngineService: if count >= timeout * 1000: break if count >= timeout * 1000: - error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged" + error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!" self.llm_logger.error(error_msg) raise Exception(error_msg) running_reqs = self.resource_manager.preempted_all() @@ -1369,12 +1383,22 @@ class EngineService: # abort inflight requests to user inflight_requests = self.scheduler.get_inflight_requests() - self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests") + self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).") for req in inflight_requests: - self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.") + self._send_error_response(req.request_id, "Request is aborted since engine is paused.") self.scheduler.reset() + # pause cache transfer + if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: + self.llm_logger.info("Start to pause cache transfer.") + pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause") + self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request)) + # Wait for cache_transfer responses + asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"])) + self.llm_logger.info("Successfully paused cache transfer.") + self.resource_manager.cache_manager.reset() + self.llm_logger.info("Successfully paused request generation.") return None def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: @@ -1386,14 +1410,24 @@ class EngineService: Args: control_request: Control request object containing resume operation information """ - self.llm_logger.info("START Resume Request Generation") + self.llm_logger.info("Start to resume request generation.") with self._pause_cond: if not self.is_paused: - self.llm_logger.info("Resume Request Generation: not paused.") + self.llm_logger.info("Engine is not paused, no need to resume.") return None self.is_paused = False self._pause_cond.notify_all() - self.llm_logger.info("END Resume Request Generation") + + # resume cache transfer + if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: + self.llm_logger.info("Start to resume cache transfer.") + resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume") + self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request)) + # Wait for cache_transfer responses + asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"])) + self.llm_logger.info("Successfully resumed cache transfer.") + + self.llm_logger.info("Successfully resumed request generation.") return None def _control_is_paused(self, control_request: ControlRequest) -> bool: @@ -1447,49 +1481,218 @@ class EngineService: return responses - async def _wait_all_control_responses(self, request_id: str, timeout: int): - """Wait for control responses from all workers with a global timeout. - - This method concurrently waits for responses from all control workers - and enforces an overall timeout to avoid leaking pending tasks. + def _parse_tags(self, control_request: ControlRequest): """ - timeout_ms = timeout * 1000 - # Create one get() coroutine per worker output queue - tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues] - - try: - results = await asyncio.wait_for( - asyncio.gather(*tasks, return_exceptions=True), - timeout=timeout, + Parse tags from control request. + """ + allowed_tags = ["weight", "kv_cache"] + tags = control_request.args.get("tags", None) + if tags is None: + tags = ",".join(allowed_tags) + control_request.args["tags"] = tags + self.llm_logger.info( + f"Detected empty tags of request {control_request.request_id}, defaulting to tags: {tags}" ) - except asyncio.TimeoutError: - # Keep the error message consistent with previous behavior - raise Exception("Worker Update Weights Timeouted after 600s") + elif isinstance(tags, list): + tags = ",".join(tags) + for tag in tags.split(","): + if tag not in allowed_tags: + raise ValueError(f"Unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}") + + return tags + + def _control_sleep(self, control_request: ControlRequest): + """ + Offload gpu memory occupation for certain parts, e.g. weight, cache. + + Args: + control_request: Control request object containing parameters for offloading memory + tags: list of tags to offload, supported values: ["weight", "cache"] + + TODO: support different level of offloading, to provide options for release memory forever + or merely offloading to cpu memory for now. + """ + # Args check + tags = self._parse_tags(control_request) + control_request.args["tags"] = tags + + # Make sure llm engine is paused. + self.llm_logger.warning( + "Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. " + "Please explicitly request to /pause the engine before /sleep." + ) + self._control_pause(None) + + # Determine which executors are needed for the sleep command + executors = set() + if "weight" in tags: + executors.add("worker") + if "kv_cache" in tags: + executors.add("worker") + if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: + executors.add("cache_transfer") + if self.cfg.cache_config.enable_prefix_caching: + self.resource_manager.cache_manager.reset() + + # Dispatch sleep request to executors + self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}") + self._dispatch_control_request(control_request, executors) + return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors)) + + def _control_wakeup(self, control_request: ControlRequest): + """ + Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache. + + Args: + control_request: Control request object containing parameters for reloading memory + tags: list of tags to reload, supported values: ["weight", "kv_cache"] + """ + # Args check + tags = self._parse_tags(control_request) + control_request.args["tags"] = tags + + # Determine which executors are needed for the wakeup command + executors = set() + if "weight" in tags: + executors.add("worker") + if "kv_cache" in tags: + executors.add("worker") + if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: + executors.add("cache_transfer") + + # Dispatch wakeup request to executors + self.llm_logger.info(f"Dispatch wakeup request to executors: {list(executors)}") + self._dispatch_control_request(control_request, executors) + result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 300, executors=executors)) + + # Resume the engine after wakeup + self._control_resume(None) + + return result + + def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]): + """ + Dispatch control requests to workers, cache managers or engine itself. + + Args: + control_request: ControlRequest + executors: List + """ + if "worker" in executors: + self.engine_worker_queue.put_tasks(([control_request], 1)) + if "cache_transfer" in executors: + if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: + self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request)) + return + + async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None): + """Wait for matching control responses from the selected executor queues. + + This helper selects the control-response queues that belong to the requested + executors, then waits for all of them concurrently. Each queue gets a local + waiter that keeps reading until it sees the target request ID and stashes stale + responses into that queue's mailbox. + + Args: + request_id: The control request ID that all returned responses must match. + timeout: Global timeout budget in seconds for the full multi-queue wait. + executors: Executor groups to wait for, for example `["worker"]` or + `["worker", "cache_transfer"]`. If `None`, waits for all control + response queues. + + Returns: + A list of `response.result` values collected from all matched + `ControlResponse` objects. If no queue is selected, returns `None`. + + Raises: + Exception: If the overall wait times out, or if any queue reports a non-200 + control response or fails while waiting. + """ + + def select_control_queues(executors: List[str] = None): + """Select control response queues by executors.""" + if executors is None: + return self._ctrl_output_queues + else: + queues = {} + for k, v in self._ctrl_output_queues.items(): + if "w2e" in k and "worker" in executors: + queues[k] = v + elif "c2e" in k and "cache_transfer" in executors: + queues[k] = v + return queues + + async def wait_one(queue_name: str, queue): + """Wait until one queue returns a response for the current request_id.""" + mailbox = self._ctrl_response_mailboxes[queue_name] + # Reuse a previously stashed response for this request before touching FMQ again. + cached_response = mailbox.pop(request_id, None) + if cached_response is not None: + self.llm_logger.info(f"Returning cached control response from {queue_name}.") + return cached_response + + while True: + msg = await queue.get() + + # Return if the response matches the control request + response: ControlResponse = msg.payload + if response.request_id == request_id: + self.llm_logger.info(f"Returning new control response from {queue_name}.") + return response + + # Stash late responses from other control requests so they do not consume the + # current request's only read chance on this queue. + mailbox[response.request_id] = response + self.llm_logger.info( + f"Stashed old control response from {queue_name}. " + f"Expected request {request_id}, got request {response.request_id}" + ) + + # Select only the control response queues that belong to the requested executors. + queues = select_control_queues(executors) + if not queues: + self.llm_logger.info(f"No queues to wait for, executors: {executors}") + return + self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}") + + # Each queue gets its own waiter, which will stash stale responses until it finds the + # target request ID for this control request. + tasks = {name: asyncio.create_task(wait_one(name, queue)) for name, queue in queues.items()} + done, pending = await asyncio.wait(tasks.values(), timeout=timeout) + if pending: + pending_names = [name for name, task in tasks.items() if task in pending] + done_names = [name for name, task in tasks.items() if task in done] + self.llm_logger.error( + f"Control request {request_id} execution timeout. " + f"Pending queues: {pending_names}, completed queues: {done_names}." + ) + # Stop unfinished queue waiters so they do not outlive the control request. + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + raise Exception(f"Control request {request_id} timed out after {timeout}s") + + # Collect the results from all completed queues. responses = [] - for output_queue, msg in zip(self._ctrl_worker_output_queues, results): - if isinstance(msg, Exception): - self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}") - raise Exception(f"Call Worker error: {repr(msg)}") - if msg is None: - # Preserve original semantics when no message is received - raise Exception("Worker Update Weights Timeouted after 600s") - response: ControlResponse = msg.payload - if response.request_id != request_id: - self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}") - continue + for name, task in tasks.items(): + try: + response = task.result() + except Exception as e: + self.llm_logger.error(f"Waiting for control response from {name} failed: {repr(e)}") + raise + if response.error_code != 200: - self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}") - raise Exception(f"Call Worker error: {response.error_message}") - self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}") + raise Exception(f"Error response from {name}: {response.error_message}") responses.append(response.result) + return responses def _call_worker(self, control_request: ControlRequest, timeout: int): request_id = control_request.request_id self.engine_worker_queue.put_tasks(([control_request], 1)) # Use a single asyncio.run() to concurrently wait for all worker responses. - return asyncio.run(self._wait_all_control_responses(request_id, timeout)) + return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"])) def _send_error_response(self, request_id, error_msg, error_code: int = 500, worker_pid=None): self.llm_logger.error( diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index a58b1acda9..3f311e743f 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -595,11 +595,14 @@ class EngineClient: return True, "" async def run_control_method(self, request: ControlRequest): - api_server_logger.info(f"Start Run Control Method: {request}") + api_server_logger.info(f"Received control request: {request}") req_dict = request.to_dict() if envs.ZMQ_SEND_BATCH_DATA: req_dict["zmq_worker_pid"] = self.worker_pid - self.zmq_client.send_json(req_dict) + if not self.enable_mm and not envs.ENABLE_V1_DATA_PROCESSOR: + self.zmq_client.send_json(req_dict) + else: + self.zmq_client.send_pyobj(req_dict) request_id = request.request_id dealer, response_queue = await self.connection_manager.get_connection(request_id) if not envs.ZMQ_SEND_BATCH_DATA: @@ -608,12 +611,29 @@ class EngineClient: # todo: support user specified timeout. default 600s is enough for most control cases response = await asyncio.wait_for(response_queue.get(), timeout=600) response = ControlResponse.from_dict(response[0]) - api_server_logger.info(f"End Run Control Method: {response}") + api_server_logger.info(f"Return control response: {response}") return response except asyncio.TimeoutError: error_response = ControlResponse(request_id, 500, "Timeout waiting for control method response") - api_server_logger.error(f"Error Run Control Method: {error_response}") + api_server_logger.error(f"Control request timed out: {error_response}") return error_response + except Exception as e: + import traceback + + api_server_logger.error(f"Unknown error in control method: {str(e)}\n{traceback.format_exc()}") + error_response = ControlResponse(request_id, 500, str(e)) + return error_response + + def run_control_method_sync(self, request: ControlRequest, event_loop): + """ + Support running control methods by a synchronous caller. + + NOTE: Since asyncio.Queue operations must occur in the same event loop, + this method bridges synchronous and asynchronous execution by running + the async run_control_method in the specified event loop. + """ + future = asyncio.run_coroutine_threadsafe(self.run_control_method(request), event_loop) + return future.result() def is_workers_alive(self): """ diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 3d4d798083..a48850e295 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -284,6 +284,7 @@ async def lifespan(app: FastAPI): app.state.completion_handler = completion_handler app.state.embedding_handler = embedding_handler app.state.reward_handler = reward_handler + app.state.event_loop = asyncio.get_running_loop() if llm_engine is not None and not isinstance(llm_engine, AsyncLLM): llm_engine.engine.data_processor = engine_client.data_processor @@ -406,6 +407,44 @@ async def is_paused(request: Request) -> Response: return control_response.to_api_json_response() +@app.post("/v1/sleep") +async def sleep(request: Request) -> Response: + request_id = f"control-{uuid.uuid4()}" + # Support both JSON body and query parameter + if await request.body(): + request_data = await request.json() + else: + # Extract query params + request_data = dict(request.query_params) + + try: + control_request = ControlRequest(request_id, "sleep", request_data) + except TypeError as e: + return JSONResponse(status_code=400, content={"error": "Invalid parameter type", "message": str(e)}) + + control_response = await app.state.engine_client.run_control_method(control_request) + return control_response.to_api_json_response() + + +@app.post("/v1/wakeup") +async def wakeup(request: Request) -> Response: + request_id = f"control-{uuid.uuid4()}" + # Support both JSON body and query parameter + if await request.body(): + request_data = await request.json() + else: + # Extract query params + request_data = dict(request.query_params) + + try: + control_request = ControlRequest(request_id, "wakeup", request_data) + except TypeError as e: + return JSONResponse(status_code=400, content={"error": "Invalid parameter type", "message": str(e)}) + + control_response = await app.state.engine_client.run_control_method(control_request) + return control_response.to_api_json_response() + + @app.post("/v1/update_weights") async def update_weights(request: Request) -> Response: request_id = f"control-{uuid.uuid4()}" @@ -606,8 +645,14 @@ def update_model_weight(request: Request) -> Response: update model weight """ if app.state.dynamic_load_weight: - status_code, msg = app.state.engine_client.update_model_weight() - return JSONResponse(content=msg, status_code=status_code) + if envs.FD_ENABLE_V1_UPDATE_WEIGHTS: + request_id = f"control-{uuid.uuid4()}" + control_request = ControlRequest(request_id, "wakeup") + control_response = app.state.engine_client.run_control_method_sync(control_request, app.state.event_loop) + return control_response.to_api_json_response() + else: + status_code, msg = app.state.engine_client.update_model_weight() + return JSONResponse(content=msg, status_code=status_code) else: return JSONResponse(content={"error": "Dynamic Load Weight Disabled."}, status_code=404) @@ -619,8 +664,14 @@ def clear_load_weight(request: Request) -> Response: clear model weight """ if app.state.dynamic_load_weight: - status_code, msg = app.state.engine_client.clear_load_weight() - return JSONResponse(content=msg, status_code=status_code) + if envs.FD_ENABLE_V1_UPDATE_WEIGHTS: + request_id = f"control-{uuid.uuid4()}" + control_request = ControlRequest(request_id, "sleep") + control_response = app.state.engine_client.run_control_method_sync(control_request, app.state.event_loop) + return control_response.to_api_json_response() + else: + status_code, msg = app.state.engine_client.clear_load_weight() + return JSONResponse(content=msg, status_code=status_code) else: return JSONResponse(content={"error": "Dynamic Load Weight Disabled."}, status_code=404) diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py index 9e63440a09..41761963be 100644 --- a/fastdeploy/entrypoints/openai/response_processors.py +++ b/fastdeploy/entrypoints/openai/response_processors.py @@ -122,12 +122,21 @@ class ChatResponseProcessor: else: self._audio_buffer[req_id] = [token_ids] else: - yield self.data_processor.process_response_dict( - response_dict=request_output, - stream=stream, - include_stop_str_in_output=include_stop_str_in_output, - request=request, - ) + if self._is_async_processor: + response = await self.data_processor.process_response_dict( + response_dict=request_output, + stream=stream, + include_stop_str_in_output=include_stop_str_in_output, + request=request, + ) + else: + response = self.data_processor.process_response_dict( + response_dict=request_output, + stream=stream, + include_stop_str_in_output=include_stop_str_in_output, + request=request, + ) + yield response elif stream: decode_type = request_output["outputs"].get("decode_type", 0) token_ids = request_output["outputs"]["token_ids"] diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index 29408506b8..baa428b500 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -184,9 +184,14 @@ class DealerConnectionManager: self.request_num[request_id] -= 1 if self.request_num[request_id] == 0: self._update_load(conn_index, -1) + else: + api_server_logger.warning( + f"request_id {request_id} not in request_map, available keys: {list(self.request_map.keys())}" + ) except Exception as e: - api_server_logger.error(f"Listener error: {str(e)}") + api_server_logger.error(f"Listener error: {str(e)}\n{traceback.format_exc()}") break + api_server_logger.info(f"Listener loop ended for conn_index {conn_index}") def _update_load(self, conn_index, delta): """Update connection load and maintain the heap""" diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 68695acec3..72cd6dc7c4 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -249,6 +249,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_MOE_PROB_IN_ADVANCE": lambda: bool(int(os.getenv("FD_MOE_PROB_IN_ADVANCE", "0"))), # Whether to use batch send data in zmq "ZMQ_SEND_BATCH_DATA": lambda: int(os.getenv("ZMQ_SEND_BATCH_DATA", "1")), + # Whether to enable v1 weight updating, which utilizes ZMQ/EngineWorkerQueue/EngineCacheQueue/FMQs + # to pass control requests and responses. + # When v1 is enabled, the legacy /clear_load_weight and /update_model_weight + # will adopt this new communication pattern. + "FD_ENABLE_V1_UPDATE_WEIGHTS": lambda: bool(int(os.getenv("FD_ENABLE_V1_UPDATE_WEIGHTS", "0"))), } diff --git a/fastdeploy/inter_communicator/engine_cache_queue.py b/fastdeploy/inter_communicator/engine_cache_queue.py index fed0bec89f..535dc1dc4c 100644 --- a/fastdeploy/inter_communicator/engine_cache_queue.py +++ b/fastdeploy/inter_communicator/engine_cache_queue.py @@ -58,6 +58,9 @@ class EngineCacheQueue: client_id: Unique identifier for client instances local_data_parallel_size: data parallel size local_data_parallel_id: local data parallel id + + TODO(liyonghua): Remove multi-DP initialization. Each DP will have its own cache queue. + """ self.address: Tuple[str, int] = address self.authkey: bytes = authkey @@ -87,6 +90,7 @@ class EngineCacheQueue: ] # Initialize barriers + self.barrier = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)] self.barrier0_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)] self.barrier1_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)] self.barrier2_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)] @@ -142,6 +146,7 @@ class EngineCacheQueue: callable=lambda idx: self.transfer_task_done_lock_init[idx], proxytype=AcquirerProxy, ) + QueueManager.register("get_barrier", callable=lambda idx: self.barrier[idx]) QueueManager.register("get_barrier0", callable=lambda idx: self.barrier0_init[idx]) QueueManager.register("get_barrier1", callable=lambda idx: self.barrier1_init[idx]) QueueManager.register("get_barrier2", callable=lambda idx: self.barrier2_init[idx]) @@ -191,6 +196,7 @@ class EngineCacheQueue: QueueManager.register("get_cache_sync_value") QueueManager.register("get_transfer_task_lock") QueueManager.register("get_transfer_task_done_lock") + QueueManager.register("get_barrier") QueueManager.register("get_barrier0") QueueManager.register("get_barrier1") QueueManager.register("get_barrier2") @@ -215,6 +221,7 @@ class EngineCacheQueue: self.task_done_lock = self.manager.get_transfer_task_done_lock(self.local_data_parallel_id) # Get barrier proxies + self.barrier = self.manager.get_barrier(self.local_data_parallel_id) self.barrier0 = self.manager.get_barrier0(self.local_data_parallel_id) self.barrier1 = self.manager.get_barrier1(self.local_data_parallel_id) self.barrier2 = self.manager.get_barrier2(self.local_data_parallel_id) @@ -264,7 +271,12 @@ class EngineCacheQueue: def put_transfer_task(self, item): """ - put swap task + Enqueue a cache transfer task (cpu/gpu swap task, read/write storage task) + or a control task (cache clearing/restoring). + + The queue is shared by multiple clients. A task can be enqueued only after + the previous task has been read by all clients. + `task_sync_value` is used as a bitmask to track per-client read status. """ self.task_lock.acquire() if 0 < self.task_sync_value.get() < self.total_num: @@ -279,7 +291,11 @@ class EngineCacheQueue: def get_transfer_task(self): """ - get swap task + Get the current cache transfer task (cpu/gpu swap task, read/write storage task) + or control signal (cache clearing/restoring) from cache task queue. + + Each client reads the same task once. The task is removed from the queue + only after all clients have read it, tracked by `task_sync_value`. """ data = None read_finish = False diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index 3b8275ecdf..7073edb48a 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -283,6 +283,9 @@ class ZmqServerBase(ABC): has_result_handle = False with self.mutex: if req_id not in self.req_dict: + llm_logger.warning( + f"req_id '{req_id}' not in req_dict, caching response. Available req_ids: {list(self.req_dict.keys())}" + ) self.cached_results[req_id].append(data) else: has_result_handle = True diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index fdb4d7b84f..ea26db28b1 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -201,6 +201,39 @@ class DynamicWeightManager: # step5: recapture cuda_graph # step6: update weight status signal + def restart_communication_group(self): + if not self.first_load: + start_time = time.perf_counter() + paddle.distributed.restart_process_group() + paddle.distributed.restart_process_group(self.parallel_config.tp_group) + if self.parallel_config.enable_expert_parallel: + paddle.distributed.restart_process_group(self.parallel_config.ep_group) + logger.info(f"finish restarting communication groups! time cost: {time.perf_counter()-start_time:.3f}s") + + def recreate_deepep_buffer(self): + if not self.first_load: + start_time = time.perf_counter() + from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager + + DeepEPBufferManager.recreate_buffer() + # ep barrier + paddle.distributed.barrier(self.parallel_config.ep_group) + logger.info(f"finish recreating deepep buffer! time cost: {time.perf_counter()-start_time:.3f}s") + + def reload_model_weights(self): + if not self.first_load: + start_time = time.perf_counter() + strategy_handlers = { + "ipc_snapshot": self._update_ipc_snapshot, + "ipc": self._update_ipc, + } + + if handler := strategy_handlers.get(self.load_config.load_strategy): + handler() + else: + raise ValueError(f"Unsupported strategy: {self.load_config.load_strategy}") + logger.info(f"finish reload model weights! time cost: {time.perf_counter()-start_time:.3f}s") + def _update_ipc_snapshot(self): """Update using IPC snapshot strategy for elastic recovery. @@ -329,6 +362,30 @@ class DynamicWeightManager: paddle.distributed.shutdown_process_group() self._update_shared_status(pid, ModelWeightsStatus.CLEARED) + def clear_deepep_buffer(self): + start_time = time.perf_counter() + from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager + + DeepEPBufferManager.clear_buffer() + logger.info(f"finish clearing deepep buffer! time cost: {time.perf_counter()-start_time:.3f}s") + + def clear_model_weight(self): + start_time = time.perf_counter() + for model in self.model_list: + for param in model.state_dict().values(): + param._clear_data() + logger.info(f"finish clearing model weight! time cost: {time.perf_counter()-start_time:.3f}s") + + def clear_communication_group(self): + start_time = time.perf_counter() + if self.parallel_config.enable_expert_parallel: + paddle.distributed.barrier(self.parallel_config.ep_group) + paddle.distributed.shutdown_process_group(self.parallel_config.ep_group) + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.barrier(self.parallel_config.tp_group) + paddle.distributed.shutdown_process_group(self.parallel_config.tp_group) + logger.info(f"finish clearing communication groups! time cost: {time.perf_counter()-start_time:.3f}s") + def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str): """Update model parameters from given state dictionary.""" if len(state_dict) == 0: diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 0ab54a1508..0a591dc277 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -702,12 +702,15 @@ def singleton(cls): return get_instance -def print_gpu_memory_use(gpu_id: int, title: str) -> None: +def print_gpu_memory_use(title: str, gpu_id: int, device_id: int | None = None) -> None: """Print memory usage""" import pynvml + if device_id is None: + device_id = gpu_id + pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) pynvml.nvmlShutdown() @@ -724,7 +727,7 @@ def print_gpu_memory_use(gpu_id: int, title: str) -> None: f"\n\tPaddle max memory Reserved(GiB): {paddle_max_reserved / 1024.0 / 1024.0 / 1024.0}", f"\n\tPaddle max memory Allocated(GiB): {paddle_max_allocated / 1024.0 / 1024.0 / 1024.0}", f"\n\tPaddle memory Reserved(GiB): {paddle_reserved / 1024.0 / 1024.0 / 1024.0}", - f"\n\tPaddle memory Allocated(GiB): {paddle_allocated / 1024.0 / 1024.0 / 1024.0}", + f"\n\tPaddle memory Allocated(GiB): {paddle_allocated / 1024.0 / 1024.0 / 1024.0}\n", ) @@ -1326,3 +1329,12 @@ else: register_op = do_nothing register_custom_python_op = register_op + + +def all_gather_values(value: int | float | bool, group: paddle.distributed.communication.group.Group) -> list: + _type = type(value) + _local = paddle.to_tensor([value], dtype="float32") + _global = [paddle.zeros_like(_local) for _ in range(group.world_size)] + paddle.distributed.all_gather(_global, _local, group) + _results = [_type(t.item()) for t in _global] + return _results diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1c91847d84..bed61bd5b1 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -54,6 +54,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.platforms import current_platform from fastdeploy.spec_decode import SpecMethod +from fastdeploy.utils import print_gpu_memory_use from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode if current_platform.is_iluvatar(): @@ -142,6 +143,9 @@ class GPUModelRunner(ModelRunnerBase): self.cache_kvs_map: dict = {} self.exist_prefill_flag = False + self.is_kvcache_sleeping = False + self.is_weight_sleeping = False + if self.speculative_decoding: self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory() self.output_token_num_event = paddle.device.cuda.Event() @@ -288,6 +292,10 @@ class GPUModelRunner(ModelRunnerBase): """ return self.exist_prefill_flag + @property + def is_sleeping(self): + return self.is_weight_sleeping or self.is_kvcache_sleeping + def exist_decode(self): """ check whether decode stage exist @@ -2673,6 +2681,83 @@ class GPUModelRunner(ModelRunnerBase): def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None): return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config) + def sleep(self, tags): + + logger.info(f">>> start offloading memory, tags: {tags}") + start_time = time.perf_counter() + + # Clear weights, deepep_buffer, cudagraph, etc. + if "weight" in tags.split(","): + if self.is_weight_sleeping: + logger.info("GPU model runner's weight is already sleeping, no need to sleep again!") + return + if self.use_cudagraph: + self.model.clear_grpah_opt_backend() + if self.fd_config.parallel_config.enable_expert_parallel: + self.dynamic_weight_manager.clear_deepep_buffer() + self.dynamic_weight_manager.clear_model_weight() + if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle: + self.dynamic_weight_manager.clear_communication_group() + self.is_weight_sleeping = True + + # Clear KV cache + if "kv_cache" in tags.split(","): + if self.is_kvcache_sleeping: + logger.info("GPU model runner's kv cache is already sleeping, no need to sleep again!") + return + if self.spec_method == SpecMethod.MTP: + self.proposer.clear_mtp_cache() + self.clear_cache() + self.is_kvcache_sleeping = True + + paddle.device.cuda.empty_cache() + logger.info(f"<<< finish offloading memory! time cost: {time.perf_counter()-start_time:.3f}s") + print_gpu_memory_use(f"After offloading memory [{tags}]", self.local_rank, self.device_id) + + def wakeup(self, tags): + + if tags == "weight" and self.use_cudagraph and self.is_kvcache_sleeping: + raise RuntimeError( + "Waking up [weight] alone is not supported when CUDA Graph is enabled, " + "as recapturing the graph requires the KV cache to be rebuilt first. " + "Please wake up [kv_cache] first." + ) + + logger.info(f">>> start reloading memory, tags: {tags}") + start_time = time.perf_counter() + + # Reset share_inputs to restore tensor shapes and values + if self.spec_method == SpecMethod.MTP: + self.proposer.model_inputs.reset_model_inputs() + self.share_inputs.reset_share_inputs() + + # Reinitialize KV cache + if "kv_cache" in tags.split(","): + if not self.is_kvcache_sleeping: + logger.info("GPU model runner's kv cache is not sleeping, no need to wakeup!") + return + if self.spec_method == SpecMethod.MTP: + self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) + self.initialize_kv_cache() + self.is_kvcache_sleeping = False + + # Reload weights, deepep_buffer, cudagraph, etc. + if "weight" in tags.split(","): + if not self.is_weight_sleeping: + logger.info("GPU model runner's weight is not sleeping, no need to wakeup!") + return + if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle: + self.dynamic_weight_manager.restart_communication_group() + if self.fd_config.parallel_config.enable_expert_parallel: + self.dynamic_weight_manager.recreate_deepep_buffer() + self.dynamic_weight_manager.reload_model_weights() + if self.use_cudagraph: + self.capture_model() + self.is_weight_sleeping = False + + logger.info(f"<<< finish reloading memory! time cost: {time.perf_counter()-start_time:.3f}s") + print_gpu_memory_use(f"After reloading memory [{tags}]", self.local_rank, self.device_id) + def padding_cudagraph_inputs(self) -> None: """ Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index be065a9991..5025dc95c7 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -196,6 +196,14 @@ class GpuWorker(WorkerBase): """update weights in place""" return self.model_runner.update_weights(version, rsync_config) + def sleep(self, **kwargs) -> None: + """Offload memory from GPU""" + return self.model_runner.sleep(**kwargs) + + def wakeup(self, **kwargs) -> None: + """Reload memory into GPU""" + return self.model_runner.wakeup(**kwargs) + def execute_model( self, model_forward_batch: Optional[List[Request]] = None, diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 29deb2db4d..3d9e647239 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -69,7 +69,7 @@ from fastdeploy.model_executor.layers.quantization import parse_quant_config from fastdeploy.model_executor.utils import v1_loader_support from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig -from fastdeploy.utils import get_logger, optional_type +from fastdeploy.utils import all_gather_values, get_logger, optional_type from fastdeploy.worker.worker_base import WorkerBase logger = get_logger("worker_process", "worker_process.log") @@ -172,6 +172,7 @@ class PaddleDisWorkerProc: self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.enable_overlap_schedule = self.scheduler_config.enable_overlap_schedule + self.cached_control_reqs = [] def init_control(self): engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port @@ -482,7 +483,7 @@ class PaddleDisWorkerProc: # run eplb self._run_eplb(tp_rank) - if self.fd_config.load_config.dynamic_load_weight: + if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: self.model_weights_signal[0] = int(self.model_weights_status.value[0]) if self.ranks > 1: self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None) @@ -504,7 +505,7 @@ class PaddleDisWorkerProc: # Synchronize the signal set by tp_rank0 visiable to other workers self._tp_barrier_wait() if tp_size > 1 else None - if self.fd_config.load_config.dynamic_load_weight: + if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: if self.ranks > 1: paddle.distributed.barrier() if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL: @@ -581,22 +582,34 @@ class PaddleDisWorkerProc: if len(control_reqs) > 0: logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.") for control_req in control_reqs: - self.run_control_method(control_req) - self._tp_barrier_wait() if tp_size > 1 else None + if self.parallel_config.use_ep: + self.cached_control_reqs.append(control_req) + logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}") + else: + self.run_control_method(control_req) + self._tp_barrier_wait() if tp_size > 1 else None - # Count prefill requests in current batch - num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL) - num_scheduled_requests = len(req_dicts) - scheduled_request_ids = [req.request_id for req in req_dicts] - logger.info( - f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, " - f"max_occupied_batch_index: {max_occupied_batch_index}, " - f"num_scheduled_requests: {num_scheduled_requests}, " - f"scheduled_request_ids: {scheduled_request_ids}" - ) + if len(req_dicts) > 0: + # Count prefill requests in current batch + num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL) + num_scheduled_requests = len(req_dicts) + scheduled_request_ids = [req.request_id for req in req_dicts] + logger.info( + f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, " + f"max_occupied_batch_index: {max_occupied_batch_index}, " + f"num_scheduled_requests: {num_scheduled_requests}, " + f"scheduled_request_ids: {scheduled_request_ids}" + ) - # Process prefill inputs - self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index) + # Process prefill inputs + self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index) + + # Let the ep group run control method synchronically + if self.parallel_config.use_ep: + pendings = all_gather_values(len(self.cached_control_reqs), self.parallel_config.ep_group) + if all([p > 0 for p in pendings]): + logger.info(f"Rank: {self.local_rank} Detected all ep ranks have pending control tasks.") + self.run_control_method(self.cached_control_reqs.pop(0)) if ( not self.parallel_config.use_ep @@ -607,6 +620,15 @@ class PaddleDisWorkerProc: time.sleep(0.001) continue + # Check if worker is paused (V1 update weights flow) + if ( + self.fd_config.load_config.dynamic_load_weight + and hasattr(self.worker.model_runner, "is_sleeping") + and self.worker.model_runner.is_sleeping + ): + self._tp_barrier_wait() if tp_size > 1 else None + continue + # Execute model to generate token. The generated token will be written to the buffer. # These generated tokens can be obtained through get_output op. start_execute_time = time.time() @@ -737,14 +759,14 @@ class PaddleDisWorkerProc: self.loaded_model_signal.value[0] = 1 def run_control_method(self, control_request: ControlRequest) -> None: - logger.info(f"Start run control request: {control_request}") + logger.info(f"Rank: {self.local_rank} Start to run control request: {control_request}") request_id = control_request.request_id method = control_request.method kwargs = control_request.args handler = getattr(self.worker, method, None) if handler is None or not callable(handler): - error_msg = f"Rank-{self.local_rank}: Unknown control method {method}" + error_msg = f"Rank: {self.local_rank} Unknown control method {method}" error_result = ControlResponse(request_id, 400, error_msg) asyncio.run(self._ctrl_output.put(error_result)) return @@ -753,11 +775,11 @@ class PaddleDisWorkerProc: result = handler(**kwargs) succ_result = ControlResponse(request_id, 200, "Success", result) logger.info( - f"Rank-{self.local_rank} Success run control request: {control_request}, response: {succ_result}" + f"Rank: {self.local_rank} Successfully run control request: {control_request}, response: {succ_result}" ) asyncio.run(self._ctrl_output.put(succ_result, shm_threshold=100 * 1024 * 1024)) except Exception as e: - error_msg = f"Rank-{self.local_rank} Failed run control method {method}: {str(e)}" + error_msg = f"Rank: {self.local_rank} Failed to run control method {method}: {str(e)}" logger.error(f"{error_msg}\n{traceback.format_exc()}") error_result = ControlResponse(request_id, 500, error_msg) asyncio.run(self._ctrl_output.put(error_result)) diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 45ae8a1804..191d1ad36e 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -106,7 +106,11 @@ class TestCacheTransferManager(unittest.TestCase): # -------------------------- # mock IPCSignal # -------------------------- - patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=MagicMock()) + class DummyIPCSignal: + def __init__(self, name, array, dtype, suffix, create=False): + self.value = array + + patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=DummyIPCSignal) patcher2.start() self.addCleanup(patcher2.stop) @@ -122,8 +126,8 @@ class TestCacheTransferManager(unittest.TestCase): # -------------------------- self._orig_init_cpu_cache = CacheTransferManager._init_cpu_cache self._orig_init_gpu_cache = CacheTransferManager._init_gpu_cache - patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None) - patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None) + patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None) + patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None) patcher3.start() patcher4.start() self.addCleanup(patcher3.stop) @@ -193,8 +197,8 @@ class TestCacheTransferManager(unittest.TestCase): kvcache_storage_backend = "unknown" with ( - patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None), - patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None), + patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None), + patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None), patch("fastdeploy.cache_manager.cache_transfer_manager.console_logger") as mock_console, ): with self.assertRaises(NotImplementedError): @@ -209,8 +213,8 @@ class TestCacheTransferManager(unittest.TestCase): kvcache_storage_backend = "file" with ( - patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None), - patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None), + patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None), + patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None), ): with self.assertRaises(ValueError): CacheTransferManager(LocalArgs()) @@ -221,8 +225,8 @@ class TestCacheTransferManager(unittest.TestCase): version_path = os.path.join(tmpdir, "version.yaml") with open(version_path, "w", encoding="utf-8") as handle: handle.write("version: RL-STEP03-20250101-uuid\n") - args.model_path = tmpdir - args.kvcache_storage_backend = None + self.manager.model_path = tmpdir + self.manager.kvcache_storage_backend = None self.manager._init_storage(args) self.assertEqual(self.manager.key_prefix, "RL-STEP03") @@ -465,23 +469,21 @@ class TestCacheTransferManager(unittest.TestCase): def __init__(self): self.value = [0] - args = Args() - args.num_cpu_blocks = 0 + self.manager.num_cpu_blocks = 0 self.manager.swap_space_ready_signal = DummySignal() - self._orig_init_cpu_cache(self.manager, args) + self._orig_init_cpu_cache(self.manager) - self.assertEqual(self.manager.swap_space_ready_signal.value[0], 1) + self.assertEqual(self.manager.swap_space_ready_signal.value[0], 0) def test_init_cpu_cache_allocates_block_wise_fp8(self): class DummySignal: def __init__(self): self.value = [0] - args = Args() - args.num_cpu_blocks = 2 - args.cache_dtype = "block_wise_fp8" - args.value_cache_shape = "2,1,1,1" + self.manager.num_cpu_blocks = 2 + self.manager.cache_dtype = "block_wise_fp8" + self.manager.has_cache_scale = True self.manager.swap_space_ready_signal = DummySignal() self.manager.value_cache_shape = [2, 1, 1, 1] @@ -492,7 +494,7 @@ class TestCacheTransferManager(unittest.TestCase): ), patch("fastdeploy.cache_manager.cache_transfer_manager.paddle.set_device"), ): - self._orig_init_cpu_cache(self.manager, args) + self._orig_init_cpu_cache(self.manager) self.assertEqual(self.manager.swap_space_ready_signal.value[0], 1) @@ -510,8 +512,8 @@ class TestCacheTransferManager(unittest.TestCase): value_cache_shape = "2,1,1,1" with ( - patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None), - patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None), + patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None), + patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None), patch("fastdeploy.cache_manager.cache_transfer_manager.MooncakeStore"), patch.object(CacheTransferManager, "_init_storage_buffer"), ): @@ -527,9 +529,8 @@ class TestCacheTransferManager(unittest.TestCase): with ( patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"), patch("fastdeploy.cache_manager.cache_transfer_manager.set_data_ipc") as mock_set_ipc, - patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0), ): - self._orig_init_gpu_cache(manager, LocalArgs()) + self._orig_init_gpu_cache(manager) self.assertEqual(mock_set_ipc.call_count, 4) self.assertIn("key_caches_0_rank0.device0", manager.gpu_cache_kvs) @@ -548,8 +549,8 @@ class TestCacheTransferManager(unittest.TestCase): value_cache_shape = "2,1,1,1" with ( - patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None), - patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None), + patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None), + patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None), patch("fastdeploy.cache_manager.cache_transfer_manager.MooncakeStore"), patch.object(CacheTransferManager, "_init_storage_buffer"), ): @@ -568,9 +569,8 @@ class TestCacheTransferManager(unittest.TestCase): with ( patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"), patch("fastdeploy.cache_manager.cache_transfer_manager.share_external_data_", side_effect=fake_share), - patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0), ): - self._orig_init_gpu_cache(manager, LocalArgs()) + self._orig_init_gpu_cache(manager) self.assertIn("key_cache_scales_0_rank0.device0", manager.gpu_cache_kvs) @@ -583,8 +583,8 @@ class TestCacheTransferManager(unittest.TestCase): value_cache_shape = "1,1,1,1" with ( - patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None), - patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None), + patch.object(CacheTransferManager, "_init_cpu_cache", lambda self: None), + patch.object(CacheTransferManager, "_init_gpu_cache", lambda self: None), ): manager = CacheTransferManager(LocalArgs()) @@ -607,9 +607,8 @@ class TestCacheTransferManager(unittest.TestCase): patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep), patch("fastdeploy.cache_manager.cache_transfer_manager.set_device"), patch("fastdeploy.cache_manager.cache_transfer_manager.share_external_data_", side_effect=fake_share), - patch("fastdeploy.cache_manager.cache_transfer_manager.memory_allocated", return_value=0), ): - self._orig_init_gpu_cache(manager, LocalArgs()) + self._orig_init_gpu_cache(manager) self.assertIn("key_caches_0_rank0.device0", manager.gpu_cache_kvs) @@ -1160,6 +1159,14 @@ class TestCacheTransferManager(unittest.TestCase): self.manager.inflight = 1 self.manager.cache_task_is_paused_signal = DummySignal(0) self.manager.cache_task_inflight_signal = DummySignal(1) + self.manager.cache_task_broadcast_signal = DummySignal(0) + self.manager.cache_task_queue.empty.return_value = False + self.manager.cache_task_queue.get_transfer_task.return_value = ( + (cache_transfer_manager.CacheStatus.CTRL, MagicMock()), + True, + ) + self.manager.control_task_thread_pool = MagicMock() + self.manager.control_task_thread_pool.submit.side_effect = SystemExit call_count = {"count": 0} @@ -1168,7 +1175,6 @@ class TestCacheTransferManager(unittest.TestCase): if call_count["count"] == 1: self.manager.inflight = 0 return None - raise SystemExit with patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep): with self.assertRaises(SystemExit): @@ -1273,6 +1279,7 @@ class TestCacheTransferManager(unittest.TestCase): args = Args() args.splitwise_role = "mixed" args.create_cache_tensor = True + self.manager.create_cache_tensor = True self.manager.kv_cache_status_signal = DummySignal(cache_transfer_manager.KVCacheStatus.CLEARING) self.manager.cache_ready_signal = DummySignal(0) self.manager.swap_space_ready_signal = DummySignal(0) @@ -1293,7 +1300,10 @@ class TestCacheTransferManager(unittest.TestCase): patch("paddle.device.cuda.empty_cache") as mock_empty, patch("paddle.set_device"), patch.object(self.manager, "_log_memory"), - patch("time.sleep", side_effect=maybe_stop_cleared_with_tensor), + patch( + "fastdeploy.cache_manager.cache_transfer_manager.time.sleep", + side_effect=maybe_stop_cleared_with_tensor, + ), ): with self.assertRaises(StopIteration): self.manager.check_cache_status(args) @@ -1427,6 +1437,7 @@ class TestCacheTransferManager(unittest.TestCase): args = Args() args.splitwise_role = "mixed" args.mp_num = 2 + self.manager.n_ranks = 2 self.manager.kv_cache_status_signal = DummySignal([cache_transfer_manager.KVCacheStatus.UPDATING]) self.manager.cache_ready_signal = DummySignal([0, 1]) self.manager.swap_space_ready_signal = DummySignal([0, 1]) @@ -1455,6 +1466,7 @@ class TestCacheTransferManager(unittest.TestCase): patch.object(self.manager, "_init_cpu_cache"), patch.object(self.manager, "_init_gpu_cache"), patch.object(self.manager, "resume"), + patch.object(self.manager, "_update_key_prefix"), patch("fastdeploy.cache_manager.cache_transfer_manager.envs.FD_ENABLE_SWAP_SPACE_CLEARING", True), patch.object(self.manager, "_log_memory"), patch("fastdeploy.cache_manager.cache_transfer_manager.time.sleep", side_effect=fake_sleep), @@ -1465,7 +1477,7 @@ class TestCacheTransferManager(unittest.TestCase): self.assertEqual(self.manager.kv_cache_status_signal.value[0], cache_transfer_manager.KVCacheStatus.NORMAL) def test_log_memory_records_gpu_stats(self): - with patch.object(cache_transfer_manager.logger, "warning") as mock_warning: + with patch.object(cache_transfer_manager.logger, "info") as mock_info: with ( patch("paddle.device.cuda.max_memory_allocated", return_value=1024**3), patch("paddle.device.cuda.max_memory_reserved", return_value=2 * 1024**3), @@ -1474,7 +1486,7 @@ class TestCacheTransferManager(unittest.TestCase): ): self.manager._log_memory("test") - mock_warning.assert_called_once() + mock_info.assert_called_once() def test_pause_and_resume_wait_for_signals(self): class DummySignal: @@ -1503,7 +1515,7 @@ class TestCacheTransferManager(unittest.TestCase): self.assertFalse(self.manager.is_paused) - def test_submit_task_decrements_inflight(self): + def test_submit_task_decrements_inflight_on_task_error(self): class DummyPool: def submit(self, fn, *args): try: @@ -1514,10 +1526,26 @@ class TestCacheTransferManager(unittest.TestCase): def raise_task(): raise RuntimeError("boom") - self.manager.inflight = 1 + self.manager.inflight = 0 self.manager.submit_task(DummyPool(), raise_task) self.assertEqual(self.manager.inflight, 0) + def test_submit_task_decrements_inflight_on_success(self): + class DummyPool: + def submit(self, fn, *args): + return fn(*args) + + task_called = {"value": False} + + def ok_task(): + task_called["value"] = True + + self.manager.inflight = 0 + self.manager.submit_task(DummyPool(), ok_task) + + self.assertTrue(task_called["value"]) + self.assertEqual(self.manager.inflight, 0) + def test_main_invokes_manager(self): cache_transfer_manager.args = Args() with patch("fastdeploy.cache_manager.cache_transfer_manager.CacheTransferManager") as mock_manager: diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index 158a61b7d6..0a2a42d857 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -427,7 +427,7 @@ class PrefixCacheManagerTest(unittest.TestCase): self.assertEqual(manager.get_required_block_num(8, 4), 2) def test_launch_cache_manager_initializes_processes(self): - manager = _create_manager() + manager = _create_manager(num_cpu_blocks=1) manager.cache_config.enable_hierarchical_cache = False with ( @@ -637,7 +637,7 @@ class PrefixCacheManagerTest(unittest.TestCase): self.assertIsNone(processes) def test_launch_cache_manager_formats_value_cache_shape(self): - manager = _create_manager() + manager = _create_manager(num_cpu_blocks=1) captured = {} diff --git a/tests/ce/stable_cases/run.sh b/tests/ce/stable_cases/run.sh index d388b9d5ae..1877cfac91 100644 --- a/tests/ce/stable_cases/run.sh +++ b/tests/ce/stable_cases/run.sh @@ -11,7 +11,9 @@ HOST="0.0.0.0" PORT="${FD_API_PORT}" # 这里需要配合启动脚本那个URL PORT BASE_URL="http://$HOST:$PORT" -TOTAL_ROUNDS=6 +V0_ROUNDS=3 +V1_ROUNDS=3 +TOTAL_ROUNDS=$((V0_ROUNDS + V1_ROUNDS)) CHAT_REQUESTS_PER_ROUND=3 export CUDA_VISIBLE_DEVICES=0,1 MAX_MEMORY_MB=10240 # 10GB @@ -48,7 +50,7 @@ assert_success() { fi } -# curl_get_status(url, options...) → returns via global variables http_code and response_body +# curl_get_status(url, options) → returns via global variables http_code and response_body curl_get_status() { local result result=$(curl -s -w "%{http_code}" "$@") @@ -56,6 +58,23 @@ curl_get_status() { response_body="${result%???}" } +post_json_and_assert() { + local url="$1" + local payload="${2:-}" + if [ -n "$payload" ]; then + curl_get_status -X POST "$url" -H "Content-Type: application/json" -d "$payload" + else + curl_get_status -X POST "$url" + fi + assert_eq "$http_code" "200" "$url failed with HTTP $http_code, body: $response_body" +} + +get_and_assert() { + local url="$1" + curl_get_status "$url" + assert_eq "$http_code" "200" "$url failed with HTTP $http_code, body: $response_body" +} + # ==================================================== # Get visible GPU IDs from CUDA_VISIBLE_DEVICES # ==================================================== @@ -79,104 +98,69 @@ check_gpu_memory() { local gpu_ids gpu_ids=($(get_visible_gpu_ids)) - echo "========== GPU Memory Check ==========" - echo "CUDA_VISIBLE_DEVICES = $CUDA_VISIBLE_DEVICES" - echo "MAX_MEMORY_MB = $MAX_MEMORY_MB" - echo "======================================" + echo "----------------------------------------" + echo " GPU Memory Check (MAX: ${MAX_MEMORY_MB}MB)" + echo "----------------------------------------" if [ ${#gpu_ids[@]} -eq 0 ]; then - echo "Assertion failed: No valid GPU IDs in CUDA_VISIBLE_DEVICES='$CUDA_VISIBLE_DEVICES'" >&2 + echo "ERROR: No valid GPU IDs in CUDA_VISIBLE_DEVICES='$CUDA_VISIBLE_DEVICES'" >&2 exit 1 fi for gpu_id in "${gpu_ids[@]}"; do - echo - echo "---- GPU $gpu_id ----" - # Query summary local summary summary=$(nvidia-smi -i "$gpu_id" \ --query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu \ --format=csv,noheader,nounits) || { - echo "Failed to query GPU $gpu_id summary" >&2 + echo "ERROR: Failed to query GPU $gpu_id" >&2 exit 1 } # Parse fields IFS=',' read -r idx name mem_total mem_used mem_free util <<< "$summary" + local used_ratio=$(( mem_used * 100 / mem_total )) - echo "GPU $idx: $name" - echo "Total Memory : ${mem_total} MB" - echo "Used Memory : ${mem_used} MB" - echo "Free Memory : ${mem_free} MB" - echo "GPU Util : ${util} %" + # Print GPU info (single line) + printf " GPU %s: %-35s Total:%5sMB Used:%5sMB (%s%%)\n" \ + "$idx" "$name" "$mem_total" "$mem_used" "$util" - # --- Hard assertions --- + # Hard assertion assert_true "$(( mem_used <= MAX_MEMORY_MB ))" \ "GPU $gpu_id memory.used ${mem_used} MB > MAX_MEMORY_MB ${MAX_MEMORY_MB} MB" - # --- Soft safety check: usage ratio --- - local used_ratio - used_ratio=$(( mem_used * 100 / mem_total )) - - echo "Used Ratio : ${used_ratio} %" - + # Usage ratio check if [ "$used_ratio" -gt 90 ]; then - echo "Assertion failed: GPU $gpu_id memory usage > 90% (${used_ratio}%)" >&2 + echo "ERROR: GPU $gpu_id memory usage > 90% (${used_ratio}%)" >&2 exit 1 fi - # --- Process-level attribution --- - echo "Processes on GPU $gpu_id:" + # Process info (compact format) local proc_info proc_info=$(nvidia-smi -i "$gpu_id" \ --query-compute-apps=pid,process_name,used_memory \ --format=csv,noheader,nounits) - if [ -z "$proc_info" ]; then - echo " (No active compute processes)" - else + if [ -n "$proc_info" ]; then echo "$proc_info" | while IFS=',' read -r pid pname pmem; do - echo " PID=$pid NAME=$pname MEM=${pmem}MB" + printf " └─ PID=%-8s %-30s MEM=%4sMB\n" \ + "$pid" "$pname" "$pmem" done fi - - echo "GPU $gpu_id memory check PASSED" done - echo "========== GPU Memory Check DONE ==========" + echo "----------------------------------------" } # ==================================================== -for round in $(seq 1 $TOTAL_ROUNDS); do - echo "=== Round $round / $TOTAL_ROUNDS ===" - - # Step 1: Clear loaded weights - echo "[Step 1] Clearing load weight..." - curl_get_status -i "$BASE_URL/clear_load_weight" - assert_eq "$http_code" "200" "/clear_load_weight failed with HTTP $http_code" - sleep 10 - - # Step 2: Check GPU memory usage - echo "[Step 2] Checking GPU memory..." - check_gpu_memory - - # Step 3: Update model weights - echo "[Step 3] Updating model weight..." - curl_get_status -i "$BASE_URL/update_model_weight" - assert_eq "$http_code" "200" "/update_model_weight failed with HTTP $http_code" - - # Step 4: Send chat completion requests - echo "[Step 4] Sending $CHAT_REQUESTS_PER_ROUND chat completions..." +send_chat_requests() { for i in $(seq 1 $CHAT_REQUESTS_PER_ROUND); do - echo " Request $i / $CHAT_REQUESTS_PER_ROUND" - # Send request and capture response + printf " └─ Sending chat request %d/%d...\n" "$i" "$CHAT_REQUESTS_PER_ROUND" response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{"messages": [{"role": "user", "content": "Hello!"}]}') - # Extract the 'content' field from the response content=$(echo "$response" | \ grep -o '"content":"[^"]*"' | \ head -1 | \ @@ -184,26 +168,93 @@ for round in $(seq 1 $TOTAL_ROUNDS); do sed 's/"$//') if [ -z "$content" ]; then - # Fallback: try extracting content using sed more robustly content=$(echo "$response" | \ sed -n 's/.*"content":"\([^"]*\)".*/\1/p' | \ head -1) fi - # Check if content is empty or null if [ -z "$content" ] || [ "$content" = "null" ]; then - echo "Failed: Empty or null 'content' in response" >&2 - echo "Raw response:" >&2 + echo " ERROR: Empty or null 'content' in response" >&2 + echo " Raw response:" >&2 echo "$response" >&2 exit 1 fi - echo "Received non-empty response" - echo -e "\n---\n" + printf " └─ Response received: %s\n" "$content" done +} - echo "Round $round completed." - echo "==================================\n" +run_v0_round() { + local round="$1" + echo "========================================" + printf "Round %d/%d (V0)\n" "$round" "$TOTAL_ROUNDS" + echo "========================================" + echo "" + + printf "[Step 1] %-30s " "Clearing load weight via v0..." + get_and_assert "$BASE_URL/clear_load_weight" + echo "[OK]" + sleep 10 + + printf "[Step 2] %-30s " "Checking GPU memory..." + echo "" + check_gpu_memory + echo "" + + printf "[Step 3] %-30s " "Updating model weight via v0..." + get_and_assert "$BASE_URL/update_model_weight" + echo "[OK]" + + echo "[Step 4] Sending $CHAT_REQUESTS_PER_ROUND chat completions" + send_chat_requests + echo "" +} + +run_v1_round() { + local round="$1" + local sleep_payload='{"tags":"weight,kv_cache"}' + local wakeup_payload='{"tags":"weight,kv_cache"}' + + echo "========================================" + printf "Round %d/%d (V1)\n" "$round" "$TOTAL_ROUNDS" + echo "========================================" + echo "" + + printf "[Step 1] %-30s " "Pausing engine via v1..." + post_json_and_assert "$BASE_URL/v1/pause" "" + echo "[OK]" + + printf "[Step 2] %-30s " "Sleeping via v1..." + post_json_and_assert "$BASE_URL/v1/sleep" "$sleep_payload" + echo "[OK]" + sleep 10 + + printf "[Step 3] %-30s " "Checking GPU memory..." + echo "" + check_gpu_memory + echo "" + + printf "[Step 4] %-30s " "Waking up via v1..." + post_json_and_assert "$BASE_URL/v1/wakeup" "$wakeup_payload" + echo "[OK]" + + printf "[Step 5] %-30s " "Resuming engine via v1..." + post_json_and_assert "$BASE_URL/v1/resume" "" + echo "[OK]" + + echo "[Step 6] Sending $CHAT_REQUESTS_PER_ROUND chat completions" + send_chat_requests + echo "" +} + +for round in $(seq 1 $V0_ROUNDS); do + run_v0_round "$round" done -echo "All $TOTAL_ROUNDS rounds completed successfully." +for round in $(seq 1 $V1_ROUNDS); do + run_v1_round "$round" +done + +echo "========================================" +printf "All %d rounds completed successfully.\n" "$TOTAL_ROUNDS" +echo "========================================" diff --git a/tests/e2e/utils/serving_utils.py b/tests/e2e/utils/serving_utils.py index 07acb42a32..56966315cd 100644 --- a/tests/e2e/utils/serving_utils.py +++ b/tests/e2e/utils/serving_utils.py @@ -11,6 +11,7 @@ FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) +FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633)) # List of ports to clean before and after tests PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT] diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 01d325ce68..9c70097058 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -20,7 +20,7 @@ import threading import time import types import unittest -from unittest.mock import ANY, MagicMock, Mock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import numpy as np import paddle @@ -1091,6 +1091,19 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): self.assertEqual(result, {"ok": True}) self._detach_finalizer(eng) + def test_control_update_weights_updates_cfg_version(self): + eng = self._make_mixed_engine() + eng.is_paused = True + eng._pause_cond = threading.Condition() + eng.cfg.model_config.version = "old-version" + eng._call_worker = Mock(return_value=[{"version": "new-version"}, {"ok": True}]) + + result = eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights")) + + self.assertEqual(result, [{"version": "new-version"}, {"ok": True}]) + self.assertEqual(eng.cfg.model_config.version, "new-version") + self._detach_finalizer(eng) + def test_control_pause_and_resume_paths(self): eng = self._make_mixed_engine() eng.is_paused = False @@ -1155,12 +1168,50 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): async def get(self, timeout=None): return Mock(payload=ControlResponse(request_id="req", result={"ok": True}, error_code=200)) - eng._ctrl_worker_output_queues = [DummyQueue()] + eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": DummyQueue()} result = eng._call_worker(ControlRequest(request_id="req", method="noop"), timeout=1) self.assertEqual(result, [{"ok": True}]) eng.engine_worker_queue.put_tasks.assert_called_once() self._detach_finalizer(eng) + def test_control_sleep_defaults_tags_and_dispatches_cache_transfer(self): + cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=4) + eng = self._make_engine(cfg) + eng.cfg.cache_config.num_cpu_blocks = 1 + eng.engine_worker_queue = Mock() + eng.cache_task_queue = Mock() + eng.resource_manager.cache_manager.reset = Mock() + eng._control_pause = Mock() + eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}]) + + result = eng._control_sleep(ControlRequest(request_id="sleep", method="sleep", args={})) + + self.assertEqual(result, [{"ok": True}]) + eng._control_pause.assert_called_once_with(None) + eng.resource_manager.cache_manager.reset.assert_called_once() + eng.engine_worker_queue.put_tasks.assert_called_once() + eng.cache_task_queue.put_transfer_task.assert_called_once() + sleep_req = eng.engine_worker_queue.put_tasks.call_args.args[0][0][0] + self.assertEqual(sleep_req.args["tags"], "weight,kv_cache") + self._detach_finalizer(eng) + + def test_control_wakeup_resumes_after_wait(self): + cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=4) + eng = self._make_engine(cfg) + eng.cfg.cache_config.num_cpu_blocks = 1 + eng.engine_worker_queue = Mock() + eng.cache_task_queue = Mock() + eng._control_resume = Mock() + eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}]) + + result = eng._control_wakeup(ControlRequest(request_id="wakeup", method="wakeup", args={"tags": "kv_cache"})) + + self.assertEqual(result, [{"ok": True}]) + eng.engine_worker_queue.put_tasks.assert_called_once() + eng.cache_task_queue.put_transfer_task.assert_called_once() + eng._control_resume.assert_called_once_with(None) + self._detach_finalizer(eng) + def test_control_update_weights_requires_pause(self): eng = self._make_mixed_engine() eng.is_paused = False @@ -1940,68 +1991,121 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): mock_logger.warning.assert_called() self._detach_finalizer(eng) - def test_wait_all_control_responses_success(self): + def test_wait_for_control_responses_success(self): eng = self._make_mixed_engine() - eng._ctrl_worker_output_queues = [ - self._make_ctrl_queue("q0", Mock(request_id="req", error_code=200, result={"ok": True})), - self._make_ctrl_queue("q1", Mock(request_id="req", error_code=200, result={"ok": True})), - ] + eng._ctrl_output_queues = { + "ctrl_w2e_rank0_6778": self._make_ctrl_queue( + "q0", Mock(request_id="req", error_code=200, result={"ok": True}) + ), + "ctrl_w2e_rank1_6778": self._make_ctrl_queue( + "q1", Mock(request_id="req", error_code=200, result={"ok": True}) + ), + } - results = asyncio.run(eng._wait_all_control_responses("req", timeout=1)) + results = asyncio.run(eng._wait_for_control_responses("req", timeout=1)) self.assertEqual(results, [{"ok": True}, {"ok": True}]) self._detach_finalizer(eng) - def test_wait_all_control_responses_ignores_mismatch(self): + def test_wait_for_control_responses_filters_executors(self): eng = self._make_mixed_engine() - eng._ctrl_worker_output_queues = [ - self._make_ctrl_queue("q0", Mock(request_id="old", error_code=200, result={"ok": False})), - self._make_ctrl_queue("q1", Mock(request_id="req", error_code=200, result={"ok": True})), - ] + eng._ctrl_output_queues = { + "ctrl_w2e_rank0_6778": self._make_ctrl_queue( + "worker", Mock(request_id="req", error_code=200, result={"worker": True}) + ), + "ctrl_c2e_rank0_6779": self._make_ctrl_queue( + "cache", Mock(request_id="req", error_code=200, result={"cache": True}) + ), + } - results = asyncio.run(eng._wait_all_control_responses("req", timeout=1)) - self.assertEqual(results, [{"ok": True}]) + worker_results = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["worker"])) + cache_results = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["cache_transfer"])) + + self.assertEqual(worker_results, [{"worker": True}]) + self.assertEqual(cache_results, [{"cache": True}]) self._detach_finalizer(eng) - def test_wait_all_control_responses_error_paths(self): + def test_wait_for_control_responses_ignores_mismatch(self): eng = self._make_mixed_engine() - eng._ctrl_worker_output_queues = [ - self._make_ctrl_queue("q0", Exception("boom"), payload_wrapped=False), - ] + class DummyQueue: + def __init__(self, name, payloads): + self.name = name + self.payloads = list(payloads) + + async def get(self, timeout=None): + return Mock(payload=self.payloads.pop(0)) + + eng._ctrl_output_queues = { + "ctrl_w2e_rank0_6778": DummyQueue( + "q0", + [ + Mock(request_id="old", error_code=200, result={"ok": False}), + Mock(request_id="req", error_code=200, result={"ok": "from-q0"}), + ], + ), + "ctrl_w2e_rank1_6778": self._make_ctrl_queue( + "q1", Mock(request_id="req", error_code=200, result={"ok": True}) + ), + } + + results = asyncio.run(eng._wait_for_control_responses("req", timeout=1)) + self.assertEqual(results, [{"ok": "from-q0"}, {"ok": True}]) + self.assertEqual( + eng._ctrl_response_mailboxes["ctrl_w2e_rank0_6778"]["old"].result, + {"ok": False}, + ) + self._detach_finalizer(eng) + + def test_wait_for_control_responses_error_paths(self): + eng = self._make_mixed_engine() + + eng._ctrl_output_queues = { + "ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", Exception("boom"), payload_wrapped=False) + } with self.assertRaises(Exception): - asyncio.run(eng._wait_all_control_responses("req", timeout=1)) + asyncio.run(eng._wait_for_control_responses("req", timeout=1)) self._detach_finalizer(eng) - def test_wait_all_control_responses_none_message(self): + def test_wait_for_control_responses_none_message(self): eng = self._make_mixed_engine() - eng._ctrl_worker_output_queues = [self._make_ctrl_queue("q0", None, payload_wrapped=False)] + eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)} with self.assertRaises(Exception): - asyncio.run(eng._wait_all_control_responses("req", timeout=1)) + asyncio.run(eng._wait_for_control_responses("req", timeout=1)) self._detach_finalizer(eng) - def test_wait_all_control_responses_error_code(self): + def test_wait_for_control_responses_error_code(self): eng = self._make_mixed_engine() - eng._ctrl_worker_output_queues = [ - self._make_ctrl_queue("q0", ControlResponse(request_id="req", error_code=500, error_message="bad")), - ] + eng._ctrl_output_queues = { + "ctrl_w2e_rank0_6778": self._make_ctrl_queue( + "q0", ControlResponse(request_id="req", error_code=500, error_message="bad") + ) + } with self.assertRaises(Exception): - asyncio.run(eng._wait_all_control_responses("req", timeout=1)) + asyncio.run(eng._wait_for_control_responses("req", timeout=1)) self._detach_finalizer(eng) - def test_wait_all_control_responses_timeout(self): + def test_wait_for_control_responses_timeout(self): eng = self._make_mixed_engine() - eng._ctrl_worker_output_queues = [self._make_ctrl_queue("q0", None, payload_wrapped=False)] + eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)} with patch("fastdeploy.engine.common_engine.asyncio.wait_for", side_effect=asyncio.TimeoutError): with self.assertRaises(Exception): - asyncio.run(eng._wait_all_control_responses("req", timeout=1)) + asyncio.run(eng._wait_for_control_responses("req", timeout=1)) + self._detach_finalizer(eng) + + def test_wait_for_control_responses_without_matching_queues(self): + eng = self._make_mixed_engine() + eng._ctrl_output_queues = {"ctrl_w2e_rank0_6778": self._make_ctrl_queue("q0", None, payload_wrapped=False)} + + result = asyncio.run(eng._wait_for_control_responses("req", timeout=1, executors=["cache_transfer"])) + self.assertIsNone(result) self._detach_finalizer(eng) def test_insert_tasks_prefill_error_and_success(self): @@ -3254,7 +3358,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase): # Lines 1299-1300: try block start + info logging info_msgs = [str(c) for c in mock_logger.info.call_args_list] - self.assertTrue(any("START run control method" in m for m in info_msgs)) + self.assertTrue(any("Start to run control method" in m for m in info_msgs)) # worker_pid should be popped from the map self.assertNotIn("ctrl-log", eng.request_worker_map) self._detach_finalizer(eng) diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py index 4c0a12c033..0cd5742170 100644 --- a/tests/entrypoints/openai/test_api_server.py +++ b/tests/entrypoints/openai/test_api_server.py @@ -84,6 +84,7 @@ def _reload_api_server(args): fake_envs_mod.EXPORTER_OTLP_ENDPOINT = "" fake_envs_mod.EXPORTER_OTLP_HEADERS = "" fake_envs_mod.FD_SUPPORT_MAX_CONNECTIONS = 1024 + fake_envs_mod.FD_ENABLE_V1_UPDATE_WEIGHTS = 0 fake_envs_mod.environment_variables = _FakeEnvVars() # Save original sys.argv and replace with minimal valid args to avoid parse errors @@ -98,9 +99,8 @@ def _reload_api_server(args): patch.dict("sys.modules", {"fastdeploy.envs": fake_envs_mod}), patch("fastdeploy.envs", fake_envs_mod), ): - from fastdeploy.entrypoints.openai import api_server as api_server_mod - - return importlib.reload(api_server_mod) + sys.modules.pop("fastdeploy.entrypoints.openai.api_server", None) + return importlib.import_module("fastdeploy.entrypoints.openai.api_server") finally: sys.argv = original_argv @@ -536,6 +536,95 @@ async def test_reward_embedding_and_weights(): assert api_server.clear_load_weight(MagicMock()).status_code == 404 +@pytest.mark.asyncio +async def test_sleep_wakeup_and_v1_clear_load_weight_routes(): + args = _build_args(dynamic_load_weight=True) + api_server = _reload_api_server(args) + api_server.app.state.dynamic_load_weight = True + api_server.app.state.event_loop = asyncio.get_running_loop() + + mock_control_response = MagicMock() + mock_control_response.to_api_json_response.return_value = api_server.JSONResponse( + content={"ok": True}, status_code=200 + ) + + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response) + api_server.app.state.engine_client.run_control_method_sync.return_value = mock_control_response + + sleep_req = MagicMock() + sleep_req.body = AsyncMock(return_value=b'{"tags":"weight"}') + sleep_req.json = AsyncMock(return_value={"tags": "weight"}) + sleep_resp = await api_server.sleep(sleep_req) + assert sleep_resp.status_code == 200 + sleep_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[0].args[0] + assert sleep_control_request.method == "sleep" + assert sleep_control_request.args == {"tags": "weight"} + + wakeup_req = MagicMock() + wakeup_req.body = AsyncMock(return_value=b'{"tags":"weight,kv_cache"}') + wakeup_req.json = AsyncMock(return_value={"tags": "weight,kv_cache"}) + wakeup_resp = await api_server.wakeup(wakeup_req) + assert wakeup_resp.status_code == 200 + wakeup_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[1].args[0] + assert wakeup_control_request.method == "wakeup" + assert wakeup_control_request.args == {"tags": "weight,kv_cache"} + + passthrough_req = MagicMock() + passthrough_req.body = AsyncMock(return_value=b'{"tags":["weight"]}') + passthrough_req.json = AsyncMock(return_value={"tags": ["weight"]}) + passthrough_resp = await api_server.sleep(passthrough_req) + assert passthrough_resp.status_code == 200 + passthrough_control_request = api_server.app.state.engine_client.run_control_method.await_args_list[2].args[0] + assert passthrough_control_request.args == {"tags": ["weight"]} + + with patch.object(api_server.envs, "FD_ENABLE_V1_UPDATE_WEIGHTS", True): + clear_resp = api_server.clear_load_weight(MagicMock()) + assert clear_resp.status_code == 200 + sync_control_request = api_server.app.state.engine_client.run_control_method_sync.call_args.args[0] + assert sync_control_request.method == "sleep" + + api_server.app.state.engine_client.run_control_method_sync.reset_mock() + with patch.object(api_server.envs, "FD_ENABLE_V1_UPDATE_WEIGHTS", True): + update_resp = api_server.update_model_weight(MagicMock()) + assert update_resp.status_code == 200 + sync_control_request = api_server.app.state.engine_client.run_control_method_sync.call_args.args[0] + assert sync_control_request.method == "wakeup" + + +@pytest.mark.asyncio +async def test_update_weights_route_validation(): + args = _build_args(dynamic_load_weight=True) + api_server = _reload_api_server(args) + mock_control_response = MagicMock() + mock_control_response.to_api_json_response.return_value = api_server.JSONResponse( + content={"ok": True}, status_code=200 + ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response) + + valid_req = MagicMock() + valid_req.body = AsyncMock(return_value=b'{"version":"v2","rsync_config":{"etcd_server":"127.0.0.1"}}') + valid_req.json = AsyncMock(return_value={"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}}) + valid_resp = await api_server.update_weights(valid_req) + assert valid_resp.status_code == 200 + control_request = api_server.app.state.engine_client.run_control_method.await_args.args[0] + assert control_request.method == "update_weights" + assert control_request.args == {"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}} + + invalid_version_req = MagicMock() + invalid_version_req.body = AsyncMock(return_value=b'{"version":1}') + invalid_version_req.json = AsyncMock(return_value={"version": 1}) + invalid_version_resp = await api_server.update_weights(invalid_version_req) + assert invalid_version_resp.status_code == 400 + + invalid_rsync_req = MagicMock() + invalid_rsync_req.body = AsyncMock(return_value=b'{"rsync_config":{"user":"u"}}') + invalid_rsync_req.json = AsyncMock(return_value={"rsync_config": {"user": "u"}}) + invalid_rsync_resp = await api_server.update_weights(invalid_rsync_req) + assert invalid_rsync_resp.status_code == 400 + + @pytest.mark.asyncio async def test_expert_and_stats_routes(): args = _build_args() diff --git a/tests/graph_optimization/test_cuda_graph_recapture.py b/tests/graph_optimization/test_cuda_graph_recapture.py index a7640c5700..1a28c0731b 100644 --- a/tests/graph_optimization/test_cuda_graph_recapture.py +++ b/tests/graph_optimization/test_cuda_graph_recapture.py @@ -142,34 +142,34 @@ class TestCUDAGrpahRecapture(unittest.TestCase): def capture_and_replay(self, input_tensor1, forward_meta1): """ """ # Trigger Capture - print_gpu_memory_use(0, "before capture") + print_gpu_memory_use("before capture", 0) output1 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - print_gpu_memory_use(0, "after capture") + print_gpu_memory_use("after capture", 0) # Replay output1 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) assert (output1 == self.output_correct).all() # Destroy - print_gpu_memory_use(0, "before destroy") + print_gpu_memory_use("before destroy", 0) self.test_model1.clear_grpah_opt_backend() - print_gpu_memory_use(0, "after destroy") + print_gpu_memory_use("after destroy", 0) def recapture_and_replay(self, input_tensor1, forward_meta1): """ """ # Trigger Capture - print_gpu_memory_use(0, "before recapture") + print_gpu_memory_use("before recapture", 0) output2 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) - print_gpu_memory_use(0, "after recapture") + print_gpu_memory_use("after recapture", 0) # Replay output2 = self.test_model1(ids_remove_padding=input_tensor1, forward_meta=forward_meta1) assert (output2 == self.output_correct).all() # Destroy - print_gpu_memory_use(0, "before destroy") + print_gpu_memory_use("before destroy", 0) self.test_model1.clear_grpah_opt_backend() - print_gpu_memory_use(0, "after destroy") + print_gpu_memory_use("after destroy", 0) if __name__ == "__main__":