add reconstruct (#6675)

This commit is contained in:
bukejiyu
2026-03-10 11:25:37 +08:00
committed by GitHub
parent ecc5032176
commit 8e322f917e
3 changed files with 97 additions and 7 deletions
+63
View File
@@ -558,3 +558,66 @@ def get_sm_version():
prop = paddle.device.cuda.get_device_properties()
return prop.major * 10 + prop.minor
return 0
@paddle.no_grad()
def _move_param(src, device=None, blocking=True):
"""
Move parameters from the source device to the target device and return the parameters on the target device.
If the target device is not specified, the current device is used.
Args:
src (Tensor): The tensor of parameters to be moved.
device (Optional[Union[str, paddle.Device]], optional): The target device. Can be a string or paddle.Device
object.
Defaults to None, which means using the current device.
blocking (bool, optional): Whether to block until the operation is complete. Defaults to True.
Returns:
Tensor: The tensor of parameters on the target device.
"""
if isinstance(device, str):
device = paddle.device._convert_to_place(device)
dst = src._copy_to(device, blocking)
dst_tensor = dst.value().get_tensor()
src_tensor = src.value().get_tensor()
src_tensor._clear()
src_tensor._share_data_with(dst_tensor)
def _reload_model(model):
"""
Reload the model from CUDAPinnedPlace to GPU.
"""
model.to(paddle.device.get_device())
def _offload_model(model):
"""
Offload the model from GPU to CUDAPinnedPlace.
"""
pin_device = paddle.CUDAPinnedPlace()
for _, src in model.named_parameters():
if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace):
_move_param(src, pin_device)
def reconstruct_memory(model):
"""
reconstruct_memory to avoid memory chunks
"""
if paddle.is_compiled_with_cuda():
_offload_model(model)
paddle.device.cuda.empty_cache()
_reload_model(model)
def need_memory_reconstruction(fd_config):
_need_memory_reconstruction_archs = ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
if fd_config.model_config.architectures[0] in _need_memory_reconstruction_archs:
logger.info(
f"{fd_config.model_config.architectures[0]} Performing model offload and reload to defragment GPU memory."
)
return True
else:
return False