mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
Feat/blackwell sm100 support (#2670)
* Add initial support for NVIDIA Blackwell (SM100) architecture
This change introduces initial support for the NVIDIA Blackwell GPU
architecture, specifically targeting SM100 (Compute Capability 10.x)
with '100a' architecture-specific features (e.g., for CUTLASS).
Key changes:
- Updated custom_ops/setup_ops.py to generate appropriate gencode
flags (arch=compute_100a,code=sm_100a) when '100' is specified
in FD_BUILDING_ARCS. Requires CUDA 12.9+.
- Updated custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h:
- Added CutlassTileConfigSM100 enum (with placeholder tile shapes).
- Added BLACKWELL to CandidateConfigTypeParam.
- Updated CutlassGemmConfig struct with is_sm100 flag,
tile_config_sm100, and new constructor for SM100.
- Modified toString() and fromString() for SM100 support.
- Updated custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu:
- Added get_candidate_tiles_sm100() (with placeholder tiles).
- Added placeholder mcast support functions for SM100.
- Updated get_candidate_configs() to include SM100 paths using
the BLACKWELL flag and new SM100 config types.
- Updated build.sh with comments to guide users on specifying '100'
for Blackwell in FD_BUILDING_ARCS.
Further work:
- Optimal CUTLASS tile configurations for SM100 need to be researched
and updated in cutlass_heuristic.cu.
- Kernel auto-generation scripts in custom_ops/utils/ may need
SM100-specific versions if Blackwell's hardware features for FP8/TMA
differ significantly from SM90.
- Compatibility of third-party libraries (CUTLASS v3.8.0, DeepGEMM)
with Blackwell should be fully verified.
* Feat: Implement detailed Blackwell (SM100) CUTLASS heuristics
This change integrates specific, expert-provided CUTLASS heuristic
configurations for the NVIDIA Blackwell (SM100) GPU architecture,
replacing previous placeholders. This includes:
- Updated `custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h`:
- Populated `CutlassTileConfigSM100` enum with specific tile shapes
(e.g., CtaShape64x64x128B, CtaShape128x128x128B) suitable for SM100.
- Added `FP4_ONLY` to `CandidateConfigTypeParam` for new FP4 paths.
- Updated `custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu`:
- Implemented `get_candidate_tiles_sm100` with detailed logic for
selecting tile configurations based on GROUPED_GEMM and FP4_ONLY flags,
using the new SM100 tile enums.
- Implemented `supports_mcast_along_m_sm100` and
`supports_mcast_along_n_sm100` with specific tile checks for Blackwell.
- Updated the `sm == 100` (Blackwell) block in `get_candidate_configs`
to use these new helper functions and accurately populate candidate
kernel configurations for various cluster shapes.
- `custom_ops/setup_ops.py` remains configured to compile for
`arch=compute_100a,code=sm_100a` with CUDA 12.9+ for these features.
This aligns the codebase with heuristic configurations similar to those
in upstream TensorRT-LLM / CUTLASS for Blackwell, enabling more
performant kernel selection on this new architecture.
---------
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -76,6 +76,34 @@ enum class SplitKStyle
|
||||
// SPLIT_K_PARALLEL // Not supported yet
|
||||
};
|
||||
|
||||
// New enum for SM100 (Blackwell) Tile Configs
|
||||
// Placeholder values - actual optimal values need research
|
||||
enum class CutlassTileConfigSM100
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// Actual SM100 tile configs based on user input (K-tile is 128B)
|
||||
CtaShape64x64x128B,
|
||||
CtaShape64x128x128B,
|
||||
CtaShape64x256x128B,
|
||||
CtaShape128x64x128B,
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
CtaShape256x64x128B,
|
||||
CtaShape256x128x128B,
|
||||
CtaShape256x256x128B
|
||||
// Note: The user-provided list for get_candidate_tiles_sm100 also includes
|
||||
// CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases.
|
||||
// These are already covered by the list above if general suffices.
|
||||
// If they need distinct enum values, they should be added.
|
||||
// For now, keeping the enum concise with unique shapes mentioned for general use.
|
||||
};
|
||||
|
||||
|
||||
enum class CutlassTileConfigSM90
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
@@ -132,9 +160,11 @@ struct CutlassGemmConfig
|
||||
WEIGHT_ONLY = 1u << 0,
|
||||
SIMT_ONLY = 1u << 1,
|
||||
INT8_ONLY = 1u << 2,
|
||||
HOPPER = 1u << 3,
|
||||
HOPPER = 1u << 3, // SM90
|
||||
GROUPED_GEMM = 1u << 4,
|
||||
FP8_ONLY = 1u << 5,
|
||||
BLACKWELL = 1u << 6, // SM100
|
||||
FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths
|
||||
};
|
||||
|
||||
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
||||
@@ -149,7 +179,17 @@ struct CutlassGemmConfig
|
||||
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
|
||||
bool is_sm90 = false;
|
||||
|
||||
CutlassGemmConfig() {}
|
||||
// config options for sm100 (Blackwell)
|
||||
// Assuming SM100 might use similar schedule/cluster types as SM90 for now.
|
||||
// These might need to become SM100-specific if Blackwell introduces new concepts.
|
||||
CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic;
|
||||
// MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types
|
||||
// EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example
|
||||
// ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example
|
||||
bool is_sm100 = false;
|
||||
|
||||
|
||||
CutlassGemmConfig() : is_sm90(false), is_sm100(false) {}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
|
||||
: tile_config(tile_config)
|
||||
@@ -157,37 +197,64 @@ struct CutlassGemmConfig
|
||||
, split_k_factor(split_k_factor)
|
||||
, stages(stages)
|
||||
, is_sm90(false)
|
||||
, is_sm100(false)
|
||||
{
|
||||
}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
|
||||
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
|
||||
: tile_config_sm90(tile_config_sm90)
|
||||
, mainloop_schedule(mainloop_schedule)
|
||||
, epilogue_schedule(epilogue_schedule)
|
||||
, cluster_shape(cluster_shape)
|
||||
// Constructor for SM90
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in,
|
||||
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
|
||||
: tile_config_sm90(tile_config_sm90_in)
|
||||
, mainloop_schedule(mainloop_schedule_in)
|
||||
, epilogue_schedule(epilogue_schedule_in)
|
||||
, cluster_shape(cluster_shape_in)
|
||||
, is_sm90(true)
|
||||
, is_sm100(false)
|
||||
{
|
||||
}
|
||||
|
||||
// Constructor for SM100 (Blackwell)
|
||||
// Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now.
|
||||
// These might need to be new SM100-specific types if Blackwell's TMA differs significantly.
|
||||
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in,
|
||||
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
|
||||
: tile_config_sm100(tile_config_sm100_in)
|
||||
, mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge
|
||||
, epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100
|
||||
, cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100
|
||||
, is_sm90(false) // Explicitly false
|
||||
, is_sm100(true)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream tactic;
|
||||
tactic << "Cutlass GEMM Tactic";
|
||||
if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=TMA"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90
|
||||
assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100");
|
||||
tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm100
|
||||
<< "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90");
|
||||
tactic << "\n\tstyle=TMA_SM90"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90
|
||||
<< "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
assert(!is_sm90 && "Invalid cutlass GEMM config");
|
||||
assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible");
|
||||
tactic << "\n\tstyle=compatible"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config
|
||||
<< "\n\ttile shape ID: " << (int) tile_config
|
||||
<< "\n\tstages: " << (int) stages
|
||||
<< "\n\tsplit_k_style: " << (int) split_k_style
|
||||
<< "\n\tsplit k: " << (int) split_k_factor;
|
||||
@@ -204,9 +271,24 @@ struct CutlassGemmConfig
|
||||
std::istringstream stream(str);
|
||||
std::string line;
|
||||
|
||||
is_sm90 = false; // Reset flags
|
||||
is_sm100 = false;
|
||||
|
||||
while (std::getline(stream, line)) {
|
||||
if (line.find("style=TMA") != std::string::npos) {
|
||||
if (line.find("style=TMA_SM100") != std::string::npos) {
|
||||
is_sm100 = true;
|
||||
is_sm90 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config_sm100 = static_cast<cutlass_extensions::CutlassTileConfigSM100>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
cluster_shape = static_cast<cutlass_extensions::ClusterShape>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
mainloop_schedule = static_cast<cutlass_extensions::MainloopScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
} else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first
|
||||
is_sm90 = true;
|
||||
is_sm100 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
@@ -217,6 +299,7 @@ struct CutlassGemmConfig
|
||||
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
} else if (line.find("style=compatible") != std::string::npos) {
|
||||
is_sm90 = false;
|
||||
is_sm100 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
@@ -233,7 +316,14 @@ struct CutlassGemmConfig
|
||||
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
|
||||
{
|
||||
// clang-format off
|
||||
if (config.is_sm90)
|
||||
if (config.is_sm100)
|
||||
{
|
||||
out << "tile_config_sm100_enum: " << int(config.tile_config_sm100)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule) // Assuming same schedule types for now
|
||||
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule) // Assuming same schedule types for now
|
||||
<< ", cluster_shape_enum: " << int(config.cluster_shape); // Assuming same cluster types for now
|
||||
}
|
||||
else if (config.is_sm90)
|
||||
{
|
||||
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
|
||||
|
||||
Reference in New Issue
Block a user