[Feature] support v1 update/clear api for RL (#6761)

* [Feature] support v1 update/clear api for RL

* [fix] fix execute_model and add sleep/wakeup api

* [fix] fix mtp and key_prefix

* [chore] move _update_key_prefix to resume method

* [fix] make the interface safe to call multiple times

* [fix] fix some tiny bugs

* [chore] make small changes against pr review

* [docs] add docs for weight update

* [test] add some tests and update docs

* [style] fix code style check

* [test] fix ci

* [fix] fix stale control responses when control method timed out

* [chore] remove unused code

* [chore] fix code style

* [chore] optimize tags and key_prefix

* [test] fix ci

* [chore] fix code style

* [test] fix ci

* [fix] fix ep control

* [fix] fix ep control for engine cache queue
This commit is contained in:
Yonghua Li
2026-03-25 19:18:46 +08:00
committed by GitHub
parent 48cfb608aa
commit a7f52c300d
26 changed files with 1857 additions and 392 deletions
+85
View File
@@ -54,6 +54,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.platforms import current_platform
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.utils import print_gpu_memory_use
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
if current_platform.is_iluvatar():
@@ -142,6 +143,9 @@ class GPUModelRunner(ModelRunnerBase):
self.cache_kvs_map: dict = {}
self.exist_prefill_flag = False
self.is_kvcache_sleeping = False
self.is_weight_sleeping = False
if self.speculative_decoding:
self._real_output_token_num_host = paddle.empty([1], dtype="int32").pin_memory()
self.output_token_num_event = paddle.device.cuda.Event()
@@ -288,6 +292,10 @@ class GPUModelRunner(ModelRunnerBase):
"""
return self.exist_prefill_flag
@property
def is_sleeping(self):
return self.is_weight_sleeping or self.is_kvcache_sleeping
def exist_decode(self):
"""
check whether decode stage exist
@@ -2673,6 +2681,83 @@ class GPUModelRunner(ModelRunnerBase):
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
def sleep(self, tags):
logger.info(f">>> start offloading memory, tags: {tags}")
start_time = time.perf_counter()
# Clear weights, deepep_buffer, cudagraph, etc.
if "weight" in tags.split(","):
if self.is_weight_sleeping:
logger.info("GPU model runner's weight is already sleeping, no need to sleep again!")
return
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
if self.fd_config.parallel_config.enable_expert_parallel:
self.dynamic_weight_manager.clear_deepep_buffer()
self.dynamic_weight_manager.clear_model_weight()
if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
self.dynamic_weight_manager.clear_communication_group()
self.is_weight_sleeping = True
# Clear KV cache
if "kv_cache" in tags.split(","):
if self.is_kvcache_sleeping:
logger.info("GPU model runner's kv cache is already sleeping, no need to sleep again!")
return
if self.spec_method == SpecMethod.MTP:
self.proposer.clear_mtp_cache()
self.clear_cache()
self.is_kvcache_sleeping = True
paddle.device.cuda.empty_cache()
logger.info(f"<<< finish offloading memory! time cost: {time.perf_counter()-start_time:.3f}s")
print_gpu_memory_use(f"After offloading memory [{tags}]", self.local_rank, self.device_id)
def wakeup(self, tags):
if tags == "weight" and self.use_cudagraph and self.is_kvcache_sleeping:
raise RuntimeError(
"Waking up [weight] alone is not supported when CUDA Graph is enabled, "
"as recapturing the graph requires the KV cache to be rebuilt first. "
"Please wake up [kv_cache] first."
)
logger.info(f">>> start reloading memory, tags: {tags}")
start_time = time.perf_counter()
# Reset share_inputs to restore tensor shapes and values
if self.spec_method == SpecMethod.MTP:
self.proposer.model_inputs.reset_model_inputs()
self.share_inputs.reset_share_inputs()
# Reinitialize KV cache
if "kv_cache" in tags.split(","):
if not self.is_kvcache_sleeping:
logger.info("GPU model runner's kv cache is not sleeping, no need to wakeup!")
return
if self.spec_method == SpecMethod.MTP:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
self.initialize_kv_cache()
self.is_kvcache_sleeping = False
# Reload weights, deepep_buffer, cudagraph, etc.
if "weight" in tags.split(","):
if not self.is_weight_sleeping:
logger.info("GPU model runner's weight is not sleeping, no need to wakeup!")
return
if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
self.dynamic_weight_manager.restart_communication_group()
if self.fd_config.parallel_config.enable_expert_parallel:
self.dynamic_weight_manager.recreate_deepep_buffer()
self.dynamic_weight_manager.reload_model_weights()
if self.use_cudagraph:
self.capture_model()
self.is_weight_sleeping = False
logger.info(f"<<< finish reloading memory! time cost: {time.perf_counter()-start_time:.3f}s")
print_gpu_memory_use(f"After reloading memory [{tags}]", self.local_rank, self.device_id)
def padding_cudagraph_inputs(self) -> None:
"""
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.