mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[LogProbs]Enable prompt logprobs output and modify data transmission method for the online interface. (#5089)
* add prompt logprobs * Merge prompt_logprobs_tensors and prompt_logprobs * fix param check * trigger ci * fix unitest * fix logprobs bug
This commit is contained in:
@@ -229,8 +229,15 @@ class ModelConfig:
|
||||
self.think_end_id = args.get("think_end_id", -1)
|
||||
self.im_patch_id = args.get("image_patch_id", -1)
|
||||
self.line_break_id = args.get("line_break_id", -1)
|
||||
if self.max_logprobs < -1:
|
||||
|
||||
num_max_logprobs = args.get("max_logprobs", None)
|
||||
if num_max_logprobs is not None and num_max_logprobs < -1:
|
||||
raise ValueError(" The possible values for max_logprobs can't be less than -1 ")
|
||||
if self.ori_vocab_size is not None and num_max_logprobs is not None:
|
||||
if num_max_logprobs > self.ori_vocab_size:
|
||||
raise ValueError(
|
||||
f" The possible values for max_logprobs can't be greater than the vocabulary size {self.ori_vocab_size}"
|
||||
)
|
||||
|
||||
self._post_init()
|
||||
|
||||
|
||||
@@ -31,12 +31,7 @@ from fastdeploy.engine.pooling_params import PoolingParams
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.openai.protocol import ToolCall
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.worker.output import (
|
||||
LogprobsLists,
|
||||
LogprobsTensors,
|
||||
PromptLogprobs,
|
||||
SampleLogprobs,
|
||||
)
|
||||
from fastdeploy.worker.output import LogprobsLists, PromptLogprobs, SampleLogprobs
|
||||
|
||||
|
||||
class RequestStatus(Enum):
|
||||
@@ -519,7 +514,6 @@ class RequestOutput:
|
||||
prompt: Optional[str] = None,
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None,
|
||||
prompt_logprobs_tensors: Optional[LogprobsTensors] = None,
|
||||
output_type: Optional[int] = 3,
|
||||
outputs: CompletionOutput = None,
|
||||
finished: bool = False,
|
||||
@@ -537,7 +531,6 @@ class RequestOutput:
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.prompt_logprobs_tensors = prompt_logprobs_tensors
|
||||
self.output_type = output_type
|
||||
self.outputs = outputs
|
||||
self.finished = finished
|
||||
|
||||
@@ -16,12 +16,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
@@ -207,12 +208,17 @@ class SamplingParams:
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
|
||||
)
|
||||
if self.logprobs is not None and self.logprobs < -1:
|
||||
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
|
||||
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
|
||||
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")
|
||||
|
||||
if not envs.FD_USE_GET_SAVE_OUTPUT_V1: # False (0)
|
||||
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 20):
|
||||
raise ValueError("Invalid value for 'top_logprobs': must be between 0 and 20.")
|
||||
if self.prompt_logprobs is not None:
|
||||
raise ValueError("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled.")
|
||||
else: # True (1)
|
||||
if self.logprobs is not None and self.logprobs < -1:
|
||||
raise ValueError(f"logprobs must be a non-negative value or -1, got {self.logprobs}.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
|
||||
raise ValueError(f"prompt_logprobs a must be non-negative value or -1, got {self.prompt_logprobs}.")
|
||||
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
@@ -56,10 +56,11 @@ class EngineClient:
|
||||
EngineClient is a class that handles the communication between the client and the server.
|
||||
"""
|
||||
|
||||
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1):
|
||||
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20):
|
||||
self.fd_config = fd_config
|
||||
self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size
|
||||
self.enable_mm = self.fd_config.model_config.enable_mm
|
||||
self.max_logprobs = max_logprobs
|
||||
input_processor = InputPreprocessor(
|
||||
self.fd_config.model_config,
|
||||
self.fd_config.structured_outputs_config.reasoning_parser,
|
||||
@@ -70,6 +71,11 @@ class EngineClient:
|
||||
)
|
||||
self.enable_logprob = self.fd_config.model_config.enable_logprob
|
||||
self.data_processor = input_processor.create_processor()
|
||||
self.ori_vocab_size = (
|
||||
len(self.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.data_processor.tokenizer, "sp_model")
|
||||
else len(self.data_processor.tokenizer.vocab)
|
||||
)
|
||||
self.max_model_len = self.fd_config.model_config.max_model_len
|
||||
self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching
|
||||
self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed"
|
||||
@@ -424,6 +430,53 @@ class EngineClient:
|
||||
elif logprobs:
|
||||
raise ParameterError("logprobs", "Invalid type for 'logprobs'")
|
||||
|
||||
max_logprobs = self.max_logprobs
|
||||
if max_logprobs == -1:
|
||||
max_logprobs = self.ori_vocab_size
|
||||
if max_logprobs < -1:
|
||||
err_msg = f"Invalid 'max_logprobs': must be >= -1, got {max_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("max_logprobs", err_msg)
|
||||
if max_logprobs > self.ori_vocab_size:
|
||||
err_msg = f"Invalid 'max_logprobs': must be <= vocab_size {self.ori_vocab_size}, got {max_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("max_logprobs", err_msg)
|
||||
|
||||
prompt_logprobs = data.get("prompt_logprobs", None)
|
||||
|
||||
if prompt_logprobs is not None:
|
||||
if not self.enable_logprob:
|
||||
err_msg = "`enable_logprob` is disabled, please enable it in startup config."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("prompt_logprobs", err_msg)
|
||||
|
||||
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
err_msg = "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("prompt_logprobs", err_msg)
|
||||
|
||||
if self.enable_prefix_caching:
|
||||
err_msg = "prompt_logprobs is not support when prefix caching is enabled."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("prompt_logprobs", err_msg)
|
||||
|
||||
if prompt_logprobs == -1 and self.ori_vocab_size > max_logprobs:
|
||||
err_msg = f"The requested value of ({self.ori_vocab_size}) for prompt_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("prompt_logprobs", err_msg)
|
||||
|
||||
if prompt_logprobs < -1:
|
||||
err_msg = (
|
||||
f"prompt_logprobs must be a non-negative value or -1; the current value is {prompt_logprobs}."
|
||||
)
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("prompt_logprobs", err_msg)
|
||||
|
||||
if prompt_logprobs > max_logprobs:
|
||||
err_msg = f"Number of prompt_logprobs requested ({prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("prompt_logprobs", err_msg)
|
||||
|
||||
# enable_logprob
|
||||
if top_logprobs:
|
||||
if not self.enable_logprob:
|
||||
@@ -437,15 +490,26 @@ class EngineClient:
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("top_logprobs", err_msg)
|
||||
|
||||
if top_logprobs < 0:
|
||||
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("top_logprobs", err_msg)
|
||||
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
if top_logprobs < 0 or top_logprobs > 20:
|
||||
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs", err_msg)
|
||||
else:
|
||||
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
|
||||
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs", err_msg)
|
||||
|
||||
if top_logprobs > 20:
|
||||
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ParameterError("top_logprobs", err_msg)
|
||||
if top_logprobs < -1:
|
||||
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs", err_msg)
|
||||
|
||||
if top_logprobs > max_logprobs:
|
||||
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
api_server_logger.error(err_msg)
|
||||
raise ValueError("top_logprobs", err_msg)
|
||||
|
||||
def check_health(self, time_interval_threashold=30):
|
||||
"""
|
||||
|
||||
@@ -335,23 +335,40 @@ class LLM:
|
||||
current_sampling_params = sampling_params[i]
|
||||
else:
|
||||
current_sampling_params = sampling_params
|
||||
if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None:
|
||||
raise ValueError("prompt_logprobs is not supported with streaming.")
|
||||
|
||||
ori_vocab_size = (
|
||||
len(self.llm_engine.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.llm_engine.data_processor.tokenizer, "sp_model")
|
||||
else len(self.llm_engine.data_processor.tokenizer.vocab)
|
||||
)
|
||||
max_logprobs = self.llm_engine.cfg.model_config.max_logprobs
|
||||
if max_logprobs == -1:
|
||||
max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
max_logprobs = ori_vocab_size
|
||||
if max_logprobs < -1:
|
||||
raise ValueError(f"max_logprobs ({max_logprobs}) can't be less than -1.")
|
||||
if max_logprobs > ori_vocab_size:
|
||||
raise ValueError(f"max_logprobs ({max_logprobs}) exceeds vocabulary size ({ori_vocab_size}).")
|
||||
|
||||
if current_sampling_params.logprobs is not None:
|
||||
num_logprobs = current_sampling_params.logprobs
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
|
||||
)
|
||||
if num_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
)
|
||||
if current_sampling_params.prompt_logprobs is not None:
|
||||
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
|
||||
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
|
||||
if kwargs.get("stream"):
|
||||
raise ValueError("prompt_logprobs is not supported with streaming.")
|
||||
num_prompt_logprobs = current_sampling_params.prompt_logprobs
|
||||
if num_prompt_logprobs == -1:
|
||||
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
if num_prompt_logprobs == -1 and ori_vocab_size > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Number of prompt_logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
|
||||
)
|
||||
if num_prompt_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
@@ -436,7 +453,7 @@ class LLM:
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
result: Optional[PromptLogprobs] = []
|
||||
result: Optional[PromptLogprobs] = [None]
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
@@ -548,11 +565,11 @@ class LLM:
|
||||
result.outputs.logprobs = self._build_sample_logprobs(
|
||||
result.outputs.top_logprobs, topk_logprobs
|
||||
)
|
||||
if result.prompt_logprobs_tensors and num_prompt_logprobs:
|
||||
if result.prompt_logprobs is not None and num_prompt_logprobs is not None:
|
||||
if num_prompt_logprobs == -1:
|
||||
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
result.prompt_logprobs = self._build_prompt_logprobs(
|
||||
result.prompt_logprobs_tensors, num_prompt_logprobs
|
||||
result.prompt_logprobs, num_prompt_logprobs
|
||||
)
|
||||
|
||||
output[pos] = result
|
||||
|
||||
@@ -181,6 +181,7 @@ async def lifespan(app: FastAPI):
|
||||
port=int(os.environ.get("INFERENCE_MSG_QUEUE_ID", "0")),
|
||||
fd_config=fd_config,
|
||||
workers=args.workers,
|
||||
max_logprobs=args.max_logprobs,
|
||||
)
|
||||
await engine_client.connection_manager.initialize()
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
|
||||
@@ -21,9 +21,17 @@ import time
|
||||
import uuid
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
ValidationInfo,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from fastdeploy.engine.pooling_params import PoolingParams
|
||||
from fastdeploy.worker.output import PromptLogprobs
|
||||
|
||||
|
||||
class InvalidParameterException(Exception):
|
||||
@@ -214,10 +222,12 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
Chat completion response choice.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
draft_logprobs: Optional[LogProbs] = None
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
|
||||
|
||||
|
||||
@@ -275,10 +285,12 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
Chat completion response choice for stream response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
draft_logprobs: Optional[LogProbs] = None
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
arrival_time: Optional[float] = None
|
||||
|
||||
@@ -301,6 +313,7 @@ class CompletionResponseChoice(BaseModel):
|
||||
Completion response choice.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
index: int
|
||||
text: str
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
@@ -310,6 +323,7 @@ class CompletionResponseChoice(BaseModel):
|
||||
arrival_time: Optional[float] = None
|
||||
logprobs: Optional[CompletionLogprobs] = None
|
||||
draft_logprobs: Optional[CompletionLogprobs] = None
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
@@ -344,11 +358,13 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
Completion response choice for stream response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
index: int
|
||||
text: str
|
||||
arrival_time: float = None
|
||||
logprobs: Optional[CompletionLogprobs] = None
|
||||
draft_logprobs: Optional[CompletionLogprobs] = None
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
completion_token_ids: Optional[List[int]] = None
|
||||
prompt_tokens: Optional[str] = None
|
||||
@@ -437,6 +453,7 @@ class CompletionRequest(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||
logprobs: Optional[int] = None
|
||||
include_draft_logprobs: Optional[bool] = False
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# For logits and logprobs post processing
|
||||
temp_scaled_logprobs: bool = False
|
||||
top_p_normalized_logprobs: bool = False
|
||||
@@ -569,6 +586,18 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (logprobs := data.get("logprobs")) is not None:
|
||||
if logprobs < -1:
|
||||
raise ValueError("`logprobs` must be a greater than -1.")
|
||||
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if prompt_logprobs < -1:
|
||||
raise ValueError("`prompt_logprobs` must be a greater than -1.")
|
||||
return data
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""
|
||||
@@ -583,6 +612,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
prompt_logprobs: Optional[int] = None
|
||||
include_draft_logprobs: Optional[bool] = False
|
||||
|
||||
# For logits and logprobs post processing
|
||||
@@ -651,6 +681,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
|
||||
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
|
||||
req_dict["prompt_logprobs"] = self.prompt_logprobs
|
||||
req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs
|
||||
req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs
|
||||
|
||||
@@ -751,12 +782,15 @@ class ChatCompletionRequest(BaseModel):
|
||||
def check_logprobs(cls, data):
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
if top_logprobs < -1:
|
||||
raise ValueError("`top_logprobs` must be a greater than -1.")
|
||||
|
||||
if top_logprobs > 0 and not data.get("logprobs"):
|
||||
if not data.get("logprobs"):
|
||||
raise ValueError("when using `top_logprobs`, `logprobs` must be set to true.")
|
||||
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if prompt_logprobs < -1:
|
||||
raise ValueError("`prompt_logprobs` must be a greater than -1.")
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -47,9 +49,17 @@ from fastdeploy.utils import (
|
||||
ErrorType,
|
||||
ParameterError,
|
||||
api_server_logger,
|
||||
clamp_prompt_logprobs,
|
||||
get_host_ip,
|
||||
)
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
from fastdeploy.worker.output import (
|
||||
Logprob,
|
||||
LogprobsLists,
|
||||
LogprobsTensors,
|
||||
PromptLogprobs,
|
||||
)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
class OpenAIServingChat:
|
||||
@@ -287,6 +297,17 @@ class OpenAIServingChat:
|
||||
num_input_image_tokens = res.get("num_input_image_tokens", 0)
|
||||
num_input_video_tokens = res.get("num_input_video_tokens", 0)
|
||||
for i in range(num_choices):
|
||||
prompt_logprobs_res: Optional[PromptLogprobs] = None
|
||||
prompt_logprobs_tensors = res.get("prompt_logprobs", None)
|
||||
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
|
||||
num_prompt_logprobs = (
|
||||
request.prompt_logprobs
|
||||
if request.prompt_logprobs != -1
|
||||
else self.engine_client.ori_vocab_size
|
||||
)
|
||||
prompt_logprobs_res = self._build_prompt_logprobs(
|
||||
prompt_logprobs_tensors, num_prompt_logprobs
|
||||
)
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
@@ -296,6 +317,7 @@ class OpenAIServingChat:
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
),
|
||||
prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res),
|
||||
)
|
||||
if response_processor.enable_multimodal_content():
|
||||
choice.delta.multimodal_content = [
|
||||
@@ -344,12 +366,16 @@ class OpenAIServingChat:
|
||||
logprobs_res: Optional[LogProbs] = None
|
||||
draft_logprobs_res: Optional[LogProbs] = None
|
||||
if request.logprobs and output_top_logprobs is not None:
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
num_top_logprobs = (
|
||||
request.top_logprobs if request.top_logprobs != -1 else self.engine_client.ori_vocab_size
|
||||
)
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, num_top_logprobs
|
||||
)
|
||||
|
||||
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
|
||||
draft_logprobs_res = self._create_chat_logprobs(
|
||||
output_draft_top_logprobs, request.logprobs, request.top_logprobs
|
||||
output_draft_top_logprobs, request.logprobs, num_top_logprobs
|
||||
)
|
||||
|
||||
delta_message = DeltaMessage(
|
||||
@@ -496,6 +522,7 @@ class OpenAIServingChat:
|
||||
enable_mm_output=self.enable_mm_output,
|
||||
decoder_base_url=self.tokenizer_base_url,
|
||||
)
|
||||
prompt_logprobs_res_list = [[] for _ in range(num_choices)]
|
||||
choices = []
|
||||
while num_choices > 0:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
@@ -538,9 +565,12 @@ class OpenAIServingChat:
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
output_draft_top_logprobs = output["draft_top_logprobs"]
|
||||
if output_top_logprobs is not None:
|
||||
num_top_logprobs = (
|
||||
request.top_logprobs if request.top_logprobs != -1 else self.engine_client.ori_vocab_size
|
||||
)
|
||||
# logprobs
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
output_top_logprobs, request.logprobs, num_top_logprobs
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents[idx].extend(logprobs_res.content)
|
||||
@@ -548,11 +578,20 @@ class OpenAIServingChat:
|
||||
# draft_logprobs
|
||||
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
|
||||
draft_logprobs_res = self._create_chat_logprobs(
|
||||
output_draft_top_logprobs, request.logprobs, request.top_logprobs
|
||||
output_draft_top_logprobs, request.logprobs, num_top_logprobs
|
||||
)
|
||||
if draft_logprobs_res and draft_logprobs_res.content is not None:
|
||||
draft_logprob_contents[idx].extend(draft_logprobs_res.content)
|
||||
|
||||
prompt_logprobs_tensors = data.get("prompt_logprobs", None)
|
||||
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
|
||||
num_prompt_logprobs = (
|
||||
request.prompt_logprobs
|
||||
if request.prompt_logprobs != -1
|
||||
else self.engine_client.ori_vocab_size
|
||||
)
|
||||
prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs)
|
||||
if prompt_logprobs_res:
|
||||
prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res))
|
||||
if data["finished"]:
|
||||
num_choices -= 1
|
||||
reasoning_num_tokens[idx] = data["outputs"].get("reasoning_token_num", 0)
|
||||
@@ -573,6 +612,7 @@ class OpenAIServingChat:
|
||||
logprob_contents=logprob_contents,
|
||||
draft_logprob_contents=draft_logprob_contents,
|
||||
response_processor=response_processor,
|
||||
prompt_logprobs_res_list=prompt_logprobs_res_list,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
choices.append(choice)
|
||||
@@ -624,6 +664,7 @@ class OpenAIServingChat:
|
||||
num_image_tokens: list,
|
||||
logprob_contents: list,
|
||||
draft_logprob_contents: list,
|
||||
prompt_logprobs_res_list: list,
|
||||
response_processor: ChatResponseProcessor,
|
||||
max_tokens: int,
|
||||
) -> ChatCompletionResponseChoice:
|
||||
@@ -649,11 +690,14 @@ class OpenAIServingChat:
|
||||
message.content = output["text"]
|
||||
|
||||
logprobs_full_res = None
|
||||
draft_logprobs_full_res = None
|
||||
prompt_logprobs_full_res = None
|
||||
if logprob_contents[idx]:
|
||||
logprobs_full_res = LogProbs(content=logprob_contents[idx])
|
||||
draft_logprobs_full_res = None
|
||||
if draft_logprob_contents[idx]:
|
||||
draft_logprobs_full_res = LogProbs(content=draft_logprob_contents[idx])
|
||||
if prompt_logprobs_res_list[idx]:
|
||||
prompt_logprobs_full_res = prompt_logprobs_res_list[idx]
|
||||
|
||||
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
|
||||
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
|
||||
@@ -675,6 +719,7 @@ class OpenAIServingChat:
|
||||
message=message,
|
||||
logprobs=logprobs_full_res,
|
||||
draft_logprobs=draft_logprobs_full_res,
|
||||
prompt_logprobs=prompt_logprobs_full_res,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
@@ -780,3 +825,86 @@ class OpenAIServingChat:
|
||||
else:
|
||||
enable_thinking = True
|
||||
return enable_thinking
|
||||
|
||||
def _build_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
num_prompt_logprobs: int,
|
||||
):
|
||||
"""Update with prompt logprobs from worker.
|
||||
Args:
|
||||
prompt_logprobs_tensors: tuple containing the prompt logprobs
|
||||
tensors.
|
||||
"""
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = [
|
||||
self.engine_client.data_processor.process_logprob_response(token_id)
|
||||
for token_id in token_ids.flatten().tolist()
|
||||
]
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Pythonize the paddle tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
result: Optional[PromptLogprobs] = [None]
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
result.append(
|
||||
self._make_logprob_dict(
|
||||
prompt_logprobs[pos],
|
||||
token_ids[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
num_prompt_logprobs,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _make_logprob_dict(
|
||||
logprobs: list[float],
|
||||
logprob_token_ids: list[int],
|
||||
decoded_tokens: Iterable[str | None],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> dict[int, Logprob]:
|
||||
"""Make a Logprob dictionary for a position.
|
||||
Args:
|
||||
logprobs: list of log probabilities
|
||||
logprob_token_ids: list of top token ids
|
||||
decoded_tokens: list of decoded top tokens
|
||||
rank: rank of the sampled token
|
||||
num_logprobs: number of logprobs requested
|
||||
by the user (in addition to sampled logprob)
|
||||
Returns:
|
||||
dict[token id, Logprob]
|
||||
"""
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = len(logprobs)
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank,), topk_ranks)
|
||||
|
||||
return {
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
|
||||
}
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -43,9 +45,17 @@ from fastdeploy.utils import (
|
||||
ErrorType,
|
||||
ParameterError,
|
||||
api_server_logger,
|
||||
clamp_prompt_logprobs,
|
||||
get_host_ip,
|
||||
)
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
from fastdeploy.worker.output import (
|
||||
Logprob,
|
||||
LogprobsLists,
|
||||
LogprobsTensors,
|
||||
PromptLogprobs,
|
||||
)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
class OpenAIServingCompletion:
|
||||
@@ -249,6 +259,7 @@ class OpenAIServingCompletion:
|
||||
aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)]
|
||||
aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)]
|
||||
aggregated_token_ids = [[] for _ in range(num_choices)]
|
||||
aggregated_prompt_logprobs_tensors = [None] * num_choices
|
||||
completion_batched_token_ids = [[] for _ in range(num_choices)]
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
@@ -293,6 +304,10 @@ class OpenAIServingCompletion:
|
||||
aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1])
|
||||
aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2])
|
||||
|
||||
output_prompt_logprobs_tensors = data.get("prompt_logprobs") or None
|
||||
if output_prompt_logprobs_tensors is not None:
|
||||
aggregated_prompt_logprobs_tensors[rid] = output_prompt_logprobs_tensors
|
||||
|
||||
aggregated_token_ids[rid].extend(data["outputs"]["token_ids"])
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
@@ -305,6 +320,7 @@ class OpenAIServingCompletion:
|
||||
data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid]
|
||||
data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid]
|
||||
data["outputs"]["token_ids"] = aggregated_token_ids[rid]
|
||||
data["prompt_logprobs_tensors"] = aggregated_prompt_logprobs_tensors[rid]
|
||||
valid_results[rid] = data
|
||||
num_choices -= 1
|
||||
break
|
||||
@@ -426,8 +442,18 @@ class OpenAIServingCompletion:
|
||||
idx = int(res["request_id"].split("_")[-1])
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
|
||||
prompt_logprobs_res: Optional[PromptLogprobs] = None
|
||||
if first_iteration[idx]:
|
||||
prompt_logprobs_tensors = res.get("prompt_logprobs", None)
|
||||
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
|
||||
num_prompt_logprobs = (
|
||||
request.prompt_logprobs
|
||||
if request.prompt_logprobs != -1
|
||||
else self.engine_client.ori_vocab_size
|
||||
)
|
||||
prompt_logprobs_res = self._build_prompt_logprobs(
|
||||
prompt_logprobs_tensors, num_prompt_logprobs
|
||||
)
|
||||
if request.return_token_ids:
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
@@ -440,6 +466,7 @@ class OpenAIServingCompletion:
|
||||
prompt_token_ids=list(
|
||||
prompt_batched_token_ids[idx // (1 if request.n is None else request.n)]
|
||||
),
|
||||
prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res),
|
||||
prompt_tokens=prompt_tokens_list[
|
||||
idx // (1 if request.n is None else request.n)
|
||||
],
|
||||
@@ -468,13 +495,16 @@ class OpenAIServingCompletion:
|
||||
output_draft_top_logprobs = output["draft_top_logprobs"]
|
||||
logprobs_res: Optional[CompletionLogprobs] = None
|
||||
draft_logprobs_res: Optional[CompletionLogprobs] = None
|
||||
if request.logprobs and output_top_logprobs is not None:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
if request.logprobs is not None and output_top_logprobs is not None:
|
||||
num_logprobs = (
|
||||
request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size
|
||||
)
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, num_logprobs, 0)
|
||||
|
||||
# draft logprobs
|
||||
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
|
||||
draft_logprobs_res = self._create_completion_logprobs(
|
||||
output_draft_top_logprobs, request.logprobs, 0
|
||||
output_draft_top_logprobs, num_logprobs, 0
|
||||
)
|
||||
output_tokens[idx] += len(output.get("token_ids", [])) or 0
|
||||
num_cache_tokens[idx] += output.get("num_cache_tokens") or 0
|
||||
@@ -492,6 +522,7 @@ class OpenAIServingCompletion:
|
||||
reasoning_content="",
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res),
|
||||
draft_logprobs=draft_logprobs_res,
|
||||
)
|
||||
if not res["finished"] and "delta_message" in output:
|
||||
@@ -602,15 +633,22 @@ class OpenAIServingCompletion:
|
||||
output_draft_top_logprobs = output.get("draft_top_logprobs") or None
|
||||
|
||||
aggregated_logprobs: Optional[CompletionLogprobs] = None
|
||||
num_logprobs = request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size
|
||||
if output_top_logprobs is not None:
|
||||
aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, num_logprobs, 0)
|
||||
|
||||
aggregated_draft_logprobs: Optional[CompletionLogprobs] = None
|
||||
if output_draft_top_logprobs is not None:
|
||||
aggregated_draft_logprobs = self._create_completion_logprobs(
|
||||
output_draft_top_logprobs, request.logprobs, 0
|
||||
output_draft_top_logprobs, num_logprobs, 0
|
||||
)
|
||||
|
||||
prompt_logprobs_res: Optional[PromptLogprobs] = None
|
||||
prompt_logprobs_tensors = final_res.get("prompt_logprobs_tensors", None)
|
||||
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
|
||||
num_prompt_logprobs = (
|
||||
request.prompt_logprobs if request.prompt_logprobs != -1 else self.engine_client.ori_vocab_size
|
||||
)
|
||||
prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs)
|
||||
if request.echo:
|
||||
prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n))
|
||||
token_ids = [*prompt_token_ids, *output["token_ids"]]
|
||||
@@ -641,6 +679,7 @@ class OpenAIServingCompletion:
|
||||
tool_calls=output.get("tool_call"),
|
||||
logprobs=aggregated_logprobs,
|
||||
draft_logprobs=aggregated_draft_logprobs,
|
||||
prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
@@ -749,13 +788,13 @@ class OpenAIServingCompletion:
|
||||
[tid], clean_up_tokenization_spaces=False
|
||||
)
|
||||
if "\ufffd" in token_str:
|
||||
token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
raw_token = self.engine_client.data_processor.tokenizer.convert_ids_to_tokens(tid)
|
||||
token_bytes = raw_token.encode("utf-8", errors="replace")
|
||||
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
|
||||
if idx == 0:
|
||||
tokens.append(token_str)
|
||||
token_logprobs.append(lp)
|
||||
else:
|
||||
top_logprobs[token_str] = lp
|
||||
top_logprobs[token_str] = lp
|
||||
idx += 1
|
||||
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
@@ -770,3 +809,86 @@ class OpenAIServingCompletion:
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error in _build_logprobs_response: {str(e)}, {str(traceback.format_exc())}")
|
||||
return None
|
||||
|
||||
def _build_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
num_prompt_logprobs: int,
|
||||
):
|
||||
"""Update with prompt logprobs from worker.
|
||||
Args:
|
||||
prompt_logprobs_tensors: tuple containing the prompt logprobs
|
||||
tensors.
|
||||
"""
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = [
|
||||
self.engine_client.data_processor.process_logprob_response(token_id)
|
||||
for token_id in token_ids.flatten().tolist()
|
||||
]
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Pythonize the paddle tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
result: Optional[PromptLogprobs] = [None]
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
result.append(
|
||||
self._make_logprob_dict(
|
||||
prompt_logprobs[pos],
|
||||
token_ids[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
num_prompt_logprobs,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _make_logprob_dict(
|
||||
logprobs: list[float],
|
||||
logprob_token_ids: list[int],
|
||||
decoded_tokens: Iterable[str | None],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> dict[int, Logprob]:
|
||||
"""Make a Logprob dictionary for a position.
|
||||
Args:
|
||||
logprobs: list of log probabilities
|
||||
logprob_token_ids: list of top token ids
|
||||
decoded_tokens: list of decoded top tokens
|
||||
rank: rank of the sampled token
|
||||
num_logprobs: number of logprobs requested
|
||||
by the user (in addition to sampled logprob)
|
||||
Returns:
|
||||
dict[token id, Logprob]
|
||||
"""
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = len(logprobs)
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank,), topk_ranks)
|
||||
|
||||
return {
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ import asyncio
|
||||
import heapq
|
||||
import random
|
||||
import time
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
|
||||
import aiozmq
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
@@ -124,7 +124,7 @@ class DealerConnectionManager:
|
||||
while self.running:
|
||||
try:
|
||||
raw_data = await dealer.read()
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
response = ForkingPickler.loads(raw_data[-1])
|
||||
_zmq_metrics_stats = ZMQMetricsStats()
|
||||
_zmq_metrics_stats.msg_recv_total += 1
|
||||
if "zmq_send_time" in response:
|
||||
|
||||
@@ -153,7 +153,7 @@ class ZmqServerBase(ABC):
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
result = ForkingPickler.dumps([result.to_dict()])
|
||||
return result
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
@@ -278,12 +278,12 @@ class ZmqServerBase(ABC):
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(new_data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in new_data])
|
||||
result = ForkingPickler.dumps([response.to_dict() for response in new_data])
|
||||
with self.response_token_lock:
|
||||
|
||||
_zmq_metrics_stats = ZMQMetricsStats()
|
||||
try:
|
||||
self.socket.send_multipart([self.req_dict[req_id], b"", result])
|
||||
self.socket.send_multipart([self.req_dict[req_id], b"", result], copy=False)
|
||||
_zmq_metrics_stats.msg_bytes_send_total += len(result)
|
||||
except Exception as e:
|
||||
_zmq_metrics_stats.msg_send_failed_total += 1
|
||||
|
||||
@@ -291,7 +291,7 @@ class TokenProcessor:
|
||||
llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}")
|
||||
if getattr(stream_data, "prompt_logprobs", None) is not None:
|
||||
try:
|
||||
result.prompt_logprobs_tensors = stream_data.prompt_logprobs
|
||||
result.prompt_logprobs = stream_data.prompt_logprobs
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}")
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
|
||||
@@ -52,6 +52,7 @@ from typing_extensions import TypeIs, assert_never
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
|
||||
from fastdeploy.logger.logger import FastDeployLogger
|
||||
from fastdeploy.worker.output import PromptLogprobs
|
||||
|
||||
T = TypeVar("T")
|
||||
from typing import Callable, List, Optional
|
||||
@@ -1073,6 +1074,21 @@ def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T
|
||||
return _optional_type
|
||||
|
||||
|
||||
def clamp_prompt_logprobs(
|
||||
prompt_logprobs: PromptLogprobs | None,
|
||||
) -> PromptLogprobs | None:
|
||||
if prompt_logprobs is None:
|
||||
return prompt_logprobs
|
||||
|
||||
for logprob_dict in prompt_logprobs:
|
||||
if logprob_dict is None:
|
||||
continue
|
||||
for logprob_values in logprob_dict.values():
|
||||
if logprob_values.logprob == float("-inf"):
|
||||
logprob_values.logprob = -9999.0
|
||||
return prompt_logprobs
|
||||
|
||||
|
||||
def to_numpy(tasks: List[Any]):
|
||||
"""
|
||||
Convert PaddlePaddle tensors in multimodal inputs to NumPy arrays.
|
||||
|
||||
@@ -30,7 +30,6 @@ class Logprob(NamedTuple):
|
||||
decoded_token: Optional[str] = None
|
||||
|
||||
|
||||
PromptLogprobs = list[dict[int, Logprob] | None]
|
||||
# [{token_id, logprob}] for tokens sampled from the top-k
|
||||
SampleLogprobs = list[dict[int, Logprob]]
|
||||
|
||||
@@ -125,6 +124,9 @@ class LogprobsTensors(NamedTuple):
|
||||
)
|
||||
|
||||
|
||||
PromptLogprobs = LogprobsTensors | list[dict[int, Logprob] | None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
""" """
|
||||
|
||||
+1
-1
@@ -40,7 +40,7 @@ opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
||||
opentelemetry-instrumentation-logging
|
||||
opentelemetry-instrumentation-logging>=0.57b0
|
||||
partial_json_parser
|
||||
msgspec
|
||||
einops
|
||||
|
||||
@@ -37,5 +37,5 @@ opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
||||
opentelemetry-instrumentation-logging
|
||||
opentelemetry-instrumentation-logging>=0.57b0
|
||||
partial_json_parser
|
||||
|
||||
@@ -37,7 +37,7 @@ opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
||||
opentelemetry-instrumentation-logging
|
||||
opentelemetry-instrumentation-logging>=0.57b0
|
||||
partial_json_parser
|
||||
msgspec
|
||||
safetensors==0.7.0rc0
|
||||
|
||||
@@ -40,7 +40,7 @@ opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
||||
opentelemetry-instrumentation-logging
|
||||
opentelemetry-instrumentation-logging>=0.57b0
|
||||
partial_json_parser
|
||||
msgspec
|
||||
einops
|
||||
|
||||
@@ -26,17 +26,28 @@ class TestSamplingParamsVerification(unittest.TestCase):
|
||||
|
||||
def test_logprobs_valid_values(self):
|
||||
"""Test valid logprobs values"""
|
||||
# Test None value (should pass)
|
||||
params = SamplingParams(logprobs=None)
|
||||
params._verify_args() # Should not raise
|
||||
# Test None value (should pass in both modes)
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
params = SamplingParams(logprobs=None)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test -1 value (should pass)
|
||||
params = SamplingParams(logprobs=-1)
|
||||
params._verify_args() # Should not raise
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(logprobs=None)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test 0 value (should pass)
|
||||
params = SamplingParams(logprobs=0)
|
||||
params._verify_args() # Should not raise
|
||||
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(logprobs=-1)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test 0 value (should pass in both modes based on actual behavior)
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
params = SamplingParams(logprobs=0)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(logprobs=0)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
@@ -44,13 +55,23 @@ class TestSamplingParamsVerification(unittest.TestCase):
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
def test_logprobs_invalid_less_than_minus_one(self):
|
||||
"""Test logprobs less than -1 should raise ValueError"""
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=-2)
|
||||
params._verify_args()
|
||||
"""Test logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=-2)
|
||||
params._verify_args()
|
||||
|
||||
self.assertIn("logprobs must be greater than -1", str(cm.exception))
|
||||
self.assertIn("got -2", str(cm.exception))
|
||||
self.assertIn("logprobs must be a non-negative value or -1", str(cm.exception))
|
||||
self.assertIn("got -2", str(cm.exception))
|
||||
|
||||
def test_logprobs_invalid_less_than_zero(self):
|
||||
"""Test logprobs less than 0 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "0" """
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=-1)
|
||||
params._verify_args()
|
||||
|
||||
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", str(cm.exception))
|
||||
|
||||
def test_logprobs_greater_than_20_with_v1_disabled(self):
|
||||
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
|
||||
@@ -59,7 +80,7 @@ class TestSamplingParamsVerification(unittest.TestCase):
|
||||
params = SamplingParams(logprobs=21)
|
||||
params._verify_args()
|
||||
|
||||
self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception))
|
||||
self.assertEqual("Invalid value for 'top_logprobs': must be between 0 and 20.", str(cm.exception))
|
||||
|
||||
def test_logprobs_greater_than_20_with_v1_enabled(self):
|
||||
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
|
||||
@@ -74,46 +95,67 @@ class TestSamplingParamsVerification(unittest.TestCase):
|
||||
|
||||
def test_prompt_logprobs_valid_values(self):
|
||||
"""Test valid prompt_logprobs values"""
|
||||
# Test None value (should pass)
|
||||
params = SamplingParams(prompt_logprobs=None)
|
||||
params._verify_args() # Should not raise
|
||||
# Test None value (should pass in both modes based on actual behavior)
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
params = SamplingParams(prompt_logprobs=None)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test -1 value (should pass)
|
||||
params = SamplingParams(prompt_logprobs=-1)
|
||||
params._verify_args() # Should not raise
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(prompt_logprobs=None)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test 0 value (should pass)
|
||||
params = SamplingParams(prompt_logprobs=0)
|
||||
params._verify_args() # Should not raise
|
||||
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(prompt_logprobs=-1)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test positive values (should pass)
|
||||
params = SamplingParams(prompt_logprobs=10)
|
||||
params._verify_args() # Should not raise
|
||||
# Test 0 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(prompt_logprobs=0)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test positive values (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(prompt_logprobs=10)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
def test_prompt_logprobs_invalid_less_than_minus_one(self):
|
||||
"""Test prompt_logprobs less than -1 should raise ValueError"""
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
"""Test prompt_logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
|
||||
self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception))
|
||||
self.assertIn("got -2", str(cm.exception))
|
||||
self.assertIn("prompt_logprobs a must be non-negative value or -1", str(cm.exception))
|
||||
self.assertIn("got -2", str(cm.exception))
|
||||
|
||||
def test_combined_logprobs_and_prompt_logprobs(self):
|
||||
"""Test both logprobs and prompt_logprobs together"""
|
||||
# Test valid combination
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
||||
params._verify_args() # Should not raise
|
||||
# Test valid combination when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test invalid logprobs with valid prompt_logprobs
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
|
||||
params._verify_args()
|
||||
# Test invalid logprobs with valid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
|
||||
params._verify_args()
|
||||
|
||||
# Test valid logprobs with invalid prompt_logprobs
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
# Test valid logprobs with invalid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
|
||||
# Test prompt_logprobs not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
||||
params._verify_args()
|
||||
self.assertIn(
|
||||
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
|
||||
)
|
||||
|
||||
def test_logprobs_boundary_values(self):
|
||||
"""Test boundary values for logprobs"""
|
||||
@@ -130,14 +172,16 @@ class TestSamplingParamsVerification(unittest.TestCase):
|
||||
|
||||
def test_prompt_logprobs_boundary_values(self):
|
||||
"""Test boundary values for prompt_logprobs"""
|
||||
# Test boundary value -1 (should pass)
|
||||
params = SamplingParams(prompt_logprobs=-1)
|
||||
params._verify_args() # Should pass
|
||||
# Test boundary value -1 (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(prompt_logprobs=-1)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test boundary value just below -1 (should fail)
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
# Test boundary value just below -1 (should fail when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
|
||||
def test_environment_variable_handling(self):
|
||||
"""Test different environment variable values"""
|
||||
@@ -167,55 +211,111 @@ class TestSamplingParamsVerification(unittest.TestCase):
|
||||
if original_value is not None:
|
||||
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value
|
||||
|
||||
# Test prompt_logprobs behavior with different environment variables
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(prompt_logprobs=5)
|
||||
params._verify_args()
|
||||
self.assertIn(
|
||||
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(prompt_logprobs=5)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
def test_error_message_formatting(self):
|
||||
"""Test that error messages are properly formatted"""
|
||||
# Test logprobs error message
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=-5)
|
||||
params._verify_args()
|
||||
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=-5)
|
||||
params._verify_args()
|
||||
|
||||
error_msg = str(cm.exception)
|
||||
self.assertIn("logprobs must be greater than -1", error_msg)
|
||||
self.assertIn("got -5", error_msg)
|
||||
error_msg = str(cm.exception)
|
||||
self.assertIn("logprobs must be a non-negative value or -1", error_msg)
|
||||
self.assertIn("got -5", error_msg)
|
||||
|
||||
# Test prompt_logprobs error message
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(prompt_logprobs=-10)
|
||||
params._verify_args()
|
||||
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=-1)
|
||||
params._verify_args()
|
||||
|
||||
error_msg = str(cm.exception)
|
||||
self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg)
|
||||
self.assertIn("got -10", error_msg)
|
||||
error_msg = str(cm.exception)
|
||||
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", error_msg)
|
||||
|
||||
# Test prompt_logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(prompt_logprobs=-10)
|
||||
params._verify_args()
|
||||
|
||||
error_msg = str(cm.exception)
|
||||
self.assertIn("prompt_logprobs a must be non-negative value or -1", error_msg)
|
||||
self.assertIn("got -10", error_msg)
|
||||
|
||||
# Test prompt_logprobs not supported error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(prompt_logprobs=5)
|
||||
params._verify_args()
|
||||
|
||||
error_msg = str(cm.exception)
|
||||
self.assertIn("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", error_msg)
|
||||
|
||||
def test_post_init_calls_verify_args(self):
|
||||
"""Test that __post_init__ calls _verify_args"""
|
||||
# This should call _verify_args internally
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
||||
# This should call _verify_args internally when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
||||
|
||||
# The params should be successfully created without errors
|
||||
self.assertEqual(params.logprobs, 5)
|
||||
self.assertEqual(params.prompt_logprobs, 3)
|
||||
# The params should be successfully created without errors
|
||||
self.assertEqual(params.logprobs, 5)
|
||||
self.assertEqual(params.prompt_logprobs, 3)
|
||||
|
||||
# Test that invalid values are caught during initialization
|
||||
with self.assertRaises(ValueError):
|
||||
SamplingParams(logprobs=-2)
|
||||
# Test that invalid values are caught during initialization
|
||||
with self.assertRaises(ValueError):
|
||||
SamplingParams(logprobs=-2)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
SamplingParams(prompt_logprobs=-2)
|
||||
with self.assertRaises(ValueError):
|
||||
SamplingParams(prompt_logprobs=-2)
|
||||
|
||||
# Test that prompt_logprobs is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError):
|
||||
SamplingParams(prompt_logprobs=3)
|
||||
|
||||
# Test that logprobs < 0 is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
||||
with self.assertRaises(ValueError):
|
||||
SamplingParams(logprobs=-1)
|
||||
|
||||
def test_logprobs_with_other_parameters(self):
|
||||
"""Test logprobs validation with other sampling parameters"""
|
||||
# Test with temperature
|
||||
params = SamplingParams(logprobs=5, temperature=0.8)
|
||||
params._verify_args() # Should pass
|
||||
# Test with temperature when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(logprobs=5, temperature=0.8)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test with top_p
|
||||
params = SamplingParams(logprobs=5, top_p=0.9)
|
||||
params._verify_args() # Should pass
|
||||
# Test with top_p when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(logprobs=5, top_p=0.9)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test with all parameters
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100)
|
||||
params._verify_args() # Should pass
|
||||
# Test with all parameters when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(
|
||||
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
|
||||
)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test that prompt_logprobs fails when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(
|
||||
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
|
||||
)
|
||||
params._verify_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -453,6 +453,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
num_input_image_tokens = [0, 0]
|
||||
num_input_video_tokens = [0, 0]
|
||||
num_image_tokens = [0, 0]
|
||||
prompt_logprobs_res_list = [[], []]
|
||||
max_tokens_list = [10, 1]
|
||||
|
||||
for idx, case in enumerate(test_cases):
|
||||
@@ -469,6 +470,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
num_image_tokens=num_image_tokens,
|
||||
logprob_contents=logprob_contents,
|
||||
draft_logprob_contents=draft_logprob_contents,
|
||||
prompt_logprobs_res_list=prompt_logprobs_res_list,
|
||||
response_processor=mock_response_processor,
|
||||
max_tokens=max_tokens_list[idx],
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,5 @@
|
||||
from unittest.mock import MagicMock
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -14,56 +15,147 @@ class DummyModelConfig:
|
||||
self.ori_vocab_size = ori_vocab_size
|
||||
|
||||
|
||||
class DummyCacheConfig:
|
||||
def __init__(self, enable_prefix_caching=False):
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
|
||||
|
||||
class DummyLLMEngineConfig:
|
||||
def __init__(self, model_config=None, cache_config=None):
|
||||
self.model_config = model_config or DummyModelConfig()
|
||||
self.cache_config = cache_config or DummyCacheConfig()
|
||||
|
||||
|
||||
class DummyLLMEngine:
|
||||
def __init__(self, model_config=None, cache_config=None):
|
||||
self.cfg = DummyLLMEngineConfig(model_config, cache_config)
|
||||
self.data_processor = MagicMock()
|
||||
# Mock tokenizer with sp_model attribute
|
||||
self.data_processor.tokenizer = MagicMock()
|
||||
self.data_processor.tokenizer.sp_model = MagicMock()
|
||||
self.data_processor.tokenizer.sp_model.__len__ = MagicMock(return_value=100)
|
||||
self.data_processor.tokenizer.vocab = MagicMock()
|
||||
self.data_processor.tokenizer.vocab.__len__ = MagicMock(return_value=100)
|
||||
self.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
|
||||
self.add_requests = MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = LLM.__new__(LLM)
|
||||
llm.llm_engine = MagicMock()
|
||||
llm.llm_engine.add_requests = MagicMock()
|
||||
llm.llm_engine.cfg.model_config = DummyModelConfig(max_logprobs=10, ori_vocab_size=100)
|
||||
# Mock the data_processor.process_logprob_response method to return proper strings
|
||||
llm.llm_engine.data_processor = MagicMock()
|
||||
llm.llm_engine.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
|
||||
llm.llm_engine = DummyLLMEngine()
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_with_prefix_caching():
|
||||
llm = LLM.__new__(LLM)
|
||||
llm.llm_engine = DummyLLMEngine(cache_config=DummyCacheConfig(enable_prefix_caching=True))
|
||||
return llm
|
||||
|
||||
|
||||
def test_prompt_logprobs_not_supported_with_stream(mock_llm):
|
||||
sampling = SamplingParams(prompt_logprobs=5)
|
||||
with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
|
||||
mock_llm._add_request(["hi"], sampling, stream=True)
|
||||
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(prompt_logprobs=5)
|
||||
with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
|
||||
mock_llm._add_request(["hi"], sampling, stream=True)
|
||||
|
||||
|
||||
def test_prompt_logprobs_not_supported_with_prefix_caching(mock_llm_with_prefix_caching):
|
||||
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(prompt_logprobs=5)
|
||||
with pytest.raises(ValueError, match="prompt_logprobs is not supported with prefix caching enabled"):
|
||||
mock_llm_with_prefix_caching._add_request(["hi"], sampling)
|
||||
|
||||
|
||||
def test_num_logprobs_exceeds_max(mock_llm):
|
||||
sampling = SamplingParams(logprobs=20)
|
||||
with pytest.raises(ValueError, match="Number of logprobs requested"):
|
||||
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs > 20
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(logprobs=20)
|
||||
with pytest.raises(ValueError, match="Number of logprobs requested"):
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
|
||||
|
||||
def test_max_logprobs_exceeds_vocab_size(mock_llm):
|
||||
# Test case where max_logprobs > ori_vocab_size
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = 150 # > vocab size (100)
|
||||
with pytest.raises(ValueError, match="max_logprobs \\(150\\) exceeds vocabulary size \\(100\\)"):
|
||||
mock_llm._add_request(["hi"], SamplingParams())
|
||||
|
||||
|
||||
def test_max_logprobs_less_than_minus_one(mock_llm):
|
||||
# Test case where max_logprobs < -1
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -2
|
||||
with pytest.raises(ValueError, match="max_logprobs \\(-2\\) can't be less than -1"):
|
||||
mock_llm._add_request(["hi"], SamplingParams())
|
||||
|
||||
|
||||
def test_logprobs_minus_one_uses_vocab_size(mock_llm):
|
||||
# Test that logprobs=-1 uses vocab size
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(logprobs=-1)
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 # Allow unlimited
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
|
||||
|
||||
def test_num_prompt_logprobs_exceeds_max(mock_llm):
|
||||
sampling = SamplingParams(prompt_logprobs=20)
|
||||
with pytest.raises(ValueError, match="Number of logprobs requested"):
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(prompt_logprobs=20)
|
||||
with pytest.raises(ValueError, match="Number of logprobs requested"):
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
|
||||
|
||||
def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm):
|
||||
sampling = SamplingParams(logprobs=-1)
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
|
||||
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 30
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
# Get the first argument (tasks) which should be a dict
|
||||
call_args = mock_llm.llm_engine.add_requests.call_args
|
||||
tasks = call_args[0][0] # First positional argument
|
||||
assert isinstance(tasks, dict)
|
||||
assert "prompt" in tasks
|
||||
assert "request_id" in tasks
|
||||
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs=-1
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(logprobs=-1)
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
# Get the first argument (tasks) which should be a dict
|
||||
call_args = mock_llm.llm_engine.add_requests.call_args
|
||||
tasks = call_args[0][0] # First positional argument
|
||||
assert isinstance(tasks, dict)
|
||||
assert "prompt" in tasks
|
||||
assert "request_id" in tasks
|
||||
|
||||
|
||||
def test_prompt_logprobs_equal_to_minus_one(mock_llm):
|
||||
sampling = SamplingParams(prompt_logprobs=-1)
|
||||
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support and allow -1
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(prompt_logprobs=-1)
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
|
||||
|
||||
def test_dynamic_vocab_size_from_sp_model(mock_llm):
|
||||
# Test that ori_vocab_size is dynamically obtained from sp_model
|
||||
mock_llm.llm_engine.data_processor.tokenizer.sp_model.__len__.return_value = 200
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
|
||||
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 25
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(logprobs=-1)
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
# Should use the dynamic vocab size (200)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
|
||||
|
||||
def test_dynamic_vocab_size_from_vocab_fallback(mock_llm):
|
||||
# Test fallback to vocab when sp_model is not available
|
||||
del mock_llm.llm_engine.data_processor.tokenizer.sp_model
|
||||
mock_llm.llm_engine.data_processor.tokenizer.vocab.__len__.return_value = 300
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
|
||||
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
sampling = SamplingParams(logprobs=-1)
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
# Should use the vocab size (300)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
|
||||
|
||||
def test_build_prompt_logprobs_basic(mock_llm):
|
||||
@@ -77,12 +169,13 @@ def test_build_prompt_logprobs_basic(mock_llm):
|
||||
|
||||
# 检查结果格式
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert len(result) == 3
|
||||
for pos_dict in result:
|
||||
assert isinstance(pos_dict, dict)
|
||||
for logprob_obj in pos_dict.values():
|
||||
assert isinstance(logprob_obj, Logprob)
|
||||
assert logprob_obj.decoded_token.startswith("TOKEN_")
|
||||
if pos_dict is not None:
|
||||
assert isinstance(pos_dict, dict)
|
||||
for logprob_obj in pos_dict.values():
|
||||
assert isinstance(logprob_obj, Logprob)
|
||||
assert logprob_obj.decoded_token.startswith("TOKEN_")
|
||||
|
||||
|
||||
def test_build_prompt_logprobs_handles_minus_one(mock_llm):
|
||||
@@ -94,7 +187,7 @@ def test_build_prompt_logprobs_handles_minus_one(mock_llm):
|
||||
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
pos_dict = result[0]
|
||||
assert len(result) == 2
|
||||
pos_dict = result[1]
|
||||
assert 7 in pos_dict
|
||||
assert pos_dict[7].decoded_token == "TOKEN_7"
|
||||
|
||||
@@ -137,7 +137,7 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
|
||||
result = self.processor._process_batch_output_use_zmq([stream_data])
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsNone(getattr(result[0], "prompt_logprobs_tensors", None))
|
||||
self.assertIsNone(getattr(result[0], "prompt_logprobs", None))
|
||||
|
||||
def test_process_batch_with_stop_flag(self):
|
||||
"""Test processing when stop flag is True"""
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from fastdeploy.utils import clamp_prompt_logprobs
|
||||
from fastdeploy.worker.output import Logprob
|
||||
|
||||
|
||||
class TestClampPromptLogprobs(unittest.TestCase):
|
||||
def test_none_input(self):
|
||||
"""Test case when input is None"""
|
||||
result = clamp_prompt_logprobs(None)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_empty_list(self):
|
||||
"""Test empty list input"""
|
||||
result = clamp_prompt_logprobs([])
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_normal_logprobs(self):
|
||||
"""Test normal logprobs values (without -inf)"""
|
||||
logprob_dict = {
|
||||
1: Logprob(logprob=-2.5, rank=1, decoded_token="hello"),
|
||||
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
|
||||
}
|
||||
prompt_logprobs = [logprob_dict]
|
||||
|
||||
result = clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
# Original values should remain unchanged
|
||||
self.assertEqual(result[0][1].logprob, -2.5)
|
||||
self.assertEqual(result[0][2].logprob, -1.0)
|
||||
|
||||
def test_negative_inf_logprobs_raises_error(self):
|
||||
"""Test that logprobs containing -inf raises AttributeError"""
|
||||
logprob_dict = {
|
||||
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
|
||||
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
|
||||
}
|
||||
prompt_logprobs = [logprob_dict]
|
||||
|
||||
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
|
||||
with self.assertRaises(AttributeError) as context:
|
||||
clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
self.assertIn("can't set attribute", str(context.exception))
|
||||
|
||||
def test_multiple_negative_inf_raises_error(self):
|
||||
"""Test that multiple -inf logprobs values raise AttributeError"""
|
||||
logprob_dict = {
|
||||
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
|
||||
2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"),
|
||||
3: Logprob(logprob=-0.5, rank=3, decoded_token="test"),
|
||||
}
|
||||
prompt_logprobs = [logprob_dict]
|
||||
|
||||
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
|
||||
with self.assertRaises(AttributeError):
|
||||
clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
def test_none_dict_in_list(self):
|
||||
"""Test case when list contains None"""
|
||||
prompt_logprobs = [None]
|
||||
|
||||
result = clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
# None should be skipped
|
||||
self.assertIsNone(result[0])
|
||||
|
||||
def test_multiple_dicts_normal_values(self):
|
||||
"""Test multiple dictionaries case (without -inf)"""
|
||||
logprob_dict1 = {
|
||||
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
|
||||
}
|
||||
logprob_dict2 = {
|
||||
2: Logprob(logprob=-2.0, rank=1, decoded_token="world"),
|
||||
}
|
||||
prompt_logprobs = [logprob_dict1, logprob_dict2]
|
||||
|
||||
result = clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
# Should return normally, values remain unchanged
|
||||
self.assertEqual(result[0][1].logprob, -2.0)
|
||||
self.assertEqual(result[1][2].logprob, -2.0)
|
||||
|
||||
def test_mixed_values_without_inf(self):
|
||||
"""Test mixed values case (without -inf)"""
|
||||
logprob_dict = {
|
||||
1: Logprob(logprob=-9999.0, rank=1, decoded_token="hello"),
|
||||
2: Logprob(logprob=-9999.0, rank=2, decoded_token="world"),
|
||||
3: Logprob(logprob=0.0, rank=3, decoded_token="test"),
|
||||
4: Logprob(logprob=-1.5, rank=4, decoded_token="again"),
|
||||
}
|
||||
prompt_logprobs = [logprob_dict]
|
||||
|
||||
result = clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
# All values should remain unchanged
|
||||
self.assertEqual(result[0][1].logprob, -9999.0)
|
||||
self.assertEqual(result[0][2].logprob, -9999.0)
|
||||
self.assertEqual(result[0][3].logprob, 0.0)
|
||||
self.assertEqual(result[0][4].logprob, -1.5)
|
||||
|
||||
def test_return_same_object(self):
|
||||
"""Test that function returns the same object (in-place modification attempt)"""
|
||||
logprob_dict = {
|
||||
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
|
||||
}
|
||||
prompt_logprobs = [logprob_dict]
|
||||
|
||||
result = clamp_prompt_logprobs(prompt_logprobs)
|
||||
|
||||
# Should return the same object (function attempts in-place modification)
|
||||
self.assertIs(result, prompt_logprobs)
|
||||
self.assertIs(result[0], prompt_logprobs[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -16,6 +16,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -163,44 +164,46 @@ class TestGPUPromptLogprobs(unittest.TestCase):
|
||||
return model_runner
|
||||
|
||||
def test_prompt_logprobs(self):
|
||||
model_runner = self.setup_model_runner()
|
||||
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
model_runner = self.setup_model_runner()
|
||||
|
||||
req: Request = Request(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
arrival_time=None,
|
||||
request_id="asd1",
|
||||
prompt_token_ids=[1, 2, 3, 4],
|
||||
prompt_token_ids_len=4,
|
||||
prefill_start_index=0,
|
||||
prefill_end_index=4,
|
||||
sampling_params=SamplingParams(prompt_logprobs=-1),
|
||||
)
|
||||
req.idx = 0
|
||||
model_runner.prompt_logprobs_reqs = {req.request_id: req}
|
||||
req: Request = Request(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
arrival_time=None,
|
||||
request_id="asd1",
|
||||
prompt_token_ids=[1, 2, 3, 4],
|
||||
prompt_token_ids_len=4,
|
||||
prefill_start_index=0,
|
||||
prefill_end_index=4,
|
||||
sampling_params=SamplingParams(prompt_logprobs=-1),
|
||||
)
|
||||
req.idx = 0
|
||||
model_runner.prompt_logprobs_reqs = {req.request_id: req}
|
||||
|
||||
hidden_states = paddle.rand(
|
||||
[len(req.prompt_token_ids) - 1, model_runner.fd_config.model_config.hidden_size], dtype="bfloat16"
|
||||
)
|
||||
ref_logits = model_runner.model.compute_logits(hidden_states)
|
||||
ref_raw_logprobs = model_runner.sampler.compute_logprobs(ref_logits)
|
||||
token_is = paddle.to_tensor(req.prompt_token_ids[1:], dtype="int64")
|
||||
hidden_states = paddle.rand(
|
||||
[len(req.prompt_token_ids) - 1, model_runner.fd_config.model_config.hidden_size], dtype="bfloat16"
|
||||
)
|
||||
ref_logits = model_runner.model.compute_logits(hidden_states)
|
||||
ref_raw_logprobs = model_runner.sampler.compute_logprobs(ref_logits)
|
||||
token_is = paddle.to_tensor(req.prompt_token_ids[1:], dtype="int64")
|
||||
|
||||
ref_token_ids, ref_logprobs, ref_ranks = model_runner.sampler.gather_logprobs(
|
||||
ref_raw_logprobs, model_runner.fd_config.model_config.ori_vocab_size, token_is
|
||||
)
|
||||
prompt_logprobs = model_runner._get_prompt_logprobs_list(hidden_states)[0]
|
||||
np.testing.assert_allclose(ref_logprobs.numpy(), prompt_logprobs.logprobs.numpy(), rtol=1e-04, atol=1e-04)
|
||||
np.testing.assert_allclose(
|
||||
ref_token_ids.numpy(), prompt_logprobs.logprob_token_ids.numpy(), rtol=1e-04, atol=1e-04
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
ref_ranks.numpy(), prompt_logprobs.selected_token_ranks.numpy(), rtol=1e-04, atol=1e-04
|
||||
)
|
||||
ref_token_ids, ref_logprobs, ref_ranks = model_runner.sampler.gather_logprobs(
|
||||
ref_raw_logprobs, model_runner.fd_config.model_config.ori_vocab_size, token_is
|
||||
)
|
||||
prompt_logprobs = model_runner._get_prompt_logprobs_list(hidden_states)[0]
|
||||
np.testing.assert_allclose(ref_logprobs.numpy(), prompt_logprobs.logprobs.numpy(), rtol=1e-04, atol=1e-04)
|
||||
np.testing.assert_allclose(
|
||||
ref_token_ids.numpy(), prompt_logprobs.logprob_token_ids.numpy(), rtol=1e-04, atol=1e-04
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
ref_ranks.numpy(), prompt_logprobs.selected_token_ranks.numpy(), rtol=1e-04, atol=1e-04
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user