Split enable_mm (#7183)

Co-authored-by: liuruian <liuruian@MacBook-Pro.local>
This commit is contained in:
K11OntheBoat
2026-04-08 11:25:41 +08:00
committed by GitHub
parent 8496ec71a6
commit bb48bcbaa2
33 changed files with 109 additions and 69 deletions
@@ -138,9 +138,7 @@ class AppendAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True)
@@ -136,7 +136,7 @@ class DSAAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
@@ -269,9 +269,7 @@ class FlashAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
# Note(ZKK): here must be consistent with append_attn_backend.py
@@ -123,9 +123,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
@@ -263,7 +263,7 @@ class MLAAttentionBackend(AttentionBackend):
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.rope_3d: bool = fd_config.enable_rope_3d_runtime
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None