mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support fd return decode response (#4407)
* [Feature] support fd return decode response * Resolving conflicts * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from opentelemetry import trace
|
||||
from fastdeploy.engine.request import Request, RequestOutput, RequestType
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
EngineWorkerQueue,
|
||||
@@ -149,6 +150,16 @@ class EngineService:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.split_mode_get_tasks()
|
||||
|
||||
def create_data_processor(self):
|
||||
self.input_processor = InputPreprocessor(
|
||||
self.cfg.model_config,
|
||||
self.cfg.structured_outputs_config.reasoning_parser,
|
||||
self.cfg.limit_mm_per_prompt,
|
||||
self.cfg.mm_processor_kwargs,
|
||||
self.cfg.tool_parser,
|
||||
)
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
|
||||
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
||||
current_suffix = int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
@@ -831,9 +842,23 @@ class EngineService:
|
||||
f"traceback={traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _decode_token(self, token_ids, req_id, is_end):
|
||||
delta_text = ""
|
||||
if envs.FD_ENABLE_RETURN_TEXT:
|
||||
delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
|
||||
if delta_text != "":
|
||||
prefix_offset = self.data_processor.decode_status[req_id][0]
|
||||
read_offset = self.data_processor.decode_status[req_id][1]
|
||||
token_ids = cum_tokens[prefix_offset:read_offset]
|
||||
else:
|
||||
token_ids = []
|
||||
if is_end:
|
||||
del self.data_processor.decode_status[req_id]
|
||||
return delta_text, token_ids
|
||||
|
||||
def _zmq_send_generated_tokens(self):
|
||||
"""
|
||||
Receive output for zmq
|
||||
Recieve output for zmq
|
||||
"""
|
||||
while self.running:
|
||||
try:
|
||||
@@ -842,10 +867,31 @@ class EngineService:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
self.send_response_server.send_response(request_id, contents)
|
||||
|
||||
new_contents = []
|
||||
for content in contents:
|
||||
decode_type = content.outputs.decode_type
|
||||
delta_text = ""
|
||||
if decode_type == 0:
|
||||
delta_text, token_ids = self._decode_token(
|
||||
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
|
||||
)
|
||||
else:
|
||||
token_ids = content.outputs.token_ids
|
||||
if len(token_ids):
|
||||
content.outputs.token_ids = token_ids
|
||||
content.outputs.text = delta_text
|
||||
new_contents.append(content)
|
||||
elif content.finished:
|
||||
new_contents.append(content)
|
||||
else:
|
||||
llm_logger.warning(
|
||||
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
|
||||
)
|
||||
if len(new_contents):
|
||||
llm_logger.info(f"Send response for request id: {request_id}")
|
||||
self.send_response_server.send_response(request_id, new_contents)
|
||||
except Exception as e:
|
||||
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
def split_mode_get_tasks(self):
|
||||
"""
|
||||
|
||||
+10
-18
@@ -38,7 +38,6 @@ from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.common_engine import EngineService
|
||||
from fastdeploy.engine.expert_service import start_data_parallel_service
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
@@ -87,13 +86,6 @@ class LLMEngine:
|
||||
self.running = True
|
||||
self.is_started = False
|
||||
|
||||
self.input_processor = InputPreprocessor(
|
||||
cfg.model_config,
|
||||
cfg.structured_outputs_config.reasoning_parser,
|
||||
cfg.limit_mm_per_prompt,
|
||||
cfg.mm_processor_kwargs,
|
||||
cfg.tool_parser,
|
||||
)
|
||||
self.engine = EngineService(cfg)
|
||||
|
||||
if self.cfg.cache_config.num_gpu_blocks_override is None:
|
||||
@@ -117,12 +109,12 @@ class LLMEngine:
|
||||
self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
|
||||
self._init_worker_signals()
|
||||
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
self.engine.data_processor = self.data_processor
|
||||
# Launch components: scheduler, cache_manager, expert_service et.al.
|
||||
self.launch_components()
|
||||
|
||||
self.engine.start()
|
||||
self.engine.create_data_processor()
|
||||
self.data_processor = self.engine.data_processor
|
||||
|
||||
# If block numer is specified and model is deployed in mixed mode, start cache manager first
|
||||
if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
@@ -246,7 +238,7 @@ class LLMEngine:
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs") or {}
|
||||
chat_template_kwargs["chat_template"] = kwargs.get("chat_template")
|
||||
kwargs["chat_template_kwargs"] = chat_template_kwargs
|
||||
request = self.data_processor.process_request(request, self.cfg.model_config.max_model_len, **kwargs)
|
||||
request = self.engine.data_processor.process_request(request, self.cfg.model_config.max_model_len, **kwargs)
|
||||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||||
request.need_prefill_tokens = request.prompt_token_ids_len
|
||||
input_ids_len = request.prompt_token_ids_len
|
||||
@@ -482,9 +474,9 @@ class LLMEngine:
|
||||
py_script = os.path.join(current_dir_path, worker_path)
|
||||
|
||||
ori_vocab_size = (
|
||||
len(self.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.data_processor.tokenizer, "sp_model")
|
||||
else len(self.data_processor.tokenizer.vocab)
|
||||
len(self.engine.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.engine.data_processor.tokenizer, "sp_model")
|
||||
else len(self.engine.data_processor.tokenizer.vocab)
|
||||
)
|
||||
|
||||
think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
|
||||
@@ -511,8 +503,8 @@ class LLMEngine:
|
||||
f" --total_block_num {self.cfg.cache_config.total_block_num}"
|
||||
f" --block_size {self.cfg.cache_config.block_size}"
|
||||
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
|
||||
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
|
||||
f" --pad_token_id {self.data_processor.pad_token_id}"
|
||||
f" --eos_tokens_lens {self.engine.data_processor.eos_token_id_len}"
|
||||
f" --pad_token_id {self.engine.data_processor.pad_token_id}"
|
||||
f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}"
|
||||
f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}"
|
||||
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
|
||||
@@ -611,7 +603,7 @@ class LLMEngine:
|
||||
for result in self._get_generated_tokens(req_id):
|
||||
is_end = result.finished
|
||||
if stream and not is_end:
|
||||
processed = self.data_processor.process_response(result)
|
||||
processed = self.engine.data_processor.process_response(result)
|
||||
if processed is None:
|
||||
continue
|
||||
output = processed.to_dict()
|
||||
@@ -619,7 +611,7 @@ class LLMEngine:
|
||||
|
||||
# Exit loop if termination condition is met
|
||||
if is_end:
|
||||
processed = self.data_processor.process_response(result)
|
||||
processed = self.engine.data_processor.process_response(result)
|
||||
output = processed.to_dict()
|
||||
llm_logger.debug(f"Generate result: {output}")
|
||||
if not stream:
|
||||
|
||||
@@ -90,6 +90,8 @@ class ExpertService:
|
||||
|
||||
start_time = time.time()
|
||||
self.engine.start()
|
||||
if envs.FD_ENABLE_RETURN_TEXT:
|
||||
self.engine.create_data_processor()
|
||||
if self.cfg.scheduler_config.name == "dp":
|
||||
self.cfg.init_cache_info()
|
||||
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
|
||||
|
||||
@@ -118,6 +118,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
|
||||
# Whether to clear cpu cache when clearing model weights.
|
||||
"FD_ENABLE_SWAP_SPACE_CLEARING": lambda: int(os.getenv("FD_ENABLE_SWAP_SPACE_CLEARING", "0")),
|
||||
# enable return text, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ENABLE_RETURN_TEXT": lambda: bool(int(os.getenv("FD_ENABLE_RETURN_TEXT", "0"))),
|
||||
# Used to truncate the string inserted during thinking when reasoning in a model. (</think> for ernie4_5_vl, \n</think>\n\n for ernie_x1)
|
||||
"FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", "</think>"),
|
||||
# Timeout for cache_transfer_manager process exit
|
||||
|
||||
Reference in New Issue
Block a user