mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Feature] Support logprobs_mode (#4567)
This commit is contained in:
@@ -199,7 +199,7 @@ class Sampler(nn.Layer):
|
||||
Sampler for normal generation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig = None):
|
||||
def __init__(self, fd_config: FDConfig = None, logprobs_mode: str = "raw_logprobs"):
|
||||
""" """
|
||||
super().__init__()
|
||||
if (
|
||||
@@ -217,6 +217,7 @@ class Sampler(nn.Layer):
|
||||
raise NotImplementedError
|
||||
|
||||
self.processor = SamplerProcessor()
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
fd_config is not None
|
||||
@@ -335,7 +336,10 @@ class Sampler(nn.Layer):
|
||||
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
@@ -352,6 +356,12 @@ class Sampler(nn.Layer):
|
||||
sampling_metadata.eos_token_ids,
|
||||
)
|
||||
|
||||
if num_logprobs is not None:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
|
||||
elif self.logprobs_mode == "processed_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
|
||||
@@ -437,6 +447,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
||||
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
||||
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
|
||||
@@ -644,7 +655,10 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["accept_num"],
|
||||
)
|
||||
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = target_logtis.clone()
|
||||
|
||||
logprobs_tensors = None
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
@@ -677,6 +691,7 @@ class MTPSampler(nn.Layer):
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
@@ -808,7 +823,12 @@ class MTPSampler(nn.Layer):
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
|
||||
raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"][:real_token_num, :], sampling_metadata)
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(
|
||||
share_inputs["draft_logits"][:real_token_num, :], sampling_metadata
|
||||
)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = share_inputs["draft_logits"][:real_token_num, :].clone()
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
|
||||
Reference in New Issue
Block a user