mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] Fix skip_x_record_stream incompatibility across deep_ep versions (#7542)
* fix skip_x_record_stream * fix * optim
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import traceback
|
||||
from abc import abstractmethod
|
||||
from types import ModuleType
|
||||
@@ -602,6 +603,8 @@ class EPPrefillRunner(EPRunner):
|
||||
use_internode_ll_two_stage=use_internode_ll_two_stage,
|
||||
)
|
||||
self.num_worst_tokens = prefill_num_worst_tokens
|
||||
self._dispatch_parameters: Optional[set] = None
|
||||
self._combine_parameters: Optional[set] = None
|
||||
logger.info(f"prefill_num_worst_tokens {prefill_num_worst_tokens}")
|
||||
|
||||
def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False):
|
||||
@@ -656,8 +659,12 @@ class EPPrefillRunner(EPRunner):
|
||||
}
|
||||
|
||||
if envs.FD_USE_PFCC_DEEP_EP:
|
||||
dispatch_args["num_worst_tokens"] = self.num_worst_tokens
|
||||
dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0
|
||||
if self._dispatch_parameters is None:
|
||||
self._dispatch_parameters = set(inspect.signature(buffer.dispatch).parameters)
|
||||
if "num_worst_tokens" in self._dispatch_parameters:
|
||||
dispatch_args["num_worst_tokens"] = self.num_worst_tokens
|
||||
if "skip_x_record_stream" in self._dispatch_parameters:
|
||||
dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0
|
||||
|
||||
return buffer.dispatch(**dispatch_args)
|
||||
|
||||
@@ -683,7 +690,10 @@ class EPPrefillRunner(EPRunner):
|
||||
}
|
||||
|
||||
if envs.FD_USE_PFCC_DEEP_EP:
|
||||
combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0
|
||||
if self._combine_parameters is None:
|
||||
self._combine_parameters = set(inspect.signature(buffer.combine).parameters)
|
||||
if "skip_x_record_stream" in self._combine_parameters:
|
||||
combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0
|
||||
|
||||
fused_moe_out, _, event = buffer.combine(**combine_args)
|
||||
return fused_moe_out, event
|
||||
|
||||
Reference in New Issue
Block a user