diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index ddebd15d41..a638ac33c7 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -140,8 +140,6 @@ class AppendAttentionBackend(AttentionBackend): self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( fd_config.model_config, "use_3d_rope", False ) - if fd_config.speculative_config.model_type != "main": - self.rope_3d = False self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method: str = fd_config.speculative_config.method self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 92d03f2071..9071453ab9 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -709,15 +709,7 @@ class ProposerInputBatch(InputBatch): dtype="bfloat16", ) - tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) - - self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, - position_ids=tmp_position_ids, - base=self.model_config.rope_theta, - model_config=self.model_config, - partial_rotary_factor=self.model_config.partial_rotary_factor, - ) + self.rope_emb = paddle.clone(self.target_model_input_batch["rope_emb"]) # self.caches = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency @@ -915,14 +907,35 @@ class ProposerInputBatch(InputBatch): fill_paddle_tensor(self, "target_hidden_states", 0) # Reset rope embedding by recreating with default position_ids - tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) - self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, - position_ids=tmp_position_ids, - base=self.model_config.rope_theta, - model_config=self.model_config, - partial_rotary_factor=self.model_config.partial_rotary_factor, - ) + if self.enable_mm: + head_dim = self.model_config.head_dim + if "qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type: + rope_head_dim = head_dim + else: + rope_head_dim = head_dim // 2 + + self.rope_emb = paddle.full( + shape=[ + self.scheduler_config.max_num_seqs, + 2, + 1, + self.model_config.max_model_len, + 1, + rope_head_dim, + ], + fill_value=0, + dtype="float32", + ) + self.image_features = None + self.image_features_list = None + else: + self.rope_emb = get_rope( + rotary_dim=self.model_config.head_dim, + position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), + base=self.model_config.rope_theta, + model_config=self.model_config, + partial_rotary_factor=self.model_config.partial_rotary_factor, + ) # Reset generation hyperparameters from the main model self.top_p = self.target_model_input_batch["top_p"] diff --git a/tests/spec_decode/test_mtp_proposer.py b/tests/spec_decode/test_mtp_proposer.py index 6679ce4286..ca7758f659 100644 --- a/tests/spec_decode/test_mtp_proposer.py +++ b/tests/spec_decode/test_mtp_proposer.py @@ -22,6 +22,7 @@ from utils import FakeModelConfig, get_default_test_fd_config from fastdeploy.config import SpeculativeConfig from fastdeploy.engine.request import Request, RequestType +from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.spec_decode.mtp import MTPProposer @@ -112,6 +113,30 @@ class TestMTPProposer(unittest.TestCase): "is_block_step": paddle.zeros([2], dtype="bool"), "actual_draft_token_num": paddle.zeros([2], dtype="int32"), } + self.update_model_inputs_rope() + + def update_model_inputs_rope(self): + if self.fd_config.model_config.enable_mm: + self.target_model_inputs["rope_emb"] = paddle.full( + shape=[ + self.fd_config.scheduler_config.max_num_seqs, + 2, + 1, + self.fd_config.model_config.max_model_len, + 1, + self.fd_config.model_config.head_dim, + ], + fill_value=0, + dtype="float32", + ) + else: + self.target_model_inputs["rope_emb"] = get_rope( + rotary_dim=self.fd_config.model_config.head_dim, + position_ids=paddle.arange(self.fd_config.model_config.max_model_len).reshape((1, -1)), + base=self.fd_config.model_config.rope_theta, + model_config=self.fd_config.model_config, + partial_rotary_factor=self.fd_config.model_config.partial_rotary_factor, + ) @patch("fastdeploy.spec_decode.mtp.get_model_loader") @patch("fastdeploy.spec_decode.mtp.get_attention_backend") @@ -604,6 +629,7 @@ class TestMTPProposer(unittest.TestCase): mock_rope.return_value = paddle.zeros([1, 2048, 64]) self.fd_config.model_config.enable_mm = True + self.update_model_inputs_rope() proposer = MTPProposer( self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs )