Support keep sampling mask (#6725)

* naive version

* return list(int)

* fix bug: first_token's sampling mask miss

* pre-commit

* support mtp

* pre-commit

* fix ut

* fix zmq name conflits

* fix ut

* add ut

* fix ut timeout

* optimize performance

* fix

* support top_k mask

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* update comment

* update comment

* update comment

---------

Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yuanle Liu
2026-03-18 11:07:31 +08:00
committed by GitHub
parent a714c1f8d4
commit 7f5f2113c2
24 changed files with 502 additions and 9 deletions
@@ -19,6 +19,7 @@ import time
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, List, Optional
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
@@ -93,6 +94,98 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le
return top_p_padding, top_k_padding, topp_seed
def _compute_sampling_mask(
probs: paddle.Tensor,
top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,
top_k_list: Optional[list] = None,
) -> List[np.ndarray]:
"""
Compute a combined top-k + top-p (nucleus) sampling mask as sparse
retained-token indices.
Processing order:
1. Sort probs descending once (shared by top-k and top-p stages).
2. top-k mask — zero out positions beyond top_k[i] in sorted order.
3. top-k renorm — renormalise in-place after truncation.
4. top-p mask — cumsum on the already-sorted renormed probs; no
second argsort needed.
5. intersect — AND of the two masks, applied on GPU before D2H.
Either filter can be disabled:
- top-k is skipped when top_k_list is None or all values <= 0.
- top-p[i] >= 1.0 → keep all tokens for that request.
Args:
probs: [num_reqs, vocab_size] softmax probabilities (GPU).
top_p: [num_reqs, 1] top-p threshold per request (GPU).
top_k: [num_reqs, 1] top-k per request (GPU, int); 0 = disabled.
top_k_list: Python list of top-k values; used to decide whether any
top-k filtering is needed at all.
Returns:
List of length num_reqs; element i is a 1-D int64 numpy array of the
retained vocab indices for request i.
"""
real_bsz = probs.shape[0]
vocab_size = probs.shape[1]
top_p = top_p[:real_bsz] # [B, 1]
has_top_k = top_k is not None and top_k_list and any(x > 0 for x in top_k_list)
# ------------------------------------------------------------------
# Stage 1: single sort — descending by probability.
# sorted_indices / sorted_probs are reused by both top-k and top-p.
# ------------------------------------------------------------------
sorted_indices = paddle.argsort(probs, axis=-1, descending=True) # [B, V]
sorted_probs = paddle.take_along_axis(probs, sorted_indices, axis=-1) # [B, V]
# ------------------------------------------------------------------
# Stage 2: top-k mask (GPU, no D2H)
# ------------------------------------------------------------------
if has_top_k:
top_k = top_k[:real_bsz] # [B, 1]
# col_idx[0, j] == j; compare against per-row top_k threshold.
col_idx = paddle.arange(vocab_size, dtype=top_k.dtype).unsqueeze(0) # [1, V]
# top_k == 0 means "disabled" → keep all columns for that row.
effective_k = paddle.where(top_k > 0, top_k, paddle.full_like(top_k, vocab_size))
topk_mask = col_idx < effective_k # [B, V], True = inside top-k
# Zero out tail, then renorm row-wise.
masked_sorted_probs = paddle.where(topk_mask, sorted_probs, paddle.zeros_like(sorted_probs))
row_sums = masked_sorted_probs.sum(axis=-1, keepdim=True).clip(min=1e-9)
renorm_sorted_probs = masked_sorted_probs / row_sums # [B, V]
else:
topk_mask = None
renorm_sorted_probs = sorted_probs
# ------------------------------------------------------------------
# Stage 3: top-p mask on already-sorted renormed probs (no re-sort).
# ------------------------------------------------------------------
cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V]
topp_mask = (cum_probs - renorm_sorted_probs) < top_p # [B, V]
# When top_p[i] >= 1.0, keep the entire row.
topp_mask = paddle.where(
(top_p >= 1.0).expand_as(topp_mask),
paddle.ones_like(topp_mask),
topp_mask,
)
# ------------------------------------------------------------------
# Stage 4: intersect on GPU, then minimal D2H.
# ------------------------------------------------------------------
final_mask = topk_mask & topp_mask if has_top_k else topp_mask # [B, V]
k_per_row = final_mask.astype("int32").sum(axis=-1) # [B]
max_k = int(k_per_row.max().item())
# Transfer only the leading max_k columns — typically max_k << vocab_size.
indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k]
mask_window_cpu = final_mask[:, :max_k].cpu().numpy() # [B, max_k]
return [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)]
class GuidedDecoding:
"""
processor for guided decoding.
@@ -525,6 +618,18 @@ class Sampler(nn.Layer):
probs = F.softmax(logits)
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
# Compute sampling mask BEFORE top_k_top_p_sampling modifies probs.
# Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated.
sampling_mask = None
if sampling_metadata.keep_sampling_mask:
sampling_mask = _compute_sampling_mask(
probs,
sampling_metadata.top_p,
top_k=sampling_metadata.top_k,
top_k_list=sampling_metadata.top_k_list,
)
_, next_tokens = top_k_top_p_sampling(
probs,
sampling_metadata.top_p,
@@ -548,6 +653,7 @@ class Sampler(nn.Layer):
sampled_token_ids=next_tokens,
logprobs_tensors=logprobs_tensors,
logits=logits,
sampling_mask=sampling_mask,
)
return sampler_output
@@ -805,8 +911,14 @@ class SpeculativeSampler(nn.Layer):
)
num_logprobs = sampling_metadata.max_num_logprobs
keep_sampling_mask = sampling_metadata.keep_sampling_mask
# Compute batch indexing offsets needed by both logprobs and sampling_mask.
batch_token_num = None
if num_logprobs is not None:
ori_cu_batch_token_offset = None
cu_batch_token_offset = None
real_bsz = None
accept_nums = None
if num_logprobs is not None or keep_sampling_mask:
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
@@ -817,13 +929,14 @@ class SpeculativeSampler(nn.Layer):
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
).astype("int32")
accept_nums = share_inputs["accept_num"][:real_bsz].reshape([-1])
cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_nums)]).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
target_logits = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
)
# Extract target logits/probs at accepted positions (shared by logprobs and sampling_mask).
# When both are enabled, reuse target_logits to derive target_probs (avoid a second kernel call).
total_accepted = int(accept_nums.sum().item())
target_logits = paddle.empty([total_accepted, logits.shape[1]], dtype=logits.dtype)
speculate_get_target_logits(
target_logits,
logits,
@@ -833,6 +946,12 @@ class SpeculativeSampler(nn.Layer):
share_inputs["seq_lens_encoder"],
share_inputs["accept_num"],
)
if keep_sampling_mask:
# Derive target probs from already-extracted target_logits; avoids a second kernel call.
target_probs = F.softmax(target_logits, axis=-1)
raw_logprobs = None
if num_logprobs is not None:
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
elif self.logprobs_mode == "raw_logits":
@@ -846,12 +965,35 @@ class SpeculativeSampler(nn.Layer):
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
# Compute sampling mask at accepted token positions.
# Shape: [total_accepted_tokens, vocab_size], bool (CPU).
sampling_mask = None
if keep_sampling_mask:
# Expand top_p from [batch, 1] to [total_accepted, 1].
accept_top_p = sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1)
accept_top_k = None
if (
sampling_metadata.top_k is not None
and sampling_metadata.top_k_list
and any(x > 0 for x in sampling_metadata.top_k_list)
):
accept_top_k = (
sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1)
)
sampling_mask = _compute_sampling_mask(
target_probs,
accept_top_p,
top_k=accept_top_k,
top_k_list=sampling_metadata.top_k_list,
)
sampler_output = SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=logprobs_tensors,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
logits=logits,
sampling_mask=sampling_mask,
)
return sampler_output