mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support NVFP4 Flashinfer-cutedsl MoE on SM100 (#6963)
This commit is contained in:
@@ -172,17 +172,19 @@ std::vector<paddle::Tensor> DepermutePrefillCombine(
|
||||
case paddle::DataType::FLOAT8_E4M3FN: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 4)
|
||||
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 6)
|
||||
DISPATCH_TOPK(paddle::DataType::FLOAT8_E4M3FN, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
PD_THROW("Unsupported topk value, must be 4, 6 or 8");
|
||||
}
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 4)
|
||||
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 6)
|
||||
DISPATCH_TOPK(paddle::DataType::BFLOAT16, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
PD_THROW("Unsupported topk value, must be 4, 6 or 8");
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -217,10 +217,12 @@ std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
|
||||
switch (topk) {
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 4)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 6)
|
||||
DISPATCH_TOPK(
|
||||
paddle::DataType::BFLOAT16, paddle::DataType::FLOAT32, 8)
|
||||
default:
|
||||
PD_THROW("Unsupported topk value, must be 4 or 8");
|
||||
PD_THROW("Unsupported topk value, must be 4 or 6 or 8");
|
||||
}
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
|
||||
Reference in New Issue
Block a user