[BugFix] fix w4afp8 tp=8 (#5868)

* fix w4afp8 tp=8

* fix
This commit is contained in:
lizexu123
2026-01-05 18:59:02 +08:00
committed by GitHub
parent 6f14b180e3
commit 1d3ae7c024
2 changed files with 10 additions and 1 deletions
@@ -39,10 +39,18 @@ void MoeFastHardamardWrapper(const T *x_data,
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
constexpr int kThreads = 128;
if (FLAGS_hardamard_use_diagonal_block_matrix) {
const int VecSize = hadamard_block_size / kThreads;
// Force effective_block_size to be at least 128 to prevent VecSize from
// being 0 when hadamard_block_size < 128 (since VecSize =
// hadamard_block_size / kThreads)
const int effective_block_size =
(hadamard_block_size < 128) ? 128 : hadamard_block_size;
const int VecSize = effective_block_size / kThreads;
const int logN = int(ceil(std::log2(kThreads * VecSize)));
constexpr int kNChunks = 1;
DISPATCH_SP_VS(VecSize, VEC_SIZE, {DISPATCH_SP_logN(logN, kLogN, {
MoeFastHardamardImplWrapper<T,
OutT,