[XPU] change XPU EP interface from xDeepEP to paddle (#5706)

* add ENV VAR to controll low lantency buffer
This commit is contained in:
zccjjj
2026-01-21 18:23:45 +08:00
committed by GitHub
parent 490a6551dc
commit 14a64e9b3b
12 changed files with 76 additions and 93 deletions
@@ -16,9 +16,9 @@
from abc import abstractmethod
import deep_ep
import paddle
from paddle import nn
from paddle.distributed.communication import deep_ep
import fastdeploy
from fastdeploy.config import MoEPhase
@@ -89,7 +89,6 @@ class DeepEPEngineHighThroughput(DeepEPEngineBase):
self.group,
int(1e9),
0,
num_experts=self.num_experts,
low_latency_mode=False,
num_qps_per_rank=1,
)
@@ -128,7 +127,6 @@ class DeepEPEngineLowLatency(DeepEPEngineBase):
self.group,
0,
num_rdma_bytes,
self.num_experts,
low_latency_mode=True,
num_qps_per_rank=self.num_experts // self.ep_size,
)
@@ -139,6 +137,7 @@ class DeepEPEngineLowLatency(DeepEPEngineBase):
topk_idx: paddle.Tensor,
expertwise_scale,
use_fp8: bool = False,
quant_group_size: int = -1,
):
"""
Args:
@@ -156,25 +155,25 @@ class DeepEPEngineLowLatency(DeepEPEngineBase):
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
moe_in_w4a8_scale = None
return_recv_hook = True
(
packed_recv_x,
recv_expert_count,
handle,
event,
dispatch_hook,
valid_token_num,
) = self.deepep_engine.low_latency_dispatch(
hidden_states,
moe_in_w4a8_scale,
topk_idx,
expertwise_scale,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=use_fp8,
async_finish=False,
return_recv_hook=True,
async_finish=not return_recv_hook,
return_recv_hook=return_recv_hook,
num_per_channel=quant_group_size,
)
return packed_recv_x, recv_expert_count, handle, dispatch_hook, valid_token_num
return packed_recv_x, recv_expert_count, handle, event, dispatch_hook
def low_latency_combine(
self,
@@ -187,15 +186,16 @@ class DeepEPEngineLowLatency(DeepEPEngineBase):
Return:
combined_hidden_states: [num_tokens, hidden_size]
"""
combined_hidden_states, combine_hook = self.deepep_engine.low_latency_combine(
return_recv_hook = True
combined_hidden_states, event, combine_hook = self.deepep_engine.low_latency_combine(
hidden_states,
topk_idx,
topk_weights,
handle,
async_finish=False,
return_recv_hook=True,
async_finish=not return_recv_hook,
return_recv_hook=return_recv_hook,
)
return combined_hidden_states, combine_hook
return combined_hidden_states, event, combine_hook
def clean_low_latency_buffer(self):
"""
@@ -348,15 +348,37 @@ class XPUEPPrefillRunner(XPUEPRunner):
x: paddle.Tensor,
topk_idx: paddle.Tensor,
topk_weights: paddle.Tensor,
expert_alignment: int = 1,
*args,
**kwargs,
):
self.num_combined_tokens = x.shape[0]
x_scale = kwargs.get("x_scale", None)
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
event,
) = self.ep_engine.deepep_engine.get_dispatch_layout(
topk_idx,
self.ep_engine.num_experts,
previous_event=kwargs.get("previous_event", None),
allocate_on_comm_stream=False,
async_finish=self.ep_engine.async_finish,
)
x_scale_tensor = kwargs.get("x_scale", None)
dispatch_args = {
"x": (x, x_scale) if x_scale is not None else x,
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"async_finish": self.ep_engine.async_finish,
"topk_idx": topk_idx,
"topk_weights": topk_weights,
"expert_alignment": expert_alignment,
"previous_event": event,
}
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
@@ -365,15 +387,18 @@ class XPUEPPrefillRunner(XPUEPRunner):
tmp_ffn_out: paddle.Tensor,
handle: tuple,
recv_topk_weights: paddle.Tensor,
event=None,
):
combine_args = {
"x": tmp_ffn_out,
"handle": handle,
"async_finish": self.ep_engine.async_finish,
"topk_weights": recv_topk_weights,
"num_combined_tokens": self.num_combined_tokens,
"previous_event": event,
}
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
fused_moe_out, _, event = self.ep_engine.deepep_engine.combine(**combine_args)
return fused_moe_out
return fused_moe_out, event
class XPUEPDecoderRunner(XPUEPRunner):
@@ -419,15 +444,20 @@ class XPUEPDecoderRunner(XPUEPRunner):
**kwargs,
):
expertwise_scale = kwargs.get("expertwise_scale", None)
use_fp8 = expertwise_scale is not None
use_fp8 = kwargs.get("use_fp8", False)
quant_group_size = kwargs.get("quant_group_size", -1)
(
recv_hidden_states,
recv_expert_count,
handle,
event,
dispatch_hook,
valid_token_num,
) = self.ep_engine.low_latency_dispatch(x, topk_idx, expertwise_scale, use_fp8)
) = self.ep_engine.low_latency_dispatch(x, topk_idx, expertwise_scale, use_fp8, quant_group_size)
if dispatch_hook is not None:
dispatch_hook()
# valid_token_num is optional:
# - if valid_token_num is None, it means that we CANNOT accurately know
# the size of the tensor, but the advantage is that it can reduce
@@ -435,15 +465,14 @@ class XPUEPDecoderRunner(XPUEPRunner):
# - if valid_token_num is NOT None, it means that we CAN accurately know
# the size of the tensor, but the disadvantage is that it will interrupt
# the process of kernel launch.
if valid_token_num is None and dispatch_hook is not None:
dispatch_hook()
if valid_token_num is None:
if recv_expert_count is None:
valid_token_num = -1
else:
valid_token_num = paddle.sum(recv_expert_count).item()
if isinstance(recv_hidden_states, tuple):
recv_x = recv_hidden_states[0]
recv_x_scale = recv_hidden_states[1]
recv_x_scale = recv_hidden_states[1].contiguous()
else:
recv_x = recv_hidden_states
recv_x_scale = None
@@ -451,7 +480,7 @@ class XPUEPDecoderRunner(XPUEPRunner):
return recv_x, recv_x_scale, recv_expert_count, handle, valid_token_num
def combine(self, ffn_out, topk_idx, topk_weights, handle):
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
combined_hidden_states, event, combine_hook = self.ep_engine.low_latency_combine(
ffn_out, topk_idx, topk_weights, handle
)
if combine_hook is not None:
@@ -453,18 +453,19 @@ class XPUMoEMethod(MoEMethodBase):
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
# 2. Dynamic compute blockwise quantization scales
if "a_tokenwise_int8" in self.xpu_moe_quant_type:
if "a_tokenwise_int8" in self.xpu_moe_quant_type and x.shape[0] > 0:
x, x_scale = quant2d_per_token(x)
else:
x_scale = None
# 3. EP Dispatch
(
recv_x,
recv_x_scales,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
_,
handle,
event,
) = self.ep_prefill_runner.dispatch(
x,
topk_idx,
@@ -472,9 +473,13 @@ class XPUMoEMethod(MoEMethodBase):
x_scale=x_scale,
)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
recv_x, recv_x_scales = recv_x if isinstance(recv_x, tuple) else (recv_x, None)
# 4. Compute ffn
token_num_per_expert = recv_num_tokens_per_expert_list.numpy().tolist()
token_all_num = sum(token_num_per_expert)
token_all_num = sum(recv_num_tokens_per_expert_list)
if "a_expertwise_int8" in self.xpu_moe_quant_type:
moe_dispatch_scale = getattr(layer, self.added_in_scale_attrs[0])
elif "a_tokenwise_int8" in self.xpu_moe_quant_type:
@@ -492,7 +497,7 @@ class XPUMoEMethod(MoEMethodBase):
recv_topk_idx,
recv_topk_weights,
moe_dispatch_scale,
token_num_per_expert,
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
)
@@ -521,8 +526,10 @@ class XPUMoEMethod(MoEMethodBase):
)
# 5. EP combine
handle = None
return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
return tmp_ffn_out
def apply_ep_decode(
self,