Files
FastDeploy/docs/features/disaggregated.md
T
Jingfeng Wu 3b564116d5 [Docs] Add docs for disaggregated deployment (#6700)
* add docs for disaggregated deployment

* pre-commit run for style check

* update docs
2026-04-01 19:27:09 +08:00

263 lines
10 KiB
Markdown

[简体中文](../zh/features/disaggregated.md)
[Best Practice](../best_practices/Disaggregated.md)
# Disaggregated Deployment
Large Language Model (LLM) inference is divided into two phases: **Prefill** and **Decode**, which are compute-intensive and memory-bound, respectively.
* **Prefill Phase:** Processes all input tokens, completes the model's forward pass, and generates the first token.
* **Decode Phase:** Generates subsequent tokens based on the first token and the cached KV Cache. Assuming a total output of N tokens, the Decode phase requires executing (N-1) forward passes.
Disaggregated deployment involves deploying Prefill and Decode on distinct computing resources, each using optimal configurations. This approach improves hardware utilization, increases throughput, and reduces end-to-end latency.
<p align="center">
<img src="../zh/features/images/mix_pd.png" width="50%">
</p>
Compared to mixed deployment, the core implementation differences of disaggregated deployment lie in **KV Cache transmission** and **request scheduling**.
## KV Cache Transmission
In disaggregated deployment, the KV Cache generated by the request in the Prefill instance needs to be transmitted to the Decode instance. FastDeploy provides two transmission methods targeting intra-node and inter-node scenarios.
**Intra-node transmission:** Uses `cudaMemcpyPeer` for KV Cache transmission between two GPUs within a single node.
**Inter-node transmission:** Uses a self-developed [RDMA transmission library](https://github.com/PaddlePaddle/FastDeploy/tree/develop/fastdeploy/cache_manager/transfer_factory/kvcache_transfer) to transfer KV Cache between multiple nodes.
## PD Disaggregated Request Scheduling
For PD (Prefill-Decode) disaggregated deployment, FastDeploy provides a Python version of the [Router](https://github.com/PaddlePaddle/FastDeploy/tree/develop/fastdeploy/router) to implement request reception and scheduling. The usage method and scheduling flow are as follows:
* Start the Router.
* Start PD instances, the PD instances will register with the Router.
* User requests are sent to the Router.
* The Router selects a suitable PD instance pair based on the load conditions of the PD instances.
* The Router forwards the request to the selected PD instance.
* The Router receives the generation results from the PD instance and returns them to the user.
A high-performance version of the Router is currently under development. Stay tuned.
## Usage Instructions
### Router-based Disaggregated Deployment
#### Environment Preparation
Please refer to the [documentation](https://github.com/PaddlePaddle/FastDeploy/tree/develop/docs/zh/get_started/installation) to prepare the environment. Using Docker is recommended.
If you are setting up the runtime environment manually, ensure that RDMA dependency packages (`librdmacm-dev`, `libibverbs-dev`, `iproute2`) and the [MLNX_OFED](https://network.nvidia.com/products/infiniband-drivers/linux/mlnx_ofed/) driver are installed.
```bash
apt update --fix-missing
apt-get install -y librdmacm-dev libibverbs-dev iproute2
# Download and install MLNX_OFED
./mlnxofedinstall --user-space-only --skip-distro-check --without-fw-update --force --without-ucx-cuda
```
Pull the latest FastDeploy code, build, and install.
```bash
git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy
bash build.sh
```
#### Deploy Services
**Quick Start**
Start the Router service. The `--splitwise` parameter specifies the scheduling mode as disaggregated deployment. Log information is output to `log_router/router.log`. `fd-router` installation instructions can be found in the [Router documentation](../online_serving/router.md).
```bash
export FD_LOG_DIR="log_router"
/usr/local/bin/fd-router \
--port 30000 \
--splitwise
```
Start the Prefill instance. Compared to single-node deployment, add the `--splitwise-role` parameter to specify the instance role as Prefill, and the `--router` parameter to specify the Router interface. Other parameters remain the same as mixed deployment.
```bash
export CUDA_VISIBLE_DEVICES=0
export FD_LOG_DIR="log_prefill"
python -m fastdeploy.entrypoints.openai.api_server \
--model "PaddlePaddle/ERNIE-4.5-0.3B-Paddle" \
--port 31000 \
--splitwise-role prefill \
--router "0.0.0.0:30000"
```
Start the Decode instance.
```bash
export CUDA_VISIBLE_DEVICES=1
export FD_LOG_DIR="log_decode"
python -m fastdeploy.entrypoints.openai.api_server \
--model "PaddlePaddle/ERNIE-4.5-0.3B-Paddle" \
--port 32000 \
--splitwise-role decode \
--router "0.0.0.0:30000"
```
After the Prefill and Decode instances are successfully started and registered with the Router, you can send requests.
```bash
curl -X POST "http://0.0.0.0:30000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "hello"}
],
"max_tokens": 100,
"stream": false
}'
```
**Details Description**
Parameter description for starting Prefill/Decode instances in disaggregated deployment:
* `--splitwise-role`: Specifies the instance role. Options are `prefill`, `decode`, and `mixed`. Default is `mixed`.
* `--cache-transfer-protocol`: Specifies the KV Cache transfer protocol. Options are `rdma` and `ipc`. Default is `rdma` and `ipc`. If PD instances are on the same machine, `ipc` transmission is prioritized.
* `--rdma-comm-ports`: Specifies RDMA communication ports, separated by commas. The number of ports must equal `dp_size * tp_size`. If unspecified, FD will internally find free ports.
* `--pd-comm-port`: Specifies the interaction interface for PD instances, separated by commas. The number of ports must equal `dp_size`. If unspecified, FD will internally find free ports.
* `--router`: Specifies the Router interface.
If the Prefill and Decode instances are deployed on different machines, RDMA network connectivity between the machines must be ensured.
To manually specify RDMA network interfaces, you can set the `KVCACHE_RDMA_NICS` environment variable. Multiple NICs should be separated by commas. FastDeploy provides a script to detect RDMA NICs automatically:
`bash FastDeploy/scripts/get_rdma_nics.sh <device>`, where `<device>` can be either `cpu` or `gpu`.
If the `KVCACHE_RDMA_NICS` environment variable is not set, FastDeploy will automatically detect available RDMA NICs internally.
**Examples**
PD disaggregated deployment supports features such as prefix caching, Tensor Parallelism (TP), and Data Parallelism (DP). For specific examples, please refer to [examples/splitwise](https://github.com/PaddlePaddle/FastDeploy/tree/develop/examples/splitwise).
### SplitwiseScheduler-based Disaggregated Deployment
**Note: Using SplitwiseScheduler is not recommended. It is recommended to use the Router for request scheduling.**
#### Environment Preparation
* Install using `conda`
> **⚠️ Note**
> **Redis Version Requirement: 6.2.0 and above**
> Versions below this may not support required commands.
```bash
# Install
conda install redis
# Start
nohup redis-server > redis.log 2>&1 &
```
* Install using `apt`
```bash
# Install
sudo apt install redis-server -y
# Start
sudo systemctl start redis-server
```
* Install using `yum`
```bash
# Install
sudo yum install redis -y
# Start
sudo systemctl start redis
```
#### Deploy Services
For multi-node deployment, ensure that the current network interface card supports RDMA and that all nodes in the cluster have network connectivity.
**Note**:
* `KVCACHE_RDMA_NICS` specifies the RDMA NICs of the current machine; separate multiple NICs with commas.
* The repository provides a script to automatically detect RDMA NICs: `bash scripts/get_rdma_nics.sh <device>`, where `<device>` can be `cpu` or `gpu`.
**prefill instance**
```bash
export FD_LOG_DIR="log_prefill"
export CUDA_VISIBLE_DEVICES=0,1,2,3
export ENABLE_V1_KVCACHE_SCHEDULER=0
echo "set RDMA NICS"
export $(bash scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
python -m fastdeploy.entrypoints.openai.api_server \
--model ERNIE-4.5-300B-A47B-BF16 \
--port 8180 --metrics-port 8181 \
--engine-worker-queue-port 8182 \
--cache-queue-port 8183 \
--tensor-parallel-size 4 \
--quantization wint4 \
--cache-transfer-protocol "rdma,ipc" \
--rdma-comm-ports "7671,7672,7673,7674" \
--pd-comm-port "2334" \
--splitwise-role "prefill" \
--scheduler-name "splitwise" \
--scheduler-host "127.0.0.1" \
--scheduler-port 6379 \
--scheduler-topic "test" \
--scheduler-ttl 9000
```
**decode instance**
```bash
export FD_LOG_DIR="log_decode"
export CUDA_VISIBLE_DEVICES=4,5,6,7
export ENABLE_V1_KVCACHE_SCHEDULER=0
echo "set RDMA NICS"
export $(bash scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
python -m fastdeploy.entrypoints.openai.api_server \
--model ERNIE-4.5-300B-A47B-BF16 \
--port 8184 --metrics-port 8185 \
--engine-worker-queue-port 8186 \
--cache-queue-port 8187 \
--tensor-parallel-size 4 \
--quantization wint4 \
--scheduler-name "splitwise" \
--cache-transfer-protocol "rdma,ipc" \
--rdma-comm-ports "7671,7672,7673,7674" \
--pd-comm-port "2334" \
--scheduler-host "127.0.0.1" \
--scheduler-port 6379 \
--scheduler-ttl 9000
--scheduler-topic "test" \
--splitwise-role "decode"
```
Parameter Explanation:
* `--splitwise-role`: Specifies whether the current service is prefill or decode.
* `--cache-queue-port`: Specifies the cache service port used for communication between prefill and decode services.
Multi-node Parameter Explanation:
* `--cache-transfer-protocol`: Specifies the KV Cache transfer protocol; supports `ipc` and `rdma`. Defaults to `ipc`.
* `--scheduler-name`: Set to `splitwise` for PD disaggregation.
* `--scheduler-host`: The Redis address to connect to.
* `--scheduler-port`: The Redis port to connect to.
* `--scheduler-ttl`: Specifies the Redis TTL (Time To Live) in seconds.
* `--scheduler-topic`: Specifies the Redis topic.
* `--pd-comm-port`: Specifies the PD communication port.
* `--rdma-comm-ports`: Specifies the RDMA communication ports, separated by commas; the quantity must match the number of cards.