[Feature] Support EP prefill with num_worst_tokens (#6574)

* support num worst tokens

* support num worst tokens

* fix build error

* support num worst tokens: fix errors

* support num worst tokens: fix feild

* support num worst tokens: delete requiements

* replace permute and depermute op by pure cuda

* replace permute and depermute op by pure cuda

* fix ci

* fix op

* fix nan

* fix code style

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
RichardWooSJTU
2026-03-11 17:09:07 +08:00
committed by GitHub
parent 0466c7e8a8
commit 9f0778f991
21 changed files with 1775 additions and 166 deletions
+11
View File
@@ -546,6 +546,11 @@ class EngineArgs:
Flag to enable entropy output. Default is False (disabled).
"""
ep_prefill_use_worst_num_tokens: bool = False
"""
Flag to enable prefill_use_worst_num_tokens. Default is False (disabled).
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -1060,6 +1065,12 @@ class EngineArgs:
default=EngineArgs.shutdown_comm_group_if_worker_idle,
help="Shutdown communication group when worker is idle.",
)
parallel_group.add_argument(
"--ep-prefill-use-worst-num-tokens",
action="store_true",
default=EngineArgs.ep_prefill_use_worst_num_tokens,
help="Enable prefill use worst num tokens for EP.",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")