[PD Disaggregation] Fix cache messager performance problem & add kv transfer benchmark tool (#6434)

* fix cache messager performance problem

* dispatch param type
This commit is contained in:
RichardWooSJTU
2026-03-02 14:28:14 +08:00
committed by GitHub
parent 6674131b0b
commit fe0b3a90ee
2 changed files with 471 additions and 6 deletions
@@ -0,0 +1,459 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
KV Cache Transfer Benchmark:
Performance test for measuring KV cache transfer throughput (GB/s)
Usage:
# Start decode server first:
python benchmark.py --splitwise_role decode --decode_rdma_port 9881 --decode_zmq_port 9882
# Then start prefill client:
python benchmark.py --splitwise_role prefill --decode_ip <decode_ip> --decode_rdma_port 9881 --decode_zmq_port 9882
"""
import argparse
import time
from dataclasses import dataclass
from typing import List
import paddle
import rdma_comm
import zmq
if paddle.is_compiled_with_xpu():
from custom_setup_ops import get_peer_mem_addr
@dataclass
class BenchmarkConfig:
"""Benchmark configuration parameters"""
num_layers: int = 61 # Number of transformer layers
max_block_num: int = 2000 # Maximum number of blocks
kv_num_head: int = 1 # Number of KV heads
block_size_seq_len: int = 64 # Block size (sequence length dimension)
head_dim: int = 128 # Hidden size per head
cache_dtype: str = "bfloat16" # Cache data type: bfloat16, float16, uint8
warmup_iters: int = 5 # Warmup iterations
benchmark_iters: int = 20 # Benchmark iterations
blocks_per_transfer: int = 50 # Number of blocks per transfer
num_queries: int = 1 # Number of queries per layer (simulate multiple queries)
@dataclass
class BenchmarkResult:
"""Benchmark result"""
total_bytes: int
elapsed_time: float
throughput_gbps: float
iterations: int
blocks_per_iter: int
layers: int
queries_per_layer: int
class KVCacheBenchmark:
def __init__(self, splitwise_role: str, config: BenchmarkConfig, port: int = None, device: str = "gpu"):
assert splitwise_role in ["prefill", "decode"], "splitwise_role must be prefill or decode"
if splitwise_role == "decode":
assert port, "port must be specified for decode server"
self.splitwise_role = splitwise_role
self.config = config
self.gpu_cache_kvs = {}
paddle.device.set_device(device)
print(f"[Benchmark] Role: {splitwise_role}, Port: {port}, Device: {device}")
print(
f"[Benchmark] Config: layers={config.num_layers}, max_blocks={config.max_block_num}, "
f"kv_heads={config.kv_num_head}, block_size={config.block_size_seq_len}, "
f"head_dim={config.head_dim}, dtype={config.cache_dtype}"
)
# Determine cache dtype
if paddle.is_compiled_with_xpu():
cache_type = "float16"
else:
cache_type = config.cache_dtype
cache_k_ptr_list = []
cache_v_ptr_list = []
cache_k_scale_ptr_list = []
cache_v_scale_ptr_list = []
# Calculate block size
block_shape = [
config.max_block_num,
config.kv_num_head,
config.block_size_seq_len,
config.head_dim,
]
scale_shape = [
config.max_block_num,
config.kv_num_head,
config.block_size_seq_len,
]
for layer_idx in range(config.num_layers):
# Create key cache
key_cache = paddle.zeros(shape=block_shape, dtype=cache_type)
self.gpu_cache_kvs[f"key_caches_{layer_idx}"] = key_cache
if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
# Create value cache
value_cache = paddle.zeros(shape=block_shape, dtype=cache_type)
self.gpu_cache_kvs[f"value_caches_{layer_idx}"] = value_cache
if paddle.is_compiled_with_xpu():
cache_v_ptr_list.append(get_peer_mem_addr(value_cache.data_ptr()))
else:
cache_v_ptr_list.append(value_cache.data_ptr())
# Create scale tensors
key_scale = paddle.zeros(shape=scale_shape, dtype="float32")
self.gpu_cache_kvs[f"key_scale_{layer_idx}"] = key_scale
if paddle.is_compiled_with_xpu():
cache_k_scale_ptr_list.append(get_peer_mem_addr(key_scale.data_ptr()))
else:
cache_k_scale_ptr_list.append(key_scale.data_ptr())
value_scale = paddle.zeros(shape=scale_shape, dtype="float32")
self.gpu_cache_kvs[f"value_scale_{layer_idx}"] = value_scale
if paddle.is_compiled_with_xpu():
cache_v_scale_ptr_list.append(get_peer_mem_addr(value_scale.data_ptr()))
else:
cache_v_scale_ptr_list.append(value_scale.data_ptr())
# Initialize prefill data with pattern for validation
if self.splitwise_role == "prefill":
key_cache.fill_(1.0)
value_cache.fill_(1.0)
# Calculate block bytes
dtype_bytes = 2 if cache_type in ["bfloat16", "float16"] else 1
block_bytes = config.kv_num_head * config.block_size_seq_len * config.head_dim * dtype_bytes
scale_block_bytes = config.kv_num_head * config.block_size_seq_len * 4 # float32
self.block_bytes = block_bytes
self.scale_block_bytes = scale_block_bytes
self.bytes_per_block = block_bytes * 2 + scale_block_bytes * 2 # K + V + K_scale + V_scale
print(f"[Benchmark] Block bytes: {block_bytes}, Scale bytes: {scale_block_bytes}")
print(f"[Benchmark] Total bytes per block (K+V+scales): {self.bytes_per_block}")
print(
f"[Benchmark] Total cache memory: {self._format_bytes(self.bytes_per_block * config.max_block_num * config.num_layers)}"
)
# Create RDMA communicator
self.rdma_comm = rdma_comm.RDMACommunicator(
splitwise_role,
0,
str(port) if self.splitwise_role == "decode" else "0",
cache_k_ptr_list,
cache_v_ptr_list,
config.max_block_num,
block_bytes,
cache_k_scale_ptr_list,
cache_v_scale_ptr_list,
scale_block_bytes,
)
def _format_bytes(self, bytes_val: int) -> str:
"""Format bytes to human readable string"""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if bytes_val < 1024:
return f"{bytes_val:.2f} {unit}"
bytes_val /= 1024
return f"{bytes_val:.2f} PB"
def is_connected(self, ip: str, port: int) -> bool:
assert self.splitwise_role == "prefill", "only prefill can call this method"
return self.rdma_comm.is_connected(ip, str(port))
def connect(self, ip: str, port: int) -> bool:
assert self.splitwise_role == "prefill", "only prefill can call this method"
if self.is_connected(ip, port):
return True
self.rdma_comm.connect(ip, str(port))
return True
def write_cache(self, ip: str, port: int, query_block_ids_list: List[List[int]]):
"""Write cache to remote with all layers and queries
Args:
ip: Remote IP address
port: Remote port
query_block_ids_list: List of block_ids for each query, length should be num_queries
"""
assert self.splitwise_role == "prefill", "only prefill can call this method"
assert (
len(query_block_ids_list) == self.config.num_queries
), f"Expected {self.config.num_queries} query block_ids, got {len(query_block_ids_list)}"
for layer_idx in range(self.config.num_layers):
for query_idx in range(self.config.num_queries):
block_ids = query_block_ids_list[query_idx]
self.rdma_comm.write_cache(ip, str(port), block_ids, block_ids, layer_idx)
def _generate_query_block_ids(self, blocks_per_query: int) -> List[List[int]]:
"""Generate block_ids for each query with gaps between queries
Each query has contiguous block_ids, but different queries are separated by gaps.
For example, with 2 queries, 3 blocks each, and gap=10:
query 0: [0, 1, 2]
query 1: [13, 14, 15] (starts at 0 + 3 + 10 = 13)
"""
num_queries = self.config.num_queries
# Gap between query block ranges to simulate non-contiguous access
gap_between_queries = blocks_per_query # gap equals to blocks_per_query
total_blocks_needed = blocks_per_query * num_queries + gap_between_queries * (num_queries - 1)
if total_blocks_needed > self.config.max_block_num:
raise ValueError(
f"Not enough blocks: need {total_blocks_needed} "
f"(blocks_per_query={blocks_per_query} x num_queries={num_queries} + gaps), "
f"but max_block_num={self.config.max_block_num}"
)
# Generate contiguous blocks for each query with gaps between queries
query_block_ids_list = []
start_block = 0
for query_idx in range(num_queries):
# Contiguous block_ids within each query
block_ids = list(range(start_block, start_block + blocks_per_query))
query_block_ids_list.append(block_ids)
# Next query starts after current blocks + gap
start_block += blocks_per_query + gap_between_queries
return query_block_ids_list
def synchronize(self):
"""Synchronize device"""
if paddle.is_compiled_with_cuda():
paddle.device.cuda.synchronize()
elif paddle.is_compiled_with_xpu():
paddle.device.xpu.synchronize()
def run_benchmark(self, ip: str, port: int) -> BenchmarkResult:
"""Run the benchmark and return results"""
assert self.splitwise_role == "prefill", "only prefill can run benchmark"
blocks_per_query = min(self.config.blocks_per_transfer, self.config.max_block_num // self.config.num_queries)
# Generate interleaved block_ids for each query
query_block_ids_list = self._generate_query_block_ids(blocks_per_query)
print(f"[Benchmark] Blocks per query: {blocks_per_query}, Queries: {self.config.num_queries}")
for i, block_ids in enumerate(query_block_ids_list):
print(f"[Benchmark] Query {i} block_ids: {block_ids[:5]}...{block_ids[-3:] if len(block_ids) > 5 else ''}")
# Warmup
print(f"\n[Benchmark] Running {self.config.warmup_iters} warmup iterations...")
for i in range(self.config.warmup_iters):
self.write_cache(ip, port, query_block_ids_list)
self.synchronize()
print("[Benchmark] Warmup complete")
# Benchmark
print(f"[Benchmark] Running {self.config.benchmark_iters} benchmark iterations...")
self.synchronize()
start_time = time.perf_counter()
for i in range(self.config.benchmark_iters):
self.write_cache(ip, port, query_block_ids_list)
self.synchronize()
end_time = time.perf_counter()
elapsed_time = end_time - start_time
# Calculate throughput
# bytes per iteration = blocks_per_query * num_queries * layers * (K + V + K_scale + V_scale)
bytes_per_iter = blocks_per_query * self.config.num_queries * self.config.num_layers * self.bytes_per_block
total_bytes = bytes_per_iter * self.config.benchmark_iters
throughput_gbps = (total_bytes / elapsed_time) / (1024**3)
result = BenchmarkResult(
total_bytes=total_bytes,
elapsed_time=elapsed_time,
throughput_gbps=throughput_gbps,
iterations=self.config.benchmark_iters,
blocks_per_iter=blocks_per_query,
layers=self.config.num_layers,
queries_per_layer=self.config.num_queries,
)
return result
def print_results(self, result: BenchmarkResult):
"""Print benchmark results"""
print("\n" + "=" * 60)
print("KV Cache Transfer Benchmark Results")
print("=" * 60)
print("Configuration:")
print(f" - Layers: {result.layers}")
print(f" - Queries per layer: {result.queries_per_layer}")
print(f" - Blocks per transfer: {result.blocks_per_iter}")
print(f" - KV heads: {self.config.kv_num_head}")
print(f" - Block size: {self.config.block_size_seq_len}")
print(f" - Head dim: {self.config.head_dim}")
print(f" - Data type: {self.config.cache_dtype}")
print("\nResults:")
print(f" - Iterations: {result.iterations}")
print(f" - Total time: {result.elapsed_time:.4f} s")
print(f" - Total data transferred: {self._format_bytes(result.total_bytes)}")
print(f" - Throughput: {result.throughput_gbps:.2f} GB/s")
print(f" - Avg latency per iter: {(result.elapsed_time / result.iterations) * 1000:.2f} ms")
print("=" * 60)
def parse_args():
parser = argparse.ArgumentParser(description="KV Cache Transfer Benchmark")
parser.add_argument(
"--splitwise_role",
type=str,
default="prefill",
choices=["prefill", "decode"],
help="Role: prefill (client) or decode (server)",
)
parser.add_argument("--decode_ip", type=str, default="127.0.0.1", help="Decode server IP")
parser.add_argument("--decode_rdma_port", type=int, default=9881, help="RDMA port")
parser.add_argument("--decode_zmq_port", type=int, default=9882, help="ZMQ port for control")
# Benchmark config
parser.add_argument("--num_layers", type=int, default=61, help="Number of transformer layers")
parser.add_argument("--max_block_num", type=int, default=2000, help="Maximum number of blocks")
parser.add_argument("--kv_num_head", type=int, default=1, help="Number of KV heads")
parser.add_argument("--block_size", type=int, default=64, help="Block size (sequence length)")
parser.add_argument("--head_dim", type=int, default=128, help="Head dimension")
parser.add_argument(
"--cache_dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "uint8"], help="Cache data type"
)
parser.add_argument("--warmup_iters", type=int, default=5, help="Warmup iterations")
parser.add_argument("--benchmark_iters", type=int, default=20, help="Benchmark iterations")
parser.add_argument("--blocks_per_transfer", type=int, default=64, help="Number of blocks per transfer")
parser.add_argument("--num_queries", type=int, default=4, help="Number of queries per layer to simulate")
parser.add_argument("--device", type=str, default="gpu", help="Device to use (gpu/cpu/xpu)")
return parser.parse_args()
def run_decode_server(args, config: BenchmarkConfig):
"""Run decode server"""
context = zmq.Context()
server_socket = context.socket(zmq.REP)
server_socket.bind(f"tcp://0.0.0.0:{args.decode_zmq_port}")
print(f"[Server] ZMQ server started on port {args.decode_zmq_port}")
_ = KVCacheBenchmark("decode", config, port=args.decode_rdma_port, device=args.device)
print("[Server] Waiting for benchmark client...")
while True:
try:
obj = server_socket.recv_pyobj()
if obj.get("msg_type") == "ping":
server_socket.send_pyobj({"status": "ready"})
elif obj.get("msg_type") == "benchmark_start":
print("[Server] Benchmark started by client")
server_socket.send_pyobj({"status": "ok"})
elif obj.get("msg_type") == "benchmark_done":
print("[Server] Benchmark completed")
server_socket.send_pyobj({"status": "ok"})
break
else:
server_socket.send_pyobj({"status": "unknown"})
except Exception as e:
print(f"[Server] Error: {e}")
break
print("[Server] Shutting down")
def run_prefill_client(args, config: BenchmarkConfig):
"""Run prefill client (benchmark runner)"""
context = zmq.Context()
client_socket = context.socket(zmq.REQ)
client_socket.connect(f"tcp://{args.decode_ip}:{args.decode_zmq_port}")
benchmark = KVCacheBenchmark("prefill", config, device=args.device)
# Connect to decode server
print(f"[Client] Connecting to decode server at {args.decode_ip}:{args.decode_rdma_port}...")
benchmark.connect(args.decode_ip, args.decode_rdma_port)
retry_count = 0
max_retries = 30
while not benchmark.is_connected(args.decode_ip, args.decode_rdma_port):
time.sleep(1)
retry_count += 1
if retry_count >= max_retries:
print("[Client] Failed to connect to decode server")
return
print(f"[Client] Waiting for connection... ({retry_count}/{max_retries})")
print("[Client] Connected to decode server")
# Ping server
client_socket.send_pyobj({"msg_type": "ping"})
reply = client_socket.recv_pyobj()
if reply.get("status") != "ready":
print("[Client] Server not ready")
return
# Notify benchmark start
client_socket.send_pyobj({"msg_type": "benchmark_start"})
client_socket.recv_pyobj()
# Run benchmark
result = benchmark.run_benchmark(args.decode_ip, args.decode_rdma_port)
benchmark.print_results(result)
# Notify benchmark done
client_socket.send_pyobj({"msg_type": "benchmark_done"})
client_socket.recv_pyobj()
def main():
args = parse_args()
config = BenchmarkConfig(
num_layers=args.num_layers,
max_block_num=args.max_block_num,
kv_num_head=args.kv_num_head,
block_size_seq_len=args.block_size,
head_dim=args.head_dim,
cache_dtype=args.cache_dtype,
warmup_iters=args.warmup_iters,
benchmark_iters=args.benchmark_iters,
blocks_per_transfer=args.blocks_per_transfer,
num_queries=args.num_queries,
)
if args.splitwise_role == "decode":
run_decode_server(args, config)
else:
run_prefill_client(args, config)
if __name__ == "__main__":
main()