[SOT] Change warnings to errors and remove fallback operations (#4378)

* Change warnings to errors and remove fallback operations

* fix unitest

* fix codestyle
This commit is contained in:
Ryan
2025-10-17 11:27:04 +08:00
committed by GitHub
parent 0413c32b8f
commit 6160145f82
3 changed files with 27 additions and 28 deletions
@@ -21,7 +21,6 @@ from typing import Callable, Optional, TypeVar, get_type_hints
from paddle.jit import sot
from paddle.jit.dy2static.utils import Backend as ToStaticBackend
from paddleformers.utils.log import logger
from typing_extensions import ParamSpec
from fastdeploy.config import FDConfig
@@ -46,11 +45,10 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
forward_type_hints = get_type_hints(forward_fn)
static_forward_fn = sot.symbolic_translate(forward_fn, training=False, backend=backend)
unsafe_static_forward_fn = None
need_warmup = True
@functools.wraps(forward_fn)
def warmup_impl(self, *args, **kwargs):
nonlocal unsafe_static_forward_fn, need_warmup
nonlocal unsafe_static_forward_fn
bound_args = forward_sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
for name, arg in bound_args.arguments.items():
@@ -66,17 +64,11 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
]
# Check has only one graph
if len(new_guarded_codes) > 1:
logger.warning("Model has multiple generated code, please check all dynamic dim has marked.")
unsafe_static_forward_fn = None
need_warmup = False
return result
raise RuntimeError("Model has multiple generated code, please check all dynamic dim has marked.")
# Check generated code has no break graph
new_code = new_guarded_codes[0][0][0]
if any(name.startswith("$") for name in new_code.co_names): # TODO(SigureMo): It's a internal impl
logger.warning("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
unsafe_static_forward_fn = None
need_warmup = False
return result
raise RuntimeError("Model has breakgraph, please set env SOT_LOG_LEVEL=3 to check it.")
unsafe_static_forward_fn = types.FunctionType(
new_code,
forward_fn.__globals__,
@@ -88,15 +80,12 @@ def apply_to_static_optimization(fn: Callable[P, T], backend: ToStaticBackend) -
@functools.wraps(forward_fn)
def static_forward(self, *args, **kwargs):
nonlocal unsafe_static_forward_fn
if in_profile_run_mode():
return forward_fn(self, *args, **kwargs)
nonlocal need_warmup
is_warmup = in_warmup_mode() and need_warmup
if is_warmup:
if in_warmup_mode():
return warmup_impl(self, *args, **kwargs)
nonlocal unsafe_static_forward_fn
if unsafe_static_forward_fn is None:
return static_forward_fn(self, *args, **kwargs)
assert unsafe_static_forward_fn is not None
return unsafe_static_forward_fn(self, *args, **kwargs)
return static_forward
@@ -32,6 +32,7 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.graph_optimization.utils import sot_warmup_guard
@support_graph_optimization
@@ -46,7 +47,7 @@ class Attention(nn.Layer):
def forward(
self,
ids_remove_padding,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
hidden_states = self.embed_tokens(forward_meta.ids_remove_padding)
@@ -58,7 +59,7 @@ class Attention(nn.Layer):
def forward_dynamic(
self,
ids_remove_padding,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
hidden_states = self.embed_tokens(forward_meta.ids_remove_padding)
@@ -164,15 +165,23 @@ class TestGraphOptBackend(unittest.TestCase):
"""
test_model = Attention(fd_config=fd_config, **self.model_config)
with sot_warmup_guard(True):
_ = test_model(ids_remove_padding=self.input_tensor, forward_meta=self.forward_meta)
# Run model test
output = test_model(ids_remove_padding=self.input_tensor, forward_meta=self.forward_meta)
# Validate results if comparison is requested
if compare_with_baseline:
np.testing.assert_allclose(
self.baseline_result, output.numpy(), err_msg=f"Test {test_name} failed: output mismatch"
self.baseline_result,
output.numpy(),
err_msg=f"Test {test_name} failed: output mismatch",
atol=1e-6, # for CINN
)
paddle.jit.sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().clear()
def test_dynamic_graph(self):
"""Test dynamic graph mode"""
fd_config = self._setup_test_config(graph_opt_level=0, use_cudagraph=False)
@@ -14,17 +14,17 @@
# limitations under the License.
"""
import os
os.environ["FLAGS_cuda_graph_blacklist"] = "pd_op.matmul,pd_op.transpose"
import unittest
from unittest.mock import Mock
import paddle
import paddle.nn as nn
from fastdeploy.model_executor.graph_optimization.utils import sot_warmup_guard
paddle.set_flags({"FLAGS_cuda_graph_blacklist": "pd_op.matmul,pd_op.transpose"})
from fastdeploy.config import (
CacheConfig,
FDConfig,
@@ -77,10 +77,10 @@ class TestModel(nn.Layer):
super().__init__()
self.model = Attention(fd_config)
def forward(self, ids_remove_padding, forward_meta: ForwardMeta):
def forward(self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta):
return self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta):
def forward_correct(self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta):
return self.model.forward_dynamic(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
@@ -111,7 +111,8 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
forward_meta1 = ForwardMeta(input_ids=x, ids_remove_padding=x, step_use_cudagraph=True)
# Trigger Capture
_ = test_model1(x, forward_meta=forward_meta1)
with sot_warmup_guard(True):
_ = test_model1(x, forward_meta=forward_meta1)
# Replay
_ = test_model1(x, forward_meta=forward_meta1)