Revert "[KSM] fix logz when top_k (#7225)"

This reverts commit f83673daac.
This commit is contained in:
Yuanle Liu
2026-04-14 00:43:36 -07:00
committed by GitHub
parent 19b0038234
commit 82e25eb3b3
@@ -197,37 +197,13 @@ def _compute_sampling_mask(
max_k = int(k_per_row.max().item()) max_k = int(k_per_row.max().item())
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Stage 5: compute logZ for renormalization # Stage 5: compute logZ_K for renormalization
# # Z_K = sum(probs[i] * final_mask[i]) for each request i
# Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the # logZ_K = log(Z_K), with small constant to avoid log(0)
# distribution actually sampled from (top-k truncated + top-p nucleus).
#
# When top_k is active the sampling pipeline first renormalises to
# π_topk, then applies top-p on π_topk. The total log-normaliser
# that maps π_full → π_mask absorbs both steps:
#
# logZ = log Z_topk + log Z_topp_on_renorm
#
# where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed)
# Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K)
#
# Substituting:
# log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp
# = log π_topk(k) - log Z_topp ✓
#
# When top_k is absent Z_topk = 1 → logZ = log Z_topp as before.
# ------------------------------------------------------------------ # ------------------------------------------------------------------
if has_top_k: candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
# Z_topp: sum of renormed probs that survive the final mask z_k = candidate_probs.sum(axis=-1) # [B]
candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs)) logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]
z_topp = candidate_probs.sum(axis=-1) # [B]
# row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B]
log_z_topk = paddle.log(row_sums.squeeze(-1))
logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B]
else:
candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs))
z_k = candidate_probs.sum(axis=-1) # [B]
logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B]
# Transfer only the leading max_k columns — typically max_k << vocab_size. # Transfer only the leading max_k columns — typically max_k << vocab_size.
indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k] indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k]