[Metax] adapt DeepSeek (#4498)

This commit is contained in:
xiaozude
2025-10-24 10:14:53 +08:00
committed by GitHub
parent 8718fa34b2
commit f7069b8057
19 changed files with 1538 additions and 324 deletions
@@ -14,7 +14,9 @@
#include "cute/tensor.hpp"
#include "helper.h"
#include "paddle/extension.h"
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
#include "paddle/phi/core/memory/memcpy.h"
#endif
#include "utils.cuh"
template <int THREADBLOCK_SIZE>
@@ -15,6 +15,7 @@
#include <cuda_runtime.h>
#include <stdint.h>
#include <cooperative_groups/memcpy_async.h>
enum class SharedMemFillMode { kFillZero, kNoFill };
@@ -42,18 +43,35 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R,
}
__device__ __forceinline__ void commit_group() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
{}
#else
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
template <size_t n>
__device__ __forceinline__ void wait_group() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
cooperative_groups::wait(cooperative_groups::this_thread_block());
#else
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
#endif
}
template <PrefetchMode prefetch_mode, typename T>
__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
} else {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
#else
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
asm volatile(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(
@@ -68,6 +86,7 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
"n"(16),
"r"(16));
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -76,6 +95,28 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 16 : 0;
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
}
} else {
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 16 : 0;
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
@@ -115,6 +156,7 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
"n"(16));
}
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -123,6 +165,17 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 8 : 0;
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 8);
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 8 : 0;
asm volatile(
@@ -141,6 +194,7 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
"l"(gmem_ptr),
"n"(8));
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -149,6 +203,17 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 4 : 0;
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 4);
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 4 : 0;
asm volatile(
@@ -167,6 +232,7 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
"l"(gmem_ptr),
"n"(4));
}
#endif
}
template <size_t num_bits, PrefetchMode prefetch_mode, typename T>