[RL][Cherry-Pick] Fix the out-of-bounds issue caused by int32 in the R3 kernel (#7155)

* [RL]Perf: Optimize batch delete prefix and fused put in R3 (#6604)

* Optimizate delete batch and fused put

* refine code

* refine code

* refine code

* Support suspend r3

* [RL] Fix R3 Empty bug with TP=1 (#6777)

* Fix int32 overflow

* refine code

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
RAM
2026-04-21 16:51:09 +08:00
committed by GitHub
parent 3c7ca62dc3
commit d8cdda86cb
4 changed files with 74 additions and 40 deletions
+6
View File
@@ -1764,6 +1764,9 @@ class RouterConfig:
else: else:
self.metrics_port = self.api_server_port self.metrics_port = self.api_server_port
def __str__(self):
return json.dumps({key: value for key, value in self.__dict__.items()})
class CommitConfig: class CommitConfig:
""" """
@@ -1877,6 +1880,9 @@ class RoutingReplayConfig:
""" """
return json.dumps({key: value for key, value in self.__dict__.items()}) return json.dumps({key: value for key, value in self.__dict__.items()})
def __str__(self):
return self.to_json_string()
class FDConfig: class FDConfig:
""" """
+2
View File
@@ -253,6 +253,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
), ),
# Suspend rollouting routing replay
"FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))),
# train-infer consistency, used in RL # train-infer consistency, used in RL
# Whether to align RoPE and moe gate precision with training # Whether to align RoPE and moe gate precision with training
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
@@ -54,6 +54,7 @@ def _save_routing_kernel(
TOP_K, TOP_K,
NUM_HIDDEN_LAYERS, NUM_HIDDEN_LAYERS,
MAX_MODEL_LEN, MAX_MODEL_LEN,
MAX_NUM_SEQS,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
): ):
@@ -63,45 +64,37 @@ def _save_routing_kernel(
token_mask = token_offsets < TOKEN_NUM token_mask = token_offsets < TOKEN_NUM
k_offsets = tl.arange(0, BLOCK_SIZE_K) k_offsets = tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < TOP_K k_mask = k_offsets < TOP_K
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
load_mask = token_mask[:, None] & k_mask[None, :] load_mask = token_mask[:, None] & k_mask[None, :]
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask) topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1)
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask) batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1)
pad_mask = token_mask & (batch_ids != -1)
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3] batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS)
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] pad_mask = token_mask & (batch_ids != -1) & batch_mask
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1] start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0)
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
token_relative_index = token_offsets - start_offsets token_relative_index = token_offsets - start_offsets
# [BLOCK_SIZE_M] len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0)
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
token_seq_pos = len_decoder + token_relative_index token_seq_pos = len_decoder + token_relative_index
STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
STRIDE_BUF_LAYER = TOP_K STRIDE_BUF_LAYER = TOP_K
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
output_ptrs = ( output_ptrs = (
ROUTING_REPLAY_TABLE_PTR ROUTING_REPLAY_TABLE_PTR
+ batch_ids[:, None] * STRIDE_BUF_SEQ + tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN + tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN
+ LAYER_IDX * STRIDE_BUF_LAYER + tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER
+ k_offsets[None, :] + k_offsets[None, :]
) )
pos_mask = token_seq_pos < MAX_MODEL_LEN pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN)
pos_mask = pos_mask & pad_mask pos_mask = pos_mask & pad_mask
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
pos_mask = pos_mask[:, None] & k_mask[None, :] pos_mask = pos_mask[:, None] & k_mask[None, :]
final_mask = load_mask & pos_mask final_mask = load_mask & pos_mask
@@ -120,10 +113,10 @@ def save_routing_to_buffer(
ep_size: int, ep_size: int,
tp_group: dist.communication.group.Group, tp_group: dist.communication.group.Group,
): ):
if tp_size > 1 and ep_size > 1:
token_num_per_rank = topk_ids.shape[0] token_num_per_rank = topk_ids.shape[0]
if token_num_per_rank == 0: if token_num_per_rank == 0:
return return
if tp_size > 1 and ep_size > 1:
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
@@ -150,6 +143,7 @@ def save_routing_to_buffer(
TOP_K=top_k, TOP_K=top_k,
NUM_HIDDEN_LAYERS=num_hidden_layers, NUM_HIDDEN_LAYERS=num_hidden_layers,
MAX_MODEL_LEN=max_model_len, MAX_MODEL_LEN=max_model_len,
MAX_NUM_SEQS=max_num_seqs,
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_K=BLOCK_SIZE_K,
) )
@@ -166,6 +160,7 @@ class RoutingReplayManager:
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
self.only_last_turn = fd_config.routing_replay_config.only_last_turn self.only_last_turn = fd_config.routing_replay_config.only_last_turn
self.use_fused_put = fd_config.routing_replay_config.use_fused_put self.use_fused_put = fd_config.routing_replay_config.use_fused_put
logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}")
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM": if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
self.moe_top_k = fd_config.model_config.num_experts_per_tok self.moe_top_k = fd_config.model_config.num_experts_per_tok
else: else:
@@ -186,6 +181,17 @@ class RoutingReplayManager:
) )
self._store_wrapper.start_store_warpper() self._store_wrapper.start_store_warpper()
# Suspend Routing Replay
self.suspend_routing_replay = False
self.update_suspend_routing_replay()
def update_suspend_routing_replay(self):
"""Allow RL to use R3 in different training rounds"""
# TODO(gongshaotian): Delete this func
suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0")
self.suspend_routing_replay = bool(int(suspend_routing_replay))
logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}")
def _init_routing_cache(self, dtype: str, total_block_num: int): def _init_routing_cache(self, dtype: str, total_block_num: int):
"""Initialize the device buffer and host buffer.""" """Initialize the device buffer and host buffer."""
@@ -341,6 +347,11 @@ class RoutingReplayManager:
seq_lens_decoder, seq_lens_decoder,
): ):
if self.tp_rank == 0: if self.tp_rank == 0:
# TODO(gongshaotian): Delete the suspend func
if self.suspend_routing_replay:
logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store")
return
before_put_request_time = time.perf_counter() before_put_request_time = time.perf_counter()
# Collect the routing of finished request # Collect the routing of finished request
@@ -351,16 +362,19 @@ class RoutingReplayManager:
if self.use_fused_put: if self.use_fused_put:
self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id) self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id)
# Only store the routing of last turn
if self.only_last_turn:
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id)
else: else:
for layer_id in range(self.num_moe_layers): for layer_id in range(self.num_moe_layers):
layer_buffer = batch_buffer[layer_id] layer_buffer = batch_buffer[layer_id]
self._store_wrapper.submit_put_task( self._store_wrapper.submit_put_task(
routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id
) )
# Only store the routing of last turn # Only store the routing of last turn
if self.only_last_turn: if self.only_last_turn:
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id)
logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}") logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}")
@@ -481,7 +495,6 @@ class StoreWrapper(object):
if qsize > self.queue_max_size * 0.8: if qsize > self.queue_max_size * 0.8:
logger.warning( logger.warning(
f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. "
f"Dropped tasks so far: {self._dropped_tasks}. "
"Consider increasing max_workers or queue_max_size." "Consider increasing max_workers or queue_max_size."
) )
logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}")
@@ -523,22 +536,26 @@ class StoreWrapper(object):
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s")
def submit_clear_prefix_batch_task(self, rollout_id) -> None: def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None:
"""Submit clear prefix batch task""" """Submit clear prefix batch task"""
if not self._sotre_process_running: if not self._sotre_process_running:
raise RuntimeError("Store not started.") raise RuntimeError("Store not started.")
prefix_batch = self.get_needed_clear_ids(rollout_id) prefix_batch_id = self.get_needed_clear_ids(rollout_id)
if prefix_batch_id is None:
if prefix_batch is None:
return return
start_time = time.perf_counter() start_time = time.perf_counter()
task: StoreTask = {"task_type": "clear_prefix_batch", "key": prefix_batch, "data": None} if layer_idx is not None:
rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}"
else:
rdma_rollout_key = prefix_batch_id
task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None}
try: try:
self._task_queue.put_nowait(task) self._task_queue.put_nowait(task)
except Exception: except Exception:
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
logger.info( logger.info(
f"[R3] Submit clear prefix batch task for key: {prefix_batch}, cost time: {time.perf_counter()-start_time} s" f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s"
) )
def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]: def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]:
@@ -615,7 +632,7 @@ class StoreProcess(Process):
self._task_queue.task_done() self._task_queue.task_done()
raise RuntimeError(f"Error during processing task. {e}") raise RuntimeError(f"Error during processing task. {e}")
logger.info(f"[Consumer Process {Process.current_process().pid}] Shutdown.") logger.info("RoutingReplay Consumer Process Shutdown.")
def process_put_task(self, store_task: StoreTask) -> None: def process_put_task(self, store_task: StoreTask) -> None:
try: try:
@@ -838,13 +855,18 @@ class RoutingStoreRDMA(RoutingStoreBase):
async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: async def put(self, routing_key: str, routing_indices: np.ndarray) -> None:
"""Put the routing indices into store""" """Put the routing indices into store"""
time_before_put = time.perf_counter() time_before_put = time.perf_counter()
if len(routing_indices.shape) == 3:
# NOTE(gongshaotian) Fused put with bytes data
routing_bytes = routing_indices.tobytes()
result = await self.p2p_client.put(routing_key, routing_bytes)
else:
result = await self.p2p_client.put(routing_key, routing_indices) result = await self.p2p_client.put(routing_key, routing_indices)
logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s")
return result return result
async def clear_prefix_batch(self, routing_prefix_key: str): async def clear_prefix_batch(self, routing_prefix_key: str):
time_before_clear = time.perf_counter() time_before_clear = time.perf_counter()
result = await self.p2p_client.delete_prefix_batch([routing_prefix_key]) result = await self.p2p_client.delete_batch([routing_prefix_key])
logger.info( logger.info(
f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s"
) )
+5 -1
View File
@@ -2857,9 +2857,13 @@ class GPUModelRunner(ModelRunnerBase):
# Recapture CUDAGraph # Recapture CUDAGraph
if self.use_cudagraph: if self.use_cudagraph:
self.capture_model() self.capture_model()
# Rollout Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
# TODO(gongshaotian): Delete suspend func
self.routing_replay_manager.update_suspend_routing_replay()
# Send single # Send single
self.dynamic_weight_manager.finalize_update(pid) self.dynamic_weight_manager.finalize_update(pid)
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
def update_weights(self, version: str = None, verify_checksum: bool = False): def update_weights(self, version: str = None, verify_checksum: bool = False):