[Feature] Support logprobs_mode (#4567)

This commit is contained in:
chen
2025-10-27 14:27:48 +08:00
committed by GitHub
parent acd331780c
commit 5c63a089f6
9 changed files with 130 additions and 5 deletions
@@ -199,7 +199,7 @@ class Sampler(nn.Layer):
Sampler for normal generation.
"""
def __init__(self, fd_config: FDConfig = None):
def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprobs"):
""" """
super().__init__()
if (
@@ -217,6 +217,7 @@ class Sampler(nn.Layer):
raise NotImplementedError
self.processor = SamplerProcessor()
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
fd_config is not None
@@ -335,7 +336,10 @@ class Sampler(nn.Layer):
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
@@ -352,6 +356,12 @@ class Sampler(nn.Layer):
sampling_metadata.eos_token_ids,
)
if num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
elif self.logprobs_mode == "processed_logits":
raw_logprobs = logits.clone()
probs = F.softmax(logits)
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
@@ -437,6 +447,7 @@ class SpeculativeSampler(nn.Layer):
self.forward = self.forward_cuda
else:
raise NotImplementedError
self.logprobs_mode = fd_config.model_config.logprobs_mode
self.speculative_verify_window = fd_config.speculative_config.verify_window
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
@@ -644,7 +655,10 @@ class SpeculativeSampler(nn.Layer):
share_inputs["seq_lens_encoder"],
share_inputs["accept_num"],
)
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = target_logtis.clone()
logprobs_tensors = None
token_ids = share_inputs["accept_tokens"]
@@ -677,6 +691,7 @@ class MTPSampler(nn.Layer):
self.forward = self.forward_cuda
else:
raise NotImplementedError
self.logprobs_mode = fd_config.model_config.logprobs_mode
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
@@ -808,7 +823,12 @@ class MTPSampler(nn.Layer):
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
if num_logprobs is not None and share_inputs["substep"] == 0:
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"][:real_token_num, :], sampling_metadata)
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(
share_inputs["draft_logits"][:real_token_num, :], sampling_metadata
)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = share_inputs["draft_logits"][:real_token_num, :].clone()
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,