From 82e25eb3b380ac42d447a385d9542add8b659445 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Tue, 14 Apr 2026 00:43:36 -0700 Subject: [PATCH] Revert "[KSM] fix logz when top_k (#7225)" This reverts commit f83673daac0a495afe239bc31d43a3f3398fb035. --- .../model_executor/layers/sample/sampler.py | 36 ++++--------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index e2d62a63c1..a0fc666bca 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -197,37 +197,13 @@ def _compute_sampling_mask( max_k = int(k_per_row.max().item()) # ------------------------------------------------------------------ - # Stage 5: compute logZ for renormalization - # - # Goal: log π_mask(k) = log π_full(k) - logZ, where π_mask is the - # distribution actually sampled from (top-k truncated + top-p nucleus). - # - # When top_k is active the sampling pipeline first renormalises to - # π_topk, then applies top-p on π_topk. The total log-normaliser - # that maps π_full → π_mask absorbs both steps: - # - # logZ = log Z_topk + log Z_topp_on_renorm - # - # where Z_topk = Σ_{j∈topk} π_full(j) (= row_sums, already computed) - # Z_topp = Σ_{k∈K} π_topk(k) (sum of renorm probs in K) - # - # Substituting: - # log π_mask(k) = log π_full(k) - log Z_topk - log Z_topp - # = log π_topk(k) - log Z_topp ✓ - # - # When top_k is absent Z_topk = 1 → logZ = log Z_topp as before. + # Stage 5: compute logZ_K for renormalization + # Z_K = sum(probs[i] * final_mask[i]) for each request i + # logZ_K = log(Z_K), with small constant to avoid log(0) # ------------------------------------------------------------------ - if has_top_k: - # Z_topp: sum of renormed probs that survive the final mask - candidate_probs = paddle.where(final_mask, renorm_sorted_probs, paddle.zeros_like(renorm_sorted_probs)) - z_topp = candidate_probs.sum(axis=-1) # [B] - # row_sums: [B, 1] already clipped ≥ 1e-9, squeeze to [B] - log_z_topk = paddle.log(row_sums.squeeze(-1)) - logz_per_batch = (log_z_topk + paddle.log(z_topp + 1e-10)).cpu().numpy() # [B] - else: - candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) - z_k = candidate_probs.sum(axis=-1) # [B] - logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] + candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + z_k = candidate_probs.sum(axis=-1) # [B] + logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] # Transfer only the leading max_k columns — typically max_k << vocab_size. indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k]