[BugFix] Fix skip_x_record_stream incompatibility across deep_ep versions (#7542)

* fix skip_x_record_stream

* fix

* optim
This commit is contained in:
Yuanle Liu
2026-04-21 21:31:00 +08:00
committed by GitHub
parent c618a39562
commit 534c43c888
+13 -3
View File
@@ -14,6 +14,7 @@
# limitations under the License.
"""
import inspect
import traceback
from abc import abstractmethod
from types import ModuleType
@@ -602,6 +603,8 @@ class EPPrefillRunner(EPRunner):
use_internode_ll_two_stage=use_internode_ll_two_stage,
)
self.num_worst_tokens = prefill_num_worst_tokens
self._dispatch_parameters: Optional[set] = None
self._combine_parameters: Optional[set] = None
logger.info(f"prefill_num_worst_tokens {prefill_num_worst_tokens}")
def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False):
@@ -656,8 +659,12 @@ class EPPrefillRunner(EPRunner):
}
if envs.FD_USE_PFCC_DEEP_EP:
dispatch_args["num_worst_tokens"] = self.num_worst_tokens
dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0
if self._dispatch_parameters is None:
self._dispatch_parameters = set(inspect.signature(buffer.dispatch).parameters)
if "num_worst_tokens" in self._dispatch_parameters:
dispatch_args["num_worst_tokens"] = self.num_worst_tokens
if "skip_x_record_stream" in self._dispatch_parameters:
dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0
return buffer.dispatch(**dispatch_args)
@@ -683,7 +690,10 @@ class EPPrefillRunner(EPRunner):
}
if envs.FD_USE_PFCC_DEEP_EP:
combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0
if self._combine_parameters is None:
self._combine_parameters = set(inspect.signature(buffer.combine).parameters)
if "skip_x_record_stream" in self._combine_parameters:
combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0
fused_moe_out, _, event = buffer.combine(**combine_args)
return fused_moe_out, event