Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -24,11 +24,11 @@
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
PD_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
PD_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
/**
@@ -38,44 +38,50 @@
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
* into code that will be executed on the device where it is defined.
*/
template <typename Kernel> struct enable_sm90_or_later : Kernel {
template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args &&...args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <paddle::DataType D> class CutlassDtypeTraits;
template <paddle::DataType D>
class CutlassDtypeTraits;
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
template <>
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
template <>
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
public:
typedef cutlass::half_t DataType;
typedef paddle::float16 data_t;
};
template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
template <>
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
public:
typedef cutlass::bfloat16_t DataType;
typedef paddle::bfloat16 data_t;
};
class CutlassGemmConfigMannager {
public:
public:
static CutlassGemmConfigMannager &getInstance() {
static CutlassGemmConfigMannager instance;
return instance;
}
CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
CutlassGemmConfigMannager &
operator=(const CutlassGemmConfigMannager &) = delete;
CutlassGemmConfigMannager &operator=(const CutlassGemmConfigMannager &) =
delete;
void up_date_configs(const nlohmann::json &j) {
std::lock_guard<std::mutex> lock(mutex_);
@@ -102,7 +108,7 @@ public:
return &json_;
}
private:
private:
void save_gemm_best_configs_(const std::string &config_file_path) {
std::ifstream file(config_file_path);
if (!file.good()) {