diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 480816adc8..25d6e4a35d 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1764,6 +1764,9 @@ class RouterConfig: else: self.metrics_port = self.api_server_port + def __str__(self): + return json.dumps({key: value for key, value in self.__dict__.items()}) + class CommitConfig: """ @@ -1877,6 +1880,9 @@ class RoutingReplayConfig: """ return json.dumps({key: value for key, value in self.__dict__.items()}) + def __str__(self): + return self.to_json_string() + class FDConfig: """ diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index d7feef972e..c6ca3d4549 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -253,6 +253,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # Suspend rollouting routing replay + "FD_SUSPEND_ROUTING_REPLAY": lambda: bool(int(os.getenv("FD_SUSPEND_ROUTING_REPLAY", "0"))), # train-infer consistency, used in RL # Whether to align RoPE and moe gate precision with training "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index b27957bf0c..7bf526feb4 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -54,6 +54,7 @@ def _save_routing_kernel( TOP_K, NUM_HIDDEN_LAYERS, MAX_MODEL_LEN, + MAX_NUM_SEQS, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): @@ -63,45 +64,37 @@ def _save_routing_kernel( token_mask = token_offsets < TOKEN_NUM k_offsets = tl.arange(0, BLOCK_SIZE_K) - k_mask = k_offsets < TOP_K topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] - # [BLOCK_SIZE_M, BLOCK_SIZE_K] - load_mask = token_mask[:, None] & k_mask[None, :] - topk_vals = tl.load(topk_ids_ptrs, mask=load_mask) + topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1) - batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask) - pad_mask = token_mask & (batch_ids != -1) - # [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3] - # -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] - # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] - # -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1] - start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask) + batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1) + + batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS) + pad_mask = token_mask & (batch_ids != -1) & batch_mask + + start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0) token_relative_index = token_offsets - start_offsets - # [BLOCK_SIZE_M] - len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask) + len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0) token_seq_pos = len_decoder + token_relative_index - STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K - STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K + STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64) + STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64) STRIDE_BUF_LAYER = TOP_K - # [BLOCK_SIZE_M, BLOCK_SIZE_K] output_ptrs = ( ROUTING_REPLAY_TABLE_PTR - + batch_ids[:, None] * STRIDE_BUF_SEQ - + token_seq_pos[:, None] * STRIDE_BUF_TOKEN - + LAYER_IDX * STRIDE_BUF_LAYER + + tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ + + tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN + + tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER + k_offsets[None, :] ) - pos_mask = token_seq_pos < MAX_MODEL_LEN + pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN) pos_mask = pos_mask & pad_mask - - # [BLOCK_SIZE_M, BLOCK_SIZE_K] pos_mask = pos_mask[:, None] & k_mask[None, :] final_mask = load_mask & pos_mask @@ -120,10 +113,10 @@ def save_routing_to_buffer( ep_size: int, tp_group: dist.communication.group.Group, ): + token_num_per_rank = topk_ids.shape[0] + if token_num_per_rank == 0: + return if tp_size > 1 and ep_size > 1: - token_num_per_rank = topk_ids.shape[0] - if token_num_per_rank == 0: - return topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] @@ -150,6 +143,7 @@ def save_routing_to_buffer( TOP_K=top_k, NUM_HIDDEN_LAYERS=num_hidden_layers, MAX_MODEL_LEN=max_model_len, + MAX_NUM_SEQS=max_num_seqs, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_K=BLOCK_SIZE_K, ) @@ -166,6 +160,7 @@ class RoutingReplayManager: self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index self.only_last_turn = fd_config.routing_replay_config.only_last_turn self.use_fused_put = fd_config.routing_replay_config.use_fused_put + logger.info(f"[R3] Rollout Routing Replay Congfig: {fd_config.routing_replay_config}") if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM": self.moe_top_k = fd_config.model_config.num_experts_per_tok else: @@ -186,6 +181,17 @@ class RoutingReplayManager: ) self._store_wrapper.start_store_warpper() + # Suspend Routing Replay + self.suspend_routing_replay = False + self.update_suspend_routing_replay() + + def update_suspend_routing_replay(self): + """Allow RL to use R3 in different training rounds""" + # TODO(gongshaotian): Delete this func + suspend_routing_replay = os.environ.get("FD_SUSPEND_ROUTING_REPLAY", "0") + self.suspend_routing_replay = bool(int(suspend_routing_replay)) + logger.info(f"[R3] Update FD_SUSPEND_ROUTING_REPLAY: {self.suspend_routing_replay}") + def _init_routing_cache(self, dtype: str, total_block_num: int): """Initialize the device buffer and host buffer.""" @@ -341,6 +347,11 @@ class RoutingReplayManager: seq_lens_decoder, ): if self.tp_rank == 0: + # TODO(gongshaotian): Delete the suspend func + if self.suspend_routing_replay: + logger.info(f"[R3] Suspend Routing Replay is enabled, skip putting request {request_id} to store") + return + before_put_request_time = time.perf_counter() # Collect the routing of finished request @@ -351,16 +362,19 @@ class RoutingReplayManager: if self.use_fused_put: self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id) + # Only store the routing of last turn + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) + else: for layer_id in range(self.num_moe_layers): layer_buffer = batch_buffer[layer_id] self._store_wrapper.submit_put_task( routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id ) - - # Only store the routing of last turn - if self.only_last_turn: - self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) + # Only store the routing of last turn + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id) logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}") @@ -481,7 +495,6 @@ class StoreWrapper(object): if qsize > self.queue_max_size * 0.8: logger.warning( f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " - f"Dropped tasks so far: {self._dropped_tasks}. " "Consider increasing max_workers or queue_max_size." ) logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") @@ -523,22 +536,26 @@ class StoreWrapper(object): raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") - def submit_clear_prefix_batch_task(self, rollout_id) -> None: + def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None: """Submit clear prefix batch task""" if not self._sotre_process_running: raise RuntimeError("Store not started.") - prefix_batch = self.get_needed_clear_ids(rollout_id) - - if prefix_batch is None: + prefix_batch_id = self.get_needed_clear_ids(rollout_id) + if prefix_batch_id is None: return start_time = time.perf_counter() - task: StoreTask = {"task_type": "clear_prefix_batch", "key": prefix_batch, "data": None} + if layer_idx is not None: + rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}" + else: + rdma_rollout_key = prefix_batch_id + + task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None} try: self._task_queue.put_nowait(task) except Exception: raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") logger.info( - f"[R3] Submit clear prefix batch task for key: {prefix_batch}, cost time: {time.perf_counter()-start_time} s" + f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s" ) def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]: @@ -615,7 +632,7 @@ class StoreProcess(Process): self._task_queue.task_done() raise RuntimeError(f"Error during processing task. {e}") - logger.info(f"[Consumer Process {Process.current_process().pid}] Shutdown.") + logger.info("RoutingReplay Consumer Process Shutdown.") def process_put_task(self, store_task: StoreTask) -> None: try: @@ -838,13 +855,18 @@ class RoutingStoreRDMA(RoutingStoreBase): async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: """Put the routing indices into store""" time_before_put = time.perf_counter() - result = await self.p2p_client.put(routing_key, routing_indices) + if len(routing_indices.shape) == 3: + # NOTE(gongshaotian) Fused put with bytes data + routing_bytes = routing_indices.tobytes() + result = await self.p2p_client.put(routing_key, routing_bytes) + else: + result = await self.p2p_client.put(routing_key, routing_indices) logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") return result async def clear_prefix_batch(self, routing_prefix_key: str): time_before_clear = time.perf_counter() - result = await self.p2p_client.delete_prefix_batch([routing_prefix_key]) + result = await self.p2p_client.delete_batch([routing_prefix_key]) logger.info( f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" ) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 473896eec0..95e4e6c75e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2857,9 +2857,13 @@ class GPUModelRunner(ModelRunnerBase): # Recapture CUDAGraph if self.use_cudagraph: self.capture_model() + # Rollout Routing Replay + if self.fd_config.routing_replay_config.enable_routing_replay: + # TODO(gongshaotian): Delete suspend func + self.routing_replay_manager.update_suspend_routing_replay() + # Send single self.dynamic_weight_manager.finalize_update(pid) - self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") def update_weights(self, version: str = None, verify_checksum: bool = False):