mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -24,11 +24,11 @@
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
PD_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
PD_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -38,44 +38,50 @@
|
||||
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
* into code that will be executed on the device where it is defined.
|
||||
*/
|
||||
template <typename Kernel> struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args> CUTLASS_DEVICE void operator()(Args &&...args) {
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args &&...args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <paddle::DataType D> class CutlassDtypeTraits;
|
||||
template <paddle::DataType D>
|
||||
class CutlassDtypeTraits;
|
||||
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::FLOAT32> {
|
||||
public:
|
||||
typedef float DataType;
|
||||
typedef float data_t;
|
||||
};
|
||||
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::FLOAT16> {
|
||||
public:
|
||||
typedef cutlass::half_t DataType;
|
||||
typedef paddle::float16 data_t;
|
||||
};
|
||||
|
||||
template <> class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
template <>
|
||||
class CutlassDtypeTraits<paddle::DataType::BFLOAT16> {
|
||||
public:
|
||||
typedef cutlass::bfloat16_t DataType;
|
||||
typedef paddle::bfloat16 data_t;
|
||||
};
|
||||
|
||||
class CutlassGemmConfigMannager {
|
||||
public:
|
||||
public:
|
||||
static CutlassGemmConfigMannager &getInstance() {
|
||||
static CutlassGemmConfigMannager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete;
|
||||
CutlassGemmConfigMannager &
|
||||
operator=(const CutlassGemmConfigMannager &) = delete;
|
||||
CutlassGemmConfigMannager &operator=(const CutlassGemmConfigMannager &) =
|
||||
delete;
|
||||
|
||||
void up_date_configs(const nlohmann::json &j) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
@@ -102,7 +108,7 @@ public:
|
||||
return &json_;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
void save_gemm_best_configs_(const std::string &config_file_path) {
|
||||
std::ifstream file(config_file_path);
|
||||
if (!file.good()) {
|
||||
|
||||
Reference in New Issue
Block a user