[Feature][BugFix][OP] Enhance Deterministic Inference Mode with Kernel-level Fixes and Batch-invariant BMM (#6610)

* add fa deter

* add ut

* add long sentence

* fix basic

* fix bugs

* fix adn

* fix first

* fix single

* fix single

* fix single test

* refine

* add more test

* refine comments

* add comments of bmm

* fix ci

* remove probe

* add

* remove not need

* refine tests

* fix comments and refine code

* refine code

* refine test

* refine test

* mv 4cards tests

* fix tests

* add

* fix comments

* fix cover

* fix cover

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-09 10:27:53 +08:00
committed by GitHub
parent 3a85ecf3bc
commit 30f9f33f34
23 changed files with 3563 additions and 153 deletions
@@ -3,6 +3,7 @@ from .batch_invariant_ops import (
disable_batch_invariant_mode,
enable_batch_invariant_mode,
get_batch_invariant_attention_block_size,
init_deterministic_mode,
is_batch_invariant_mode_enabled,
log_softmax,
matmul_persistent,
@@ -17,6 +18,7 @@ __all__ = [
"is_batch_invariant_mode_enabled",
"disable_batch_invariant_mode",
"enable_batch_invariant_mode",
"init_deterministic_mode",
"matmul_persistent",
"log_softmax",
"mean_dim",
@@ -466,12 +466,201 @@ def mean_dim(
return output
# The bmm_kernel_persistent kernel and bmm_persistent wrapper below are adapted from
# SGLang (https://github.com/sgl-project/sglang), licensed under Apache License 2.0.
# Original source:
# sglang/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
# which itself was adapted from:
# https://github.com/thinking-machines-lab/batch_invariant_ops
# We thank the SGLang authors and the Thinking Machines Lab for their contributions.
@triton.jit # pragma: no cover
def bmm_kernel_persistent(
a_ptr,
b_ptr,
c_ptr, #
B,
M,
N,
K, #
stride_ab,
stride_am,
stride_ak,
stride_bb,
stride_bk,
stride_bn,
stride_cb,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
NUM_SMS: tl.constexpr, #
A_LARGE: tl.constexpr,
B_LARGE: tl.constexpr,
C_LARGE: tl.constexpr,
):
"""
Batched matrix multiplication kernel that processes batches in parallel.
Each tile processes a (BLOCK_SIZE_M, BLOCK_SIZE_N) output block for a specific batch.
Uses persistent kernel approach with fixed tile traversal order for determinism.
"""
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles_per_batch = num_pid_m * num_pid_n
num_tiles_total = B * num_tiles_per_batch
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Process tiles in a deterministic order: batch-major ordering
for tile_id in tl.range(start_pid, num_tiles_total, NUM_SMS, flatten=True):
# Decompose tile_id into batch and within-batch tile
batch_idx = tile_id // num_tiles_per_batch
tile_in_batch = tile_id % num_tiles_per_batch
pid_m, pid_n = _compute_pid(tile_in_batch, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
if A_LARGE:
offs_am = offs_am.to(tl.int64)
if B_LARGE:
offs_bn = offs_bn.to(tl.int64)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
# Add batch offset
if A_LARGE or B_LARGE:
batch_idx_typed = batch_idx.to(tl.int64)
else:
batch_idx_typed = batch_idx
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
if A_LARGE or B_LARGE:
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
else:
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (batch_idx_typed * stride_ab + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (batch_idx_typed * stride_bb + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if C_LARGE:
offs_cm = offs_cm.to(tl.int64)
offs_cn = offs_cn.to(tl.int64)
c_ptrs = c_ptr + batch_idx_typed * stride_cb + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.store(c_ptrs, c, mask=c_mask)
def bmm_persistent(a: paddle.Tensor, b: paddle.Tensor) -> paddle.Tensor:
"""Batch-invariant batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)"""
assert a.ndim == 3 and b.ndim == 3, f"bmm_persistent expects 3D tensors, got shapes {a.shape} and {b.shape}"
assert a.shape[0] == b.shape[0], "Batch sizes must match"
assert a.shape[2] == b.shape[1], "Incompatible dimensions"
assert a.dtype == b.dtype, f"Incompatible dtypes: a={a.dtype}, b={b.dtype}"
B = a.shape[0]
M = a.shape[1]
K = a.shape[2]
N = b.shape[2]
dtype = a.dtype
NUM_SMS = get_compute_units()
c = paddle.empty((B, M, N), dtype=dtype)
configs = {
paddle.bfloat16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
paddle.float16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
paddle.float32: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
}
config = configs.get(dtype)
if config is None:
raise ValueError(
f"Unsupported dtype {dtype} for bmm_persistent. " f"Supported dtypes are: {list(configs.keys())}"
)
num_tiles_per_batch = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(N, config["BLOCK_SIZE_N"])
num_tiles_total = B * num_tiles_per_batch
grid = (min(NUM_SMS, num_tiles_total),)
bmm_kernel_persistent[grid](
a,
b,
c, #
B,
M,
N,
K, #
a.stride(0),
a.stride(1),
a.stride(2), #
b.stride(0),
b.stride(1),
b.stride(2), #
c.stride(0),
c.stride(1),
c.stride(2), #
NUM_SMS=NUM_SMS, #
# Use element counts instead of numel() to avoid cudaErrorStreamCaptureImplicit
# during CUDA Graph capture
A_LARGE=int(B * M * K > 2**31),
B_LARGE=int(B * K * N > 2**31),
C_LARGE=int(B * M * N > 2**31),
**config,
)
return c
def bmm_batch_invariant(x, y):
"""Drop-in replacement for paddle._C_ops.bmm"""
return bmm_persistent(x, y)
def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False, out=None):
if transpose_x:
a = a.T
if transpose_y:
b = b.T
return matmul_persistent(a, b)
result = matmul_persistent(a, b)
return result
def addmm_batch_invariant(
@@ -522,7 +711,7 @@ def mean_batch_invariant(
return result
_original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None}
_original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None, "bmm": None}
_batch_invariant_MODE = False
@@ -552,15 +741,26 @@ def enable_batch_invariant_mode():
_original_ops["addmm"] = paddle._C_ops.addmm
_original_ops["log_softmax"] = paddle._C_ops.log_softmax
_original_ops["mean"] = paddle._C_ops.mean
_original_ops["bmm"] = paddle._C_ops.bmm
paddle._C_ops.matmul = mm_batch_invariant
paddle._C_ops.addmm = addmm_batch_invariant
paddle._C_ops.log_softmax = _log_softmax_batch_invariant
paddle._C_ops.mean = mean_batch_invariant
paddle._C_ops.bmm = bmm_batch_invariant
_batch_invariant_MODE = True
def init_deterministic_mode():
"""One-stop initialization for deterministic mode.
Call after worker creation but before model loading.
"""
if not is_batch_invariant_mode_enabled():
enable_batch_invariant_mode()
def disable_batch_invariant_mode():
global _batch_invariant_MODE, _original_ops
if not _batch_invariant_MODE:
@@ -574,6 +774,8 @@ def disable_batch_invariant_mode():
paddle._C_ops.log_softmax = _original_ops["log_softmax"]
if _original_ops["mean"]:
paddle._C_ops.mean = _original_ops["mean"]
if _original_ops["bmm"]:
paddle._C_ops.bmm = _original_ops["bmm"]
_batch_invariant_MODE = False