mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
feat: add post-processing step for pool_output (#4462)
* feat: add post-processing step for pool_output * bugfix * fix: test_serving_embedding * fix test_request_to_batch_dicts * fix: code style
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import base64
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Literal, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -99,11 +100,13 @@ class OpenAIServingEmbedding(ZmqOpenAIServing):
|
||||
|
||||
for idx, prompt in enumerate(request_prompts):
|
||||
request_dict = self._request_to_dict(ctx)
|
||||
request_dict["request_id"] = f"{ctx.request_id}-{idx}"
|
||||
request_dict["request_id"] = f"{ctx.request_id}_{idx}"
|
||||
request_dict["prompt"] = prompt
|
||||
request_dicts.append(request_dict)
|
||||
else:
|
||||
request_dicts = [self._request_to_dict(ctx)]
|
||||
request_dict = self._request_to_dict(ctx)
|
||||
request_dict["request_id"] = f"{ctx.request_id}_0"
|
||||
request_dicts = [request_dict]
|
||||
return request_dicts
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest):
|
||||
@@ -118,9 +121,20 @@ class OpenAIServingEmbedding(ZmqOpenAIServing):
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
generation = self.handle(ctx)
|
||||
async for response in generation:
|
||||
return response
|
||||
idx = 0
|
||||
response: EmbeddingResponse = None
|
||||
generators: AsyncGenerator[EmbeddingResponse, None] = self.handle(ctx)
|
||||
async for r in generators:
|
||||
r.data[0].index = idx
|
||||
idx += 1
|
||||
if response is None:
|
||||
response = r
|
||||
else:
|
||||
response.data.append(r.data[0])
|
||||
response.usage.prompt_tokens += r.usage.prompt_tokens
|
||||
response.usage.total_tokens += r.usage.total_tokens
|
||||
|
||||
return response
|
||||
|
||||
@override
|
||||
def _build_response(self, ctx: ServeContext):
|
||||
|
||||
Reference in New Issue
Block a user