[Feature] support audio tts (#5333)

This commit is contained in:
ming1753
2025-12-03 21:06:48 +08:00
committed by GitHub
parent 83dbc4e5dd
commit 5f8d4aedea
5 changed files with 86 additions and 9 deletions
@@ -14,7 +14,8 @@
# limitations under the License.
"""
from typing import Any, List, Optional
import inspect
from typing import Any, Dict, List, Optional
from fastdeploy.entrypoints.openai.usage_calculator import count_tokens
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
@@ -34,12 +35,14 @@ class ChatResponseProcessor:
data_processor,
enable_mm_output: Optional[bool] = False,
eoi_token_id: Optional[int] = 101032,
eoa_token_id: Optional[int] = 2048,
eos_token_id: Optional[int] = 2,
decoder_base_url: Optional[str] = None,
):
self.data_processor = data_processor
self.enable_mm_output = enable_mm_output
self.eoi_token_id = eoi_token_id
self.eoa_token_id = eoa_token_id
self.eos_token_id = eos_token_id
if decoder_base_url is not None:
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
@@ -47,6 +50,7 @@ class ChatResponseProcessor:
self.decoder_client = None
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
self._end_image_code_request_output: Optional[Any] = None
self._audio_buffer: Dict[Any] = {}
self._multipart_buffer = []
def enable_multimodal_content(self):
@@ -80,16 +84,54 @@ class ChatResponseProcessor:
for request_output in request_outputs:
api_server_logger.debug(f"request_output {request_output}")
if not self.enable_mm_output:
yield self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
outputs = request_output.get("outputs", None)
token_ids = outputs.get("token_ids", None) if outputs is not None else None
req_id = request_output.get("request_id", None)
if outputs is not None and token_ids is not None and req_id is not None:
decode_type = request_output["outputs"].get("decode_type", 0) or 0
if decode_type == 0: # text
tts = req_id in self._audio_buffer
if token_ids[-1] == self.eos_token_id:
all_audio_tokens = self._audio_buffer.pop(req_id, [])
else:
all_audio_tokens = None
if inspect.iscoroutinefunction(self.data_processor.process_response_dict):
response = await self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
audio_tokens=all_audio_tokens,
tts=tts,
)
else:
response = self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
audio_tokens=all_audio_tokens,
tts=tts,
)
yield response
elif decode_type == 2: # audio
if self.eoa_token_id is not None and self.eoa_token_id in token_ids:
continue
if req_id in self._audio_buffer:
self._audio_buffer[req_id].append(token_ids)
else:
self._audio_buffer[req_id] = [token_ids]
else:
yield self.data_processor.process_response_dict(
response_dict=request_output,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
elif stream:
decode_type = request_output["outputs"].get("decode_type", 0)
token_ids = request_output["outputs"]["token_ids"]
if decode_type == 0:
if decode_type == 0: # text
if self.eoi_token_id and self.eoi_token_id in token_ids:
if self._mm_buffer:
all_tokens = self._mm_buffer
@@ -118,7 +160,7 @@ class ChatResponseProcessor:
request_output["outputs"]["multipart"] = [text]
yield request_output
elif decode_type == 1:
elif decode_type == 1: # image
self._mm_buffer.append(token_ids)
self._end_image_code_request_output = request_output
else: