[RL][CI] Support Async R3 And Add Accuracy Test (#5937)

* add bs1 r3 test case

* async put

* r3 test case 1.0

* success run eb5

* refine test case

* pre-commit

* add eb45 & glm testcase

* format code

* add p2pstore requirements

* support only last turn

* R3 use worker log

* refine code &fix ci bug

* refine error mesg

* fix empty input bug

* Success set acc ci of eb45 and glm45

* refine code

* fix bug
This commit is contained in:
RAM
2026-01-14 20:25:06 +08:00
committed by GitHub
parent 9373f373dc
commit b3f59fd9b5
9 changed files with 443 additions and 20 deletions
+4 -2
View File
@@ -644,14 +644,16 @@ class FusedMoE(nn.Layer):
"""
topk_ids_hookfunc = None
if self.enable_routing_replay:
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
# When execute empty_input_forward forward_meta is None. When execute mtp layer routing_replay_table is None.
if forward_meta is not None and forward_meta.routing_replay_table is not None:
moe_layer_idx = self.layer_idx - self.fd_config.model_config.moe_layer_start_index
topk_ids_hookfunc = partial(
save_routing_to_buffer,
routing_replay_table=forward_meta.routing_replay_table,
batch_id_per_token=forward_meta.batch_id_per_token,
seq_lens_decoder=forward_meta.seq_lens_decoder,
cu_seqlens_q=forward_meta.cu_seqlens_q,
layer_idx=self.layer_idx,
layer_idx=moe_layer_idx,
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
ep_size=self.fd_config.parallel_config.expert_parallel_size,
tp_group=self.fd_config.parallel_config.tp_group,
@@ -26,6 +26,7 @@ import paddle
import paddle.distributed as dist
import triton
import triton.language as tl
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
@@ -110,6 +111,8 @@ def save_routing_to_buffer(
):
if tp_size > 1 and ep_size > 1:
token_num_per_rank = topk_ids.shape[0]
if token_num_per_rank == 0:
return
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
@@ -152,6 +155,7 @@ class RoutingReplayManager:
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.max_model_len = fd_config.model_config.max_model_len
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
self.only_last_turn = fd_config.routing_replay_config.only_last_turn
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
self.moe_top_k = fd_config.model_config.num_experts_per_tok
@@ -177,9 +181,10 @@ class RoutingReplayManager:
# Save requests that have been finished for the current slot
if batch_id in self.routing_batch_to_request:
pre_request_id = self._deregister_request(batch_id)
self._put_request_to_store(batch_id, pre_request_id)
asyncio.run(self._put_request_to_store(batch_id, pre_request_id))
# Register the new request
self.routing_batch_to_request[batch_id] = request_id
logger.info(f"[R3] Register request {request_id} with batch id {batch_id}")
def _deregister_request(self, batch_id: int) -> str:
"""
@@ -188,26 +193,35 @@ class RoutingReplayManager:
assert batch_id in self.routing_batch_to_request
return self.routing_batch_to_request.pop(batch_id)
def _put_request_to_store(
async def _put_request_to_store(
self,
batch_id: int,
request_id: str,
):
before_put_request_time = time.perf_counter()
if self.tp_rank == 0:
batch_buffer = self.routing_replay_table[batch_id]
tasks = []
for layer_id in range(self.num_moe_layers):
layer_buffer = batch_buffer[layer_id]
rollout_id = self.split_request_id(request_id)
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
tasks.append(
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
)
if self.only_last_turn:
prefix_batch = self.get_needed_clear_ids(rollout_id)
tasks.append(self.routing_store.clear_prefix_batch(roullout_id_prefixes=prefix_batch))
await asyncio.gather(*tasks)
logger.info(f"[R3] Async put {request_id} time cost: {time.perf_counter() - before_put_request_time}")
self._clear_table_slot(batch_id)
def put_table_to_store(self):
"""Put the routing table"""
logger.info("[R3] Put routing table to store.")
batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys()))
for batch_id in batch_ids:
request_id = self._deregister_request(batch_id)
self._put_request_to_store(batch_id, request_id)
asyncio.run(self._put_request_to_store(batch_id, request_id))
def _clear_table_slot(self, batch_id: int):
assert 0 <= batch_id < self.max_num_seqs
@@ -241,14 +255,39 @@ class RoutingReplayManager:
return self.routing_replay_table
def split_request_id(self, request_id: str):
"""Split the request id to get rollout id"""
"""
Split the request id to get rollout id.
request_id: "chatcmpl-request.user-uuid"
rollout_id: "request.user"
example: "chatcmpl-xxx_xxx_epoch_15:2:2:1-d9f16c5c-65f6-4815-b44d-14e2c581907c_0" -> "xxx_xxx_epoch_15:2:2:1"
"""
chat_type, tmp_str = request_id.split("-", 1)
# NOTE(gongshaotian): only support chatcmpl now
# assert chat_type == "chatcmpl"
assert (
chat_type == "chatcmpl"
), "Rollout Routing Replay only supports chatcmpl. Please check whether the request type and userid settings are correct."
reversed_tmp_str = tmp_str[::-1].split("-", 5)
rollout_id = reversed_tmp_str[-1][::-1]
return rollout_id
def get_needed_clear_ids(self, roullout_id: str) -> List[str]:
"""
Generate the prefix IDs for all closed multi-round tasks.
rollout_id: "xxx_xxx_epoch_15:2:2:1"
example: xxx_xxx_data_id:gen_id:turn_id:segment_id
"""
reversed_segment_id, reversed_turn_id, reversed_prefix_gen_id = roullout_id[::-1].split(":", 2)
prefix_gen_id = reversed_prefix_gen_id[::-1]
turn_id = eval(reversed_turn_id[::-1])
segment_id = eval(reversed_segment_id[::-1])
assert turn_id >= 0 and segment_id >= 0
prefix_batch = []
if turn_id > 0:
prefix_batch.append(f"{prefix_gen_id}:{(turn_id-1)}:{segment_id}")
return prefix_batch
def clear_request(self, batch_id: int):
"""Clear the routing indices of the request"""
self._clear_table_slot(batch_id)
@@ -262,7 +301,7 @@ class RoutingStoreBase(ABC):
self.fd_config = fd_config
@abstractmethod
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
async def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
"""Put the routing indices into store"""
raise NotImplementedError
@@ -283,6 +322,11 @@ class RoutingStoreBase(ABC):
"""Clear the routing indices store"""
raise NotImplementedError
@abstractmethod
async def clear_prefix_batch(self, roullout_id_prefixes: List[str]):
"""Clear the routing indices"""
raise NotImplementedError
class RoutingStoreLocal(RoutingStoreBase):
"""Routing Store using local memory"""
@@ -292,12 +336,17 @@ class RoutingStoreLocal(RoutingStoreBase):
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
self.clear_store()
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
async def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
routing_key = f"{rollout_id}_{layer_idx}"
# async put
time_before_put = time.perf_counter()
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
os.makedirs(dir_path, exist_ok=True)
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
paddle.save(routing_indices, file_path)
logger.info(f"[R3] The routing key {routing_key} put cost is {time.perf_counter()-time_before_put}s")
def get(
self,
@@ -334,6 +383,10 @@ class RoutingStoreLocal(RoutingStoreBase):
file_path = os.path.join(self.local_store_dir, file_name)
shutil.rmtree(file_path)
async def clear_prefix_batch(self, roullout_id_prefixes: List[str]):
# async delete
logger.info(f"[R3] clear_prefix_batch {roullout_id_prefixes}")
class RoutingStoreRDMA(RoutingStoreBase):
"""Routing Store using RDMA"""
@@ -351,16 +404,19 @@ class RoutingStoreRDMA(RoutingStoreBase):
self.p2p_client = P2PClient(p2pConfig)
self.clear_store()
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
async def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# async put
time_before_put = time.perf_counter()
routing_indices_pin = routing_indices.pin_memory()
routing_indices_pin = routing_indices.cpu()
routing_indices_np = routing_indices_pin.numpy()
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
copy_time = time.perf_counter()
await self.p2p_client.put(rdma_rollout_key, routing_indices_np)
logger.info(
f"[R3] The routing key {rdma_rollout_key} copy cost is {copy_time-time_before_put}s, put cost is {time.perf_counter()-time_before_put}s"
)
def get(
self,
@@ -383,6 +439,11 @@ class RoutingStoreRDMA(RoutingStoreBase):
# sync delete
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
async def clear_prefix_batch(self, roullout_id_prefixes: List[str]):
# async delete
await self.p2p_client.delete_prefix_batch(roullout_id_prefixes)
logger.info(f"[R3] clear_prefix_batch {roullout_id_prefixes}")
def clear_store(self):
"""Clear the routing indices store"""
# sync clear routing store