mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Cleanup] Replace torch proxy alias with public compat API (#7348)
This commit is contained in:
@@ -84,7 +84,7 @@ def init_flash_attn_version():
|
||||
sm_version = get_sm_version()
|
||||
if sm_version >= 100:
|
||||
try:
|
||||
paddle.compat.enable_torch_proxy(scope={"cutlass"})
|
||||
paddle.enable_compat(scope={"cutlass"})
|
||||
from flash_mask.cute.interface import flashmask_attention as fa4
|
||||
|
||||
global flashmask_attention_v4
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import paddle
|
||||
|
||||
paddle.enable_compat(scope={"flash_mla"}) # Enable torch proxy before importing flash_mla
|
||||
paddle.enable_compat(scope={"flash_mla"}) # Enable paddle.enable_compat before importing flash_mla
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -805,9 +805,9 @@ def enable_batch_invariant_mode():
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
if hasattr(paddle, "compat") and hasattr(paddle.compat, "enable_torch_proxy"):
|
||||
paddle.compat.enable_torch_proxy()
|
||||
# TODO(liujundong): Enabling torch proxy here has a global effect.
|
||||
if hasattr(paddle, "enable_compat"):
|
||||
paddle.enable_compat()
|
||||
# TODO(liujundong): Enabling paddle.enable_compat() here has a global effect.
|
||||
# Do NOT call this function from module import time,
|
||||
# otherwise it may affect other test cases during pytest collection.
|
||||
# (ex: Could not import module 'PretrainedTokenizer' or No module named 'paddle.distributed.tensor')
|
||||
|
||||
@@ -39,8 +39,8 @@ def load_deep_ep() -> ModuleType:
|
||||
|
||||
try:
|
||||
if envs.FD_USE_PFCC_DEEP_EP:
|
||||
# Enable torch proxy before importing deep_ep (required by PFCC/PaddleFleet variants)
|
||||
paddle.compat.enable_torch_proxy(scope={"deep_ep"})
|
||||
# Enable paddle.enable_compat before importing deep_ep (required by PFCC/PaddleFleet variants)
|
||||
paddle.enable_compat(scope={"deep_ep"})
|
||||
try:
|
||||
import paddlefleet.ops.deep_ep as deep_ep # type: ignore
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ def load_deep_gemm():
|
||||
if current_platform.is_cuda():
|
||||
if get_sm_version() >= 100:
|
||||
# SM100 should use PFCC DeepGemm
|
||||
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
|
||||
paddle.enable_compat(scope={"deep_gemm"})
|
||||
try:
|
||||
import logging
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from fastdeploy.utils import get_logger
|
||||
from ..moe import FusedMoE
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
|
||||
paddle.enable_compat(scope={"flashinfer"})
|
||||
|
||||
logger = get_logger("config", "config.log")
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from .quant_base import QuantConfigBase, QuantMethodBase, is_nvfp4_supported
|
||||
|
||||
# Only import flashinfer on supported GPUs (B卡)
|
||||
if is_nvfp4_supported():
|
||||
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
|
||||
paddle.enable_compat(scope={"flashinfer"})
|
||||
|
||||
from flashinfer import fp4_quantize, mm_fp4
|
||||
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
||||
|
||||
Reference in New Issue
Block a user