[Speculative Decoding] Unify Spec and non-spec branch (#6685)

* optimize spec-inference architecture

* delete debug log

* optimize spec_method usage  && fix unit_test

* add claude unit-test skill

* fix some ugly bug

* enhance robustness and bounds check

* unify method & spec_method to method to avoid bug

* activate CI

* fix unit test

* Unify logprobs computation for naive and speculative decoding, fix CUDA kernel

* fix logprob bug && optimize verify kernel

* fix exist_decode() judge
This commit is contained in:
freeliuzc
2026-03-11 14:58:44 +08:00
committed by GitHub
parent b6190de557
commit cf7934a4b2
41 changed files with 3428 additions and 392 deletions
+54 -49
View File
@@ -53,6 +53,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.platforms import current_platform
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
if current_platform.is_iluvatar():
@@ -78,17 +79,6 @@ else:
unset_data_ipc,
)
from fastdeploy.model_executor.pre_and_post_process import (
async_set_value,
post_process,
pre_process,
rebuild_padding,
save_output_normal,
)
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy.spec_decode import MTPProposer, NgramProposer, SuffixProposer
import zmq
from fastdeploy import envs
@@ -100,6 +90,13 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
from fastdeploy.model_executor.pre_and_post_process import (
async_set_value,
post_process,
pre_process,
rebuild_padding,
save_output_normal,
)
from fastdeploy.output.pooler import PoolerOutput
from fastdeploy.worker.model_runner_base import (
DistributedOut,
@@ -124,8 +121,8 @@ class GPUModelRunner(ModelRunnerBase):
self.rank = rank
self.local_rank = local_rank
self.device_id = device_id
self.speculative_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.speculative_method is not None
self.spec_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.spec_method is not None
self.enable_logprob = fd_config.model_config.enable_logprob
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling"
@@ -296,7 +293,9 @@ class GPUModelRunner(ModelRunnerBase):
"""
check whether decode stage exist
"""
return (self.share_inputs["seq_lens_decoder"] > 0).any().cpu().numpy().item()
seq_lens_decoder = self.share_inputs["seq_lens_decoder"]
stop_flags = self.share_inputs["stop_flags"].squeeze(1)
return ((seq_lens_decoder > 0) & ~stop_flags).any().cpu().numpy().item()
def _resolve_current_launch_token_num(
self, cached_token_num: int, token_num_event, is_dummy_or_profile_run: bool
@@ -428,21 +427,19 @@ class GPUModelRunner(ModelRunnerBase):
"""
Init speculative proposer
"""
if self.speculative_method == "ngram":
self.proposer = NgramProposer(self.fd_config)
elif self.speculative_method == "mtp":
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
self.proposer = MTPProposer(
self.fd_config,
self.get_model(),
self.local_rank,
self.device_id,
self.share_inputs,
)
elif self.speculative_method == "suffix":
self.proposer = SuffixProposer(self.fd_config)
else:
if self.spec_method is None:
self.proposer = None
return
# MTP-specific: swap seq_lens_this_time to the buffer tensor
if self.spec_method == SpecMethod.MTP:
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
self.proposer = self.spec_method.create_proposer(
self.fd_config,
main_model=self.get_model(),
local_rank=self.local_rank,
device_id=self.device_id,
share_inputs=self.share_inputs,
)
def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]:
"""
@@ -868,7 +865,7 @@ class GPUModelRunner(ModelRunnerBase):
self.prompt_logprobs_reqs[request.request_id] = request
self.forward_batch_reqs_list[idx] = request
if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None:
if self.speculative_decoding and self.spec_method == SpecMethod.SUFFIX and self.proposer is not None:
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
@@ -984,7 +981,7 @@ class GPUModelRunner(ModelRunnerBase):
self._process_mm_features(req_dicts)
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests]
if self.speculative_method in ["mtp"]:
if self.spec_method == SpecMethod.MTP:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
@@ -1228,7 +1225,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs.condense()
reorder_split_prefill_and_decode(input_batch=self.share_inputs)
if self.speculative_decoding:
if self.speculative_method == "mtp":
if self.spec_method == SpecMethod.MTP:
self.proposer.reorder_inputs()
def load_model(self) -> None:
@@ -1249,7 +1246,7 @@ class GPUModelRunner(ModelRunnerBase):
if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
if self.fd_config.speculative_config.method == "mtp":
if self.spec_method == SpecMethod.MTP:
self.dynamic_weight_manager = DynamicWeightManager(
self.fd_config, [self.model, self.proposer.model], self.local_rank
)
@@ -1745,15 +1742,19 @@ class GPUModelRunner(ModelRunnerBase):
think_end_id=self.model_config.think_end_id,
splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode",
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
is_naive_mode=(self.speculative_decoding and self.proposer is None),
prefill_one_step_stop=self.parallel_config.prefill_one_step_stop,
)
self.exist_prefill_flag = False
if self.speculative_decoding:
if self.speculative_method == "mtp":
if self.spec_method == SpecMethod.MTP:
self.proposer.run(
full_hidden_states=model_output,
step_use_cudagraph=self.forward_meta.step_use_cudagraph,
is_dummy_run=True,
)
elif self.spec_method == SpecMethod.NAIVE:
pass
else:
self.proposer.prepare_dummy_speculative_drafts(share_inputs=self.share_inputs, batch_size=batch_size)
return sampler_output
@@ -1789,7 +1790,7 @@ class GPUModelRunner(ModelRunnerBase):
max_dec_len_list=max_dec_len_list,
block_num=block_num,
)
if self.speculative_method in ["mtp"]:
if self.spec_method == SpecMethod.MTP:
self.proposer.dummy_prefill_inputs(
num_tokens=num_tokens,
batch_size=batch_size,
@@ -1803,7 +1804,6 @@ class GPUModelRunner(ModelRunnerBase):
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
self.padding_cudagraph_inputs()
# 3. Run model
if self.enable_mm:
model_output = self.model(
self.forward_meta.ids_remove_padding,
@@ -1877,7 +1877,7 @@ class GPUModelRunner(ModelRunnerBase):
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
elif self.speculative_decoding:
elif self.speculative_decoding and self.spec_method == SpecMethod.MTP:
# Capture Target Model without bsz 1
for capture_size in sorted(capture_sizes, reverse=True):
expected_decode_len = self.speculative_config.num_speculative_tokens * 2 + 1
@@ -2116,7 +2116,6 @@ class GPUModelRunner(ModelRunnerBase):
# 2. Padding inputs for cuda graph
self.padding_cudagraph_inputs()
# 3. Execute model
if self.enable_mm:
model_output = self.model(
@@ -2330,7 +2329,7 @@ class GPUModelRunner(ModelRunnerBase):
enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False),
)
if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
if self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill":
skip_save_output = True
else:
skip_save_output = False
@@ -2348,19 +2347,25 @@ class GPUModelRunner(ModelRunnerBase):
think_end_id=self.model_config.think_end_id,
splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode",
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
is_naive_mode=(self.speculative_decoding and self.proposer is None),
prefill_one_step_stop=self.parallel_config.prefill_one_step_stop,
)
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids)
# 6. Speculative decode
if self.speculative_decoding:
if self.speculative_method == "mtp":
# 6. Speculative decode -- proposer run (method="naive" has proposer=None, skip)
# For naive mode: seq_lens_this_time is already reset to 1 inside
# unified_update_model_status kernel. For MTP/Ngram, the proposer
# will overwrite it with (draft_count + 1) below.
if self.speculative_decoding and self.proposer is not None:
if self.spec_method == SpecMethod.MTP:
self.proposer.run(
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
)
elif self.speculative_method == "suffix":
self.proposer.run(share_inputs=self.share_inputs)
elif self.spec_method == SpecMethod.NAIVE:
pass
else:
self.proposer.run(share_inputs=self.share_inputs)
@@ -2483,7 +2488,7 @@ class GPUModelRunner(ModelRunnerBase):
# TODO(gongshaotian): Optimize the management logic of kvcache
self.num_gpu_blocks = self.cache_config.total_block_num
self.initialize_kv_cache(profile=True)
if self.speculative_method in ["mtp"]:
if self.spec_method == SpecMethod.MTP:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
# 1. Profile with multimodal encoder & encoder cache
@@ -2499,7 +2504,7 @@ class GPUModelRunner(ModelRunnerBase):
)
# 3. gc
if self.speculative_method in ["mtp"]:
if self.spec_method == SpecMethod.MTP:
self.proposer.clear_mtp_cache(profile=True)
self.clear_cache(profile=True)
@@ -2530,7 +2535,7 @@ class GPUModelRunner(ModelRunnerBase):
}
)
if self.speculative_method in ["mtp"]:
if self.spec_method == SpecMethod.MTP:
self.proposer.update_mtp_block_num(num_gpu_blocks)
def cal_theortical_kvcache(self):
@@ -2561,7 +2566,7 @@ class GPUModelRunner(ModelRunnerBase):
# NOTE(liuzichang): Implement multi-layer MTP architecture in the future
num_layers = (
self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio
if self.speculative_method in ["mtp"]
if self.spec_method == SpecMethod.MTP
else self.model_config.num_hidden_layers
)
@@ -2620,7 +2625,7 @@ class GPUModelRunner(ModelRunnerBase):
self.dynamic_weight_manager.clear_parameters(
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
)
if self.speculative_method in ["mtp"]:
if self.spec_method == SpecMethod.MTP:
self.proposer.clear_mtp_cache()
self.clear_cache()
paddle.device.cuda.empty_cache()
@@ -2646,7 +2651,7 @@ class GPUModelRunner(ModelRunnerBase):
# Reset share_inputs
self.share_inputs.reset_share_inputs()
if self.speculative_method in ["mtp"]:
if self.spec_method == SpecMethod.MTP:
self.proposer.model_inputs.reset_model_inputs()
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
self.initialize_kv_cache()