mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL][Cherry-Pick] Fix the out-of-bounds issue caused by int32 in the R3 kernel (#7155)
* [RL]Perf: Optimize batch delete prefix and fused put in R3 (#6604) * Optimizate delete batch and fused put * refine code * refine code * refine code * Support suspend r3 * [RL] Fix R3 Empty bug with TP=1 (#6777) * Fix int32 overflow * refine code --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -1764,6 +1764,9 @@ class RouterConfig:
|
|||||||
else:
|
else:
|
||||||
self.metrics_port = self.api_server_port
|
self.metrics_port = self.api_server_port
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||||
|
|
||||||
|
|
||||||
class CommitConfig:
|
class CommitConfig:
|
||||||
"""
|
"""
|
||||||
@@ -1877,6 +1880,9 @@ class RoutingReplayConfig:
|
|||||||
"""
|
"""
|
||||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.to_json_string()
|
||||||
|
|
||||||
|
|
||||||
class FDConfig:
|
class FDConfig:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -253,6 +253,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
|
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
|
||||||
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
|
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
|
||||||
),
|
),
|
||||||
|
# Suspend rollouting routing replay
|
||||||
|
"FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))),
|
||||||
# train-infer consistency, used in RL
|
# train-infer consistency, used in RL
|
||||||
# Whether to align RoPE and moe gate precision with training
|
# Whether to align RoPE and moe gate precision with training
|
||||||
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
|
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ def _save_routing_kernel(
|
|||||||
TOP_K,
|
TOP_K,
|
||||||
NUM_HIDDEN_LAYERS,
|
NUM_HIDDEN_LAYERS,
|
||||||
MAX_MODEL_LEN,
|
MAX_MODEL_LEN,
|
||||||
|
MAX_NUM_SEQS,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
):
|
):
|
||||||
@@ -63,45 +64,37 @@ def _save_routing_kernel(
|
|||||||
token_mask = token_offsets < TOKEN_NUM
|
token_mask = token_offsets < TOKEN_NUM
|
||||||
|
|
||||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
|
||||||
k_mask = k_offsets < TOP_K
|
k_mask = k_offsets < TOP_K
|
||||||
|
|
||||||
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
|
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
|
||||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
|
||||||
|
|
||||||
load_mask = token_mask[:, None] & k_mask[None, :]
|
load_mask = token_mask[:, None] & k_mask[None, :]
|
||||||
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask)
|
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1)
|
||||||
|
|
||||||
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
|
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1)
|
||||||
pad_mask = token_mask & (batch_ids != -1)
|
|
||||||
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
|
batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS)
|
||||||
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
|
pad_mask = token_mask & (batch_ids != -1) & batch_mask
|
||||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
|
|
||||||
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
|
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0)
|
||||||
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
|
|
||||||
token_relative_index = token_offsets - start_offsets
|
token_relative_index = token_offsets - start_offsets
|
||||||
|
|
||||||
# [BLOCK_SIZE_M]
|
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0)
|
||||||
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
|
|
||||||
token_seq_pos = len_decoder + token_relative_index
|
token_seq_pos = len_decoder + token_relative_index
|
||||||
|
|
||||||
STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K
|
STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
|
||||||
STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K
|
STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
|
||||||
STRIDE_BUF_LAYER = TOP_K
|
STRIDE_BUF_LAYER = TOP_K
|
||||||
|
|
||||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
|
||||||
output_ptrs = (
|
output_ptrs = (
|
||||||
ROUTING_REPLAY_TABLE_PTR
|
ROUTING_REPLAY_TABLE_PTR
|
||||||
+ batch_ids[:, None] * STRIDE_BUF_SEQ
|
+ tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ
|
||||||
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
|
+ tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN
|
||||||
+ LAYER_IDX * STRIDE_BUF_LAYER
|
+ tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER
|
||||||
+ k_offsets[None, :]
|
+ k_offsets[None, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
pos_mask = token_seq_pos < MAX_MODEL_LEN
|
pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN)
|
||||||
pos_mask = pos_mask & pad_mask
|
pos_mask = pos_mask & pad_mask
|
||||||
|
|
||||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
|
||||||
pos_mask = pos_mask[:, None] & k_mask[None, :]
|
pos_mask = pos_mask[:, None] & k_mask[None, :]
|
||||||
|
|
||||||
final_mask = load_mask & pos_mask
|
final_mask = load_mask & pos_mask
|
||||||
@@ -120,10 +113,10 @@ def save_routing_to_buffer(
|
|||||||
ep_size: int,
|
ep_size: int,
|
||||||
tp_group: dist.communication.group.Group,
|
tp_group: dist.communication.group.Group,
|
||||||
):
|
):
|
||||||
if tp_size > 1 and ep_size > 1:
|
|
||||||
token_num_per_rank = topk_ids.shape[0]
|
token_num_per_rank = topk_ids.shape[0]
|
||||||
if token_num_per_rank == 0:
|
if token_num_per_rank == 0:
|
||||||
return
|
return
|
||||||
|
if tp_size > 1 and ep_size > 1:
|
||||||
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
|
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
|
||||||
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
|
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
|
||||||
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
|
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
|
||||||
@@ -150,6 +143,7 @@ def save_routing_to_buffer(
|
|||||||
TOP_K=top_k,
|
TOP_K=top_k,
|
||||||
NUM_HIDDEN_LAYERS=num_hidden_layers,
|
NUM_HIDDEN_LAYERS=num_hidden_layers,
|
||||||
MAX_MODEL_LEN=max_model_len,
|
MAX_MODEL_LEN=max_model_len,
|
||||||
|
MAX_NUM_SEQS=max_num_seqs,
|
||||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
)
|
)
|
||||||
@@ -166,6 +160,7 @@ class RoutingReplayManager:
|
|||||||
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
|
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
|
||||||
self.only_last_turn = fd_config.routing_replay_config.only_last_turn
|
self.only_last_turn = fd_config.routing_replay_config.only_last_turn
|
||||||
self.use_fused_put = fd_config.routing_replay_config.use_fused_put
|
self.use_fused_put = fd_config.routing_replay_config.use_fused_put
|
||||||
|
logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}")
|
||||||
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
|
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||||
self.moe_top_k = fd_config.model_config.num_experts_per_tok
|
self.moe_top_k = fd_config.model_config.num_experts_per_tok
|
||||||
else:
|
else:
|
||||||
@@ -186,6 +181,17 @@ class RoutingReplayManager:
|
|||||||
)
|
)
|
||||||
self._store_wrapper.start_store_warpper()
|
self._store_wrapper.start_store_warpper()
|
||||||
|
|
||||||
|
# Suspend Routing Replay
|
||||||
|
self.suspend_routing_replay = False
|
||||||
|
self.update_suspend_routing_replay()
|
||||||
|
|
||||||
|
def update_suspend_routing_replay(self):
|
||||||
|
"""Allow RL to use R3 in different training rounds"""
|
||||||
|
# TODO(gongshaotian): Delete this func
|
||||||
|
suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0")
|
||||||
|
self.suspend_routing_replay = bool(int(suspend_routing_replay))
|
||||||
|
logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}")
|
||||||
|
|
||||||
def _init_routing_cache(self, dtype: str, total_block_num: int):
|
def _init_routing_cache(self, dtype: str, total_block_num: int):
|
||||||
"""Initialize the device buffer and host buffer."""
|
"""Initialize the device buffer and host buffer."""
|
||||||
|
|
||||||
@@ -341,6 +347,11 @@ class RoutingReplayManager:
|
|||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
):
|
):
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
|
# TODO(gongshaotian): Delete the suspend func
|
||||||
|
if self.suspend_routing_replay:
|
||||||
|
logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store")
|
||||||
|
return
|
||||||
|
|
||||||
before_put_request_time = time.perf_counter()
|
before_put_request_time = time.perf_counter()
|
||||||
|
|
||||||
# Collect the routing of finished request
|
# Collect the routing of finished request
|
||||||
@@ -351,16 +362,19 @@ class RoutingReplayManager:
|
|||||||
|
|
||||||
if self.use_fused_put:
|
if self.use_fused_put:
|
||||||
self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id)
|
self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id)
|
||||||
|
# Only store the routing of last turn
|
||||||
|
if self.only_last_turn:
|
||||||
|
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for layer_id in range(self.num_moe_layers):
|
for layer_id in range(self.num_moe_layers):
|
||||||
layer_buffer = batch_buffer[layer_id]
|
layer_buffer = batch_buffer[layer_id]
|
||||||
self._store_wrapper.submit_put_task(
|
self._store_wrapper.submit_put_task(
|
||||||
routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id
|
routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only store the routing of last turn
|
# Only store the routing of last turn
|
||||||
if self.only_last_turn:
|
if self.only_last_turn:
|
||||||
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id)
|
self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id)
|
||||||
|
|
||||||
logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}")
|
logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}")
|
||||||
|
|
||||||
@@ -481,7 +495,6 @@ class StoreWrapper(object):
|
|||||||
if qsize > self.queue_max_size * 0.8:
|
if qsize > self.queue_max_size * 0.8:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. "
|
f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. "
|
||||||
f"Dropped tasks so far: {self._dropped_tasks}. "
|
|
||||||
"Consider increasing max_workers or queue_max_size."
|
"Consider increasing max_workers or queue_max_size."
|
||||||
)
|
)
|
||||||
logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}")
|
logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}")
|
||||||
@@ -523,22 +536,26 @@ class StoreWrapper(object):
|
|||||||
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
|
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
|
||||||
logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s")
|
logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s")
|
||||||
|
|
||||||
def submit_clear_prefix_batch_task(self, rollout_id) -> None:
|
def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None:
|
||||||
"""Submit clear prefix batch task"""
|
"""Submit clear prefix batch task"""
|
||||||
if not self._sotre_process_running:
|
if not self._sotre_process_running:
|
||||||
raise RuntimeError("Store not started.")
|
raise RuntimeError("Store not started.")
|
||||||
prefix_batch = self.get_needed_clear_ids(rollout_id)
|
prefix_batch_id = self.get_needed_clear_ids(rollout_id)
|
||||||
|
if prefix_batch_id is None:
|
||||||
if prefix_batch is None:
|
|
||||||
return
|
return
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
task: StoreTask = {"task_type": "clear_prefix_batch", "key": prefix_batch, "data": None}
|
if layer_idx is not None:
|
||||||
|
rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}"
|
||||||
|
else:
|
||||||
|
rdma_rollout_key = prefix_batch_id
|
||||||
|
|
||||||
|
task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None}
|
||||||
try:
|
try:
|
||||||
self._task_queue.put_nowait(task)
|
self._task_queue.put_nowait(task)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
|
raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[R3] Submit clear prefix batch task for key: {prefix_batch}, cost time: {time.perf_counter()-start_time} s"
|
f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]:
|
def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]:
|
||||||
@@ -615,7 +632,7 @@ class StoreProcess(Process):
|
|||||||
self._task_queue.task_done()
|
self._task_queue.task_done()
|
||||||
raise RuntimeError(f"Error during processing task. {e}")
|
raise RuntimeError(f"Error during processing task. {e}")
|
||||||
|
|
||||||
logger.info(f"[Consumer Process {Process.current_process().pid}] Shutdown.")
|
logger.info("RoutingReplay Consumer Process Shutdown.")
|
||||||
|
|
||||||
def process_put_task(self, store_task: StoreTask) -> None:
|
def process_put_task(self, store_task: StoreTask) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -838,13 +855,18 @@ class RoutingStoreRDMA(RoutingStoreBase):
|
|||||||
async def put(self, routing_key: str, routing_indices: np.ndarray) -> None:
|
async def put(self, routing_key: str, routing_indices: np.ndarray) -> None:
|
||||||
"""Put the routing indices into store"""
|
"""Put the routing indices into store"""
|
||||||
time_before_put = time.perf_counter()
|
time_before_put = time.perf_counter()
|
||||||
|
if len(routing_indices.shape) == 3:
|
||||||
|
# NOTE(gongshaotian) Fused put with bytes data
|
||||||
|
routing_bytes = routing_indices.tobytes()
|
||||||
|
result = await self.p2p_client.put(routing_key, routing_bytes)
|
||||||
|
else:
|
||||||
result = await self.p2p_client.put(routing_key, routing_indices)
|
result = await self.p2p_client.put(routing_key, routing_indices)
|
||||||
logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s")
|
logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def clear_prefix_batch(self, routing_prefix_key: str):
|
async def clear_prefix_batch(self, routing_prefix_key: str):
|
||||||
time_before_clear = time.perf_counter()
|
time_before_clear = time.perf_counter()
|
||||||
result = await self.p2p_client.delete_prefix_batch([routing_prefix_key])
|
result = await self.p2p_client.delete_batch([routing_prefix_key])
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s"
|
f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2857,9 +2857,13 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# Recapture CUDAGraph
|
# Recapture CUDAGraph
|
||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
self.capture_model()
|
self.capture_model()
|
||||||
|
# Rollout Routing Replay
|
||||||
|
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||||
|
# TODO(gongshaotian): Delete suspend func
|
||||||
|
self.routing_replay_manager.update_suspend_routing_replay()
|
||||||
|
|
||||||
# Send single
|
# Send single
|
||||||
self.dynamic_weight_manager.finalize_update(pid)
|
self.dynamic_weight_manager.finalize_update(pid)
|
||||||
|
|
||||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||||
|
|
||||||
def update_weights(self, version: str = None, verify_checksum: bool = False):
|
def update_weights(self, version: str = None, verify_checksum: bool = False):
|
||||||
|
|||||||
Reference in New Issue
Block a user