mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 17:49:42 +08:00
[Speculative Decoding] Unify Spec and non-spec branch (#6685)
* optimize spec-inference architecture * delete debug log * optimize spec_method usage && fix unit_test * add claude unit-test skill * fix some ugly bug * enhance robustness and bounds check * unify method & spec_method to method to avoid bug * activate CI * fix unit test * Unify logprobs computation for naive and speculative decoding, fix CUDA kernel * fix logprob bug && optimize verify kernel * fix exist_decode() judge
This commit is contained in:
@@ -32,19 +32,22 @@ from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
|
||||
from fastdeploy.model_executor.layers.sample.early_stopper import (
|
||||
get_early_stopper_cls_from_stragegy,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.logprobs import batched_count_greater_than
|
||||
from fastdeploy.model_executor.layers.sample.logprobs import (
|
||||
batched_count_greater_than,
|
||||
build_output_logprobs,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores,
|
||||
apply_speculative_penalty_multi_scores,
|
||||
min_p_sampling,
|
||||
reasoning_phase_token_constraint,
|
||||
speculate_get_target_logits,
|
||||
speculate_insert_first_token,
|
||||
top_k_top_p_sampling,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.reasoning import ReasoningParser
|
||||
from fastdeploy.spec_decode import SpecMethod, VerifyStrategy
|
||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||
|
||||
|
||||
@@ -638,6 +641,18 @@ class SpeculativeSampler(nn.Layer):
|
||||
self.line_break_id = fd_config.model_config.line_break_id
|
||||
self.enf_gen_phase_tag = fd_config.speculative_config.enf_gen_phase_tag
|
||||
|
||||
# Verify strategy derived from config (replaces env vars in CUDA kernel)
|
||||
spec_config = fd_config.speculative_config
|
||||
# Verify strategy enum: VerifyStrategy.TOPP/GREEDY/TARGET_MATCH
|
||||
# Use .value (0/1/2) when passing to CUDA kernel
|
||||
self.spec_method = spec_config.method
|
||||
self.verify_strategy = spec_config.verify_strategy
|
||||
self.prefill_one_step_stop = fd_config.parallel_config.prefill_one_step_stop
|
||||
|
||||
# Accept policy from config (can be overridden by function parameters)
|
||||
self.config_accept_all = spec_config.accept_policy == "accept_all"
|
||||
self.config_reject_all = spec_config.accept_policy == "reject_all"
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
pass
|
||||
@@ -750,6 +765,157 @@ class SpeculativeSampler(nn.Layer):
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
|
||||
def _verify_and_sample(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
probs: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
max_model_len: int,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
accept_all_drafts: bool = False,
|
||||
reject_all_drafts: bool = False,
|
||||
) -> SamplerOutput:
|
||||
"""
|
||||
Verify draft tokens against target model output and produce final samples.
|
||||
|
||||
This is the core speculative decoding logic that compares draft tokens
|
||||
with target model predictions to determine acceptance/rejection.
|
||||
|
||||
Args:
|
||||
logits: Target model raw logits
|
||||
probs: Target model softmax output
|
||||
sampling_metadata: Sampling parameters and metadata
|
||||
max_model_len: Maximum model sequence length
|
||||
share_inputs: Shared input tensors including draft_tokens, accept_tokens, etc.
|
||||
accept_all_drafts: Force accept all draft tokens (debug mode)
|
||||
reject_all_drafts: Force reject all draft tokens (debug mode)
|
||||
|
||||
Returns:
|
||||
SamplerOutput with accepted tokens and metadata
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
top_p_candidates,
|
||||
verify_draft_tokens,
|
||||
)
|
||||
|
||||
# Prepare strategy-specific tensors
|
||||
# TARGET_MATCH: needs target_tokens=sampled, candidates=None
|
||||
# GREEDY: needs target_tokens=argmax, candidates=None
|
||||
# TOPP: needs target_tokens=None, candidates=full top_p set
|
||||
target_tokens, candidate_ids, candidate_scores, candidate_lens = None, None, None, None
|
||||
|
||||
if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
|
||||
# Only TARGET_MATCH needs stochastic sampling
|
||||
top_p, top_k, topp_seed = padding_sampling_params(
|
||||
sampling_metadata.top_p,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.seed,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
)
|
||||
_, target_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
|
||||
elif self.verify_strategy == VerifyStrategy.GREEDY:
|
||||
# GREEDY: deterministic argmax in target_tokens, no candidates needed
|
||||
target_tokens = paddle.argmax(probs, axis=-1)
|
||||
elif self.verify_strategy == VerifyStrategy.TOPP: # TOPP
|
||||
# TOPP: needs full candidate set, target_tokens unused
|
||||
candidate_scores, candidate_ids, candidate_lens = top_p_candidates(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["batch_id_per_token_output"],
|
||||
self.speculative_max_candidate_len,
|
||||
max_model_len,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown verify strategy: {self.verify_strategy}")
|
||||
|
||||
# Accept policy: config default OR function parameter (OR logic)
|
||||
final_accept_all = self.config_accept_all or accept_all_drafts
|
||||
final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode
|
||||
|
||||
verify_draft_tokens(
|
||||
# Core I/O
|
||||
share_inputs["accept_tokens"], # step_output_ids
|
||||
share_inputs["accept_num"], # step_output_len
|
||||
share_inputs["draft_tokens"], # step_input_ids
|
||||
# Target model outputs
|
||||
target_tokens,
|
||||
# Candidate set (strategy-dependent usage)
|
||||
candidate_ids,
|
||||
candidate_scores,
|
||||
candidate_lens,
|
||||
# Sampling params
|
||||
sampling_metadata.top_p,
|
||||
# Metadata
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["cu_seqlens_q_output"],
|
||||
share_inputs["reasoning_status"],
|
||||
# max_dec_len / step_idx for EOS/max-len detection, only read
|
||||
share_inputs["max_dec_len"],
|
||||
share_inputs["step_idx"],
|
||||
# Config
|
||||
max_model_len,
|
||||
self.speculative_verify_window,
|
||||
self.verify_strategy.value,
|
||||
final_reject_all,
|
||||
final_accept_all,
|
||||
)
|
||||
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=share_inputs["accept_tokens"],
|
||||
logprobs_tensors=None,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
logits=logits,
|
||||
)
|
||||
|
||||
def _normal_sample(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
probs: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
share_inputs: List[paddle.Tensor],
|
||||
) -> SamplerOutput:
|
||||
"""
|
||||
Normal sampling without draft token verification.
|
||||
|
||||
Used by NAIVE mode: directly samples from target model output
|
||||
and writes results to share_inputs["accept_tokens"]/["accept_num"].
|
||||
|
||||
Args:
|
||||
probs: Target model softmax output
|
||||
logits: Target model output logits
|
||||
sampling_metadata: Sampling parameters and metadata
|
||||
share_inputs: Shared input tensors
|
||||
|
||||
Returns:
|
||||
SamplerOutput with sampled tokens (no logprobs; logprobs are computed in forward_cuda)
|
||||
"""
|
||||
# Apply min_p sampling if configured
|
||||
probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list)
|
||||
|
||||
# Sample tokens
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_k_list,
|
||||
topp_seed=sampling_metadata.seed,
|
||||
)
|
||||
|
||||
# For NAIVE mode: write directly to accept_tokens/accept_num
|
||||
share_inputs["accept_tokens"][: next_tokens.shape[0], 0] = next_tokens.squeeze(-1)
|
||||
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=share_inputs["accept_tokens"],
|
||||
logprobs_tensors=None,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
logits=logits,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
@@ -758,10 +924,26 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs: List[paddle.Tensor],
|
||||
accept_all_drafts: bool = False,
|
||||
reject_all_drafts: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
""" """
|
||||
) -> SamplerOutput:
|
||||
"""
|
||||
Forward pass for speculative sampling.
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import speculate_verify, top_p_candidates
|
||||
Routes between:
|
||||
- NAIVE mode: Normal sampling without draft verification
|
||||
- MTP/Ngram mode: Draft token verification + sampling
|
||||
|
||||
Args:
|
||||
logits: Target model output logits
|
||||
sampling_metadata: Sampling parameters and metadata
|
||||
max_model_len: Maximum model sequence length
|
||||
share_inputs: Shared input tensors
|
||||
accept_all_drafts: Force accept all draft tokens (debug mode)
|
||||
reject_all_drafts: Force reject all draft tokens (debug mode)
|
||||
|
||||
Returns:
|
||||
SamplerOutput with sampled/accepted tokens
|
||||
"""
|
||||
# Apply speculative penalty scores (shared path)
|
||||
|
||||
if sampling_metadata.token_ids_all is not None:
|
||||
token_ids_all = sampling_metadata.token_ids_all
|
||||
@@ -788,11 +970,11 @@ class SpeculativeSampler(nn.Layer):
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
# Apply reasoning phase constraint if enabled
|
||||
if self.enf_gen_phase_tag:
|
||||
reasoning_phase_token_constraint(
|
||||
logits,
|
||||
token_ids_all,
|
||||
prompt_lens,
|
||||
sampling_metadata.pre_token_ids,
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
@@ -808,102 +990,34 @@ class SpeculativeSampler(nn.Layer):
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
top_p, top_k, topp_seed = padding_sampling_params(
|
||||
sampling_metadata.top_p,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.seed,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
)
|
||||
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
|
||||
|
||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["batch_id_per_token_output"],
|
||||
self.speculative_max_candidate_len,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
speculate_verify(
|
||||
sampled_token_ids,
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs[
|
||||
"draft_tokens"
|
||||
], # Both input and output, need to write the last 1 token accepted to position 0.
|
||||
share_inputs["seq_lens_this_time"],
|
||||
verify_tokens,
|
||||
verify_scores,
|
||||
share_inputs["max_dec_len"],
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["cu_seqlens_q_output"],
|
||||
actual_candidate_len,
|
||||
share_inputs["actual_draft_token_num"],
|
||||
sampling_metadata.top_p,
|
||||
share_inputs["reasoning_status"],
|
||||
max_model_len,
|
||||
self.speculative_verify_window,
|
||||
True, # enable_topp
|
||||
(self.speculative_benchmark_mode or reject_all_drafts),
|
||||
accept_all_drafts,
|
||||
)
|
||||
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
batch_token_num = None
|
||||
if num_logprobs is not None:
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
batch_token_num = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
|
||||
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
|
||||
share_inputs["seq_lens_this_time"],
|
||||
).flatten()
|
||||
share_inputs["batch_token_num"] = batch_token_num
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat(
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
|
||||
).astype("int32")
|
||||
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
target_logits = paddle.empty(
|
||||
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
|
||||
)
|
||||
speculate_get_target_logits(
|
||||
target_logits,
|
||||
# Route based on spec_method
|
||||
is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
|
||||
if is_naive:
|
||||
sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs)
|
||||
else:
|
||||
sampler_output = self._verify_and_sample(
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["accept_num"],
|
||||
probs,
|
||||
sampling_metadata,
|
||||
max_model_len,
|
||||
share_inputs,
|
||||
accept_all_drafts,
|
||||
reject_all_drafts,
|
||||
)
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = target_logits.clone()
|
||||
|
||||
logprobs_tensors = None
|
||||
if num_logprobs is not None:
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
|
||||
mask = idx < share_inputs["accept_num"].unsqueeze(1)
|
||||
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=share_inputs["accept_tokens"],
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
logits=logits,
|
||||
)
|
||||
|
||||
# Build logprobs via unified path (outside of sampling logic)
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
logprobs_tensors, cu_batch_token_offset = build_output_logprobs(
|
||||
logits,
|
||||
sampling_metadata,
|
||||
share_inputs,
|
||||
is_naive=is_naive,
|
||||
logprobs_mode=self.logprobs_mode,
|
||||
compute_logprobs_fn=self.compute_logprobs,
|
||||
)
|
||||
sampler_output.logprobs_tensors = logprobs_tensors
|
||||
if cu_batch_token_offset is not None:
|
||||
sampler_output.cu_batch_token_offset = cu_batch_token_offset
|
||||
return sampler_output
|
||||
|
||||
def forward_xpu(
|
||||
|
||||
Reference in New Issue
Block a user