diff --git a/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.cu index 763eb5d109..ab8faac1e9 100644 --- a/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.cu +++ b/custom_ops/gpu_ops/moe/moe_fast_hardamard_kernel.cu @@ -39,10 +39,18 @@ void MoeFastHardamardWrapper(const T *x_data, bool FLAGS_hardamard_use_diagonal_block_matrix = true; constexpr int kThreads = 128; + if (FLAGS_hardamard_use_diagonal_block_matrix) { - const int VecSize = hadamard_block_size / kThreads; + // Force effective_block_size to be at least 128 to prevent VecSize from + // being 0 when hadamard_block_size < 128 (since VecSize = + // hadamard_block_size / kThreads) + const int effective_block_size = + (hadamard_block_size < 128) ? 128 : hadamard_block_size; + + const int VecSize = effective_block_size / kThreads; const int logN = int(ceil(std::log2(kThreads * VecSize))); constexpr int kNChunks = 1; + DISPATCH_SP_VS(VecSize, VEC_SIZE, {DISPATCH_SP_logN(logN, kLogN, { MoeFastHardamardImplWrapper