mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] ep+tp all2all (#4836)
This commit is contained in:
@@ -577,6 +577,15 @@ class ParallelConfig:
|
||||
else:
|
||||
self.pd_disaggregation_mode = "None"
|
||||
|
||||
# ep+tp strategy: "all_reduce" or "all_to_all"
|
||||
# all_reduce: qkv_linear + attn + out_linear + allreduce
|
||||
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
|
||||
self.ep_tp_strategy = envs.FD_EP_TP_STRATEGY
|
||||
assert self.ep_tp_strategy in [
|
||||
"all_reduce",
|
||||
"all_to_all",
|
||||
], f"FD_EP_TP_STRATEGY: '{self.ep_tp_strategy}' is not supported, only supports 'all_reduce' or 'all_to_all'."
|
||||
|
||||
def set_communicate_group(self):
|
||||
# different tp group id
|
||||
# prevent different tp_groups using the same group_id
|
||||
|
||||
@@ -155,6 +155,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"),
|
||||
# Enable offline perf test mode for PD disaggregation
|
||||
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
|
||||
# ep+tp strategy: "all_reduce" or "all_to_all"
|
||||
# all_reduce: qkv_linear + attn + out_linear + allreduce
|
||||
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
|
||||
"FD_EP_TP_STRATEGY": lambda: os.getenv("FD_EP_TP_STRATEGY", "all_reduce"),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -799,6 +799,7 @@ class RowParallelLinear(LinearBase):
|
||||
reduce_results: bool = True,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype="",
|
||||
layer_id: int = -1,
|
||||
):
|
||||
"""
|
||||
Initialize a linear layer with additional parameters for inference and quantization.
|
||||
@@ -815,14 +816,25 @@ class RowParallelLinear(LinearBase):
|
||||
"""
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
|
||||
self.split_token = (
|
||||
self.ep_size > 1
|
||||
and self.tp_size > 1
|
||||
and fd_config.parallel_config.ep_tp_strategy == "all_to_all"
|
||||
and layer_id >= fd_config.model_config.moe_layer_start_index
|
||||
and layer_id < fd_config.model_config.num_hidden_layers
|
||||
)
|
||||
|
||||
# Split input_size when using TP inference.
|
||||
self.input_size = divide(input_size, self.nranks)
|
||||
if self.split_token:
|
||||
self.input_size = input_size
|
||||
else:
|
||||
self.input_size = divide(input_size, self.nranks)
|
||||
self.output_size = output_size
|
||||
|
||||
super().__init__(
|
||||
@@ -854,13 +866,30 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
token_num = x.shape[0]
|
||||
token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size
|
||||
if token_num_pad > token_num:
|
||||
x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
|
||||
x_new[:token_num, :] = x
|
||||
x = x_new
|
||||
out = paddle.zeros_like(x)
|
||||
paddle.distributed.alltoall(out, x, group=self.tp_group)
|
||||
out.reshape_([self.tp_size, -1, x.shape[1]])
|
||||
out = paddle.transpose(out, [1, 0, 2])
|
||||
out.reshape_([x.shape[0] // self.tp_size, self.hidden_size])
|
||||
return out
|
||||
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
if self.split_token:
|
||||
x = self.all2all_transpose(x)
|
||||
|
||||
if self.fd_config.quant_config:
|
||||
out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
out = paddle.matmul(x, self.weight)
|
||||
|
||||
if self.reduce_results and self.nranks > 1:
|
||||
if self.reduce_results and self.nranks > 1 and not self.split_token:
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
if not self.fd_config.quant_config and self.add_bias:
|
||||
out = paddle.add(out, self.bias)
|
||||
|
||||
@@ -137,6 +137,7 @@ class FusedMoE(nn.Layer):
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.ep_tp_strategy = self.fd_config.parallel_config.ep_tp_strategy
|
||||
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
|
||||
if self.ep_size > 1:
|
||||
self.tp_size = 1
|
||||
@@ -612,7 +613,7 @@ class FusedMoE(nn.Layer):
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
tp_size = self.fd_config.parallel_config.tensor_parallel_size
|
||||
if self.ep_size > 1 and tp_size > 1 and token_num >= tp_size:
|
||||
if self.ep_size > 1 and tp_size > 1 and self.ep_tp_strategy == "all_reduce" and token_num >= tp_size:
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
else:
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
|
||||
@@ -28,6 +28,7 @@ else:
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -47,6 +48,7 @@ class RMSNorm(nn.Layer):
|
||||
quant_scale: float = None,
|
||||
begin_norm_axis: int = 1,
|
||||
dtype: str = None,
|
||||
layer_id: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the RMSNormalization layer.
|
||||
@@ -97,6 +99,30 @@ class RMSNorm(nn.Layer):
|
||||
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
|
||||
self.begin_norm_axis: int = begin_norm_axis
|
||||
|
||||
self.layer_id = layer_id
|
||||
parallel_config = self.fd_config.parallel_config
|
||||
self.ep_size = parallel_config.expert_parallel_size
|
||||
self.tp_size = parallel_config.tensor_parallel_size
|
||||
self.tp_rank = parallel_config.tensor_parallel_rank
|
||||
self.tp_group = parallel_config.tp_group
|
||||
self.ep_tp_strategy = parallel_config.ep_tp_strategy
|
||||
self.moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
|
||||
is_input_norm = prefix.endswith(".input_layernorm")
|
||||
is_last_norm = prefix.endswith(".norm")
|
||||
self.split_x = (
|
||||
self.ep_size > 1
|
||||
and self.tp_size > 1
|
||||
and self.ep_tp_strategy == "all_to_all"
|
||||
and self.layer_id == self.moe_layer_start_index
|
||||
and is_input_norm
|
||||
)
|
||||
self.allgather_out = (
|
||||
self.ep_size > 1
|
||||
and self.tp_size > 1
|
||||
and self.ep_tp_strategy == "all_to_all"
|
||||
and ((self.layer_id > self.moe_layer_start_index and is_input_norm) or is_last_norm)
|
||||
)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
@@ -124,7 +150,50 @@ class RMSNorm(nn.Layer):
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key))
|
||||
self.weight.set_value(weight_tensor.astype(self._norm_weight_dtype))
|
||||
|
||||
def forward(self, x, residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor:
|
||||
def split(self, x):
|
||||
"""
|
||||
Split the input tensor across tensor parallel dimension.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): Input tensor to be split.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: Splitted tensor.
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size
|
||||
# AllGather will hang when the data shapes on multi-ranks are different!
|
||||
start_offset = self.tp_rank * token_num_per_rank
|
||||
end_offset = (self.tp_rank + 1) * token_num_per_rank
|
||||
if start_offset >= token_num:
|
||||
start_offset = token_num
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
|
||||
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
|
||||
return part_x
|
||||
|
||||
def allgather(self, out, token_num):
|
||||
"""
|
||||
Gather the output tensor from each tensor parallel rank.
|
||||
|
||||
Args:
|
||||
out (paddle.Tensor): Output tensor to be gathered.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: Gathered tensor.
|
||||
"""
|
||||
token_num_per_rank = out.shape[0]
|
||||
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, out.shape[1]], dtype=out.dtype)
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
return multi_outs[:token_num, :]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
residual_input: Optional[paddle.Tensor] = None,
|
||||
forward_meta: Optional[ForwardMeta] = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
|
||||
@@ -165,10 +234,18 @@ class RMSNorm(nn.Layer):
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0].astype(x_dtype), norm_out[1].astype(residual_input_dtype)
|
||||
out = norm_out[0].astype(x_dtype)
|
||||
residual_out = norm_out[1].astype(residual_input_dtype) if residual_input is not None else None
|
||||
|
||||
if self.split_x:
|
||||
residual_out = self.split(residual_out)
|
||||
if self.allgather_out:
|
||||
out = self.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
if residual_input is None:
|
||||
return out
|
||||
else:
|
||||
return norm_out[0].astype(x_dtype)
|
||||
return out, residual_out
|
||||
|
||||
|
||||
class LayerNorm(nn.Layer):
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
@@ -267,8 +268,14 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
|
||||
filtered_map[k] = weight_list[k]
|
||||
|
||||
if fd_config.parallel_config.tensor_parallel_size > 1:
|
||||
no_tp_action_keys = copy.deepcopy(num_local_ffn_keys)
|
||||
if fd_config.parallel_config.ep_tp_strategy == "all_to_all":
|
||||
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
|
||||
k = f"ernie.layers.{i}.self_attn.o_proj.weight"
|
||||
if k in weight_list:
|
||||
no_tp_action_keys.append(k)
|
||||
tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
|
||||
new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys}
|
||||
new_actions = {k: v for k, v in tp_actions.items() if k not in no_tp_action_keys}
|
||||
|
||||
state_dict = {}
|
||||
# Get all safetensor file paths that need to be opened
|
||||
|
||||
@@ -235,6 +235,7 @@ class Ernie4_5_Attention(nn.Layer):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=fd_config.model_config.head_dim * fd_config.model_config.num_attention_heads,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.attn = Attention(
|
||||
fd_config=fd_config,
|
||||
@@ -303,6 +304,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -329,16 +331,27 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.input_layernorm(
|
||||
hidden_states,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states,
|
||||
residual,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states,
|
||||
residual,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
@@ -444,7 +457,7 @@ class Ernie4_5_Model(nn.Layer):
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, forward_meta=forward_meta)
|
||||
|
||||
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
|
||||
out = forward_meta.attn_backend.reverse_transpose(out)
|
||||
|
||||
+48
-4
@@ -251,7 +251,7 @@ if [ ${vl_test_exit_code} -ne 0 ]; then
|
||||
fi
|
||||
|
||||
|
||||
echo "============================开始 EP4TP1 测试!============================"
|
||||
echo "============================开始 EP8TP1 测试!============================"
|
||||
sleep 5
|
||||
rm -rf log/*
|
||||
rm -f core*
|
||||
@@ -290,12 +290,12 @@ stop_processes
|
||||
if [ ${ep_exit_code} -ne 0 ]; then
|
||||
echo "log/workerlog.0"
|
||||
cat log/workerlog.0
|
||||
echo "EP4TP1 相关测试失败,请检查pr代码"
|
||||
echo "EP8TP1 相关测试失败,请检查pr代码"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "============================开始 EP4TP4 测试!============================"
|
||||
echo "============================开始 EP8TP8 allreduce 测试!============================"
|
||||
sleep 5
|
||||
rm -rf log/*
|
||||
rm -f core*
|
||||
@@ -323,11 +323,55 @@ unset BKCL_PCIE_RING
|
||||
unset XSHMEM_MODE
|
||||
unset XSHMEM_QP_NUM_PER_RANK
|
||||
unset BKCL_RDMA_VERBS
|
||||
unset enable_expert_parallel
|
||||
unset enable_tensor_parallel
|
||||
stop_processes
|
||||
|
||||
if [ ${ep_exit_code} -ne 0 ]; then
|
||||
echo "log/workerlog.0"
|
||||
cat log/workerlog.0
|
||||
echo "EP4TP4 相关测试失败,请检查pr代码"
|
||||
echo "EP8TP8 allreduce 相关测试失败,请检查pr代码"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "============================开始 EP8TP8 all2all 测试!============================"
|
||||
sleep 5
|
||||
rm -rf log/*
|
||||
rm -f core*
|
||||
ipcrm --all=msg
|
||||
xpu-smi
|
||||
export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
export BKCL_ENABLE_XDR=1
|
||||
export BKCL_RDMA_NICS=xgbe1,xgbe2,xgbe3,xgbe4
|
||||
export BKCL_TRACE_TOPO=1
|
||||
export BKCL_PCIE_RING=1
|
||||
export XSHMEM_MODE=1
|
||||
export XSHMEM_QP_NUM_PER_RANK=32
|
||||
export BKCL_RDMA_VERBS=1
|
||||
|
||||
export enable_expert_parallel=1
|
||||
export enable_tensor_parallel=1
|
||||
export EP_TP_SPLIT_MODE=1
|
||||
|
||||
python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py
|
||||
ep_exit_code=$?
|
||||
|
||||
unset BKCL_ENABLE_XDR
|
||||
unset BKCL_RDMA_NICS
|
||||
unset BKCL_TRACE_TOPO
|
||||
unset BKCL_PCIE_RING
|
||||
unset XSHMEM_MODE
|
||||
unset XSHMEM_QP_NUM_PER_RANK
|
||||
unset BKCL_RDMA_VERBS
|
||||
unset enable_expert_parallel
|
||||
unset enable_tensor_parallel
|
||||
unset EP_TP_SPLIT_MODE
|
||||
stop_processes
|
||||
|
||||
if [ ${ep_exit_code} -ne 0 ]; then
|
||||
echo "log/workerlog.0"
|
||||
cat log/workerlog.0
|
||||
echo "EP8TP8 all2all 相关测试失败,请检查pr代码"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
Reference in New Issue
Block a user