mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Unify Spec and non-spec branch (#6685)
* optimize spec-inference architecture * delete debug log * optimize spec_method usage && fix unit_test * add claude unit-test skill * fix some ugly bug * enhance robustness and bounds check * unify method & spec_method to method to avoid bug * activate CI * fix unit test * Unify logprobs computation for naive and speculative decoding, fix CUDA kernel * fix logprob bug && optimize verify kernel * fix exist_decode() judge
This commit is contained in:
@@ -53,6 +53,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
|
||||
|
||||
if current_platform.is_iluvatar():
|
||||
@@ -78,17 +79,6 @@ else:
|
||||
unset_data_ipc,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.pre_and_post_process import (
|
||||
async_set_value,
|
||||
post_process,
|
||||
pre_process,
|
||||
rebuild_padding,
|
||||
save_output_normal,
|
||||
)
|
||||
|
||||
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer, SuffixProposer
|
||||
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
@@ -100,6 +90,13 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
|
||||
from fastdeploy.model_executor.pre_and_post_process import (
|
||||
async_set_value,
|
||||
post_process,
|
||||
pre_process,
|
||||
rebuild_padding,
|
||||
save_output_normal,
|
||||
)
|
||||
from fastdeploy.output.pooler import PoolerOutput
|
||||
from fastdeploy.worker.model_runner_base import (
|
||||
DistributedOut,
|
||||
@@ -124,8 +121,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.device_id = device_id
|
||||
self.speculative_method = self.fd_config.speculative_config.method
|
||||
self.speculative_decoding = self.speculative_method is not None
|
||||
self.spec_method = self.fd_config.speculative_config.method
|
||||
self.speculative_decoding = self.spec_method is not None
|
||||
self.enable_logprob = fd_config.model_config.enable_logprob
|
||||
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
|
||||
self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling"
|
||||
@@ -296,7 +293,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
check whether decode stage exist
|
||||
"""
|
||||
return (self.share_inputs["seq_lens_decoder"] > 0).any().cpu().numpy().item()
|
||||
seq_lens_decoder = self.share_inputs["seq_lens_decoder"]
|
||||
stop_flags = self.share_inputs["stop_flags"].squeeze(1)
|
||||
return ((seq_lens_decoder > 0) & ~stop_flags).any().cpu().numpy().item()
|
||||
|
||||
def _resolve_current_launch_token_num(
|
||||
self, cached_token_num: int, token_num_event, is_dummy_or_profile_run: bool
|
||||
@@ -428,21 +427,19 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
Init speculative proposer
|
||||
"""
|
||||
if self.speculative_method == "ngram":
|
||||
self.proposer = NgramProposer(self.fd_config)
|
||||
elif self.speculative_method == "mtp":
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
|
||||
self.proposer = MTPProposer(
|
||||
self.fd_config,
|
||||
self.get_model(),
|
||||
self.local_rank,
|
||||
self.device_id,
|
||||
self.share_inputs,
|
||||
)
|
||||
elif self.speculative_method == "suffix":
|
||||
self.proposer = SuffixProposer(self.fd_config)
|
||||
else:
|
||||
if self.spec_method is None:
|
||||
self.proposer = None
|
||||
return
|
||||
# MTP-specific: swap seq_lens_this_time to the buffer tensor
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
|
||||
self.proposer = self.spec_method.create_proposer(
|
||||
self.fd_config,
|
||||
main_model=self.get_model(),
|
||||
local_rank=self.local_rank,
|
||||
device_id=self.device_id,
|
||||
share_inputs=self.share_inputs,
|
||||
)
|
||||
|
||||
def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]:
|
||||
"""
|
||||
@@ -868,7 +865,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.prompt_logprobs_reqs[request.request_id] = request
|
||||
self.forward_batch_reqs_list[idx] = request
|
||||
|
||||
if self.speculative_decoding and self.speculative_method == "suffix" and self.proposer is not None:
|
||||
if self.speculative_decoding and self.spec_method == SpecMethod.SUFFIX and self.proposer is not None:
|
||||
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||
else:
|
||||
@@ -984,7 +981,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self._process_mm_features(req_dicts)
|
||||
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests]
|
||||
if self.speculative_method in ["mtp"]:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
@@ -1228,7 +1225,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs.condense()
|
||||
reorder_split_prefill_and_decode(input_batch=self.share_inputs)
|
||||
if self.speculative_decoding:
|
||||
if self.speculative_method == "mtp":
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.reorder_inputs()
|
||||
|
||||
def load_model(self) -> None:
|
||||
@@ -1249,7 +1246,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||
|
||||
if self.fd_config.speculative_config.method == "mtp":
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.dynamic_weight_manager = DynamicWeightManager(
|
||||
self.fd_config, [self.model, self.proposer.model], self.local_rank
|
||||
)
|
||||
@@ -1745,15 +1742,19 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode",
|
||||
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
|
||||
is_naive_mode=(self.speculative_decoding and self.proposer is None),
|
||||
prefill_one_step_stop=self.parallel_config.prefill_one_step_stop,
|
||||
)
|
||||
self.exist_prefill_flag = False
|
||||
if self.speculative_decoding:
|
||||
if self.speculative_method == "mtp":
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.run(
|
||||
full_hidden_states=model_output,
|
||||
step_use_cudagraph=self.forward_meta.step_use_cudagraph,
|
||||
is_dummy_run=True,
|
||||
)
|
||||
elif self.spec_method == SpecMethod.NAIVE:
|
||||
pass
|
||||
else:
|
||||
self.proposer.prepare_dummy_speculative_drafts(share_inputs=self.share_inputs, batch_size=batch_size)
|
||||
return sampler_output
|
||||
@@ -1789,7 +1790,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
max_dec_len_list=max_dec_len_list,
|
||||
block_num=block_num,
|
||||
)
|
||||
if self.speculative_method in ["mtp"]:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.dummy_prefill_inputs(
|
||||
num_tokens=num_tokens,
|
||||
batch_size=batch_size,
|
||||
@@ -1803,7 +1804,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
|
||||
self.padding_cudagraph_inputs()
|
||||
|
||||
# 3. Run model
|
||||
if self.enable_mm:
|
||||
model_output = self.model(
|
||||
self.forward_meta.ids_remove_padding,
|
||||
@@ -1877,7 +1877,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
logger.info(
|
||||
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
|
||||
)
|
||||
elif self.speculative_decoding:
|
||||
elif self.speculative_decoding and self.spec_method == SpecMethod.MTP:
|
||||
# Capture Target Model without bsz 1
|
||||
for capture_size in sorted(capture_sizes, reverse=True):
|
||||
expected_decode_len = self.speculative_config.num_speculative_tokens * 2 + 1
|
||||
@@ -2116,7 +2116,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# 2. Padding inputs for cuda graph
|
||||
self.padding_cudagraph_inputs()
|
||||
|
||||
# 3. Execute model
|
||||
if self.enable_mm:
|
||||
model_output = self.model(
|
||||
@@ -2330,7 +2329,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False),
|
||||
)
|
||||
|
||||
if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
|
||||
if self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill":
|
||||
skip_save_output = True
|
||||
else:
|
||||
skip_save_output = False
|
||||
@@ -2348,19 +2347,25 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
splitwise_role_is_decode=self.scheduler_config.splitwise_role == "decode",
|
||||
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
|
||||
is_naive_mode=(self.speculative_decoding and self.proposer is None),
|
||||
prefill_one_step_stop=self.parallel_config.prefill_one_step_stop,
|
||||
)
|
||||
|
||||
if self.guided_backend is not None and sampler_output is not None:
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids)
|
||||
|
||||
# 6. Speculative decode
|
||||
if self.speculative_decoding:
|
||||
if self.speculative_method == "mtp":
|
||||
# 6. Speculative decode -- proposer run (method="naive" has proposer=None, skip)
|
||||
# For naive mode: seq_lens_this_time is already reset to 1 inside
|
||||
# unified_update_model_status kernel. For MTP/Ngram, the proposer
|
||||
# will overwrite it with (draft_count + 1) below.
|
||||
|
||||
if self.speculative_decoding and self.proposer is not None:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.run(
|
||||
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
|
||||
)
|
||||
elif self.speculative_method == "suffix":
|
||||
self.proposer.run(share_inputs=self.share_inputs)
|
||||
elif self.spec_method == SpecMethod.NAIVE:
|
||||
pass
|
||||
else:
|
||||
self.proposer.run(share_inputs=self.share_inputs)
|
||||
|
||||
@@ -2483,7 +2488,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# TODO(gongshaotian): Optimize the management logic of kvcache
|
||||
self.num_gpu_blocks = self.cache_config.total_block_num
|
||||
self.initialize_kv_cache(profile=True)
|
||||
if self.speculative_method in ["mtp"]:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
|
||||
|
||||
# 1. Profile with multimodal encoder & encoder cache
|
||||
@@ -2499,7 +2504,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
# 3. gc
|
||||
if self.speculative_method in ["mtp"]:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.clear_mtp_cache(profile=True)
|
||||
self.clear_cache(profile=True)
|
||||
|
||||
@@ -2530,7 +2535,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
}
|
||||
)
|
||||
|
||||
if self.speculative_method in ["mtp"]:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.update_mtp_block_num(num_gpu_blocks)
|
||||
|
||||
def cal_theortical_kvcache(self):
|
||||
@@ -2561,7 +2566,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# NOTE(liuzichang): Implement multi-layer MTP architecture in the future
|
||||
num_layers = (
|
||||
self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio
|
||||
if self.speculative_method in ["mtp"]
|
||||
if self.spec_method == SpecMethod.MTP
|
||||
else self.model_config.num_hidden_layers
|
||||
)
|
||||
|
||||
@@ -2620,7 +2625,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.dynamic_weight_manager.clear_parameters(
|
||||
pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle
|
||||
)
|
||||
if self.speculative_method in ["mtp"]:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.clear_mtp_cache()
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
@@ -2646,7 +2651,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Reset share_inputs
|
||||
self.share_inputs.reset_share_inputs()
|
||||
if self.speculative_method in ["mtp"]:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.model_inputs.reset_model_inputs()
|
||||
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
|
||||
self.initialize_kv_cache()
|
||||
|
||||
Reference in New Issue
Block a user