[Feature] Support NVFP4 Flashinfer-cutedsl MoE on SM100 (#6963)

This commit is contained in:
mpgemm
2026-03-30 11:37:04 +08:00
committed by GitHub
parent 61a9079c60
commit 1a1d048774
9 changed files with 1249 additions and 75 deletions
@@ -172,17 +172,19 @@ std::vector<paddle::Tensor> DepermutePrefillCombine(
case paddle::DataType::FLOAT8_E4M3FN: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4)
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 6)
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
PD_THROW("Unsupported topk value, must be 4, 6 or 8");
}
}
case paddle::DataType::BFLOAT16: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4)
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 6)
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
PD_THROW("Unsupported topk value, must be 4, 6 or 8");
}
}
default:
@@ -217,10 +217,12 @@ std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 6)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
PD_THROW("Unsupported topk value, must be 4 or 6 or 8");
}
}
case paddle::DataType::INT32: {