mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] add del to decrease peak memory in MoE prefill (#5863)
This commit is contained in:
@@ -148,6 +148,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
|
||||
hidden_size = x.shape[1]
|
||||
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||
|
||||
@@ -179,6 +181,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
# Note(ZKK):
|
||||
# below code have many del, so ugly!
|
||||
# but considering MoE Prefill will reach peak GPU memory,
|
||||
# so here we manually del a var as soon as it's not used.
|
||||
|
||||
# 4. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
@@ -206,13 +213,14 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
True, # use_in_ep
|
||||
token_all_num,
|
||||
)
|
||||
assert permute_input.shape[0] == token_all_num
|
||||
del recv_x
|
||||
|
||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||
permute_scale = permute_scale.transpose([1, 0])
|
||||
permute_scale = permute_scale.transpose([1, 0]).contiguous().transpose([1, 0])
|
||||
|
||||
# up_gate_proj
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
|
||||
(token_all_num, getattr(layer, self.added_weight_attrs[0]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
@@ -221,6 +229,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
del permute_input
|
||||
|
||||
# swiglu
|
||||
ffn_out = paddle.incubate.nn.functional.swiglu(ffn_out, None)
|
||||
|
||||
@@ -228,11 +238,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_in_x, ffn_in_x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
ffn_out, self.quant_config.weight_block_size[0]
|
||||
)
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous()
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]).contiguous().transpose([1, 0])
|
||||
|
||||
del ffn_out
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
|
||||
(token_all_num, getattr(layer, self.added_weight_attrs[1]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
@@ -241,6 +251,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
del ffn_in_x
|
||||
|
||||
# prmt back per rank
|
||||
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
@@ -251,9 +263,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)
|
||||
|
||||
del ffn_out
|
||||
else:
|
||||
tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16)
|
||||
tmp_ffn_out = paddle.empty([0, hidden_size], paddle.bfloat16)
|
||||
|
||||
# 5. EP combine
|
||||
event = deep_ep.Buffer.capture()
|
||||
|
||||
Reference in New Issue
Block a user