mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-25 01:55:45 +08:00
[Speculative Decoding] Unify Spec and non-spec branch (#6685)
* optimize spec-inference architecture * delete debug log * optimize spec_method usage && fix unit_test * add claude unit-test skill * fix some ugly bug * enhance robustness and bounds check * unify method & spec_method to method to avoid bug * activate CI * fix unit test * Unify logprobs computation for naive and speculative decoding, fix CUDA kernel * fix logprob bug && optimize verify kernel * fix exist_decode() judge
This commit is contained in:
@@ -14,11 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -80,3 +84,134 @@ def batched_count_greater_than(x: paddle.Tensor, y: paddle.Tensor) -> paddle.Ten
|
||||
out = (x >= y).sum(-1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def gather_logprobs(
|
||||
logprobs: paddle.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: paddle.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs) or sampled tokens
|
||||
(if sampled logprobs); 1D token ID tensor with (num tokens) elements.
|
||||
Must be int64.
|
||||
|
||||
Returns:
|
||||
LogprobsTensors with top-k indices, top-k logprobs, and token ranks.
|
||||
"""
|
||||
assert token_ids.dtype == paddle.int64
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
|
||||
|
||||
if num_logprobs >= 1:
|
||||
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
|
||||
indices = paddle.concat([token_ids, topk_indices], axis=1)
|
||||
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
|
||||
else:
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
|
||||
|
||||
def build_output_logprobs(
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
is_naive: bool = False,
|
||||
logprobs_mode: str = "default",
|
||||
compute_logprobs_fn: Optional[Callable] = None,
|
||||
) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]:
|
||||
"""
|
||||
Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes.
|
||||
|
||||
This is a standalone function (not tied to any sampler) so that both
|
||||
naive and speculative decoding paths can share the same logprob logic.
|
||||
|
||||
For NAIVE mode: logits are already per-token, no extraction needed.
|
||||
For speculative mode: extracts target logits for accepted token positions.
|
||||
|
||||
Args:
|
||||
logits: Model output logits.
|
||||
sampling_metadata: Sampling parameters and metadata.
|
||||
share_inputs: Shared input tensors.
|
||||
is_naive: True for NAIVE mode (single token per request).
|
||||
logprobs_mode: One of "raw_logprobs", "raw_logits", or "default".
|
||||
compute_logprobs_fn: Callable for computing logprobs with temperature
|
||||
scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs".
|
||||
|
||||
Returns:
|
||||
tuple: (logprobs_tensors, cu_batch_token_offset)
|
||||
"""
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
logprobs_tensors = None
|
||||
cu_batch_token_offset = None
|
||||
|
||||
if num_logprobs is None:
|
||||
return logprobs_tensors, cu_batch_token_offset
|
||||
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
|
||||
if is_naive:
|
||||
# NAIVE mode: one token per request, logits are already correct
|
||||
output_logits = logits
|
||||
token_ids = share_inputs["accept_tokens"][:real_bsz, 0]
|
||||
else:
|
||||
# Speculative mode: extract target logits for accepted positions
|
||||
from fastdeploy.model_executor.layers.sample.ops import (
|
||||
speculate_get_target_logits,
|
||||
)
|
||||
|
||||
batch_token_num = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
|
||||
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
|
||||
share_inputs["seq_lens_this_time"],
|
||||
).flatten()
|
||||
|
||||
share_inputs["batch_token_num"] = batch_token_num
|
||||
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat(
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
|
||||
).astype("int32")
|
||||
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
|
||||
output_logits = paddle.empty(
|
||||
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]],
|
||||
dtype=logits.dtype,
|
||||
)
|
||||
speculate_get_target_logits(
|
||||
output_logits,
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["accept_num"],
|
||||
)
|
||||
|
||||
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
|
||||
mask = idx < share_inputs["accept_num"].unsqueeze(1)
|
||||
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
||||
|
||||
# Compute logprobs with temperature scaling and top_p normalization
|
||||
if logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata)
|
||||
elif logprobs_mode == "raw_logits":
|
||||
raw_logprobs = output_logits.clone()
|
||||
else:
|
||||
raw_logprobs = F.log_softmax(output_logits, axis=-1)
|
||||
|
||||
logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
return logprobs_tensors, cu_batch_token_offset
|
||||
|
||||
Reference in New Issue
Block a user