[Speculative Decoding] Unify Spec and non-spec branch (#6685)

* optimize spec-inference architecture

* delete debug log

* optimize spec_method usage  && fix unit_test

* add claude unit-test skill

* fix some ugly bug

* enhance robustness and bounds check

* unify method & spec_method to method to avoid bug

* activate CI

* fix unit test

* Unify logprobs computation for naive and speculative decoding, fix CUDA kernel

* fix logprob bug && optimize verify kernel

* fix exist_decode() judge
This commit is contained in:
freeliuzc
2026-03-11 14:58:44 +08:00
committed by GitHub
parent b6190de557
commit cf7934a4b2
41 changed files with 3428 additions and 392 deletions
+214 -100
View File
@@ -32,19 +32,22 @@ from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
from fastdeploy.model_executor.layers.sample.logprobs import batched_count_greater_than
from fastdeploy.model_executor.layers.sample.logprobs import (
batched_count_greater_than,
build_output_logprobs,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
min_p_sampling,
reasoning_phase_token_constraint,
speculate_get_target_logits,
speculate_insert_first_token,
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
from fastdeploy.reasoning import ReasoningParser
from fastdeploy.spec_decode import SpecMethod, VerifyStrategy
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
@@ -638,6 +641,18 @@ class SpeculativeSampler(nn.Layer):
self.line_break_id = fd_config.model_config.line_break_id
self.enf_gen_phase_tag = fd_config.speculative_config.enf_gen_phase_tag
# Verify strategy derived from config (replaces env vars in CUDA kernel)
spec_config = fd_config.speculative_config
# Verify strategy enum: VerifyStrategy.TOPP/GREEDY/TARGET_MATCH
# Use .value (0/1/2) when passing to CUDA kernel
self.spec_method = spec_config.method
self.verify_strategy = spec_config.verify_strategy
self.prefill_one_step_stop = fd_config.parallel_config.prefill_one_step_stop
# Accept policy from config (can be overridden by function parameters)
self.config_accept_all = spec_config.accept_policy == "accept_all"
self.config_reject_all = spec_config.accept_policy == "reject_all"
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
pass
@@ -750,6 +765,157 @@ class SpeculativeSampler(nn.Layer):
return LogprobsTensors(indices, top_logprobs, token_ranks)
def _verify_and_sample(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> SamplerOutput:
"""
Verify draft tokens against target model output and produce final samples.
This is the core speculative decoding logic that compares draft tokens
with target model predictions to determine acceptance/rejection.
Args:
logits: Target model raw logits
probs: Target model softmax output
sampling_metadata: Sampling parameters and metadata
max_model_len: Maximum model sequence length
share_inputs: Shared input tensors including draft_tokens, accept_tokens, etc.
accept_all_drafts: Force accept all draft tokens (debug mode)
reject_all_drafts: Force reject all draft tokens (debug mode)
Returns:
SamplerOutput with accepted tokens and metadata
"""
from fastdeploy.model_executor.ops.gpu import (
top_p_candidates,
verify_draft_tokens,
)
# Prepare strategy-specific tensors
# TARGET_MATCH: needs target_tokens=sampled, candidates=None
# GREEDY: needs target_tokens=argmax, candidates=None
# TOPP: needs target_tokens=None, candidates=full top_p set
target_tokens, candidate_ids, candidate_scores, candidate_lens = None, None, None, None
if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
# Only TARGET_MATCH needs stochastic sampling
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, target_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
elif self.verify_strategy == VerifyStrategy.GREEDY:
# GREEDY: deterministic argmax in target_tokens, no candidates needed
target_tokens = paddle.argmax(probs, axis=-1)
elif self.verify_strategy == VerifyStrategy.TOPP: # TOPP
# TOPP: needs full candidate set, target_tokens unused
candidate_scores, candidate_ids, candidate_lens = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)
else:
raise ValueError(f"Unknown verify strategy: {self.verify_strategy}")
# Accept policy: config default OR function parameter (OR logic)
final_accept_all = self.config_accept_all or accept_all_drafts
final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode
verify_draft_tokens(
# Core I/O
share_inputs["accept_tokens"], # step_output_ids
share_inputs["accept_num"], # step_output_len
share_inputs["draft_tokens"], # step_input_ids
# Target model outputs
target_tokens,
# Candidate set (strategy-dependent usage)
candidate_ids,
candidate_scores,
candidate_lens,
# Sampling params
sampling_metadata.top_p,
# Metadata
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_this_time"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["cu_seqlens_q_output"],
share_inputs["reasoning_status"],
# max_dec_len / step_idx for EOS/max-len detection, only read
share_inputs["max_dec_len"],
share_inputs["step_idx"],
# Config
max_model_len,
self.speculative_verify_window,
self.verify_strategy.value,
final_reject_all,
final_accept_all,
)
return SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
logits=logits,
)
def _normal_sample(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
share_inputs: List[paddle.Tensor],
) -> SamplerOutput:
"""
Normal sampling without draft token verification.
Used by NAIVE mode: directly samples from target model output
and writes results to share_inputs["accept_tokens"]/["accept_num"].
Args:
probs: Target model softmax output
logits: Target model output logits
sampling_metadata: Sampling parameters and metadata
share_inputs: Shared input tensors
Returns:
SamplerOutput with sampled tokens (no logprobs; logprobs are computed in forward_cuda)
"""
# Apply min_p sampling if configured
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
# Sample tokens
_, next_tokens = top_k_top_p_sampling(
probs,
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.top_k_list,
topp_seed=sampling_metadata.seed,
)
# For NAIVE mode: write directly to accept_tokens/accept_num
share_inputs["accept_tokens"][: next_tokens.shape[0], 0] = next_tokens.squeeze(-1)
return SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
logits=logits,
)
def forward_cuda(
self,
logits: paddle.Tensor,
@@ -758,10 +924,26 @@ class SpeculativeSampler(nn.Layer):
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> paddle.Tensor:
""" """
) -> SamplerOutput:
"""
Forward pass for speculative sampling.
from fastdeploy.model_executor.ops.gpu import speculate_verify, top_p_candidates
Routes between:
- NAIVE mode: Normal sampling without draft verification
- MTP/Ngram mode: Draft token verification + sampling
Args:
logits: Target model output logits
sampling_metadata: Sampling parameters and metadata
max_model_len: Maximum model sequence length
share_inputs: Shared input tensors
accept_all_drafts: Force accept all draft tokens (debug mode)
reject_all_drafts: Force reject all draft tokens (debug mode)
Returns:
SamplerOutput with sampled/accepted tokens
"""
# Apply speculative penalty scores (shared path)
if sampling_metadata.token_ids_all is not None:
token_ids_all = sampling_metadata.token_ids_all
@@ -788,11 +970,11 @@ class SpeculativeSampler(nn.Layer):
max_model_len,
)
# Apply reasoning phase constraint if enabled
if self.enf_gen_phase_tag:
reasoning_phase_token_constraint(
logits,
token_ids_all,
prompt_lens,
sampling_metadata.pre_token_ids,
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
@@ -808,102 +990,34 @@ class SpeculativeSampler(nn.Layer):
probs = F.softmax(logits)
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)
speculate_verify(
sampled_token_ids,
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
share_inputs["seq_lens_this_time"],
verify_tokens,
verify_scores,
share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["cu_seqlens_q_output"],
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
share_inputs["reasoning_status"],
max_model_len,
self.speculative_verify_window,
True, # enable_topp
(self.speculative_benchmark_mode or reject_all_drafts),
accept_all_drafts,
)
num_logprobs = sampling_metadata.max_num_logprobs
batch_token_num = None
if num_logprobs is not None:
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
share_inputs["seq_lens_this_time"],
).flatten()
share_inputs["batch_token_num"] = batch_token_num
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
target_logits = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
)
speculate_get_target_logits(
target_logits,
# Route based on spec_method
is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
if is_naive:
sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs)
else:
sampler_output = self._verify_and_sample(
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["accept_num"],
probs,
sampling_metadata,
max_model_len,
share_inputs,
accept_all_drafts,
reject_all_drafts,
)
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = target_logits.clone()
logprobs_tensors = None
if num_logprobs is not None:
token_ids = share_inputs["accept_tokens"]
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
mask = idx < share_inputs["accept_num"].unsqueeze(1)
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=logprobs_tensors,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
logits=logits,
)
# Build logprobs via unified path (outside of sampling logic)
if sampling_metadata.max_num_logprobs is not None:
logprobs_tensors, cu_batch_token_offset = build_output_logprobs(
logits,
sampling_metadata,
share_inputs,
is_naive=is_naive,
logprobs_mode=self.logprobs_mode,
compute_logprobs_fn=self.compute_logprobs,
)
sampler_output.logprobs_tensors = logprobs_tensors
if cu_batch_token_offset is not None:
sampler_output.cu_batch_token_offset = cu_batch_token_offset
return sampler_output
def forward_xpu(