mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
This reverts commit c5eb6b65e7.
This commit is contained in:
@@ -22,7 +22,6 @@ from utils import FakeModelConfig, get_default_test_fd_config
|
||||
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||
from fastdeploy.spec_decode.mtp import MTPProposer
|
||||
|
||||
|
||||
@@ -113,30 +112,6 @@ class TestMTPProposer(unittest.TestCase):
|
||||
"is_block_step": paddle.zeros([2], dtype="bool"),
|
||||
"actual_draft_token_num": paddle.zeros([2], dtype="int32"),
|
||||
}
|
||||
self.update_model_inputs_rope()
|
||||
|
||||
def update_model_inputs_rope(self):
|
||||
if self.fd_config.model_config.enable_mm:
|
||||
self.target_model_inputs["rope_emb"] = paddle.full(
|
||||
shape=[
|
||||
self.fd_config.scheduler_config.max_num_seqs,
|
||||
2,
|
||||
1,
|
||||
self.fd_config.model_config.max_model_len,
|
||||
1,
|
||||
self.fd_config.model_config.head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype="float32",
|
||||
)
|
||||
else:
|
||||
self.target_model_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.fd_config.model_config.head_dim,
|
||||
position_ids=paddle.arange(self.fd_config.model_config.max_model_len).reshape((1, -1)),
|
||||
base=self.fd_config.model_config.rope_theta,
|
||||
model_config=self.fd_config.model_config,
|
||||
partial_rotary_factor=self.fd_config.model_config.partial_rotary_factor,
|
||||
)
|
||||
|
||||
@patch("fastdeploy.spec_decode.mtp.get_model_loader")
|
||||
@patch("fastdeploy.spec_decode.mtp.get_attention_backend")
|
||||
@@ -629,7 +604,6 @@ class TestMTPProposer(unittest.TestCase):
|
||||
mock_rope.return_value = paddle.zeros([1, 2048, 64])
|
||||
|
||||
self.fd_config.model_config.enable_mm = True
|
||||
self.update_model_inputs_rope()
|
||||
proposer = MTPProposer(
|
||||
self.fd_config, self.main_model, self.local_rank, self.device_id, self.target_model_inputs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user