mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] xgrammar async compile, multi thread, speed up (#4835)
* xgrammar async compile, multi thread, speed up * fix test_sampler.py & pre-commit err * add redis version check && fix request.llm_engine_recv_req_timestamp * xgrammar prefill & decode & v0 * fix test_gpu_prompt_logprobs.py * add test_guided_decoding.py * Update fastdeploy/scheduler/splitwise_scheduler.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/guided_decoding/xgrammar_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix torch xgrammar unittest env --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
import os
|
||||
import queue
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from threading import Thread
|
||||
from typing import List, Optional, cast
|
||||
|
||||
@@ -287,7 +288,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.proposer = None
|
||||
|
||||
def _init_logits_processor(self, request):
|
||||
def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]:
|
||||
"""
|
||||
init logits processor for guided decoding
|
||||
"""
|
||||
@@ -307,7 +308,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
return (
|
||||
self.guided_backend.get_logits_processor(
|
||||
schemata_key=schemata_key,
|
||||
enable_thinking=True,
|
||||
enable_thinking=False, # TODO cfg
|
||||
),
|
||||
schemata_key,
|
||||
)
|
||||
@@ -696,6 +697,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
length = len(request.prompt_token_ids)
|
||||
assert length > 0, "The prompt requested must not be empty."
|
||||
|
||||
logits_info = None
|
||||
prefill_tokens = []
|
||||
if (
|
||||
request.guided_json is not None
|
||||
@@ -704,7 +706,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
or request.guided_grammar is not None
|
||||
):
|
||||
logits_info, schemata_key = self._init_logits_processor(request)
|
||||
request.logits_processor, request.logits_cached = logits_info
|
||||
request.schemata_key = schemata_key
|
||||
|
||||
# Is Decode Node
|
||||
@@ -874,7 +875,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens)
|
||||
self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)
|
||||
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
@@ -2006,34 +2007,36 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
|
||||
|
||||
def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
|
||||
def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_running_requests: int):
|
||||
"""
|
||||
Get the index of the request that needs to be skipped during execution.
|
||||
Args:
|
||||
model_forward_batch: A list of requests to be executed by this runner.
|
||||
Returns:
|
||||
A list of indices corresponding to the requests that need to be skipped.
|
||||
Get indices for guided decoding.
|
||||
When Prefill is done, async compiled logits_processor must be joined.
|
||||
"""
|
||||
if (
|
||||
not self.cache_config.enable_chunked_prefill
|
||||
or self.guided_backend is None
|
||||
or model_forward_batch is None
|
||||
or envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
):
|
||||
if self.guided_backend is None:
|
||||
return []
|
||||
|
||||
skip_idx_list = []
|
||||
for task in model_forward_batch:
|
||||
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
prefill_done_idxs = []
|
||||
for idx in range(0, num_running_requests):
|
||||
if self.share_inputs["step_idx"][idx] == 0:
|
||||
prefill_done_idxs.append(idx)
|
||||
|
||||
for task in self.restore_chunked_prefill_request.values():
|
||||
if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||
continue
|
||||
skip_idx_list.append(task.idx)
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
if model_forward_batch is not None:
|
||||
for task in model_forward_batch:
|
||||
# new Request with ChunkPrefill, unfinished, store
|
||||
if task.chunk_idx < len(task.prefill_chunk_info):
|
||||
if task.request_id not in self.restore_chunked_prefill_request:
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
|
||||
return skip_idx_list
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
# unfinished, remove
|
||||
if task.chunk_idx < len(task.prefill_chunk_info) and task.idx in prefill_done_idxs:
|
||||
prefill_done_idxs.remove(task.idx)
|
||||
# finished, add
|
||||
if task.chunk_idx == len(task.prefill_chunk_info) and task.idx not in prefill_done_idxs:
|
||||
prefill_done_idxs.append(task.idx)
|
||||
|
||||
return prefill_done_idxs
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -2050,9 +2053,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
num_running_requests: batch_size
|
||||
"""
|
||||
# 1. Prepare inputs of model and sampler.
|
||||
skip_idx_list = self._get_skip_idx(model_forward_batch)
|
||||
p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests)
|
||||
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
self.sampler.pre_process(p_done_idxs)
|
||||
|
||||
# 1.1 Update state of logits processor
|
||||
for proc in self.sampling_metadata.logits_processors:
|
||||
@@ -2157,7 +2161,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
skip_idx_list,
|
||||
p_done_idxs,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(
|
||||
@@ -2244,7 +2248,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
line_break_id=self.model_config.line_break_id,
|
||||
)
|
||||
if self.guided_backend is not None and sampler_output is not None:
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids)
|
||||
|
||||
# 6. Speculative decode
|
||||
if self.speculative_decoding:
|
||||
@@ -2268,7 +2272,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
elif self.speculative_decoding:
|
||||
speculate_schedule_cache(
|
||||
self.share_inputs["draft_tokens"],
|
||||
@@ -2325,24 +2328,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
return pooler_output
|
||||
|
||||
def _add_cache(self, model_forward_batch) -> None:
|
||||
"""
|
||||
Add cache for guided decoding.
|
||||
"""
|
||||
if self.guided_backend is None or model_forward_batch is None:
|
||||
return
|
||||
|
||||
for request in model_forward_batch:
|
||||
logits_cached = request.get("logits_cached", None)
|
||||
if logits_cached is None or logits_cached:
|
||||
continue
|
||||
|
||||
request.logits_cached = True
|
||||
if isinstance(request.logits_processor, LogitsProcessorBase):
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor)
|
||||
else:
|
||||
self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result())
|
||||
|
||||
def _execute_empty_input(self) -> None:
|
||||
"""
|
||||
In certain scenarios, such as during EP,
|
||||
|
||||
Reference in New Issue
Block a user