mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Feature][BugFix][OP] Enhance Deterministic Inference Mode with Kernel-level Fixes and Batch-invariant BMM (#6610)
* add fa deter * add ut * add long sentence * fix basic * fix bugs * fix adn * fix first * fix single * fix single * fix single test * refine * add more test * refine comments * add comments of bmm * fix ci * remove probe * add * remove not need * refine tests * fix comments and refine code * refine code * refine test * refine test * mv 4cards tests * fix tests * add * fix comments * fix cover * fix cover --------- Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from .batch_invariant_ops import (
|
||||
disable_batch_invariant_mode,
|
||||
enable_batch_invariant_mode,
|
||||
get_batch_invariant_attention_block_size,
|
||||
init_deterministic_mode,
|
||||
is_batch_invariant_mode_enabled,
|
||||
log_softmax,
|
||||
matmul_persistent,
|
||||
@@ -17,6 +18,7 @@ __all__ = [
|
||||
"is_batch_invariant_mode_enabled",
|
||||
"disable_batch_invariant_mode",
|
||||
"enable_batch_invariant_mode",
|
||||
"init_deterministic_mode",
|
||||
"matmul_persistent",
|
||||
"log_softmax",
|
||||
"mean_dim",
|
||||
|
||||
@@ -466,12 +466,201 @@ def mean_dim(
|
||||
return output
|
||||
|
||||
|
||||
# The bmm_kernel_persistent kernel and bmm_persistent wrapper below are adapted from
|
||||
# SGLang (https://github.com/sgl-project/sglang), licensed under Apache License 2.0.
|
||||
# Original source:
|
||||
# sglang/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
|
||||
# which itself was adapted from:
|
||||
# https://github.com/thinking-machines-lab/batch_invariant_ops
|
||||
# We thank the SGLang authors and the Thinking Machines Lab for their contributions.
|
||||
|
||||
|
||||
@triton.jit # pragma: no cover
|
||||
def bmm_kernel_persistent(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr, #
|
||||
B,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
stride_ab,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bb,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cb,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, #
|
||||
BLOCK_SIZE_N: tl.constexpr, #
|
||||
BLOCK_SIZE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
NUM_SMS: tl.constexpr, #
|
||||
A_LARGE: tl.constexpr,
|
||||
B_LARGE: tl.constexpr,
|
||||
C_LARGE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Batched matrix multiplication kernel that processes batches in parallel.
|
||||
Each tile processes a (BLOCK_SIZE_M, BLOCK_SIZE_N) output block for a specific batch.
|
||||
Uses persistent kernel approach with fixed tile traversal order for determinism.
|
||||
"""
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles_per_batch = num_pid_m * num_pid_n
|
||||
num_tiles_total = B * num_tiles_per_batch
|
||||
|
||||
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
||||
# Process tiles in a deterministic order: batch-major ordering
|
||||
for tile_id in tl.range(start_pid, num_tiles_total, NUM_SMS, flatten=True):
|
||||
# Decompose tile_id into batch and within-batch tile
|
||||
batch_idx = tile_id // num_tiles_per_batch
|
||||
tile_in_batch = tile_id % num_tiles_per_batch
|
||||
|
||||
pid_m, pid_n = _compute_pid(tile_in_batch, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
|
||||
start_m = pid_m * BLOCK_SIZE_M
|
||||
start_n = pid_n * BLOCK_SIZE_N
|
||||
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
|
||||
if A_LARGE:
|
||||
offs_am = offs_am.to(tl.int64)
|
||||
if B_LARGE:
|
||||
offs_bn = offs_bn.to(tl.int64)
|
||||
offs_am = tl.where(offs_am < M, offs_am, 0)
|
||||
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
||||
|
||||
# Add batch offset
|
||||
if A_LARGE or B_LARGE:
|
||||
batch_idx_typed = batch_idx.to(tl.int64)
|
||||
else:
|
||||
batch_idx_typed = batch_idx
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for ki in range(k_tiles):
|
||||
if A_LARGE or B_LARGE:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||
else:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
a_ptrs = a_ptr + (batch_idx_typed * stride_ab + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (batch_idx_typed * stride_bb + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if C_LARGE:
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = c_ptr + batch_idx_typed * stride_cb + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def bmm_persistent(a: paddle.Tensor, b: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Batch-invariant batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)"""
|
||||
assert a.ndim == 3 and b.ndim == 3, f"bmm_persistent expects 3D tensors, got shapes {a.shape} and {b.shape}"
|
||||
assert a.shape[0] == b.shape[0], "Batch sizes must match"
|
||||
assert a.shape[2] == b.shape[1], "Incompatible dimensions"
|
||||
assert a.dtype == b.dtype, f"Incompatible dtypes: a={a.dtype}, b={b.dtype}"
|
||||
|
||||
B = a.shape[0]
|
||||
M = a.shape[1]
|
||||
K = a.shape[2]
|
||||
N = b.shape[2]
|
||||
dtype = a.dtype
|
||||
|
||||
NUM_SMS = get_compute_units()
|
||||
c = paddle.empty((B, M, N), dtype=dtype)
|
||||
|
||||
configs = {
|
||||
paddle.bfloat16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
paddle.float16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
paddle.float32: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
}
|
||||
|
||||
config = configs.get(dtype)
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
f"Unsupported dtype {dtype} for bmm_persistent. " f"Supported dtypes are: {list(configs.keys())}"
|
||||
)
|
||||
|
||||
num_tiles_per_batch = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(N, config["BLOCK_SIZE_N"])
|
||||
num_tiles_total = B * num_tiles_per_batch
|
||||
grid = (min(NUM_SMS, num_tiles_total),)
|
||||
|
||||
bmm_kernel_persistent[grid](
|
||||
a,
|
||||
b,
|
||||
c, #
|
||||
B,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
a.stride(2), #
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
b.stride(2), #
|
||||
c.stride(0),
|
||||
c.stride(1),
|
||||
c.stride(2), #
|
||||
NUM_SMS=NUM_SMS, #
|
||||
# Use element counts instead of numel() to avoid cudaErrorStreamCaptureImplicit
|
||||
# during CUDA Graph capture
|
||||
A_LARGE=int(B * M * K > 2**31),
|
||||
B_LARGE=int(B * K * N > 2**31),
|
||||
C_LARGE=int(B * M * N > 2**31),
|
||||
**config,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def bmm_batch_invariant(x, y):
|
||||
"""Drop-in replacement for paddle._C_ops.bmm"""
|
||||
return bmm_persistent(x, y)
|
||||
|
||||
|
||||
def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False, out=None):
|
||||
if transpose_x:
|
||||
a = a.T
|
||||
if transpose_y:
|
||||
b = b.T
|
||||
return matmul_persistent(a, b)
|
||||
result = matmul_persistent(a, b)
|
||||
return result
|
||||
|
||||
|
||||
def addmm_batch_invariant(
|
||||
@@ -522,7 +711,7 @@ def mean_batch_invariant(
|
||||
return result
|
||||
|
||||
|
||||
_original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None}
|
||||
_original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None, "bmm": None}
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
|
||||
@@ -552,15 +741,26 @@ def enable_batch_invariant_mode():
|
||||
_original_ops["addmm"] = paddle._C_ops.addmm
|
||||
_original_ops["log_softmax"] = paddle._C_ops.log_softmax
|
||||
_original_ops["mean"] = paddle._C_ops.mean
|
||||
_original_ops["bmm"] = paddle._C_ops.bmm
|
||||
|
||||
paddle._C_ops.matmul = mm_batch_invariant
|
||||
paddle._C_ops.addmm = addmm_batch_invariant
|
||||
paddle._C_ops.log_softmax = _log_softmax_batch_invariant
|
||||
paddle._C_ops.mean = mean_batch_invariant
|
||||
paddle._C_ops.bmm = bmm_batch_invariant
|
||||
|
||||
_batch_invariant_MODE = True
|
||||
|
||||
|
||||
def init_deterministic_mode():
|
||||
"""One-stop initialization for deterministic mode.
|
||||
|
||||
Call after worker creation but before model loading.
|
||||
"""
|
||||
if not is_batch_invariant_mode_enabled():
|
||||
enable_batch_invariant_mode()
|
||||
|
||||
|
||||
def disable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _original_ops
|
||||
if not _batch_invariant_MODE:
|
||||
@@ -574,6 +774,8 @@ def disable_batch_invariant_mode():
|
||||
paddle._C_ops.log_softmax = _original_ops["log_softmax"]
|
||||
if _original_ops["mean"]:
|
||||
paddle._C_ops.mean = _original_ops["mean"]
|
||||
if _original_ops["bmm"]:
|
||||
paddle._C_ops.bmm = _original_ops["bmm"]
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user