[Optimization] add del to decrease peak memory in MoE prefill (#5863)

This commit is contained in:
周周周
2026-01-05 14:01:48 +08:00
committed by GitHub
parent e911ac2ce7
commit dc13344ab8
@@ -148,6 +148,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
"""
gate_out = gate(x.cast("float32"))
hidden_size = x.shape[1]
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
@@ -179,6 +181,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
token_all_num = sum(recv_num_tokens_per_expert_list)
# Note(ZKK):
# below code have many del, so ugly!
# but considering MoE Prefill will reach peak GPU memory,
# so here we manually del a var as soon as it's not used.
# 4. Compute ffn
if token_all_num > 0:
logger.debug(f"token_all_num {token_all_num}")
@@ -206,13 +213,14 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
True, # use_in_ep
token_all_num,
)
assert permute_input.shape[0] == token_all_num
del recv_x
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
permute_scale = permute_scale.transpose([1, 0]).contiguous().transpose([1, 0])
# up_gate_proj
ffn_out = paddle.empty(
(permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
(token_all_num, getattr(layer, self.added_weight_attrs[0]).shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
@@ -221,6 +229,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_out,
m_indices,
)
del permute_input
# swiglu
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
@@ -228,11 +238,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
ffn_out, self.quant_config.weight_block_size[0]
)
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0])
del ffn_out
ffn_out = paddle.empty(
(ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
(token_all_num, getattr(layer, self.added_weight_attrs[1]).shape[1]),
dtype=paddle.bfloat16,
)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
@@ -241,6 +251,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_out,
m_indices,
)
del ffn_in_x
# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
@@ -251,9 +263,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
False, # norm_topk_prob
1.0,
)
del ffn_out
else:
tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16)
tmp_ffn_out = paddle.empty([0, hidden_size], paddle.bfloat16)
# 5. EP combine
event = deep_ep.Buffer.capture()