mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Speculative Decoding] Support mtp expert-parallel and support different modality deploy (#7018)
* support mtp ep and support different modality * fix default arg
This commit is contained in:
@@ -102,6 +102,8 @@ class MTPProposer(Proposer):
|
||||
self.num_main_model_layers = self.model_config.num_hidden_layers
|
||||
self.local_rank = local_rank
|
||||
self.device_id = device_id
|
||||
self.use_attn_mask_offset = self.enable_mm and self.fd_config.deploy_modality != "text"
|
||||
|
||||
self._update_mtp_config(main_model)
|
||||
self._load_model()
|
||||
self.target_model_inputs = target_model_inputs
|
||||
@@ -162,6 +164,8 @@ class MTPProposer(Proposer):
|
||||
self.model_config.quantization = self.speculative_config.quantization
|
||||
self.model_config.start_layer_index = self.num_main_model_layers
|
||||
self.speculative_config.model_type = "mtp"
|
||||
if not self.use_attn_mask_offset:
|
||||
self.model_config.causal = True
|
||||
|
||||
def _load_model(self):
|
||||
"""
|
||||
@@ -503,7 +507,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
)
|
||||
if self.enable_mm:
|
||||
if self.use_attn_mask_offset:
|
||||
inputs = request.multimodal_inputs
|
||||
self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = (
|
||||
paddle.to_tensor(
|
||||
@@ -662,7 +666,7 @@ class MTPProposer(Proposer):
|
||||
kv_batch_ids=self.model_inputs["kv_batch_ids"],
|
||||
kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"],
|
||||
kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"],
|
||||
attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.enable_mm else None,
|
||||
attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None,
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
@@ -888,7 +892,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
)
|
||||
|
||||
if self.enable_mm:
|
||||
if self.use_attn_mask_offset:
|
||||
attn_mask_offsets = update_attn_mask_offsets(
|
||||
ids_remove_padding,
|
||||
getattr(
|
||||
|
||||
Reference in New Issue
Block a user