[Optimization] xgrammar async compile, multi thread, speed up (#4835)

* xgrammar async compile, multi thread, speed up

* fix test_sampler.py & pre-commit err

* add redis version check && fix request.llm_engine_recv_req_timestamp

* xgrammar prefill & decode & v0

* fix test_gpu_prompt_logprobs.py

* add test_guided_decoding.py

* Update fastdeploy/scheduler/splitwise_scheduler.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix torch xgrammar unittest env

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Daci
2025-11-14 18:05:26 +08:00
committed by GitHub
parent b925533051
commit 5fc12eddfe
11 changed files with 810 additions and 373 deletions
+34 -49
View File
@@ -17,6 +17,7 @@
import os
import queue
import time
from concurrent.futures import Future
from threading import Thread
from typing import List, Optional, cast
@@ -287,7 +288,7 @@ class GPUModelRunner(ModelRunnerBase):
else:
self.proposer = None
def _init_logits_processor(self, request):
def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]:
"""
init logits processor for guided decoding
"""
@@ -307,7 +308,7 @@ class GPUModelRunner(ModelRunnerBase):
return (
self.guided_backend.get_logits_processor(
schemata_key=schemata_key,
enable_thinking=True,
enable_thinking=False, # TODO cfg
),
schemata_key,
)
@@ -696,6 +697,7 @@ class GPUModelRunner(ModelRunnerBase):
length = len(request.prompt_token_ids)
assert length > 0, "The prompt requested must not be empty."
logits_info = None
prefill_tokens = []
if (
request.guided_json is not None
@@ -704,7 +706,6 @@ class GPUModelRunner(ModelRunnerBase):
or request.guided_grammar is not None
):
logits_info, schemata_key = self._init_logits_processor(request)
request.logits_processor, request.logits_cached = logits_info
request.schemata_key = schemata_key
# Is Decode Node
@@ -874,7 +875,7 @@ class GPUModelRunner(ModelRunnerBase):
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens)
self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)
self.share_inputs["not_need_stop"][0] = True
@@ -2006,34 +2007,36 @@ class GPUModelRunner(ModelRunnerBase):
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int):
"""
Get the index of the request that needs to be skipped during execution.
Args:
model_forward_batch: A list of requests to be executed by this runner.
Returns:
A list of indices corresponding to the requests that need to be skipped.
Get indices for guided decoding.
When Prefill is done, async compiled logits_processor must be joined.
"""
if (
not self.cache_config.enable_chunked_prefill
or self.guided_backend is None
or model_forward_batch is None
or envs.ENABLE_V1_KVCACHE_SCHEDULER
):
if self.guided_backend is None:
return []
skip_idx_list = []
for task in model_forward_batch:
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
prefill_done_idxs = []
for idx in range(0, num_running_requests):
if self.share_inputs["step_idx"][idx] == 0:
prefill_done_idxs.append(idx)
for task in self.restore_chunked_prefill_request.values():
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
continue
skip_idx_list.append(task.idx)
if self.cache_config.enable_chunked_prefill:
if model_forward_batch is not None:
for task in model_forward_batch:
# new Request with ChunkPrefill, unfinished, store
if task.chunk_idx < len(task.prefill_chunk_info):
if task.request_id not in self.restore_chunked_prefill_request:
self.restore_chunked_prefill_request[task.request_id] = task
return skip_idx_list
for id, task in list(self.restore_chunked_prefill_request.items()):
# unfinished, remove
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
prefill_done_idxs.remove(task.idx)
# finished, add
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
prefill_done_idxs.append(task.idx)
return prefill_done_idxs
def execute_model(
self,
@@ -2050,9 +2053,10 @@ class GPUModelRunner(ModelRunnerBase):
num_running_requests: batch_size
"""
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
self.sampler.pre_process(p_done_idxs)
# 1.1 Update state of logits processor
for proc in self.sampling_metadata.logits_processors:
@@ -2157,7 +2161,7 @@ class GPUModelRunner(ModelRunnerBase):
sampler_output = self.sampler(
logits,
self.sampling_metadata,
skip_idx_list,
p_done_idxs,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
@@ -2244,7 +2248,7 @@ class GPUModelRunner(ModelRunnerBase):
line_break_id=self.model_config.line_break_id,
)
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
self.sampler.post_process(sampler_output.sampled_token_ids)
# 6. Speculative decode
if self.speculative_decoding:
@@ -2268,7 +2272,6 @@ class GPUModelRunner(ModelRunnerBase):
)
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
elif self.speculative_decoding:
speculate_schedule_cache(
self.share_inputs["draft_tokens"],
@@ -2325,24 +2328,6 @@ class GPUModelRunner(ModelRunnerBase):
return pooler_output
def _add_cache(self, model_forward_batch) -> None:
"""
Add cache for guided decoding.
"""
if self.guided_backend is None or model_forward_batch is None:
return
for request in model_forward_batch:
logits_cached = request.get("logits_cached", None)
if logits_cached is None or logits_cached:
continue
request.logits_cached = True
if isinstance(request.logits_processor, LogitsProcessorBase):
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
else:
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
def _execute_empty_input(self) -> None:
"""
In certain scenarios, such as during EP,