mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
Revert "[KSM] fix logz when top_k (#7225)"
This reverts commit f83673daac.
This commit is contained in:
@@ -197,37 +197,13 @@ def _compute_sampling_mask(
|
||||
max_k = int(k_per_row.max().item())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stage 5: compute logZ for renormalization
|
||||
#
|
||||
# Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the
|
||||
# distribution actually sampled from (top-k truncated + top-p nucleus).
|
||||
#
|
||||
# When top_k is active the sampling pipeline first renormalises to
|
||||
# π_topk, then applies top-p on π_topk. The total log-normaliser
|
||||
# that maps π_full → π_mask absorbs both steps:
|
||||
#
|
||||
# logZ = log Z_topk + log Z_topp_on_renorm
|
||||
#
|
||||
# where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed)
|
||||
# Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K)
|
||||
#
|
||||
# Substituting:
|
||||
# log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp
|
||||
# = log π_topk(k) - log Z_topp ✓
|
||||
#
|
||||
# When top_k is absent Z_topk = 1 → logZ = log Z_topp as before.
|
||||
# Stage 5: compute logZ_K for renormalization
|
||||
# Z_K = sum(probs[i] * final_mask[i]) for each request i
|
||||
# logZ_K = log(Z_K), with small constant to avoid log(0)
|
||||
# ------------------------------------------------------------------
|
||||
if has_top_k:
|
||||
# Z_topp: sum of renormed probs that survive the final mask
|
||||
candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs))
|
||||
z_topp = candidate_probs.sum(axis=-1) # [B]
|
||||
# row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B]
|
||||
log_z_topk = paddle.log(row_sums.squeeze(-1))
|
||||
logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B]
|
||||
else:
|
||||
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
|
||||
z_k = candidate_probs.sum(axis=-1) # [B]
|
||||
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]
|
||||
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
|
||||
z_k = candidate_probs.sum(axis=-1) # [B]
|
||||
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]
|
||||
|
||||
# Transfer only the leading max_k columns — typically max_k << vocab_size.
|
||||
indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k]
|
||||
|
||||
Reference in New Issue
Block a user