[Cleanup] Replace torch proxy alias with public compat API (#7348)

This commit is contained in:
Nyako Shigure
2026-04-13 11:43:26 +08:00
committed by GitHub
parent cb03958b52
commit d659099415
20 changed files with 38 additions and 54 deletions
@@ -84,7 +84,7 @@ def init_flash_attn_version():
sm_version = get_sm_version()
if sm_version >= 100:
try:
paddle.compat.enable_torch_proxy(scope={"cutlass"})
paddle.enable_compat(scope={"cutlass"})
from flash_mask.cute.interface import flashmask_attention as fa4
global flashmask_attention_v4
@@ -18,7 +18,7 @@ from __future__ import annotations
import paddle
paddle.enable_compat(scope={"flash_mla"}) # Enable torch proxy before importing flash_mla
paddle.enable_compat(scope={"flash_mla"}) # Enable paddle.enable_compat before importing flash_mla
import math
import os
from dataclasses import dataclass, field
@@ -805,9 +805,9 @@ def enable_batch_invariant_mode():
if _batch_invariant_MODE:
return
if hasattr(paddle, "compat") and hasattr(paddle.compat, "enable_torch_proxy"):
paddle.compat.enable_torch_proxy()
# TODO(liujundong): Enabling torch proxy here has a global effect.
if hasattr(paddle, "enable_compat"):
paddle.enable_compat()
# TODO(liujundong): Enabling paddle.enable_compat() here has a global effect.
# Do NOT call this function from module import time,
# otherwise it may affect other test cases during pytest collection.
# (ex: Could not import module 'PretrainedTokenizer' or No module named 'paddle.distributed.tensor')
+2 -2
View File
@@ -39,8 +39,8 @@ def load_deep_ep() -> ModuleType:
try:
if envs.FD_USE_PFCC_DEEP_EP:
# Enable torch proxy before importing deep_ep (required by PFCC/PaddleFleet variants)
paddle.compat.enable_torch_proxy(scope={"deep_ep"})
# Enable paddle.enable_compat before importing deep_ep (required by PFCC/PaddleFleet variants)
paddle.enable_compat(scope={"deep_ep"})
try:
import paddlefleet.ops.deep_ep as deep_ep # type: ignore
@@ -68,7 +68,7 @@ def load_deep_gemm():
if current_platform.is_cuda():
if get_sm_version() >= 100:
# SM100 should use PFCC DeepGemm
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
paddle.enable_compat(scope={"deep_gemm"})
try:
import logging
@@ -35,7 +35,7 @@ from fastdeploy.utils import get_logger
from ..moe import FusedMoE
from .quant_base import QuantConfigBase, QuantMethodBase
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
paddle.enable_compat(scope={"flashinfer"})
logger = get_logger("config", "config.log")
@@ -38,7 +38,7 @@ from .quant_base import QuantConfigBase, QuantMethodBase, is_nvfp4_supported
# Only import flashinfer on supported GPUs (B卡)
if is_nvfp4_supported():
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
paddle.enable_compat(scope={"flashinfer"})
from flashinfer import fp4_quantize, mm_fp4
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe