【Fix bug] w4afp8 的nblock固定为256,并且fa3的append attn 增加mask参数 (#3771)

* fix w4afp8

* 增加集中式配置

* codestyle

* fix fa3 append attn
This commit is contained in:
yangjianfengo1
2025-09-02 19:17:01 +08:00
committed by GitHub
parent b6a4115369
commit 8e1b35a09b
3 changed files with 4 additions and 5 deletions
@@ -75,12 +75,8 @@ void DisPatchW4AFp8Gemm(
const int64_t K,
cudaStream_t stream) {
int kBlockN = (max_tokens + 15) / 16 * 16;
int kBlockN = 256;
int TailN = 0;
if (kBlockN > 256) {
TailN = kBlockN % 256;
kBlockN = 256;
}
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
GEMM_SWITCH_BF16(
M, K, batch_size, token_padding_size, kBlockN, TailN,