[LogProbs]Enable prompt logprobs output and modify data transmission method for the online interface. (#5089)

* add prompt logprobs

* Merge prompt_logprobs_tensors and prompt_logprobs

* fix param check

* trigger ci

* fix unitest

* fix logprobs bug
This commit is contained in:
qwes5s5
2025-12-02 13:49:51 +08:00
committed by GitHub
parent af39819fcd
commit 117980dd4e
27 changed files with 4947 additions and 233 deletions
+8 -1
View File
@@ -229,8 +229,15 @@ class ModelConfig:
self.think_end_id = args.get("think_end_id", -1)
self.im_patch_id = args.get("image_patch_id", -1)
self.line_break_id = args.get("line_break_id", -1)
if self.max_logprobs < -1:
num_max_logprobs = args.get("max_logprobs", None)
if num_max_logprobs is not None and num_max_logprobs < -1:
raise ValueError(" The possible values for max_logprobs can't be less than -1 ")
if self.ori_vocab_size is not None and num_max_logprobs is not None:
if num_max_logprobs > self.ori_vocab_size:
raise ValueError(
f" The possible values for max_logprobs can't be greater than the vocabulary size {self.ori_vocab_size}"
)
self._post_init()
+1 -8
View File
@@ -31,12 +31,7 @@ from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.protocol import ToolCall
from fastdeploy.utils import data_processor_logger
from fastdeploy.worker.output import (
LogprobsLists,
LogprobsTensors,
PromptLogprobs,
SampleLogprobs,
)
from fastdeploy.worker.output import LogprobsLists, PromptLogprobs, SampleLogprobs
class RequestStatus(Enum):
@@ -519,7 +514,6 @@ class RequestOutput:
prompt: Optional[str] = None,
prompt_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[PromptLogprobs] = None,
prompt_logprobs_tensors: Optional[LogprobsTensors] = None,
output_type: Optional[int] = 3,
outputs: CompletionOutput = None,
finished: bool = False,
@@ -537,7 +531,6 @@ class RequestOutput:
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_logprobs = prompt_logprobs
self.prompt_logprobs_tensors = prompt_logprobs_tensors
self.output_type = output_type
self.outputs = outputs
self.finished = finished
+13 -7
View File
@@ -16,12 +16,13 @@
from __future__ import annotations
import os
import random
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, List, Optional, Union
from fastdeploy import envs
@dataclass
class SamplingParams:
@@ -207,12 +208,17 @@ class SamplingParams:
raise ValueError(
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
)
if self.logprobs is not None and self.logprobs < -1:
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")
if not envs.FD_USE_GET_SAVE_OUTPUT_V1: # False (0)
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 20):
raise ValueError("Invalid value for 'top_logprobs': must be between 0 and 20.")
if self.prompt_logprobs is not None:
raise ValueError("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled.")
else: # True (1)
if self.logprobs is not None and self.logprobs < -1:
raise ValueError(f"logprobs must be a non-negative value or -1, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
raise ValueError(f"prompt_logprobs a must be non-negative value or -1, got {self.prompt_logprobs}.")
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
+73 -9
View File
@@ -56,10 +56,11 @@ class EngineClient:
EngineClient is a class that handles the communication between the client and the server.
"""
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1):
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20):
self.fd_config = fd_config
self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size
self.enable_mm = self.fd_config.model_config.enable_mm
self.max_logprobs = max_logprobs
input_processor = InputPreprocessor(
self.fd_config.model_config,
self.fd_config.structured_outputs_config.reasoning_parser,
@@ -70,6 +71,11 @@ class EngineClient:
)
self.enable_logprob = self.fd_config.model_config.enable_logprob
self.data_processor = input_processor.create_processor()
self.ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, "sp_model")
else len(self.data_processor.tokenizer.vocab)
)
self.max_model_len = self.fd_config.model_config.max_model_len
self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching
self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed"
@@ -424,6 +430,53 @@ class EngineClient:
elif logprobs:
raise ParameterError("logprobs", "Invalid type for 'logprobs'")
max_logprobs = self.max_logprobs
if max_logprobs == -1:
max_logprobs = self.ori_vocab_size
if max_logprobs < -1:
err_msg = f"Invalid 'max_logprobs': must be >= -1, got {max_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("max_logprobs", err_msg)
if max_logprobs > self.ori_vocab_size:
err_msg = f"Invalid 'max_logprobs': must be <= vocab_size {self.ori_vocab_size}, got {max_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("max_logprobs", err_msg)
prompt_logprobs = data.get("prompt_logprobs", None)
if prompt_logprobs is not None:
if not self.enable_logprob:
err_msg = "`enable_logprob` is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ParameterError("prompt_logprobs", err_msg)
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
err_msg = "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled."
api_server_logger.error(err_msg)
raise ParameterError("prompt_logprobs", err_msg)
if self.enable_prefix_caching:
err_msg = "prompt_logprobs is not support when prefix caching is enabled."
api_server_logger.error(err_msg)
raise ParameterError("prompt_logprobs", err_msg)
if prompt_logprobs == -1 and self.ori_vocab_size > max_logprobs:
err_msg = f"The requested value of ({self.ori_vocab_size}) for prompt_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
api_server_logger.error(err_msg)
raise ValueError("prompt_logprobs", err_msg)
if prompt_logprobs < -1:
err_msg = (
f"prompt_logprobs must be a non-negative value or -1; the current value is {prompt_logprobs}."
)
api_server_logger.error(err_msg)
raise ValueError("prompt_logprobs", err_msg)
if prompt_logprobs > max_logprobs:
err_msg = f"Number of prompt_logprobs requested ({prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
api_server_logger.error(err_msg)
raise ValueError("prompt_logprobs", err_msg)
# enable_logprob
if top_logprobs:
if not self.enable_logprob:
@@ -437,15 +490,26 @@ class EngineClient:
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs", err_msg)
if top_logprobs < 0:
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs", err_msg)
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
if top_logprobs < 0 or top_logprobs > 20:
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
else:
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
if top_logprobs > 20:
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs", err_msg)
if top_logprobs < -1:
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
if top_logprobs > max_logprobs:
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
def check_health(self, time_interval_threashold=30):
"""
+27 -10
View File
@@ -335,23 +335,40 @@ class LLM:
current_sampling_params = sampling_params[i]
else:
current_sampling_params = sampling_params
if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None:
raise ValueError("prompt_logprobs is not supported with streaming.")
ori_vocab_size = (
len(self.llm_engine.data_processor.tokenizer.sp_model)
if hasattr(self.llm_engine.data_processor.tokenizer, "sp_model")
else len(self.llm_engine.data_processor.tokenizer.vocab)
)
max_logprobs = self.llm_engine.cfg.model_config.max_logprobs
if max_logprobs == -1:
max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
max_logprobs = ori_vocab_size
if max_logprobs < -1:
raise ValueError(f"max_logprobs ({max_logprobs}) can't be less than -1.")
if max_logprobs > ori_vocab_size:
raise ValueError(f"max_logprobs ({max_logprobs}) exceeds vocabulary size ({ori_vocab_size}).")
if current_sampling_params.logprobs is not None:
num_logprobs = current_sampling_params.logprobs
if num_logprobs == -1:
num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
raise ValueError(
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
)
if num_logprobs > max_logprobs:
raise ValueError(
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
)
if current_sampling_params.prompt_logprobs is not None:
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
if kwargs.get("stream"):
raise ValueError("prompt_logprobs is not supported with streaming.")
num_prompt_logprobs = current_sampling_params.prompt_logprobs
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
if num_prompt_logprobs == -1 and ori_vocab_size > max_logprobs:
raise ValueError(
f"Number of prompt_logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
)
if num_prompt_logprobs > max_logprobs:
raise ValueError(
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
@@ -436,7 +453,7 @@ class LLM:
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
result: Optional[PromptLogprobs] = []
result: Optional[PromptLogprobs] = [None]
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
@@ -548,11 +565,11 @@ class LLM:
result.outputs.logprobs = self._build_sample_logprobs(
result.outputs.top_logprobs, topk_logprobs
)
if result.prompt_logprobs_tensors and num_prompt_logprobs:
if result.prompt_logprobs is not None and num_prompt_logprobs is not None:
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
result.prompt_logprobs = self._build_prompt_logprobs(
result.prompt_logprobs_tensors, num_prompt_logprobs
result.prompt_logprobs, num_prompt_logprobs
)
output[pos] = result
@@ -181,6 +181,7 @@ async def lifespan(app: FastAPI):
port=int(os.environ.get("INFERENCE_MSG_QUEUE_ID", "0")),
fd_config=fd_config,
workers=args.workers,
max_logprobs=args.max_logprobs,
)
await engine_client.connection_manager.initialize()
app.state.dynamic_load_weight = args.dynamic_load_weight
+38 -4
View File
@@ -21,9 +21,17 @@ import time
import uuid
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
ValidationInfo,
field_validator,
model_validator,
)
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.worker.output import PromptLogprobs
class InvalidParameterException(Exception):
@@ -214,10 +222,12 @@ class ChatCompletionResponseChoice(BaseModel):
Chat completion response choice.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
@@ -275,10 +285,12 @@ class ChatCompletionResponseStreamChoice(BaseModel):
Chat completion response choice for stream response.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
arrival_time: Optional[float] = None
@@ -301,6 +313,7 @@ class CompletionResponseChoice(BaseModel):
Completion response choice.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
index: int
text: str
prompt_token_ids: Optional[List[int]] = None
@@ -310,6 +323,7 @@ class CompletionResponseChoice(BaseModel):
arrival_time: Optional[float] = None
logprobs: Optional[CompletionLogprobs] = None
draft_logprobs: Optional[CompletionLogprobs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
reasoning_content: Optional[str] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
@@ -344,11 +358,13 @@ class CompletionResponseStreamChoice(BaseModel):
Completion response choice for stream response.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
index: int
text: str
arrival_time: float = None
logprobs: Optional[CompletionLogprobs] = None
draft_logprobs: Optional[CompletionLogprobs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
prompt_token_ids: Optional[List[int]] = None
completion_token_ids: Optional[List[int]] = None
prompt_tokens: Optional[str] = None
@@ -437,6 +453,7 @@ class CompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False
prompt_logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
@@ -569,6 +586,18 @@ class CompletionRequest(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (logprobs := data.get("logprobs")) is not None:
if logprobs < -1:
raise ValueError("`logprobs` must be a greater than -1.")
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if prompt_logprobs < -1:
raise ValueError("`prompt_logprobs` must be a greater than -1.")
return data
class ChatCompletionRequest(BaseModel):
"""
@@ -583,6 +612,7 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
prompt_logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False
# For logits and logprobs post processing
@@ -651,6 +681,7 @@ class ChatCompletionRequest(BaseModel):
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
req_dict["prompt_logprobs"] = self.prompt_logprobs
req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs
req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs
@@ -751,12 +782,15 @@ class ChatCompletionRequest(BaseModel):
def check_logprobs(cls, data):
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if top_logprobs < -1:
raise ValueError("`top_logprobs` must be a greater than -1.")
if top_logprobs > 0 and not data.get("logprobs"):
if not data.get("logprobs"):
raise ValueError("when using `top_logprobs`, `logprobs` must be set to true.")
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if prompt_logprobs < -1:
raise ValueError("`prompt_logprobs` must be a greater than -1.")
return data
+136 -8
View File
@@ -15,9 +15,11 @@
"""
import asyncio
import itertools
import time
import traceback
import uuid
from collections.abc import Iterable
from typing import List, Optional
import numpy as np
@@ -47,9 +49,17 @@ from fastdeploy.utils import (
ErrorType,
ParameterError,
api_server_logger,
clamp_prompt_logprobs,
get_host_ip,
)
from fastdeploy.worker.output import LogprobsLists
from fastdeploy.worker.output import (
Logprob,
LogprobsLists,
LogprobsTensors,
PromptLogprobs,
)
NONES = itertools.repeat(None)
class OpenAIServingChat:
@@ -287,6 +297,17 @@ class OpenAIServingChat:
num_input_image_tokens = res.get("num_input_image_tokens", 0)
num_input_video_tokens = res.get("num_input_video_tokens", 0)
for i in range(num_choices):
prompt_logprobs_res: Optional[PromptLogprobs] = None
prompt_logprobs_tensors = res.get("prompt_logprobs", None)
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
num_prompt_logprobs = (
request.prompt_logprobs
if request.prompt_logprobs != -1
else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(
prompt_logprobs_tensors, num_prompt_logprobs
)
choice = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
@@ -296,6 +317,7 @@ class OpenAIServingChat:
prompt_token_ids=None,
completion_token_ids=None,
),
prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res),
)
if response_processor.enable_multimodal_content():
choice.delta.multimodal_content = [
@@ -344,12 +366,16 @@ class OpenAIServingChat:
logprobs_res: Optional[LogProbs] = None
draft_logprobs_res: Optional[LogProbs] = None
if request.logprobs and output_top_logprobs is not None:
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, request.top_logprobs
num_top_logprobs = (
request.top_logprobs if request.top_logprobs != -1 else self.engine_client.ori_vocab_size
)
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, num_top_logprobs
)
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
draft_logprobs_res = self._create_chat_logprobs(
output_draft_top_logprobs, request.logprobs, request.top_logprobs
output_draft_top_logprobs, request.logprobs, num_top_logprobs
)
delta_message = DeltaMessage(
@@ -496,6 +522,7 @@ class OpenAIServingChat:
enable_mm_output=self.enable_mm_output,
decoder_base_url=self.tokenizer_base_url,
)
prompt_logprobs_res_list = [[] for _ in range(num_choices)]
choices = []
while num_choices > 0:
if self.engine_client.check_model_weight_status():
@@ -538,9 +565,12 @@ class OpenAIServingChat:
output_top_logprobs = output["top_logprobs"]
output_draft_top_logprobs = output["draft_top_logprobs"]
if output_top_logprobs is not None:
num_top_logprobs = (
request.top_logprobs if request.top_logprobs != -1 else self.engine_client.ori_vocab_size
)
# logprobs
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, request.top_logprobs
output_top_logprobs, request.logprobs, num_top_logprobs
)
if logprobs_res and logprobs_res.content is not None:
logprob_contents[idx].extend(logprobs_res.content)
@@ -548,11 +578,20 @@ class OpenAIServingChat:
# draft_logprobs
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
draft_logprobs_res = self._create_chat_logprobs(
output_draft_top_logprobs, request.logprobs, request.top_logprobs
output_draft_top_logprobs, request.logprobs, num_top_logprobs
)
if draft_logprobs_res and draft_logprobs_res.content is not None:
draft_logprob_contents[idx].extend(draft_logprobs_res.content)
prompt_logprobs_tensors = data.get("prompt_logprobs", None)
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
num_prompt_logprobs = (
request.prompt_logprobs
if request.prompt_logprobs != -1
else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs)
if prompt_logprobs_res:
prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res))
if data["finished"]:
num_choices -= 1
reasoning_num_tokens[idx] = data["outputs"].get("reasoning_token_num", 0)
@@ -573,6 +612,7 @@ class OpenAIServingChat:
logprob_contents=logprob_contents,
draft_logprob_contents=draft_logprob_contents,
response_processor=response_processor,
prompt_logprobs_res_list=prompt_logprobs_res_list,
max_tokens=max_tokens,
)
choices.append(choice)
@@ -624,6 +664,7 @@ class OpenAIServingChat:
num_image_tokens: list,
logprob_contents: list,
draft_logprob_contents: list,
prompt_logprobs_res_list: list,
response_processor: ChatResponseProcessor,
max_tokens: int,
) -> ChatCompletionResponseChoice:
@@ -649,11 +690,14 @@ class OpenAIServingChat:
message.content = output["text"]
logprobs_full_res = None
draft_logprobs_full_res = None
prompt_logprobs_full_res = None
if logprob_contents[idx]:
logprobs_full_res = LogProbs(content=logprob_contents[idx])
draft_logprobs_full_res = None
if draft_logprob_contents[idx]:
draft_logprobs_full_res = LogProbs(content=draft_logprob_contents[idx])
if prompt_logprobs_res_list[idx]:
prompt_logprobs_full_res = prompt_logprobs_res_list[idx]
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
@@ -675,6 +719,7 @@ class OpenAIServingChat:
message=message,
logprobs=logprobs_full_res,
draft_logprobs=draft_logprobs_full_res,
prompt_logprobs=prompt_logprobs_full_res,
finish_reason=finish_reason,
)
@@ -780,3 +825,86 @@ class OpenAIServingChat:
else:
enable_thinking = True
return enable_thinking
def _build_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
num_prompt_logprobs: int,
):
"""Update with prompt logprobs from worker.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
]
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the paddle tensors.
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
result: Optional[PromptLogprobs] = [None]
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
result.append(
self._make_logprob_dict(
prompt_logprobs[pos],
token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
num_prompt_logprobs,
)
)
return result
@staticmethod
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: Iterable[str | None],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
dict[token id, Logprob]
"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
}
@@ -15,9 +15,11 @@
"""
import asyncio
import itertools
import time
import traceback
import uuid
from collections.abc import Iterable
from typing import List, Optional
import numpy as np
@@ -43,9 +45,17 @@ from fastdeploy.utils import (
ErrorType,
ParameterError,
api_server_logger,
clamp_prompt_logprobs,
get_host_ip,
)
from fastdeploy.worker.output import LogprobsLists
from fastdeploy.worker.output import (
Logprob,
LogprobsLists,
LogprobsTensors,
PromptLogprobs,
)
NONES = itertools.repeat(None)
class OpenAIServingCompletion:
@@ -249,6 +259,7 @@ class OpenAIServingCompletion:
aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)]
aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)]
aggregated_token_ids = [[] for _ in range(num_choices)]
aggregated_prompt_logprobs_tensors = [None] * num_choices
completion_batched_token_ids = [[] for _ in range(num_choices)]
current_waiting_time = 0
while num_choices > 0:
@@ -293,6 +304,10 @@ class OpenAIServingCompletion:
aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1])
aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2])
output_prompt_logprobs_tensors = data.get("prompt_logprobs") or None
if output_prompt_logprobs_tensors is not None:
aggregated_prompt_logprobs_tensors[rid] = output_prompt_logprobs_tensors
aggregated_token_ids[rid].extend(data["outputs"]["token_ids"])
self.engine_client.data_processor.process_response_dict(
@@ -305,6 +320,7 @@ class OpenAIServingCompletion:
data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid]
data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid]
data["outputs"]["token_ids"] = aggregated_token_ids[rid]
data["prompt_logprobs_tensors"] = aggregated_prompt_logprobs_tensors[rid]
valid_results[rid] = data
num_choices -= 1
break
@@ -426,8 +442,18 @@ class OpenAIServingCompletion:
idx = int(res["request_id"].split("_")[-1])
if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"]))
prompt_logprobs_res: Optional[PromptLogprobs] = None
if first_iteration[idx]:
prompt_logprobs_tensors = res.get("prompt_logprobs", None)
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
num_prompt_logprobs = (
request.prompt_logprobs
if request.prompt_logprobs != -1
else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(
prompt_logprobs_tensors, num_prompt_logprobs
)
if request.return_token_ids:
chunk = CompletionStreamResponse(
id=request_id,
@@ -440,6 +466,7 @@ class OpenAIServingCompletion:
prompt_token_ids=list(
prompt_batched_token_ids[idx // (1 if request.n is None else request.n)]
),
prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res),
prompt_tokens=prompt_tokens_list[
idx // (1 if request.n is None else request.n)
],
@@ -468,13 +495,16 @@ class OpenAIServingCompletion:
output_draft_top_logprobs = output["draft_top_logprobs"]
logprobs_res: Optional[CompletionLogprobs] = None
draft_logprobs_res: Optional[CompletionLogprobs] = None
if request.logprobs and output_top_logprobs is not None:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
if request.logprobs is not None and output_top_logprobs is not None:
num_logprobs = (
request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size
)
logprobs_res = self._create_completion_logprobs(output_top_logprobs, num_logprobs, 0)
# draft logprobs
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
draft_logprobs_res = self._create_completion_logprobs(
output_draft_top_logprobs, request.logprobs, 0
output_draft_top_logprobs, num_logprobs, 0
)
output_tokens[idx] += len(output.get("token_ids", [])) or 0
num_cache_tokens[idx] += output.get("num_cache_tokens") or 0
@@ -492,6 +522,7 @@ class OpenAIServingCompletion:
reasoning_content="",
arrival_time=arrival_time,
logprobs=logprobs_res,
prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res),
draft_logprobs=draft_logprobs_res,
)
if not res["finished"] and "delta_message" in output:
@@ -602,15 +633,22 @@ class OpenAIServingCompletion:
output_draft_top_logprobs = output.get("draft_top_logprobs") or None
aggregated_logprobs: Optional[CompletionLogprobs] = None
num_logprobs = request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size
if output_top_logprobs is not None:
aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, num_logprobs, 0)
aggregated_draft_logprobs: Optional[CompletionLogprobs] = None
if output_draft_top_logprobs is not None:
aggregated_draft_logprobs = self._create_completion_logprobs(
output_draft_top_logprobs, request.logprobs, 0
output_draft_top_logprobs, num_logprobs, 0
)
prompt_logprobs_res: Optional[PromptLogprobs] = None
prompt_logprobs_tensors = final_res.get("prompt_logprobs_tensors", None)
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
num_prompt_logprobs = (
request.prompt_logprobs if request.prompt_logprobs != -1 else self.engine_client.ori_vocab_size
)
prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs)
if request.echo:
prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n))
token_ids = [*prompt_token_ids, *output["token_ids"]]
@@ -641,6 +679,7 @@ class OpenAIServingCompletion:
tool_calls=output.get("tool_call"),
logprobs=aggregated_logprobs,
draft_logprobs=aggregated_draft_logprobs,
prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res),
finish_reason=finish_reason,
)
choices.append(choice_data)
@@ -749,13 +788,13 @@ class OpenAIServingCompletion:
[tid], clean_up_tokenization_spaces=False
)
if "\ufffd" in token_str:
token_bytes = token_str.encode("utf-8", errors="replace")
raw_token = self.engine_client.data_processor.tokenizer.convert_ids_to_tokens(tid)
token_bytes = raw_token.encode("utf-8", errors="replace")
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
if idx == 0:
tokens.append(token_str)
token_logprobs.append(lp)
else:
top_logprobs[token_str] = lp
top_logprobs[token_str] = lp
idx += 1
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
@@ -770,3 +809,86 @@ class OpenAIServingCompletion:
except Exception as e:
api_server_logger.error(f"Error in _build_logprobs_response: {str(e)}, {str(traceback.format_exc())}")
return None
def _build_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
num_prompt_logprobs: int,
):
"""Update with prompt logprobs from worker.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = [
self.engine_client.data_processor.process_logprob_response(token_id)
for token_id in token_ids.flatten().tolist()
]
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the paddle tensors.
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
result: Optional[PromptLogprobs] = [None]
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
result.append(
self._make_logprob_dict(
prompt_logprobs[pos],
token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
num_prompt_logprobs,
)
)
return result
@staticmethod
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: Iterable[str | None],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
dict[token id, Logprob]
"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
}
+2 -2
View File
@@ -18,9 +18,9 @@ import asyncio
import heapq
import random
import time
from multiprocessing.reduction import ForkingPickler
import aiozmq
import msgpack
import zmq
from fastdeploy.engine.args_utils import EngineArgs
@@ -124,7 +124,7 @@ class DealerConnectionManager:
while self.running:
try:
raw_data = await dealer.read()
response = msgpack.unpackb(raw_data[-1])
response = ForkingPickler.loads(raw_data[-1])
_zmq_metrics_stats = ZMQMetricsStats()
_zmq_metrics_stats.msg_recv_total += 1
if "zmq_send_time" in response:
+3 -3
View File
@@ -153,7 +153,7 @@ class ZmqServerBase(ABC):
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
result = ForkingPickler.dumps([result.to_dict()])
return result
def receive_json_once(self, block=False):
@@ -278,12 +278,12 @@ class ZmqServerBase(ABC):
if self.aggregate_send:
result = self.pack_aggregated_data(new_data)
else:
result = msgpack.packb([response.to_dict() for response in new_data])
result = ForkingPickler.dumps([response.to_dict() for response in new_data])
with self.response_token_lock:
_zmq_metrics_stats = ZMQMetricsStats()
try:
self.socket.send_multipart([self.req_dict[req_id], b"", result])
self.socket.send_multipart([self.req_dict[req_id], b"", result], copy=False)
_zmq_metrics_stats.msg_bytes_send_total += len(result)
except Exception as e:
_zmq_metrics_stats.msg_send_failed_total += 1
+1 -1
View File
@@ -291,7 +291,7 @@ class TokenProcessor:
llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}")
if getattr(stream_data, "prompt_logprobs", None) is not None:
try:
result.prompt_logprobs_tensors = stream_data.prompt_logprobs
result.prompt_logprobs = stream_data.prompt_logprobs
except Exception as e:
llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}")
if self.tokens_counter[task_id] == 0:
+16
View File
@@ -52,6 +52,7 @@ from typing_extensions import TypeIs, assert_never
from fastdeploy import envs
from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
from fastdeploy.logger.logger import FastDeployLogger
from fastdeploy.worker.output import PromptLogprobs
T = TypeVar("T")
from typing import Callable, List, Optional
@@ -1073,6 +1074,21 @@ def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T
return _optional_type
def clamp_prompt_logprobs(
prompt_logprobs: PromptLogprobs | None,
) -> PromptLogprobs | None:
if prompt_logprobs is None:
return prompt_logprobs
for logprob_dict in prompt_logprobs:
if logprob_dict is None:
continue
for logprob_values in logprob_dict.values():
if logprob_values.logprob == float("-inf"):
logprob_values.logprob = -9999.0
return prompt_logprobs
def to_numpy(tasks: List[Any]):
"""
Convert PaddlePaddle tensors in multimodal inputs to NumPy arrays.
+3 -1
View File
@@ -30,7 +30,6 @@ class Logprob(NamedTuple):
decoded_token: Optional[str] = None
PromptLogprobs = list[dict[int, Logprob] | None]
# [{token_id, logprob}] for tokens sampled from the top-k
SampleLogprobs = list[dict[int, Logprob]]
@@ -125,6 +124,9 @@ class LogprobsTensors(NamedTuple):
)
PromptLogprobs = LogprobsTensors | list[dict[int, Logprob] | None]
@dataclass
class SamplerOutput:
""" """
+1 -1
View File
@@ -40,7 +40,7 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
opentelemetry-instrumentation-logging>=0.57b0
partial_json_parser
msgspec
einops
+1 -1
View File
@@ -37,5 +37,5 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro 
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
opentelemetry-instrumentation-logging>=0.57b0
partial_json_parser
+1 -1
View File
@@ -37,7 +37,7 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
opentelemetry-instrumentation-logging>=0.57b0
partial_json_parser
msgspec
safetensors==0.7.0rc0
+1 -1
View File
@@ -40,7 +40,7 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
opentelemetry-instrumentation-logging>=0.57b0
partial_json_parser
msgspec
einops
+185 -85
View File
@@ -26,17 +26,28 @@ class TestSamplingParamsVerification(unittest.TestCase):
def test_logprobs_valid_values(self):
"""Test valid logprobs values"""
# Test None value (should pass)
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
# Test None value (should pass in both modes)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
# Test -1 value (should pass)
params = SamplingParams(logprobs=-1)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
# Test 0 value (should pass)
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=-1)
params._verify_args() # Should not raise
# Test 0 value (should pass in both modes based on actual behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
# Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
@@ -44,13 +55,23 @@ class TestSamplingParamsVerification(unittest.TestCase):
params._verify_args() # Should not raise
def test_logprobs_invalid_less_than_minus_one(self):
"""Test logprobs less than -1 should raise ValueError"""
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-2)
params._verify_args()
"""Test logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-2)
params._verify_args()
self.assertIn("logprobs must be greater than -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
self.assertIn("logprobs must be a non-negative value or -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
def test_logprobs_invalid_less_than_zero(self):
"""Test logprobs less than 0 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "0" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-1)
params._verify_args()
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", str(cm.exception))
def test_logprobs_greater_than_20_with_v1_disabled(self):
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
@@ -59,7 +80,7 @@ class TestSamplingParamsVerification(unittest.TestCase):
params = SamplingParams(logprobs=21)
params._verify_args()
self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception))
self.assertEqual("Invalid value for 'top_logprobs': must be between 0 and 20.", str(cm.exception))
def test_logprobs_greater_than_20_with_v1_enabled(self):
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
@@ -74,46 +95,67 @@ class TestSamplingParamsVerification(unittest.TestCase):
def test_prompt_logprobs_valid_values(self):
"""Test valid prompt_logprobs values"""
# Test None value (should pass)
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
# Test None value (should pass in both modes based on actual behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
# Test -1 value (should pass)
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should not raise
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
# Test 0 value (should pass)
params = SamplingParams(prompt_logprobs=0)
params._verify_args() # Should not raise
# Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should not raise
# Test positive values (should pass)
params = SamplingParams(prompt_logprobs=10)
params._verify_args() # Should not raise
# Test 0 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=0)
params._verify_args() # Should not raise
# Test positive values (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=10)
params._verify_args() # Should not raise
def test_prompt_logprobs_invalid_less_than_minus_one(self):
"""Test prompt_logprobs less than -1 should raise ValueError"""
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
"""Test prompt_logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
self.assertIn("prompt_logprobs a must be non-negative value or -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
def test_combined_logprobs_and_prompt_logprobs(self):
"""Test both logprobs and prompt_logprobs together"""
# Test valid combination
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args() # Should not raise
# Test valid combination when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args() # Should not raise
# Test invalid logprobs with valid prompt_logprobs
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
params._verify_args()
# Test invalid logprobs with valid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
params._verify_args()
# Test valid logprobs with invalid prompt_logprobs
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
params._verify_args()
# Test valid logprobs with invalid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
params._verify_args()
# Test prompt_logprobs not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args()
self.assertIn(
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
)
def test_logprobs_boundary_values(self):
"""Test boundary values for logprobs"""
@@ -130,14 +172,16 @@ class TestSamplingParamsVerification(unittest.TestCase):
def test_prompt_logprobs_boundary_values(self):
"""Test boundary values for prompt_logprobs"""
# Test boundary value -1 (should pass)
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should pass
# Test boundary value -1 (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should pass
# Test boundary value just below -1 (should fail)
with self.assertRaises(ValueError):
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
# Test boundary value just below -1 (should fail when FD_USE_GET_SAVE_OUTPUT_V1 is "1")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError):
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
def test_environment_variable_handling(self):
"""Test different environment variable values"""
@@ -167,55 +211,111 @@ class TestSamplingParamsVerification(unittest.TestCase):
if original_value is not None:
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value
# Test prompt_logprobs behavior with different environment variables
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=5)
params._verify_args()
self.assertIn(
"prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception)
)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(prompt_logprobs=5)
params._verify_args() # Should pass
def test_error_message_formatting(self):
"""Test that error messages are properly formatted"""
# Test logprobs error message
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-5)
params._verify_args()
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-5)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("logprobs must be greater than -1", error_msg)
self.assertIn("got -5", error_msg)
error_msg = str(cm.exception)
self.assertIn("logprobs must be a non-negative value or -1", error_msg)
self.assertIn("got -5", error_msg)
# Test prompt_logprobs error message
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-10)
params._verify_args()
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-1)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg)
self.assertIn("got -10", error_msg)
error_msg = str(cm.exception)
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", error_msg)
# Test prompt_logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-10)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs a must be non-negative value or -1", error_msg)
self.assertIn("got -10", error_msg)
# Test prompt_logprobs not supported error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=5)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", error_msg)
def test_post_init_calls_verify_args(self):
"""Test that __post_init__ calls _verify_args"""
# This should call _verify_args internally
params = SamplingParams(logprobs=5, prompt_logprobs=3)
# This should call _verify_args internally when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, prompt_logprobs=3)
# The params should be successfully created without errors
self.assertEqual(params.logprobs, 5)
self.assertEqual(params.prompt_logprobs, 3)
# The params should be successfully created without errors
self.assertEqual(params.logprobs, 5)
self.assertEqual(params.prompt_logprobs, 3)
# Test that invalid values are caught during initialization
with self.assertRaises(ValueError):
SamplingParams(logprobs=-2)
# Test that invalid values are caught during initialization
with self.assertRaises(ValueError):
SamplingParams(logprobs=-2)
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=-2)
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=-2)
# Test that prompt_logprobs is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=3)
# Test that logprobs < 0 is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with self.assertRaises(ValueError):
SamplingParams(logprobs=-1)
def test_logprobs_with_other_parameters(self):
"""Test logprobs validation with other sampling parameters"""
# Test with temperature
params = SamplingParams(logprobs=5, temperature=0.8)
params._verify_args() # Should pass
# Test with temperature when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, temperature=0.8)
params._verify_args() # Should pass
# Test with top_p
params = SamplingParams(logprobs=5, top_p=0.9)
params._verify_args() # Should pass
# Test with top_p when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=5, top_p=0.9)
params._verify_args() # Should pass
# Test with all parameters
params = SamplingParams(logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100)
params._verify_args() # Should pass
# Test with all parameters when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
)
params._verify_args() # Should pass
# Test that prompt_logprobs fails when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
params = SamplingParams(
logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100
)
params._verify_args()
if __name__ == "__main__":
@@ -453,6 +453,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
num_input_image_tokens = [0, 0]
num_input_video_tokens = [0, 0]
num_image_tokens = [0, 0]
prompt_logprobs_res_list = [[], []]
max_tokens_list = [10, 1]
for idx, case in enumerate(test_cases):
@@ -469,6 +470,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
num_image_tokens=num_image_tokens,
logprob_contents=logprob_contents,
draft_logprob_contents=draft_logprob_contents,
prompt_logprobs_res_list=prompt_logprobs_res_list,
response_processor=mock_response_processor,
max_tokens=max_tokens_list[idx],
)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+130 -37
View File
@@ -1,4 +1,5 @@
from unittest.mock import MagicMock
import os
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
@@ -14,56 +15,147 @@ class DummyModelConfig:
self.ori_vocab_size = ori_vocab_size
class DummyCacheConfig:
def __init__(self, enable_prefix_caching=False):
self.enable_prefix_caching = enable_prefix_caching
class DummyLLMEngineConfig:
def __init__(self, model_config=None, cache_config=None):
self.model_config = model_config or DummyModelConfig()
self.cache_config = cache_config or DummyCacheConfig()
class DummyLLMEngine:
def __init__(self, model_config=None, cache_config=None):
self.cfg = DummyLLMEngineConfig(model_config, cache_config)
self.data_processor = MagicMock()
# Mock tokenizer with sp_model attribute
self.data_processor.tokenizer = MagicMock()
self.data_processor.tokenizer.sp_model = MagicMock()
self.data_processor.tokenizer.sp_model.__len__ = MagicMock(return_value=100)
self.data_processor.tokenizer.vocab = MagicMock()
self.data_processor.tokenizer.vocab.__len__ = MagicMock(return_value=100)
self.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
self.add_requests = MagicMock()
@pytest.fixture
def mock_llm():
llm = LLM.__new__(LLM)
llm.llm_engine = MagicMock()
llm.llm_engine.add_requests = MagicMock()
llm.llm_engine.cfg.model_config = DummyModelConfig(max_logprobs=10, ori_vocab_size=100)
# Mock the data_processor.process_logprob_response method to return proper strings
llm.llm_engine.data_processor = MagicMock()
llm.llm_engine.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
llm.llm_engine = DummyLLMEngine()
return llm
@pytest.fixture
def mock_llm_with_prefix_caching():
llm = LLM.__new__(LLM)
llm.llm_engine = DummyLLMEngine(cache_config=DummyCacheConfig(enable_prefix_caching=True))
return llm
def test_prompt_logprobs_not_supported_with_stream(mock_llm):
sampling = SamplingParams(prompt_logprobs=5)
with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
mock_llm._add_request(["hi"], sampling, stream=True)
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(prompt_logprobs=5)
with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
mock_llm._add_request(["hi"], sampling, stream=True)
def test_prompt_logprobs_not_supported_with_prefix_caching(mock_llm_with_prefix_caching):
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(prompt_logprobs=5)
with pytest.raises(ValueError, match="prompt_logprobs is not supported with prefix caching enabled"):
mock_llm_with_prefix_caching._add_request(["hi"], sampling)
def test_num_logprobs_exceeds_max(mock_llm):
sampling = SamplingParams(logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs > 20
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
def test_max_logprobs_exceeds_vocab_size(mock_llm):
# Test case where max_logprobs > ori_vocab_size
mock_llm.llm_engine.cfg.model_config.max_logprobs = 150 # > vocab size (100)
with pytest.raises(ValueError, match="max_logprobs \\(150\\) exceeds vocabulary size \\(100\\)"):
mock_llm._add_request(["hi"], SamplingParams())
def test_max_logprobs_less_than_minus_one(mock_llm):
# Test case where max_logprobs < -1
mock_llm.llm_engine.cfg.model_config.max_logprobs = -2
with pytest.raises(ValueError, match="max_logprobs \\(-2\\) can't be less than -1"):
mock_llm._add_request(["hi"], SamplingParams())
def test_logprobs_minus_one_uses_vocab_size(mock_llm):
# Test that logprobs=-1 uses vocab size
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 # Allow unlimited
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_num_prompt_logprobs_exceeds_max(mock_llm):
sampling = SamplingParams(prompt_logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(prompt_logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm):
sampling = SamplingParams(logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 30
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
# Get the first argument (tasks) which should be a dict
call_args = mock_llm.llm_engine.add_requests.call_args
tasks = call_args[0][0] # First positional argument
assert isinstance(tasks, dict)
assert "prompt" in tasks
assert "request_id" in tasks
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs=-1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
# Get the first argument (tasks) which should be a dict
call_args = mock_llm.llm_engine.add_requests.call_args
tasks = call_args[0][0] # First positional argument
assert isinstance(tasks, dict)
assert "prompt" in tasks
assert "request_id" in tasks
def test_prompt_logprobs_equal_to_minus_one(mock_llm):
sampling = SamplingParams(prompt_logprobs=-1)
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support and allow -1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(prompt_logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_dynamic_vocab_size_from_sp_model(mock_llm):
# Test that ori_vocab_size is dynamically obtained from sp_model
mock_llm.llm_engine.data_processor.tokenizer.sp_model.__len__.return_value = 200
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 25
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=-1)
mock_llm._add_request(["hi"], sampling)
# Should use the dynamic vocab size (200)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_dynamic_vocab_size_from_vocab_fallback(mock_llm):
# Test fallback to vocab when sp_model is not available
del mock_llm.llm_engine.data_processor.tokenizer.sp_model
mock_llm.llm_engine.data_processor.tokenizer.vocab.__len__.return_value = 300
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=-1)
mock_llm._add_request(["hi"], sampling)
# Should use the vocab size (300)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_build_prompt_logprobs_basic(mock_llm):
@@ -77,12 +169,13 @@ def test_build_prompt_logprobs_basic(mock_llm):
# 检查结果格式
assert isinstance(result, list)
assert len(result) == 2
assert len(result) == 3
for pos_dict in result:
assert isinstance(pos_dict, dict)
for logprob_obj in pos_dict.values():
assert isinstance(logprob_obj, Logprob)
assert logprob_obj.decoded_token.startswith("TOKEN_")
if pos_dict is not None:
assert isinstance(pos_dict, dict)
for logprob_obj in pos_dict.values():
assert isinstance(logprob_obj, Logprob)
assert logprob_obj.decoded_token.startswith("TOKEN_")
def test_build_prompt_logprobs_handles_minus_one(mock_llm):
@@ -94,7 +187,7 @@ def test_build_prompt_logprobs_handles_minus_one(mock_llm):
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1)
assert isinstance(result, list)
assert len(result) == 1
pos_dict = result[0]
assert len(result) == 2
pos_dict = result[1]
assert 7 in pos_dict
assert pos_dict[7].decoded_token == "TOKEN_7"
@@ -137,7 +137,7 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.assertIsNone(getattr(result[0], "prompt_logprobs_tensors", None))
self.assertIsNone(getattr(result[0], "prompt_logprobs", None))
def test_process_batch_with_stop_flag(self):
"""Test processing when stop flag is True"""
+133
View File
@@ -0,0 +1,133 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from fastdeploy.utils import clamp_prompt_logprobs
from fastdeploy.worker.output import Logprob
class TestClampPromptLogprobs(unittest.TestCase):
def test_none_input(self):
"""Test case when input is None"""
result = clamp_prompt_logprobs(None)
self.assertIsNone(result)
def test_empty_list(self):
"""Test empty list input"""
result = clamp_prompt_logprobs([])
self.assertEqual(result, [])
def test_normal_logprobs(self):
"""Test normal logprobs values (without -inf)"""
logprob_dict = {
1: Logprob(logprob=-2.5, rank=1, decoded_token="hello"),
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
}
prompt_logprobs = [logprob_dict]
result = clamp_prompt_logprobs(prompt_logprobs)
# Original values should remain unchanged
self.assertEqual(result[0][1].logprob, -2.5)
self.assertEqual(result[0][2].logprob, -1.0)
def test_negative_inf_logprobs_raises_error(self):
"""Test that logprobs containing -inf raises AttributeError"""
logprob_dict = {
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
}
prompt_logprobs = [logprob_dict]
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
with self.assertRaises(AttributeError) as context:
clamp_prompt_logprobs(prompt_logprobs)
self.assertIn("can't set attribute", str(context.exception))
def test_multiple_negative_inf_raises_error(self):
"""Test that multiple -inf logprobs values raise AttributeError"""
logprob_dict = {
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"),
3: Logprob(logprob=-0.5, rank=3, decoded_token="test"),
}
prompt_logprobs = [logprob_dict]
# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
with self.assertRaises(AttributeError):
clamp_prompt_logprobs(prompt_logprobs)
def test_none_dict_in_list(self):
"""Test case when list contains None"""
prompt_logprobs = [None]
result = clamp_prompt_logprobs(prompt_logprobs)
# None should be skipped
self.assertIsNone(result[0])
def test_multiple_dicts_normal_values(self):
"""Test multiple dictionaries case (without -inf)"""
logprob_dict1 = {
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
}
logprob_dict2 = {
2: Logprob(logprob=-2.0, rank=1, decoded_token="world"),
}
prompt_logprobs = [logprob_dict1, logprob_dict2]
result = clamp_prompt_logprobs(prompt_logprobs)
# Should return normally, values remain unchanged
self.assertEqual(result[0][1].logprob, -2.0)
self.assertEqual(result[1][2].logprob, -2.0)
def test_mixed_values_without_inf(self):
"""Test mixed values case (without -inf)"""
logprob_dict = {
1: Logprob(logprob=-9999.0, rank=1, decoded_token="hello"),
2: Logprob(logprob=-9999.0, rank=2, decoded_token="world"),
3: Logprob(logprob=0.0, rank=3, decoded_token="test"),
4: Logprob(logprob=-1.5, rank=4, decoded_token="again"),
}
prompt_logprobs = [logprob_dict]
result = clamp_prompt_logprobs(prompt_logprobs)
# All values should remain unchanged
self.assertEqual(result[0][1].logprob, -9999.0)
self.assertEqual(result[0][2].logprob, -9999.0)
self.assertEqual(result[0][3].logprob, 0.0)
self.assertEqual(result[0][4].logprob, -1.5)
def test_return_same_object(self):
"""Test that function returns the same object (in-place modification attempt)"""
logprob_dict = {
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
}
prompt_logprobs = [logprob_dict]
result = clamp_prompt_logprobs(prompt_logprobs)
# Should return the same object (function attempts in-place modification)
self.assertIs(result, prompt_logprobs)
self.assertIs(result[0], prompt_logprobs[0])
if __name__ == "__main__":
unittest.main()
+38 -35
View File
@@ -16,6 +16,7 @@ import json
import os
import time
import unittest
from unittest.mock import patch
import numpy as np
import paddle
@@ -163,44 +164,46 @@ class TestGPUPromptLogprobs(unittest.TestCase):
return model_runner
def test_prompt_logprobs(self):
model_runner = self.setup_model_runner()
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
model_runner = self.setup_model_runner()
req: Request = Request(
prompt=None,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
arrival_time=None,
request_id="asd1",
prompt_token_ids=[1, 2, 3, 4],
prompt_token_ids_len=4,
prefill_start_index=0,
prefill_end_index=4,
sampling_params=SamplingParams(prompt_logprobs=-1),
)
req.idx = 0
model_runner.prompt_logprobs_reqs = {req.request_id: req}
req: Request = Request(
prompt=None,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
arrival_time=None,
request_id="asd1",
prompt_token_ids=[1, 2, 3, 4],
prompt_token_ids_len=4,
prefill_start_index=0,
prefill_end_index=4,
sampling_params=SamplingParams(prompt_logprobs=-1),
)
req.idx = 0
model_runner.prompt_logprobs_reqs = {req.request_id: req}
hidden_states = paddle.rand(
[len(req.prompt_token_ids) - 1, model_runner.fd_config.model_config.hidden_size], dtype="bfloat16"
)
ref_logits = model_runner.model.compute_logits(hidden_states)
ref_raw_logprobs = model_runner.sampler.compute_logprobs(ref_logits)
token_is = paddle.to_tensor(req.prompt_token_ids[1:], dtype="int64")
hidden_states = paddle.rand(
[len(req.prompt_token_ids) - 1, model_runner.fd_config.model_config.hidden_size], dtype="bfloat16"
)
ref_logits = model_runner.model.compute_logits(hidden_states)
ref_raw_logprobs = model_runner.sampler.compute_logprobs(ref_logits)
token_is = paddle.to_tensor(req.prompt_token_ids[1:], dtype="int64")
ref_token_ids, ref_logprobs, ref_ranks = model_runner.sampler.gather_logprobs(
ref_raw_logprobs, model_runner.fd_config.model_config.ori_vocab_size, token_is
)
prompt_logprobs = model_runner._get_prompt_logprobs_list(hidden_states)[0]
np.testing.assert_allclose(ref_logprobs.numpy(), prompt_logprobs.logprobs.numpy(), rtol=1e-04, atol=1e-04)
np.testing.assert_allclose(
ref_token_ids.numpy(), prompt_logprobs.logprob_token_ids.numpy(), rtol=1e-04, atol=1e-04
)
np.testing.assert_allclose(
ref_ranks.numpy(), prompt_logprobs.selected_token_ranks.numpy(), rtol=1e-04, atol=1e-04
)
ref_token_ids, ref_logprobs, ref_ranks = model_runner.sampler.gather_logprobs(
ref_raw_logprobs, model_runner.fd_config.model_config.ori_vocab_size, token_is
)
prompt_logprobs = model_runner._get_prompt_logprobs_list(hidden_states)[0]
np.testing.assert_allclose(ref_logprobs.numpy(), prompt_logprobs.logprobs.numpy(), rtol=1e-04, atol=1e-04)
np.testing.assert_allclose(
ref_token_ids.numpy(), prompt_logprobs.logprob_token_ids.numpy(), rtol=1e-04, atol=1e-04
)
np.testing.assert_allclose(
ref_ranks.numpy(), prompt_logprobs.selected_token_ranks.numpy(), rtol=1e-04, atol=1e-04
)
if __name__ == "__main__":