mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL][CI] Support Async R3 And Add Accuracy Test (#5937)
* add bs1 r3 test case * async put * r3 test case 1.0 * success run eb5 * refine test case * pre-commit * add eb45 & glm testcase * format code * add p2pstore requirements * support only last turn * R3 use worker log * refine code &fix ci bug * refine error mesg * fix empty input bug * Success set acc ci of eb45 and glm45 * refine code * fix bug
This commit is contained in:
@@ -644,14 +644,16 @@ class FusedMoE(nn.Layer):
|
||||
"""
|
||||
topk_ids_hookfunc = None
|
||||
if self.enable_routing_replay:
|
||||
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
|
||||
# When execute empty_input_forward forward_meta is None. When execute mtp layer routing_replay_table is None.
|
||||
if forward_meta is not None and forward_meta.routing_replay_table is not None:
|
||||
moe_layer_idx = self.layer_idx - self.fd_config.model_config.moe_layer_start_index
|
||||
topk_ids_hookfunc = partial(
|
||||
save_routing_to_buffer,
|
||||
routing_replay_table=forward_meta.routing_replay_table,
|
||||
batch_id_per_token=forward_meta.batch_id_per_token,
|
||||
seq_lens_decoder=forward_meta.seq_lens_decoder,
|
||||
cu_seqlens_q=forward_meta.cu_seqlens_q,
|
||||
layer_idx=self.layer_idx,
|
||||
layer_idx=moe_layer_idx,
|
||||
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
|
||||
ep_size=self.fd_config.parallel_config.expert_parallel_size,
|
||||
tp_group=self.fd_config.parallel_config.tp_group,
|
||||
|
||||
@@ -26,6 +26,7 @@ import paddle
|
||||
import paddle.distributed as dist
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
@@ -110,6 +111,8 @@ def save_routing_to_buffer(
|
||||
):
|
||||
if tp_size > 1 and ep_size > 1:
|
||||
token_num_per_rank = topk_ids.shape[0]
|
||||
if token_num_per_rank == 0:
|
||||
return
|
||||
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
|
||||
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
|
||||
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
|
||||
@@ -152,6 +155,7 @@ class RoutingReplayManager:
|
||||
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||
self.max_model_len = fd_config.model_config.max_model_len
|
||||
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
|
||||
self.only_last_turn = fd_config.routing_replay_config.only_last_turn
|
||||
|
||||
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
self.moe_top_k = fd_config.model_config.num_experts_per_tok
|
||||
@@ -177,9 +181,10 @@ class RoutingReplayManager:
|
||||
# Save requests that have been finished for the current slot
|
||||
if batch_id in self.routing_batch_to_request:
|
||||
pre_request_id = self._deregister_request(batch_id)
|
||||
self._put_request_to_store(batch_id, pre_request_id)
|
||||
asyncio.run(self._put_request_to_store(batch_id, pre_request_id))
|
||||
# Register the new request
|
||||
self.routing_batch_to_request[batch_id] = request_id
|
||||
logger.info(f"[R3] Register request {request_id} with batch id {batch_id}")
|
||||
|
||||
def _deregister_request(self, batch_id: int) -> str:
|
||||
"""
|
||||
@@ -188,26 +193,35 @@ class RoutingReplayManager:
|
||||
assert batch_id in self.routing_batch_to_request
|
||||
return self.routing_batch_to_request.pop(batch_id)
|
||||
|
||||
def _put_request_to_store(
|
||||
async def _put_request_to_store(
|
||||
self,
|
||||
batch_id: int,
|
||||
request_id: str,
|
||||
):
|
||||
before_put_request_time = time.perf_counter()
|
||||
if self.tp_rank == 0:
|
||||
batch_buffer = self.routing_replay_table[batch_id]
|
||||
tasks = []
|
||||
for layer_id in range(self.num_moe_layers):
|
||||
layer_buffer = batch_buffer[layer_id]
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
|
||||
|
||||
tasks.append(
|
||||
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
|
||||
)
|
||||
if self.only_last_turn:
|
||||
prefix_batch = self.get_needed_clear_ids(rollout_id)
|
||||
tasks.append(self.routing_store.clear_prefix_batch(roullout_id_prefixes=prefix_batch))
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(f"[R3] Async put {request_id} time cost: {time.perf_counter() - before_put_request_time}")
|
||||
self._clear_table_slot(batch_id)
|
||||
|
||||
def put_table_to_store(self):
|
||||
"""Put the routing table"""
|
||||
logger.info("[R3] Put routing table to store.")
|
||||
batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys()))
|
||||
for batch_id in batch_ids:
|
||||
request_id = self._deregister_request(batch_id)
|
||||
self._put_request_to_store(batch_id, request_id)
|
||||
asyncio.run(self._put_request_to_store(batch_id, request_id))
|
||||
|
||||
def _clear_table_slot(self, batch_id: int):
|
||||
assert 0 <= batch_id < self.max_num_seqs
|
||||
@@ -241,14 +255,39 @@ class RoutingReplayManager:
|
||||
return self.routing_replay_table
|
||||
|
||||
def split_request_id(self, request_id: str):
|
||||
"""Split the request id to get rollout id"""
|
||||
"""
|
||||
Split the request id to get rollout id.
|
||||
|
||||
request_id: "chatcmpl-request.user-uuid"
|
||||
rollout_id: "request.user"
|
||||
example: "chatcmpl-xxx_xxx_epoch_15:2:2:1-d9f16c5c-65f6-4815-b44d-14e2c581907c_0" -> "xxx_xxx_epoch_15:2:2:1"
|
||||
"""
|
||||
chat_type, tmp_str = request_id.split("-", 1)
|
||||
# NOTE(gongshaotian): only support chatcmpl now
|
||||
# assert chat_type == "chatcmpl"
|
||||
assert (
|
||||
chat_type == "chatcmpl"
|
||||
), "Rollout Routing Replay only supports chatcmpl. Please check whether the request type and userid settings are correct."
|
||||
reversed_tmp_str = tmp_str[::-1].split("-", 5)
|
||||
rollout_id = reversed_tmp_str[-1][::-1]
|
||||
return rollout_id
|
||||
|
||||
def get_needed_clear_ids(self, roullout_id: str) -> List[str]:
|
||||
"""
|
||||
Generate the prefix IDs for all closed multi-round tasks.
|
||||
rollout_id: "xxx_xxx_epoch_15:2:2:1"
|
||||
example: xxx_xxx_data_id:gen_id:turn_id:segment_id
|
||||
"""
|
||||
reversed_segment_id, reversed_turn_id, reversed_prefix_gen_id = roullout_id[::-1].split(":", 2)
|
||||
prefix_gen_id = reversed_prefix_gen_id[::-1]
|
||||
turn_id = eval(reversed_turn_id[::-1])
|
||||
segment_id = eval(reversed_segment_id[::-1])
|
||||
|
||||
assert turn_id >= 0 and segment_id >= 0
|
||||
prefix_batch = []
|
||||
if turn_id > 0:
|
||||
prefix_batch.append(f"{prefix_gen_id}:{(turn_id-1)}:{segment_id}")
|
||||
return prefix_batch
|
||||
|
||||
def clear_request(self, batch_id: int):
|
||||
"""Clear the routing indices of the request"""
|
||||
self._clear_table_slot(batch_id)
|
||||
@@ -262,7 +301,7 @@ class RoutingStoreBase(ABC):
|
||||
self.fd_config = fd_config
|
||||
|
||||
@abstractmethod
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
|
||||
async def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -283,6 +322,11 @@ class RoutingStoreBase(ABC):
|
||||
"""Clear the routing indices store"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def clear_prefix_batch(self, roullout_id_prefixes: List[str]):
|
||||
"""Clear the routing indices"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RoutingStoreLocal(RoutingStoreBase):
|
||||
"""Routing Store using local memory"""
|
||||
@@ -292,12 +336,17 @@ class RoutingStoreLocal(RoutingStoreBase):
|
||||
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
|
||||
self.clear_store()
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
async def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
routing_key = f"{rollout_id}_{layer_idx}"
|
||||
|
||||
# async put
|
||||
time_before_put = time.perf_counter()
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
paddle.save(routing_indices, file_path)
|
||||
logger.info(f"[R3] The routing key {routing_key} put cost is {time.perf_counter()-time_before_put}s")
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -334,6 +383,10 @@ class RoutingStoreLocal(RoutingStoreBase):
|
||||
file_path = os.path.join(self.local_store_dir, file_name)
|
||||
shutil.rmtree(file_path)
|
||||
|
||||
async def clear_prefix_batch(self, roullout_id_prefixes: List[str]):
|
||||
# async delete
|
||||
logger.info(f"[R3] clear_prefix_batch {roullout_id_prefixes}")
|
||||
|
||||
|
||||
class RoutingStoreRDMA(RoutingStoreBase):
|
||||
"""Routing Store using RDMA"""
|
||||
@@ -351,16 +404,19 @@ class RoutingStoreRDMA(RoutingStoreBase):
|
||||
self.p2p_client = P2PClient(p2pConfig)
|
||||
self.clear_store()
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
async def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
|
||||
# async put
|
||||
time_before_put = time.perf_counter()
|
||||
routing_indices_pin = routing_indices.pin_memory()
|
||||
routing_indices_pin = routing_indices.cpu()
|
||||
routing_indices_np = routing_indices_pin.numpy()
|
||||
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
|
||||
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
|
||||
copy_time = time.perf_counter()
|
||||
await self.p2p_client.put(rdma_rollout_key, routing_indices_np)
|
||||
logger.info(
|
||||
f"[R3] The routing key {rdma_rollout_key} copy cost is {copy_time-time_before_put}s, put cost is {time.perf_counter()-time_before_put}s"
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -383,6 +439,11 @@ class RoutingStoreRDMA(RoutingStoreBase):
|
||||
# sync delete
|
||||
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
|
||||
|
||||
async def clear_prefix_batch(self, roullout_id_prefixes: List[str]):
|
||||
# async delete
|
||||
await self.p2p_client.delete_prefix_batch(roullout_id_prefixes)
|
||||
logger.info(f"[R3] clear_prefix_batch {roullout_id_prefixes}")
|
||||
|
||||
def clear_store(self):
|
||||
"""Clear the routing indices store"""
|
||||
# sync clear routing store
|
||||
|
||||
Reference in New Issue
Block a user