[XPU] ep+tp all2all (#4836)

This commit is contained in:
zhupengyang
2025-11-06 17:26:14 +08:00
committed by GitHub
parent 901d559aa7
commit b54eb7ad81
8 changed files with 201 additions and 17 deletions
+9
View File
@@ -577,6 +577,15 @@ class ParallelConfig:
else:
self.pd_disaggregation_mode = "None"
# ep+tp strategy: "all_reduce" or "all_to_all"
# all_reduce: qkv_linear + attn + out_linear + allreduce
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
self.ep_tp_strategy = envs.FD_EP_TP_STRATEGY
assert self.ep_tp_strategy in [
"all_reduce",
"all_to_all",
], f"FD_EP_TP_STRATEGY: '{self.ep_tp_strategy}' is not supported, only supports 'all_reduce' or 'all_to_all'."
def set_communicate_group(self):
# different tp group id
# prevent different tp_groups using the same group_id
+4
View File
@@ -155,6 +155,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"),
# Enable offline perf test mode for PD disaggregation
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
# ep+tp strategy: "all_reduce" or "all_to_all"
# all_reduce: qkv_linear + attn + out_linear + allreduce
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
"FD_EP_TP_STRATEGY": lambda: os.getenv("FD_EP_TP_STRATEGY", "all_reduce"),
}
+32 -3
View File
@@ -799,6 +799,7 @@ class RowParallelLinear(LinearBase):
reduce_results: bool = True,
skip_quant: bool = False,
weight_dtype="",
layer_id: int = -1,
):
"""
Initialize a linear layer with additional parameters for inference and quantization.
@@ -815,14 +816,25 @@ class RowParallelLinear(LinearBase):
"""
self.fd_config = fd_config
self.skip_quant = False
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
self.split_token = (
self.ep_size > 1
and self.tp_size > 1
and fd_config.parallel_config.ep_tp_strategy == "all_to_all"
and layer_id >= fd_config.model_config.moe_layer_start_index
and layer_id < fd_config.model_config.num_hidden_layers
)
# Split input_size when using TP inference.
self.input_size = divide(input_size, self.nranks)
if self.split_token:
self.input_size = input_size
else:
self.input_size = divide(input_size, self.nranks)
self.output_size = output_size
super().__init__(
@@ -854,13 +866,30 @@ class RowParallelLinear(LinearBase):
self.reduce_results = reduce_results
def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
token_num = x.shape[0]
token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size
if token_num_pad > token_num:
x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
x_new[:token_num, :] = x
x = x_new
out = paddle.zeros_like(x)
paddle.distributed.alltoall(out, x, group=self.tp_group)
out.reshape_([self.tp_size, -1, x.shape[1]])
out = paddle.transpose(out, [1, 0, 2])
out.reshape_([x.shape[0] // self.tp_size, self.hidden_size])
return out
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
if self.split_token:
x = self.all2all_transpose(x)
if self.fd_config.quant_config:
out = self.quant_method.apply(self, x)
else:
out = paddle.matmul(x, self.weight)
if self.reduce_results and self.nranks > 1:
if self.reduce_results and self.nranks > 1 and not self.split_token:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
if not self.fd_config.quant_config and self.add_bias:
out = paddle.add(out, self.bias)
+2 -1
View File
@@ -137,6 +137,7 @@ class FusedMoE(nn.Layer):
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.ep_tp_strategy = self.fd_config.parallel_config.ep_tp_strategy
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
if self.ep_size > 1:
self.tp_size = 1
@@ -612,7 +613,7 @@ class FusedMoE(nn.Layer):
"""
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
if self.ep_size > 1 and tp_size > 1 and token_num >= tp_size:
if self.ep_size > 1 and tp_size > 1 and self.ep_tp_strategy == "all_reduce" and token_num >= tp_size:
out = self.forward_split_allgather(x, gate)
else:
out = self.quant_method.apply(self, x, gate)
@@ -28,6 +28,7 @@ else:
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from .utils import get_tensor
@@ -47,6 +48,7 @@ class RMSNorm(nn.Layer):
quant_scale: float = None,
begin_norm_axis: int = 1,
dtype: str = None,
layer_id: int = -1,
) -> None:
"""
Initializes the RMSNormalization layer.
@@ -97,6 +99,30 @@ class RMSNorm(nn.Layer):
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.begin_norm_axis: int = begin_norm_axis
self.layer_id = layer_id
parallel_config = self.fd_config.parallel_config
self.ep_size = parallel_config.expert_parallel_size
self.tp_size = parallel_config.tensor_parallel_size
self.tp_rank = parallel_config.tensor_parallel_rank
self.tp_group = parallel_config.tp_group
self.ep_tp_strategy = parallel_config.ep_tp_strategy
self.moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
is_input_norm = prefix.endswith(".input_layernorm")
is_last_norm = prefix.endswith(".norm")
self.split_x = (
self.ep_size > 1
and self.tp_size > 1
and self.ep_tp_strategy == "all_to_all"
and self.layer_id == self.moe_layer_start_index
and is_input_norm
)
self.allgather_out = (
self.ep_size > 1
and self.tp_size > 1
and self.ep_tp_strategy == "all_to_all"
and ((self.layer_id > self.moe_layer_start_index and is_input_norm) or is_last_norm)
)
self.init_weight()
def init_weight(self):
@@ -124,7 +150,50 @@ class RMSNorm(nn.Layer):
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
self.weight.set_value(weight_tensor.astype(self._norm_weight_dtype))
def forward(self, x, residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
def split(self, x):
"""
Split the input tensor across tensor parallel dimension.
Args:
x (paddle.Tensor): Input tensor to be split.
Returns:
paddle.Tensor: Splitted tensor.
"""
token_num = x.shape[0]
token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size
# AllGather will hang when the data shapes on multi-ranks are different!
start_offset = self.tp_rank * token_num_per_rank
end_offset = (self.tp_rank + 1) * token_num_per_rank
if start_offset >= token_num:
start_offset = token_num
if end_offset > token_num:
end_offset = token_num
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
return part_x
def allgather(self, out, token_num):
"""
Gather the output tensor from each tensor parallel rank.
Args:
out (paddle.Tensor): Output tensor to be gathered.
Returns:
paddle.Tensor: Gathered tensor.
"""
token_num_per_rank = out.shape[0]
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, out.shape[1]], dtype=out.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
return multi_outs[:token_num, :]
def forward(
self,
x,
residual_input: Optional[paddle.Tensor] = None,
forward_meta: Optional[ForwardMeta] = None,
) -> paddle.Tensor:
"""
Defines the forward computation of the layer.
@@ -165,10 +234,18 @@ class RMSNorm(nn.Layer):
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if residual_input is not None:
return norm_out[0].astype(x_dtype), norm_out[1].astype(residual_input_dtype)
out = norm_out[0].astype(x_dtype)
residual_out = norm_out[1].astype(residual_input_dtype) if residual_input is not None else None
if self.split_x:
residual_out = self.split(residual_out)
if self.allgather_out:
out = self.allgather(out, forward_meta.ids_remove_padding.shape[0])
if residual_input is None:
return out
else:
return norm_out[0].astype(x_dtype)
return out, residual_out
class LayerNorm(nn.Layer):
@@ -15,6 +15,7 @@
"""
import contextlib
import copy
import hashlib
import inspect
import json
@@ -267,8 +268,14 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
filtered_map[k] = weight_list[k]
if fd_config.parallel_config.tensor_parallel_size > 1:
no_tp_action_keys = copy.deepcopy(num_local_ffn_keys)
if fd_config.parallel_config.ep_tp_strategy == "all_to_all":
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
k = f"ernie.layers.{i}.self_attn.o_proj.weight"
if k in weight_list:
no_tp_action_keys.append(k)
tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys}
new_actions = {k: v for k, v in tp_actions.items() if k not in no_tp_action_keys}
state_dict = {}
# Get all safetensor file paths that need to be opened
@@ -235,6 +235,7 @@ class Ernie4_5_Attention(nn.Layer):
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim * fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
)
self.attn = Attention(
fd_config=fd_config,
@@ -303,6 +304,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
layer_id=layer_id,
)
self.post_attention_layernorm = RMSNorm(
@@ -329,16 +331,27 @@ class Ernie4_5_DecoderLayer(nn.Layer):
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(
hidden_states,
forward_meta=forward_meta,
)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(
hidden_states,
residual,
forward_meta=forward_meta,
)
hidden_states = self.self_attn(
hidden_states=hidden_states,
forward_meta=forward_meta,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(
hidden_states,
residual,
forward_meta=forward_meta,
)
hidden_states = self.mlp(hidden_states)
@@ -444,7 +457,7 @@ class Ernie4_5_Model(nn.Layer):
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, forward_meta=forward_meta)
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
out = forward_meta.attn_backend.reverse_transpose(out)
+48 -4
View File
@@ -251,7 +251,7 @@ if [ ${vl_test_exit_code} -ne 0 ]; then
fi
echo "============================开始 EP4TP1 测试!============================"
echo "============================开始 EP8TP1 测试!============================"
sleep 5
rm -rf log/*
rm -f core*
@@ -290,12 +290,12 @@ stop_processes
if [ ${ep_exit_code} -ne 0 ]; then
echo "log/workerlog.0"
cat log/workerlog.0
echo "EP4TP1 相关测试失败,请检查pr代码"
echo "EP8TP1 相关测试失败,请检查pr代码"
exit 1
fi
echo "============================开始 EP4TP4 测试!============================"
echo "============================开始 EP8TP8 allreduce 测试!============================"
sleep 5
rm -rf log/*
rm -f core*
@@ -323,11 +323,55 @@ unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset enable_expert_parallel
unset enable_tensor_parallel
stop_processes
if [ ${ep_exit_code} -ne 0 ]; then
echo "log/workerlog.0"
cat log/workerlog.0
echo "EP4TP4 相关测试失败,请检查pr代码"
echo "EP8TP8 allreduce 相关测试失败,请检查pr代码"
exit 1
fi
echo "============================开始 EP8TP8 all2all 测试!============================"
sleep 5
rm -rf log/*
rm -f core*
ipcrm --all=msg
xpu-smi
export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export BKCL_ENABLE_XDR=1
export BKCL_RDMA_NICS=xgbe1,xgbe2,xgbe3,xgbe4
export BKCL_TRACE_TOPO=1
export BKCL_PCIE_RING=1
export XSHMEM_MODE=1
export XSHMEM_QP_NUM_PER_RANK=32
export BKCL_RDMA_VERBS=1
export enable_expert_parallel=1
export enable_tensor_parallel=1
export EP_TP_SPLIT_MODE=1
python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py
ep_exit_code=$?
unset BKCL_ENABLE_XDR
unset BKCL_RDMA_NICS
unset BKCL_TRACE_TOPO
unset BKCL_PCIE_RING
unset XSHMEM_MODE
unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset enable_expert_parallel
unset enable_tensor_parallel
unset EP_TP_SPLIT_MODE
stop_processes
if [ ${ep_exit_code} -ne 0 ]; then
echo "log/workerlog.0"
cat log/workerlog.0
echo "EP8TP8 all2all 相关测试失败,请检查pr代码"
exit 1
fi