mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
Support keep sampling mask (#6725)
* naive version * return list(int) * fix bug: first_token's sampling mask miss * pre-commit * support mtp * pre-commit * fix ut * fix zmq name conflits * fix ut * add ut * fix ut timeout * optimize performance * fix * support top_k mask * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * update comment * update comment * update comment --------- Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
@@ -93,6 +94,98 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le
|
||||
return top_p_padding, top_k_padding, topp_seed
|
||||
|
||||
|
||||
def _compute_sampling_mask(
|
||||
probs: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: Optional[paddle.Tensor] = None,
|
||||
top_k_list: Optional[list] = None,
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute a combined top-k + top-p (nucleus) sampling mask as sparse
|
||||
retained-token indices.
|
||||
|
||||
Processing order:
|
||||
1. Sort probs descending once (shared by top-k and top-p stages).
|
||||
2. top-k mask — zero out positions beyond top_k[i] in sorted order.
|
||||
3. top-k renorm — renormalise in-place after truncation.
|
||||
4. top-p mask — cumsum on the already-sorted renormed probs; no
|
||||
second argsort needed.
|
||||
5. intersect — AND of the two masks, applied on GPU before D2H.
|
||||
|
||||
Either filter can be disabled:
|
||||
- top-k is skipped when top_k_list is None or all values <= 0.
|
||||
- top-p[i] >= 1.0 → keep all tokens for that request.
|
||||
|
||||
Args:
|
||||
probs: [num_reqs, vocab_size] softmax probabilities (GPU).
|
||||
top_p: [num_reqs, 1] top-p threshold per request (GPU).
|
||||
top_k: [num_reqs, 1] top-k per request (GPU, int); 0 = disabled.
|
||||
top_k_list: Python list of top-k values; used to decide whether any
|
||||
top-k filtering is needed at all.
|
||||
|
||||
Returns:
|
||||
List of length num_reqs; element i is a 1-D int64 numpy array of the
|
||||
retained vocab indices for request i.
|
||||
"""
|
||||
real_bsz = probs.shape[0]
|
||||
vocab_size = probs.shape[1]
|
||||
top_p = top_p[:real_bsz] # [B, 1]
|
||||
|
||||
has_top_k = top_k is not None and top_k_list and any(x > 0 for x in top_k_list)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stage 1: single sort — descending by probability.
|
||||
# sorted_indices / sorted_probs are reused by both top-k and top-p.
|
||||
# ------------------------------------------------------------------
|
||||
sorted_indices = paddle.argsort(probs, axis=-1, descending=True) # [B, V]
|
||||
sorted_probs = paddle.take_along_axis(probs, sorted_indices, axis=-1) # [B, V]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stage 2: top-k mask (GPU, no D2H)
|
||||
# ------------------------------------------------------------------
|
||||
if has_top_k:
|
||||
top_k = top_k[:real_bsz] # [B, 1]
|
||||
# col_idx[0, j] == j; compare against per-row top_k threshold.
|
||||
col_idx = paddle.arange(vocab_size, dtype=top_k.dtype).unsqueeze(0) # [1, V]
|
||||
# top_k == 0 means "disabled" → keep all columns for that row.
|
||||
effective_k = paddle.where(top_k > 0, top_k, paddle.full_like(top_k, vocab_size))
|
||||
topk_mask = col_idx < effective_k # [B, V], True = inside top-k
|
||||
|
||||
# Zero out tail, then renorm row-wise.
|
||||
masked_sorted_probs = paddle.where(topk_mask, sorted_probs, paddle.zeros_like(sorted_probs))
|
||||
row_sums = masked_sorted_probs.sum(axis=-1, keepdim=True).clip(min=1e-9)
|
||||
renorm_sorted_probs = masked_sorted_probs / row_sums # [B, V]
|
||||
else:
|
||||
topk_mask = None
|
||||
renorm_sorted_probs = sorted_probs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stage 3: top-p mask on already-sorted renormed probs (no re-sort).
|
||||
# ------------------------------------------------------------------
|
||||
cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V]
|
||||
topp_mask = (cum_probs - renorm_sorted_probs) < top_p # [B, V]
|
||||
# When top_p[i] >= 1.0, keep the entire row.
|
||||
topp_mask = paddle.where(
|
||||
(top_p >= 1.0).expand_as(topp_mask),
|
||||
paddle.ones_like(topp_mask),
|
||||
topp_mask,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stage 4: intersect on GPU, then minimal D2H.
|
||||
# ------------------------------------------------------------------
|
||||
final_mask = topk_mask & topp_mask if has_top_k else topp_mask # [B, V]
|
||||
|
||||
k_per_row = final_mask.astype("int32").sum(axis=-1) # [B]
|
||||
max_k = int(k_per_row.max().item())
|
||||
|
||||
# Transfer only the leading max_k columns — typically max_k << vocab_size.
|
||||
indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k]
|
||||
mask_window_cpu = final_mask[:, :max_k].cpu().numpy() # [B, max_k]
|
||||
|
||||
return [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)]
|
||||
|
||||
|
||||
class GuidedDecoding:
|
||||
"""
|
||||
processor for guided decoding.
|
||||
@@ -525,6 +618,18 @@ class Sampler(nn.Layer):
|
||||
probs = F.softmax(logits)
|
||||
|
||||
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
|
||||
|
||||
# Compute sampling mask BEFORE top_k_top_p_sampling modifies probs.
|
||||
# Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated.
|
||||
sampling_mask = None
|
||||
if sampling_metadata.keep_sampling_mask:
|
||||
sampling_mask = _compute_sampling_mask(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
top_k=sampling_metadata.top_k,
|
||||
top_k_list=sampling_metadata.top_k_list,
|
||||
)
|
||||
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
@@ -548,6 +653,7 @@ class Sampler(nn.Layer):
|
||||
sampled_token_ids=next_tokens,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
logits=logits,
|
||||
sampling_mask=sampling_mask,
|
||||
)
|
||||
|
||||
return sampler_output
|
||||
@@ -805,8 +911,14 @@ class SpeculativeSampler(nn.Layer):
|
||||
)
|
||||
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
keep_sampling_mask = sampling_metadata.keep_sampling_mask
|
||||
# Compute batch indexing offsets needed by both logprobs and sampling_mask.
|
||||
batch_token_num = None
|
||||
if num_logprobs is not None:
|
||||
ori_cu_batch_token_offset = None
|
||||
cu_batch_token_offset = None
|
||||
real_bsz = None
|
||||
accept_nums = None
|
||||
if num_logprobs is not None or keep_sampling_mask:
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
batch_token_num = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
|
||||
@@ -817,13 +929,14 @@ class SpeculativeSampler(nn.Layer):
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat(
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
|
||||
).astype("int32")
|
||||
accept_nums = share_inputs["accept_num"][:real_bsz].reshape([-1])
|
||||
cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_nums)]).astype("int32")
|
||||
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
target_logits = paddle.empty(
|
||||
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
|
||||
)
|
||||
|
||||
# Extract target logits/probs at accepted positions (shared by logprobs and sampling_mask).
|
||||
# When both are enabled, reuse target_logits to derive target_probs (avoid a second kernel call).
|
||||
total_accepted = int(accept_nums.sum().item())
|
||||
target_logits = paddle.empty([total_accepted, logits.shape[1]], dtype=logits.dtype)
|
||||
speculate_get_target_logits(
|
||||
target_logits,
|
||||
logits,
|
||||
@@ -833,6 +946,12 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["accept_num"],
|
||||
)
|
||||
if keep_sampling_mask:
|
||||
# Derive target probs from already-extracted target_logits; avoids a second kernel call.
|
||||
target_probs = F.softmax(target_logits, axis=-1)
|
||||
|
||||
raw_logprobs = None
|
||||
if num_logprobs is not None:
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
@@ -846,12 +965,35 @@ class SpeculativeSampler(nn.Layer):
|
||||
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
# Compute sampling mask at accepted token positions.
|
||||
# Shape: [total_accepted_tokens, vocab_size], bool (CPU).
|
||||
sampling_mask = None
|
||||
if keep_sampling_mask:
|
||||
# Expand top_p from [batch, 1] to [total_accepted, 1].
|
||||
accept_top_p = sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1)
|
||||
accept_top_k = None
|
||||
if (
|
||||
sampling_metadata.top_k is not None
|
||||
and sampling_metadata.top_k_list
|
||||
and any(x > 0 for x in sampling_metadata.top_k_list)
|
||||
):
|
||||
accept_top_k = (
|
||||
sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1)
|
||||
)
|
||||
sampling_mask = _compute_sampling_mask(
|
||||
target_probs,
|
||||
accept_top_p,
|
||||
top_k=accept_top_k,
|
||||
top_k_list=sampling_metadata.top_k_list,
|
||||
)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=share_inputs["accept_tokens"],
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
logits=logits,
|
||||
sampling_mask=sampling_mask,
|
||||
)
|
||||
|
||||
return sampler_output
|
||||
|
||||
Reference in New Issue
Block a user