Split enable_mm (#7183)

Co-authored-by: liuruian <liuruian@MacBook-Pro.local>
This commit is contained in:
K11OntheBoat
2026-04-08 11:25:41 +08:00
committed by GitHub
parent 8496ec71a6
commit bb48bcbaa2
33 changed files with 109 additions and 69 deletions
@@ -138,9 +138,7 @@ class AppendAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True)
@@ -136,7 +136,7 @@ class DSAAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
@@ -269,9 +269,7 @@ class FlashAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
# Note(ZKK): here must be consistent with append_attn_backend.py
@@ -123,9 +123,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
@@ -263,7 +263,7 @@ class MLAAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
@@ -89,7 +89,7 @@ class MhaAttnBackend(AttentionBackend):
# note: scale need to change if using MLA
self.scale = 1.0 / sqrt(head_dim)
self.dtype = paddle.get_default_dtype()
self.enable_mm = fd_config.model_config.enable_mm
self.enable_mm = fd_config.enable_mm_runtime
self.rope_batch_stride = self.max_context_len * self.head_dim if self.enable_mm else 0
if "paddleocr" in fd_config.model_config.model_type:
self.is_interleaved_rope_mode = False
@@ -219,7 +219,7 @@ class HPUAttentionBackend(AttentionBackend_HPU):
self.block_size = llm_config.cache_config.block_size
self.max_seq_len = llm_config.model_config.max_model_len
self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta
self.rope_3d = getattr(llm_config.model_config, "rope_3d", False)
self.rope_3d = llm_config.enable_rope_3d_runtime
self.causal = getattr(llm_config.model_config, "causal", True)
self.speculative_method = llm_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
@@ -101,7 +101,7 @@ class FlashAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
@@ -128,7 +128,7 @@ class FlashAttentionBackend(AttentionBackend):
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.enable_mm = fd_config.model_config.enable_mm
self.enable_mm = fd_config.enable_mm_runtime
self.model_type = fd_config.model_config.model_type
self.is_neox_style = False
if "paddleocr" in fd_config.model_config.model_type:
@@ -105,7 +105,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
@@ -88,9 +88,7 @@ class XPUAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)