[Cherry-Pick][RL] Add clear_graph_opt_backend for glm4_mtp (#7378) (#7379)

* add clear_grpah func

* fix spell
This commit is contained in:
GoldPancake
2026-04-15 19:45:09 +08:00
committed by GitHub
parent 61bfe6e5b3
commit 26674bbbb6
14 changed files with 32 additions and 28 deletions
@@ -92,7 +92,7 @@ class GraphOptWrapper:
def __call__(self, **kwargs):
return self.graph_opt_backend(**kwargs)
def clear_grpah_opt_backend(self, fd_config):
def clear_graph_opt_backend(self, fd_config):
""" """
# TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights
assert (
@@ -1306,9 +1306,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
)
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class DeepSeekV3PretrainedModel(PretrainedModel):
@@ -701,9 +701,9 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
self.ernie.clear_graph_opt_backend(fd_config=self.fd_config)
@ModelRegistry.register_model_class(
@@ -829,9 +829,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
self.ernie.clear_graph_opt_backend(fd_config=self.fd_config)
class Ernie4_5_VLPretrainedModel(PretrainedModel):
+2 -2
View File
@@ -563,9 +563,9 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Glm4MoePretrainedModel(PretrainedModel):
@@ -369,3 +369,7 @@ class Glm4MTPForCausalLM(ModelForCasualLM):
)
return hidden_states
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
+2 -2
View File
@@ -417,9 +417,9 @@ class Qwen2ForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config)
self.qwen2.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen2PretrainedModel(PretrainedModel):
+2 -2
View File
@@ -341,9 +341,9 @@ class Qwen3ForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen3PretrainedModel(PretrainedModel):
@@ -382,9 +382,9 @@ class Qwen3VLForConditionalGeneration(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen3VLPretrainedModel(PretrainedModel):
+2 -2
View File
@@ -453,9 +453,9 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
return hidden_states
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
self.model.clear_graph_opt_backend(fd_config=self.fd_config)
class Qwen3MoePretrainedModel(PretrainedModel):
+3 -3
View File
@@ -2692,13 +2692,13 @@ class GPUModelRunner(ModelRunnerBase):
"""Dynamic model loader use to clear parameters use for RL"""
# Clear CUDAGraph
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
self.model.clear_graph_opt_backend()
# Clear parameters and Send single
self.dynamic_weight_manager.clear_parameters(
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
)
if self.spec_method == SpecMethod.MTP:
self.proposer.model.clear_grpah_opt_backend()
self.proposer.model.clear_graph_opt_backend()
self.proposer.clear_mtp_cache()
self.clear_cache()
paddle.device.cuda.empty_cache()
@@ -2752,7 +2752,7 @@ class GPUModelRunner(ModelRunnerBase):
logger.info("GPU model runner's weight is already sleeping, no need to sleep again!")
return
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
self.model.clear_graph_opt_backend()
if self.fd_config.parallel_config.enable_expert_parallel:
self.dynamic_weight_manager.clear_deepep_buffer()
self.dynamic_weight_manager.clear_model_weight()
+1 -1
View File
@@ -2511,7 +2511,7 @@ class MetaxModelRunner(ModelRunnerBase):
"""Dynamic model loader use to clear parameters use for RL"""
# Clear CUDAGraph
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
self.model.clear_graph_opt_backend()
# Clear parameters and Send single
self.dynamic_weight_manager.clear_parameters(
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
@@ -91,10 +91,10 @@ class TestModel1(paddle.nn.Layer):
return sublayer2_output
def clear_grpah_opt_backend(self):
def clear_graph_opt_backend(self):
""" """
self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config)
self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config)
self.sublayer1.clear_graph_opt_backend(fd_config=self.fd_config)
self.sublayer2.clear_graph_opt_backend(fd_config=self.fd_config)
class TestCUDAGrpahRecapture(unittest.TestCase):
@@ -152,7 +152,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
# Destroy
print_gpu_memory_use("before destroy", 0)
self.test_model1.clear_grpah_opt_backend()
self.test_model1.clear_graph_opt_backend()
print_gpu_memory_use("after destroy", 0)
def recapture_and_replay(self, input_tensor1, forward_meta1):
@@ -168,7 +168,7 @@ class TestCUDAGrpahRecapture(unittest.TestCase):
# Destroy
print_gpu_memory_use("before destroy", 0)
self.test_model1.clear_grpah_opt_backend()
self.test_model1.clear_graph_opt_backend()
print_gpu_memory_use("after destroy", 0)
+2 -2
View File
@@ -487,7 +487,7 @@ class TestSleepWakeupBehavior(unittest.TestCase):
runner.local_rank = 0
runner.device_id = 1
runner.num_gpu_blocks = 8
runner.model = Mock(clear_grpah_opt_backend=Mock())
runner.model = Mock(clear_graph_opt_backend=Mock())
runner.clear_cache = Mock()
runner.initialize_kv_cache = Mock()
runner.capture_model = Mock()
@@ -523,7 +523,7 @@ class TestSleepWakeupBehavior(unittest.TestCase):
runner.sleep("weight,kv_cache")
runner.model.clear_grpah_opt_backend.assert_called_once()
runner.model.clear_graph_opt_backend.assert_called_once()
runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once()
runner.dynamic_weight_manager.clear_model_weight.assert_called_once()
runner.dynamic_weight_manager.clear_communication_group.assert_called_once()