[Cherry-Pick] [BugFix] fix mtp cache attaching for pd disaggregation (#5884) (#5885)

* [fix] fix mtp cache attaching for pd disaggregation

* [fix] fix port
This commit is contained in:
Yonghua Li
2026-01-06 14:19:38 +08:00
committed by GitHub
parent f3ebd64446
commit 682e1ab2d0
+26 -1
View File
@@ -15,6 +15,7 @@
"""
import os
import time
from typing import List
import numpy as np
@@ -24,6 +25,7 @@ from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
@@ -205,7 +207,30 @@ class MTPProposer(Proposer):
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and self.scheduler_config.splitwise_role != "mixed":
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# Check if gpu runner needs to create kv cache
# 1. During profiling, it creates its own kv cache.
# 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed"
if not create_cache_tensor:
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
while cache_ready_signal.value[local_rank] != 1:
time.sleep(1)
logger.info(f"OK! Stop waiting. {cache_ready_signal.value}")
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
if not create_cache_tensor:
cache_kvs_list = []
for i in range(
self.num_main_model_layers,