mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Support mtp expert-parallel and support different modality deploy (#7018)
* support mtp ep and support different modality * fix default arg
This commit is contained in:
@@ -17,7 +17,13 @@
|
||||
import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
DeployModality,
|
||||
FDConfig,
|
||||
ModelConfig,
|
||||
SpeculativeConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||
from fastdeploy.model_executor.logits_processor import build_logits_processors
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -829,20 +835,23 @@ class ProposerInputBatch(InputBatch):
|
||||
)
|
||||
# attn_mask
|
||||
if self.enable_mm:
|
||||
self.attn_mask_offsets = paddle.full(
|
||||
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
|
||||
fill_value=-1,
|
||||
dtype="int32",
|
||||
)
|
||||
self.attn_mask_offsets_full = paddle.full(
|
||||
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
|
||||
)
|
||||
self.attn_mask_offsets_decoder = paddle.full([self.scheduler_config.max_num_seqs, 1], -1, dtype="int32")
|
||||
self.decode_states = paddle.full(
|
||||
[self.scheduler_config.max_num_seqs, self.speculative_config.num_speculative_tokens + 1],
|
||||
-1,
|
||||
dtype="int32",
|
||||
)
|
||||
if self.fd_config.deploy_modality != DeployModality.TEXT:
|
||||
self.attn_mask_offsets = paddle.full(
|
||||
shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len],
|
||||
fill_value=-1,
|
||||
dtype="int32",
|
||||
)
|
||||
self.attn_mask_offsets_full = paddle.full(
|
||||
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32"
|
||||
)
|
||||
self.attn_mask_offsets_decoder = paddle.full(
|
||||
[self.scheduler_config.max_num_seqs, 1], -1, dtype="int32"
|
||||
)
|
||||
|
||||
def swap_states(self, i1, i2) -> None:
|
||||
def swap_data(tensor, idx1, idx2):
|
||||
@@ -864,7 +873,7 @@ class ProposerInputBatch(InputBatch):
|
||||
swap_data(self.input_ids_len, i1, i2)
|
||||
swap_data(self.mask_rollback, i1, i2)
|
||||
swap_data(self.recompute_token_num, i1, i2)
|
||||
if self.enable_mm:
|
||||
if self.enable_mm and self.fd_config.deploy_modality != DeployModality.TEXT:
|
||||
swap_data(self.attn_mask_offsets_full, i1, i2)
|
||||
swap_data(self.attn_mask_offsets_decoder, i1, i2)
|
||||
|
||||
@@ -998,10 +1007,11 @@ class ProposerInputBatch(InputBatch):
|
||||
|
||||
# Reset multimodal tensors if enabled
|
||||
if self.enable_mm:
|
||||
fill_paddle_tensor(self, "attn_mask_offsets", -1)
|
||||
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
|
||||
fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
|
||||
fill_paddle_tensor(self, "decode_states", -1)
|
||||
if self.fd_config.deploy_modality != DeployModality.TEXT:
|
||||
fill_paddle_tensor(self, "attn_mask_offsets", -1)
|
||||
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
|
||||
fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1)
|
||||
|
||||
logger.info("model_inputs reset completed")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user