[Metax] adapt DeepSeek (#4498)

This commit is contained in:
xiaozude
2025-10-24 10:14:53 +08:00
committed by GitHub
parent 8718fa34b2
commit f7069b8057
19 changed files with 1538 additions and 324 deletions
@@ -14,7 +14,9 @@
#include "cute/tensor.hpp"
#include "helper.h"
#include "paddle/extension.h"
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
#include "paddle/phi/core/memory/memcpy.h"
#endif
#include "utils.cuh"
template <int THREADBLOCK_SIZE>
@@ -15,6 +15,7 @@
#include <cuda_runtime.h>
#include <stdint.h>
#include <cooperative_groups/memcpy_async.h>
enum class SharedMemFillMode { kFillZero, kNoFill };
@@ -42,18 +43,35 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R,
}
__device__ __forceinline__ void commit_group() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
{}
#else
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
template <size_t n>
__device__ __forceinline__ void wait_group() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
cooperative_groups::wait(cooperative_groups::this_thread_block());
#else
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
#endif
}
template <PrefetchMode prefetch_mode, typename T>
__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
} else {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
#else
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
asm volatile(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(
@@ -68,6 +86,7 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
"n"(16),
"r"(16));
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -76,6 +95,28 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 16 : 0;
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
}
} else {
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 16 : 0;
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
@@ -115,6 +156,7 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
"n"(16));
}
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -123,6 +165,17 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 8 : 0;
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 8);
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 8 : 0;
asm volatile(
@@ -141,6 +194,7 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
"l"(gmem_ptr),
"n"(8));
}
#endif
}
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
@@ -149,6 +203,17 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 4 : 0;
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 4);
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 4 : 0;
asm volatile(
@@ -167,6 +232,7 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
"l"(gmem_ptr),
"n"(4));
}
#endif
}
template <size_t num_bits, PrefetchMode prefetch_mode, typename T>
+4 -1
View File
@@ -595,10 +595,13 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
#endif
inline int GetSMVersion() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
return 80;
#else
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
phi::backends::gpu::GetCurrentDeviceId());
return sm_version;
#endif
}
inline bool GetMlaUseTensorcore() {
+15 -1
View File
@@ -18,6 +18,7 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "helper.h"
#include <cuda/std/limits>
namespace cg = cooperative_groups;
@@ -601,7 +602,7 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
topk_sum += cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}
@@ -658,6 +659,11 @@ void invokeNoAuxTc(T* scores,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group);
#else
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks;
@@ -671,6 +677,7 @@ void invokeNoAuxTc(T* scores,
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
#endif
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
@@ -678,6 +685,12 @@ void invokeNoAuxTc(T* scores,
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream>>>(
scores, group_scores, topk_values, topk_indices, scores_with_bias,
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor);
#else
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
@@ -691,6 +704,7 @@ void invokeNoAuxTc(T* scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
#endif
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
+8 -1
View File
@@ -601,9 +601,16 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
"gpu_ops/read_data_ipc.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/share_external_data.cu",
"gpu_ops/recover_decode_task.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/append_attn/mla_cache_kernel.cu",
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/moe/moe_topk_select.cu",
"gpu_ops/recover_decode_task.cu",
"metax_ops/moe_dispatch.cu",
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",