mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support logits processors (#4515)
* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor * [chore] fix code style * [fix] add unit test & fix existing bugs * [feat] add engine/worker arg --logits-processors * [fix] redefine user args as logits_processors_args and fix some bugs * [fix] fix test_sampler * Update fastdeploy/model_executor/logits_processor/builtin.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/logits_processor/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/model_executor/test_logits_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix typo * Update fastdeploy/engine/sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix bracelet * [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits * [doc] add docs * [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests * [fix] fix redundant code * [feat] skip apply() if no bias is specified --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -970,6 +970,9 @@ class PlasAttentionConfig:
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class EarlyStopConfig:
|
||||
def __init__(
|
||||
@@ -1071,6 +1074,9 @@ class LoadConfig:
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
@@ -1339,11 +1345,15 @@ class StructuredOutputsConfig:
|
||||
self.guided_decoding_backend: Optional[str] = None
|
||||
# disable any whitespace for guided decoding
|
||||
self.disable_any_whitespace: bool = True
|
||||
self.logits_processors: Optional[list[str]] = None
|
||||
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key) and value != "None":
|
||||
setattr(self, key, value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class FDConfig:
|
||||
"""
|
||||
|
||||
@@ -422,6 +422,16 @@ class EngineArgs:
|
||||
Flag to specify the dtype of lm_head as FP32. Default is False (Using model default dtype).
|
||||
"""
|
||||
|
||||
logits_processors: Optional[List[str]] = None
|
||||
"""
|
||||
A list of FQCNs (Fully Qualified Class Names) of logits processors supported by the service.
|
||||
A fully qualified class name (FQCN) is a string that uniquely identifies a class within a Python module.
|
||||
|
||||
- To enable builtin logits processors, add builtin module paths and class names to the list. Currently support:
|
||||
- fastdeploy.model_executor.logits_processor:LogitBiasLogitsProcessor
|
||||
- To enable custom logits processors, add your dotted paths to module and class names to the list.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -687,6 +697,13 @@ class EngineArgs:
|
||||
default=EngineArgs.lm_head_fp32,
|
||||
help="Specify the dtype of lm_head weight as float32.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--logits-processors",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=EngineArgs.logits_processors,
|
||||
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
|
||||
@@ -535,6 +535,8 @@ class LLMEngine:
|
||||
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
|
||||
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
|
||||
)
|
||||
if self.cfg.structured_outputs_config.logits_processors is not None:
|
||||
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
|
||||
|
||||
worker_append_flag = {
|
||||
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
|
||||
|
||||
@@ -103,6 +103,7 @@ class SamplingParams:
|
||||
bad_words: Optional[List[str]] = None
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
logits_processors_args: Optional[dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
|
||||
@@ -136,6 +137,7 @@ class SamplingParams:
|
||||
bad_words=None,
|
||||
guided_decoding=None,
|
||||
bad_words_token_ids=None,
|
||||
logits_processors_args=None,
|
||||
) -> SamplingParams:
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(
|
||||
@@ -158,6 +160,7 @@ class SamplingParams:
|
||||
bad_words=bad_words,
|
||||
guided_decoding=guided_decoding,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
logits_processors_args=logits_processors_args,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -208,6 +211,24 @@ class SamplingParams:
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
# Verify logits processors arguments
|
||||
if self.logits_processors_args is not None:
|
||||
if self.logits_processors_args.get("logit_bias") is not None:
|
||||
logit_bias = self.logits_processors_args.get("logit_bias")
|
||||
if not isinstance(logit_bias, dict):
|
||||
raise TypeError(f"logit_bias must be a dict, but got {type(logit_bias)}")
|
||||
elif not all(isinstance(k, int) and isinstance(v, float) for k, v in logit_bias.items()):
|
||||
# try to cast the dict to the correct type first
|
||||
try:
|
||||
cast_logit_bias = {}
|
||||
for k, v in logit_bias.items():
|
||||
cast_logit_bias[int(k)] = float(v)
|
||||
self.logits_processors_args["logit_bias"] = cast_logit_bias
|
||||
except:
|
||||
raise TypeError(
|
||||
"failed to cast logit_bias to the correct {key -> value} type, expected {int -> float}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchParams:
|
||||
|
||||
@@ -266,14 +266,11 @@ class ResourceManagerV1(ResourceManager):
|
||||
del self.req_dict[preempted_req.request_id]
|
||||
self._free_blocks(preempted_req)
|
||||
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
else:
|
||||
self._free_blocks(preempted_req)
|
||||
preempted_req.cached_block_num = 0
|
||||
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
||||
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
|
||||
main_process_metrics.num_requests_waiting.inc(1)
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
preempted_reqs.append(preempted_req)
|
||||
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
||||
|
||||
@@ -651,8 +648,6 @@ class ResourceManagerV1(ResourceManager):
|
||||
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||
)
|
||||
request.status = RequestStatus.RUNNING
|
||||
main_process_metrics.num_requests_waiting.dec(1)
|
||||
main_process_metrics.num_requests_running.inc(1)
|
||||
if self.config.scheduler_config.splitwise_role == "mixed":
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
|
||||
@@ -461,6 +461,7 @@ class CompletionRequest(BaseModel):
|
||||
include_stop_str_in_output: Optional[bool] = False
|
||||
bad_words: Optional[List[str]] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
logits_processors_args: Optional[Dict] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: start-completion-extra-params
|
||||
@@ -613,6 +614,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
logits_processors_args: Optional[Dict] = None
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
# doc: start-chat-completion-extra-params
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import Dict, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.logits_processor import LogitsProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
@@ -54,6 +56,7 @@ class SamplingMetadata:
|
||||
temp_scaled_logprobs: Optional[paddle.Tensor] = None
|
||||
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
|
||||
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
|
||||
logits_processors: Optional[list[LogitsProcessor]] = None
|
||||
# Add for HPU post-processing
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
|
||||
@@ -53,9 +53,9 @@ def top_p_normalize_probs_paddle(
|
||||
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
|
||||
|
||||
|
||||
class SamplerProcessor:
|
||||
class GuidedDecoding:
|
||||
"""
|
||||
SamplingProcessor for guided decoding.
|
||||
processor for guided decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -75,7 +75,7 @@ class SamplerProcessor:
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = [],
|
||||
):
|
||||
"""add logits processor to SamplerProcessor"""
|
||||
"""add logits processor to GuidedDecoding"""
|
||||
with self.logits_lock:
|
||||
if future is None:
|
||||
if ids in self.logits_processor:
|
||||
@@ -216,7 +216,7 @@ class Sampler(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.processor = SamplerProcessor()
|
||||
self.guided_decoding = GuidedDecoding()
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
@@ -230,19 +230,19 @@ class Sampler(nn.Layer):
|
||||
|
||||
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||
"""set reasoning parser"""
|
||||
self.processor.apply_reasoning_parser(reasoning_parser)
|
||||
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
|
||||
|
||||
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
|
||||
"""apply logits processor to sampler"""
|
||||
self.processor.add_logits_processor(ids, future, prefill_tokens)
|
||||
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
self.processor.pre_process(skip_idx_list)
|
||||
self.guided_decoding.pre_process(skip_idx_list)
|
||||
|
||||
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
"""post process after running"""
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
@@ -332,7 +332,7 @@ class Sampler(nn.Layer):
|
||||
skip_idx_list: List[int] = [],
|
||||
) -> SamplerOutput:
|
||||
""" """
|
||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
|
||||
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
@@ -341,6 +341,9 @@ class Sampler(nn.Layer):
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
for proc in sampling_metadata.logits_processors or []:
|
||||
logits = proc.apply(logits)
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
sampling_metadata.prompt_ids,
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
from .base import LogitsProcessor
|
||||
from .builtin import LogitBiasLogitsProcessor
|
||||
|
||||
|
||||
def load_class(spec: str):
|
||||
"""
|
||||
Load a class from a string spec.
|
||||
|
||||
If the spec is in the form 'package.module:ClassName', loads ClassName from the specified module.
|
||||
If the spec does not contain a colon, it is treated as the name of a builtin class from
|
||||
'fastdeploy.model_executor.logits_processor'.
|
||||
|
||||
Args:
|
||||
spec (str): The class specifier string.
|
||||
|
||||
Returns:
|
||||
type: The loaded class object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the spec is invalid.
|
||||
ImportError: If the module cannot be imported.
|
||||
AttributeError: If the class cannot be found in the module.
|
||||
"""
|
||||
try:
|
||||
if ":" in spec:
|
||||
module_path, class_name = spec.split(":", 1)
|
||||
else:
|
||||
module_path = "fastdeploy.model_executor.logits_processor"
|
||||
class_name = spec
|
||||
module = import_module(module_path)
|
||||
obj = getattr(module, class_name)
|
||||
return obj
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid spec {spec!r}; expected 'module:ClassName'.") from e
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import module {module_path}") from e
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"Module {module_path} does not have attribute {class_name}") from e
|
||||
|
||||
|
||||
def build_logits_processors(fd_config):
|
||||
logit_procs = []
|
||||
for fqcn in fd_config.structured_outputs_config.logits_processors or []:
|
||||
logit_procs.append(load_class(fqcn)(fd_config))
|
||||
return logit_procs
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_logits_processors",
|
||||
"LogitsProcessor",
|
||||
"LogitBiasLogitsProcessor",
|
||||
]
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from paddle import Tensor
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, fd_config: FDConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(self, share_inputs: dict) -> None:
|
||||
"""Called when there are new output tokens, prior to each forward pass.
|
||||
|
||||
Each field in the `share_inputs` dict typically stores information for all request
|
||||
slots. It has a `stop_flags` array that indicates whether a slot currently has a
|
||||
running request (`False` means the slot is active). Therefore, it is recommended to
|
||||
filter entries by `stop_flags` to keep only data for the current batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: Tensor) -> Tensor:
|
||||
"""Apply LogitsProcessor to batch logits tensor.
|
||||
|
||||
The updated tensor must be returned but may be modified in-place.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.logits_processor.base import LogitsProcessor
|
||||
|
||||
|
||||
class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
Maintains per-request logit biases and applies them to logits.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
self.device = paddle.device.get_device()
|
||||
self.dtype = fd_config.model_config.dtype
|
||||
self.batch_ids: list[int] = []
|
||||
self.token_ids: list[int] = []
|
||||
self.biases: list[float] = []
|
||||
|
||||
def update_state(self, share_inputs: dict):
|
||||
"""Build per-step logit-bias state from request slots and move it to device."""
|
||||
|
||||
# Retrive inference states from share_inputs
|
||||
stop_flags = share_inputs["stop_flags"]
|
||||
logits_processors_args = share_inputs["logits_processors_args"]
|
||||
logits_processors_args = [a for a, f in zip(logits_processors_args, stop_flags) if not f]
|
||||
|
||||
# Get bias states for each request
|
||||
self.batch_ids = []
|
||||
self.token_ids: list[int] = []
|
||||
self.biases: list[float] = []
|
||||
for batch_id, logit_proc_args in enumerate(logits_processors_args):
|
||||
tok_id_bias_map = logit_proc_args.get("logit_bias") or {}
|
||||
self.batch_ids.extend([batch_id] * len(tok_id_bias_map))
|
||||
self.token_ids.extend(tok_id_bias_map.keys())
|
||||
self.biases.extend(tok_id_bias_map.values())
|
||||
|
||||
return
|
||||
|
||||
def apply(self, logits: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Apply logit bias to logits: [batch_size, vocab_size]"""
|
||||
# Skip if no bias is applied
|
||||
if len(self.biases) == 0:
|
||||
return logits
|
||||
|
||||
# Make bias indices and bias tensor
|
||||
bias_indices = (
|
||||
paddle.tensor(self.batch_ids, dtype="int32").to(self.device),
|
||||
paddle.tensor(self.token_ids, dtype="int32").to(self.device),
|
||||
)
|
||||
bias_tensor = paddle.tensor(self.biases, device=self.device, dtype=self.dtype)
|
||||
logits[bias_indices] += bias_tensor
|
||||
return logits
|
||||
@@ -87,6 +87,7 @@ from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
|
||||
from fastdeploy.model_executor.logits_processor import build_logits_processors
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
|
||||
from fastdeploy.output.pooler import PoolerOutput
|
||||
@@ -624,6 +625,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
# For logits processors
|
||||
self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {}
|
||||
|
||||
if len(multi_vision_inputs["images_lst"]) > 0:
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)
|
||||
|
||||
@@ -1188,6 +1192,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
# For logits processors
|
||||
self.share_inputs["logits_processors"] = build_logits_processors(self.fd_config)
|
||||
self.share_inputs["logits_processors_args"] = [{} for _ in range(max_num_seqs)]
|
||||
logger.info(f"Enabled logits processors: {self.share_inputs['logits_processors']}")
|
||||
|
||||
def _prepare_inputs(self) -> None:
|
||||
"""Prepare the model inputs"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
@@ -1265,6 +1274,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
stop_flags=self.share_inputs["stop_flags"],
|
||||
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
|
||||
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
|
||||
logits_processors=self.share_inputs["logits_processors"],
|
||||
share_inputs=self.share_inputs,
|
||||
)
|
||||
|
||||
@@ -1993,6 +2003,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._prepare_inputs()
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
|
||||
# 1.1 Update state of logits processor
|
||||
for proc in self.sampling_metadata.logits_processors:
|
||||
proc.update_state(self.share_inputs)
|
||||
|
||||
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||
|
||||
@@ -686,6 +686,14 @@ def parse_args():
|
||||
help="Override configuration for the pooler.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--logits-processors",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="FQCNs (Fully Qualified Class Names) of logits processors supported by the service.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
Reference in New Issue
Block a user