[Speculative Decoding] Unify Spec and non-spec branch (#6685)

* optimize spec-inference architecture

* delete debug log

* optimize spec_method usage  && fix unit_test

* add claude unit-test skill

* fix some ugly bug

* enhance robustness and bounds check

* unify method & spec_method to method to avoid bug

* activate CI

* fix unit test

* Unify logprobs computation for naive and speculative decoding, fix CUDA kernel

* fix logprob bug && optimize verify kernel

* fix exist_decode() judge
This commit is contained in:
freeliuzc
2026-03-11 14:58:44 +08:00
committed by GitHub
parent b6190de557
commit cf7934a4b2
41 changed files with 3428 additions and 392 deletions
@@ -14,11 +14,15 @@
# limitations under the License.
"""
from typing import Callable, List, Optional, Tuple
import paddle
import paddle.nn.functional as F
import triton
import triton.language as tl
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors
@triton.jit
@@ -80,3 +84,134 @@ def batched_count_greater_than(x: paddle.Tensor, y: paddle.Tensor) -> paddle.Ten
out = (x >= y).sum(-1)
return out
def gather_logprobs(
logprobs: paddle.Tensor,
num_logprobs: int,
token_ids: paddle.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to retain per token
token_ids: prompt tokens (if prompt logprobs) or sampled tokens
(if sampled logprobs); 1D token ID tensor with (num tokens) elements.
Must be int64.
Returns:
LogprobsTensors with top-k indices, top-k logprobs, and token ranks.
"""
assert token_ids.dtype == paddle.int64
token_ids = token_ids.unsqueeze(1)
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
if num_logprobs >= 1:
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
def build_output_logprobs(
logits: paddle.Tensor,
sampling_metadata,
share_inputs: List[paddle.Tensor],
is_naive: bool = False,
logprobs_mode: str = "default",
compute_logprobs_fn: Optional[Callable] = None,
) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]:
"""
Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes.
This is a standalone function (not tied to any sampler) so that both
naive and speculative decoding paths can share the same logprob logic.
For NAIVE mode: logits are already per-token, no extraction needed.
For speculative mode: extracts target logits for accepted token positions.
Args:
logits: Model output logits.
sampling_metadata: Sampling parameters and metadata.
share_inputs: Shared input tensors.
is_naive: True for NAIVE mode (single token per request).
logprobs_mode: One of "raw_logprobs", "raw_logits", or "default".
compute_logprobs_fn: Callable for computing logprobs with temperature
scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs".
Returns:
tuple: (logprobs_tensors, cu_batch_token_offset)
"""
num_logprobs = sampling_metadata.max_num_logprobs
logprobs_tensors = None
cu_batch_token_offset = None
if num_logprobs is None:
return logprobs_tensors, cu_batch_token_offset
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
if is_naive:
# NAIVE mode: one token per request, logits are already correct
output_logits = logits
token_ids = share_inputs["accept_tokens"][:real_bsz, 0]
else:
# Speculative mode: extract target logits for accepted positions
from fastdeploy.model_executor.layers.sample.ops import (
speculate_get_target_logits,
)
batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
share_inputs["seq_lens_this_time"],
).flatten()
share_inputs["batch_token_num"] = batch_token_num
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
output_logits = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]],
dtype=logits.dtype,
)
speculate_get_target_logits(
output_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["accept_num"],
)
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
mask = idx < share_inputs["accept_num"].unsqueeze(1)
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
# Compute logprobs with temperature scaling and top_p normalization
if logprobs_mode == "raw_logprobs":
raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata)
elif logprobs_mode == "raw_logits":
raw_logprobs = output_logits.clone()
else:
raw_logprobs = F.log_softmax(output_logits, axis=-1)
logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
return logprobs_tensors, cu_batch_token_offset