add dsv3 mixed deploy as EP16 TP8 (#6525)

This commit is contained in:
周周周
2026-02-27 14:08:25 +08:00
committed by GitHub
parent 16de778343
commit 1503443871
2 changed files with 48 additions and 5 deletions
+32
View File
@@ -37,3 +37,35 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4
```
**示例2** H800上16卡部署 blockwise_fp8 模型16K上下文的服务
```shell
MODEL_PATH=/models/DeepSeek-V3.2-Exp-BF16
export FD_DISABLE_CHUNKED_PREFILL=1
export FD_ATTENTION_BACKEND="MLA_ATTN"
export FLAGS_flash_attn_version=3
# 暂时只支持 tp_size为8ep_size 为 16的 配置
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FD_ENABLE_MULTI_API_SERVER=1
python -m fastdeploy.entrypoints.openai.multi_api_server \
--ports "9811" \
--num-servers 1 \
--args --model "$model_path" \
--ips "10.95.247.24,10.95.244.147" \
--no-enable-prefix-caching \
--quantization block_wise_fp8 \
--disable-sequence-parallel-moe \
--tensor-parallel-size 8 \
--num-gpu-blocks-override 1024 \
--data-parallel-size 2 \
--max-model-len 16384 \
--enable-expert-parallel \
--max-num-seqs 20 \
--graph-optimization-config '{"use_cudagraph":true}' \
```
@@ -121,6 +121,10 @@ class DeepSeekV3MoE(nn.Layer):
super().__init__()
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.attn_tp_size = fd_config.parallel_config.tensor_parallel_size
if self.ep_size > 1:
self.tp_size = 1
self.norm_topk_prob = fd_config.model_config.norm_topk_prob
weight_key_map = {
@@ -190,6 +194,10 @@ class DeepSeekV3MoE(nn.Layer):
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
""" """
shared_experts_out = self.shared_experts(hidden_states)
if self.attn_tp_size > 1 and self.ep_size > 1:
shared_experts_out = tensor_model_parallel_all_reduce(shared_experts_out)
moe_out = self.experts(hidden_states, self.gate, forward_meta)
moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts.
@@ -508,13 +516,16 @@ class DeepSeekV3DecoderLayer(nn.Layer):
mask_encoder_batch: paddle.Tensor,
):
""" """
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
if hidden_states.shape[0] > 0:
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
else:
residual = hidden_states
hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual