[Models][OP][Optimization] Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM (#6689)

* Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM
This commit is contained in:
AIbin
2026-03-10 15:05:14 +08:00
committed by GitHub
parent 25c479312d
commit c3aceb6bdc
22 changed files with 8022 additions and 143 deletions
@@ -0,0 +1,616 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* DeepSeek Kv3.2 (DsKv3.2) Attention WriteCache Implementation
*
* This file implements writecache operations for DeepSeek MLA (Multi-head
* Latent Attention) with FP8 quantization support, migrated from vLLM.
*
* Key features:
* 1. DS MLA FP8 cache format (656 bytes per token):
* - 512 bytes: quantized NoPE part (fp8_e4m3)
* - 16 bytes: scale factors (4 x float32)
* - 128 bytes: RoPE part (64 x bf16, unquantized)
*
* 2. Standard MLA cache format (kv_lora_rank + pe_dim elements)
*
* 3. Indexer K quantization and cache operations
*/
#include "ds_mla_cache_kernel.cuh"
#include "helper.h"
#include "remote_cache_kv_ipc.h"
//==============================================================================
// DS MLA FP8 WriteCache Implementation
//==============================================================================
/**
* Prefill stage: Write KV cache with DS MLA FP8 format
*/
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto num_tokens = slot_mapping.dims()[0];
auto kv_lora_rank = 512; // DS MLA uses 512
auto pe_dim = 64; // DS MLA uses 64
auto block_size = meta_data.block_size;
// Entry size for DS MLA FP8: 512 (fp8) + 16 (scales) + 128 (rope bf16) = 656
// bytes
const int entry_size = 656;
// Launch kernel with 96 threads (64 for NoPE, 32 for RoPE)
dim3 grid(num_tokens);
dim3 block(96);
const auto& kv_cache_dims = kv_cache->dims();
int block_stride = kv_cache->strides()[0];
int entry_stride = entry_size;
int kv_c_stride = kv_nope.strides()[0];
int k_pe_stride = kv_pe.strides()[0];
ds_mla::concat_and_cache_ds_mla_kernel<DataType_><<<grid, block, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
slot_mapping.data<int64_t>(),
block_stride,
entry_stride,
kv_c_stride,
k_pe_stride,
kv_lora_rank,
pe_dim,
block_size);
// Handle PD disaggregation signal
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
/**
* Decode stage: Write KV cache with DS MLA FP8 format
*/
template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeDSMLAWriteCacheFP8(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto num_tokens = slot_mapping.dims()[0];
auto kv_lora_rank = 512;
auto pe_dim = 64;
auto block_size = meta_data.block_size;
const int entry_size = 656;
dim3 grid(num_tokens);
dim3 block(96);
const auto& kv_cache_dims = kv_cache->dims();
int block_stride = kv_cache->strides()[0];
int entry_stride = entry_size;
int kv_c_stride = kv_nope.strides()[0];
int k_pe_stride = kv_pe.strides()[0];
ds_mla::concat_and_cache_ds_mla_kernel<DataType_><<<grid, block, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
slot_mapping.data<int64_t>(),
block_stride,
entry_stride,
kv_c_stride,
k_pe_stride,
kv_lora_rank,
pe_dim,
block_size);
return {};
}
//==============================================================================
// Standard MLA WriteCache Implementation
//==============================================================================
/**
* Prefill stage: Write KV cache with standard MLA format
*/
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillDSMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const float* scale,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto num_tokens = slot_mapping.dims()[0];
auto kv_lora_rank = meta_data.head_dims_v;
auto pe_dim = meta_data.head_dims - meta_data.head_dims_v;
auto block_size = meta_data.block_size;
const auto& kv_cache_dims = kv_cache->dims();
int block_stride = kv_cache->strides()[0];
int entry_stride = kv_cache->strides()[1];
int kv_c_stride = kv_nope.strides()[0];
int k_pe_stride = kv_pe.strides()[0];
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
ds_mla::concat_and_cache_mla_kernel<DataType_, DataType_>
<<<grid, block, 0, stream>>>(
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
slot_mapping.data<int64_t>(),
block_stride,
entry_stride,
kv_c_stride,
k_pe_stride,
kv_lora_rank,
pe_dim,
block_size,
scale);
// Handle PD disaggregation signal
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
//==============================================================================
// Indexer K Quantization and Cache Operations
//==============================================================================
/**
* Quantize K tensor to FP8 and write to cache
*/
template <paddle::DataType T>
std::vector<paddle::Tensor> IndexerKQuantAndCache(
const paddle::Tensor& k,
const paddle::Tensor& slot_mapping,
const int head_dim,
const int quant_block_size,
const int cache_block_size,
const int cache_stride,
const bool use_ue8m0,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
int num_tokens = k.dims()[0];
constexpr int vec_size = 4;
dim3 grid(num_tokens,
(head_dim + quant_block_size * vec_size - 1) /
(quant_block_size * vec_size));
dim3 block(32, vec_size);
ds_mla::indexer_k_quant_and_cache_kernel<DataType_>
<<<grid, block, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(k.data<data_t>())),
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
slot_mapping.data<int64_t>(),
head_dim,
quant_block_size,
cache_block_size,
cache_stride,
use_ue8m0);
return {};
}
/**
* Gather K from quantized cache
*/
void CpGatherIndexerKQuantCache(const paddle::Tensor& kv_cache,
paddle::Tensor& dst_k,
paddle::Tensor& dst_scale,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seq_lens,
cudaStream_t& stream) {
int batch_size = block_table.dims()[0];
int num_tokens = dst_k.dims()[0];
int head_dim = dst_k.dims()[1];
int quant_block_size = head_dim * 4 / dst_scale.dims()[1];
constexpr int vec_size = 16;
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
ds_mla::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
(head_dim + 8 * vec_size - 1) / (8 * vec_size)), \
dim3(8, BLOCK_Y_SIZE), \
0, \
stream>>>(reinterpret_cast<const char*>(kv_cache.data<uint8_t>()), \
reinterpret_cast<char*>(dst_k.data<uint8_t>()), \
reinterpret_cast<char*>(dst_scale.data<float>()), \
block_table.data<int>(), \
cu_seq_lens.data<int>(), \
batch_size, \
dst_k.strides()[0], \
dst_k.dims()[1], \
kv_cache.strides()[0], \
kv_cache.strides()[1], \
kv_cache.dims()[1], \
block_table.dims()[1], \
num_tokens, \
quant_block_size);
if (num_tokens < 32) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
} else if (num_tokens < 64) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2);
} else if (num_tokens < 128) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4);
} else if (num_tokens < 256) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8);
} else if (num_tokens < 512) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16);
} else {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
}
#undef CALL_CP_GATHER_INDEXER_K_QUANT_CACHE
}
//==============================================================================
// Kernel Entry Points
//==============================================================================
/**
* DS MLA WriteCache entry point - supports both FP8 and standard formats
*/
std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& scale,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool is_prefill) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = seq_lens_decoder.dims()[0];
const float* scale_ptr = scale ? scale.get().data<float>() : nullptr;
if (cache_quant_type_str == "fp8_ds_mla") {
// FP8 DS MLA format
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
if (is_prefill) {
return PrefillDSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
} else {
return DecodeDSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
false,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
case paddle::DataType::FLOAT16: {
if (is_prefill) {
return PrefillDSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
} else {
return DecodeDSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
false,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
default:
PD_THROW("Unsupported dtype for DS MLA FP8 cache");
}
} else {
// Standard MLA format (auto/bf16/fp16)
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillDSMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
scale_ptr,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillDSMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
scale_ptr,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
default:
PD_THROW("Unsupported dtype for DS MLA cache");
}
}
return {};
}
/**
* Indexer K Quant and Cache entry point
*/
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
const paddle::Tensor& k,
const paddle::Tensor& kv_cache,
const paddle::Tensor& slot_mapping,
const int64_t quant_block_size,
const std::string& scale_fmt) {
cudaStream_t stream = k.stream();
int num_tokens = k.dims()[0];
int head_dim = k.dims()[1];
int cache_block_size = kv_cache.dims()[1];
int cache_stride = kv_cache.dims()[2];
bool use_ue8m0 = scale_fmt == "ue8m0";
switch (k.dtype()) {
case paddle::DataType::BFLOAT16: {
return IndexerKQuantAndCache<paddle::DataType::BFLOAT16>(
k,
slot_mapping,
head_dim,
quant_block_size,
cache_block_size,
cache_stride,
use_ue8m0,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return IndexerKQuantAndCache<paddle::DataType::FLOAT16>(
k,
slot_mapping,
head_dim,
quant_block_size,
cache_block_size,
cache_stride,
use_ue8m0,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
default:
PD_THROW("Unsupported dtype for Indexer K Quant");
}
return {};
}
/**
* Gather Indexer K from Quant Cache entry point
*/
std::vector<paddle::Tensor> CpGatherIndexerKQuantCacheKernel(
const paddle::Tensor& kv_cache,
paddle::Tensor& dst_k,
paddle::Tensor& dst_scale,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seq_lens) {
cudaStream_t stream = kv_cache.stream();
CpGatherIndexerKQuantCache(
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens, stream);
return {};
}
//==============================================================================
// Paddle Custom Operator Registration
//==============================================================================
PD_BUILD_STATIC_OP(ds_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"slot_mapping",
"seq_lens",
"seq_lens_decoder",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
paddle::Optional("kv_signal_data"),
paddle::Optional("scale")})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int",
"is_prefill: bool"})
.SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel));
PD_BUILD_STATIC_OP(indexer_k_quant_and_cache)
.Inputs({"k", "kv_cache", "slot_mapping"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"quant_block_size: int64_t", "scale_fmt: std::string"})
.SetKernelFn(PD_KERNEL(IndexerKQuantAndCacheKernel));
PD_BUILD_STATIC_OP(cp_gather_indexer_k_quant_cache)
.Inputs({"kv_cache", "dst_k", "dst_scale", "block_table", "cu_seq_lens"})
.Outputs({"dst_k_out", "dst_scale_out"})
.SetInplaceMap({{"dst_k", "dst_k_out"}, {"dst_scale", "dst_scale_out"}})
.SetKernelFn(PD_KERNEL(CpGatherIndexerKQuantCacheKernel));
@@ -0,0 +1,548 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cfloat>
#include "helper.h"
#include "mem_util.cuh"
#include "utils.cuh"
// FP8 scale divisor constant (for SM90+)
#if defined(__gfx942__)
constexpr float kFp8ScaleDivisorDS = 224.f;
#else
constexpr float kFp8ScaleDivisorDS = 448.f;
#endif
namespace ds_mla {
/**
* FP8 scaled conversion utilities
*/
template <typename OutT, typename InT>
__device__ __forceinline__ OutT fp8_scaled_convert(InT src, float scale) {
return static_cast<OutT>(static_cast<float>(src) / scale);
}
template <>
__device__ __forceinline__ uint8_t
fp8_scaled_convert<uint8_t, __nv_bfloat16>(__nv_bfloat16 src, float scale) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
float val = __bfloat162float(src) / scale;
val = fminf(fmaxf(val, -448.0f), 448.0f);
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
return *reinterpret_cast<uint8_t*>(&fp8_val);
#else
return 0;
#endif
}
template <>
__device__ __forceinline__ uint8_t
fp8_scaled_convert<uint8_t, half>(half src, float scale) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
float val = __half2float(src) / scale;
val = fminf(fmaxf(val, -448.0f), 448.0f);
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
return *reinterpret_cast<uint8_t*>(&fp8_val);
#else
return 0;
#endif
}
/**
* DeepSeek MLA FP8 Cache Write Kernel
*
* Cache format (fp8_ds_mla - 656 bytes per token):
* - First 512 bytes: quantized NoPE part (512 x fp8_e4m3)
* - Next 16 bytes: scale factors (4 x float32, one per 128 fp8 values)
* - Last 128 bytes: RoPE part (64 x bfloat16, not quantized)
*
* Thread organization:
* - First 2 warps (64 threads): handle NoPE FP8 quantization
* - Last 1 warp (32 threads): handle RoPE copy
* - Total: 96 threads per block
*/
template <typename scalar_t>
__global__ void concat_and_cache_ds_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
uint8_t* __restrict__ kv_cache, // [num_blocks, block_size,
// cache_entry_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, // stride per block in cache
const int entry_stride, // stride per token entry in cache
const int kv_c_stride, // stride for kv_c input
const int k_pe_stride, // stride for k_pe input
const int kv_lora_rank, // 512 for DS MLA
const int pe_dim, // 64 for DS MLA
const int block_size // number of tokens per cache block
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int64_t dst_idx_start =
block_idx * block_stride + block_offset * entry_stride;
// Cast kv_cache to 16-bit for RoPE values
scalar_t* kv_cache_16bit =
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
// The last warp handles the RoPE part
if (threadIdx.x >= 64) {
// Each thread handles two elements of RoPE
const int8_t pe_idx_start = (threadIdx.x - 64) * 2;
const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start;
// Vectorized load of two 16-bit values, performed as one 32-bit load
const int32_t vals = *reinterpret_cast<const int32_t*>(&k_pe[src_idx]);
// RoPE values start after the packed 8-bit NoPE values and the 32-bit
// scales Position: kv_lora_rank/2 (256 bytes in 16-bit units) + 8 (16 bytes
// of scales in 16-bit units)
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start;
// Vectorized store of two 16-bit values
*reinterpret_cast<int32_t*>(&kv_cache_16bit[dst_idx]) = vals;
return;
}
// The first two warps handle the NoPE part
const int8_t warp_idx = threadIdx.x >> 5;
const int8_t lane_idx = threadIdx.x & 31;
const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4);
// Each thread handles 8 elements of NoPE
const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8);
// Vectorized load of eight 16-bit values
const int4 vals_i4 = *reinterpret_cast<const int4*>(&kv_c[src_idx_start]);
const scalar_t* vals = reinterpret_cast<const scalar_t*>(&vals_i4);
// Max absolute value of this thread's elements
float max_abs = fmaxf(fmaxf(fmaxf(fabsf(static_cast<float>(vals[0])),
fabsf(static_cast<float>(vals[1]))),
fmaxf(fabsf(static_cast<float>(vals[2])),
fabsf(static_cast<float>(vals[3])))),
fmaxf(fmaxf(fabsf(static_cast<float>(vals[4])),
fabsf(static_cast<float>(vals[5]))),
fmaxf(fabsf(static_cast<float>(vals[6])),
fabsf(static_cast<float>(vals[7])))));
// Warp-level reduction to find the max absolute value in each half-warp
#pragma unroll
for (int offset = 8; offset > 0; offset /= 2) {
max_abs = fmaxf(max_abs, __shfl_xor_sync(0xFFFF, max_abs, offset, 16));
}
// Compute the scale for the tile
float tile_scale = fmaxf(max_abs / kFp8ScaleDivisorDS, FLT_MIN);
// The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) {
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
kv_cache_32bit[dst_idx] = tile_scale;
}
// Now all threads in the block scale and write their elements
const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8);
uint8_t result[8];
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] = fp8_scaled_convert<uint8_t, scalar_t>(vals[i], tile_scale);
}
// Store as aligned 64-bit writes
*reinterpret_cast<uint64_t*>(&kv_cache[dst_idx_base]) =
*reinterpret_cast<const uint64_t*>(result);
}
/**
* Standard MLA Cache Write Kernel (non-FP8)
*
* For auto/bf16/fp16 cache types
*/
template <typename scalar_t, typename cache_t>
__global__ void concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride,
const int entry_stride,
const int kv_c_stride,
const int k_pe_stride,
const int kv_lora_rank,
const int pe_dim,
const int block_size,
const float* scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
// Copy kv_c (NoPE part)
for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) {
const int64_t src_idx = token_idx * kv_c_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i;
kv_cache[dst_idx] = static_cast<cache_t>(kv_c[src_idx]);
}
// Copy k_pe (RoPE part)
for (int i = threadIdx.x; i < pe_dim; i += blockDim.x) {
const int64_t src_idx = token_idx * k_pe_stride + i;
const int64_t dst_idx = block_idx * block_stride +
block_offset * entry_stride + kv_lora_rank + i;
kv_cache[dst_idx] = static_cast<cache_t>(k_pe[src_idx]);
}
}
/**
* Indexer K Quantization and Cache Kernel
*
* Quantizes K values to FP8 and stores them in cache with scale factors
* Cache layout: [quantized_k (head_dim bytes)] + [scales
* (head_dim/quant_block_size * 4 bytes)]
*/
template <typename scalar_t>
__global__ void indexer_k_quant_and_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
uint8_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim,
const int quant_block_size, // typically 128
const int cache_block_size,
const int cache_stride,
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
if (slot_idx < 0 || head_dim_idx >= head_dim) {
return;
}
// Load 4 values at once using float2 (for bf16/fp16)
float2 k_val = reinterpret_cast<const float2*>(
k)[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
float amax = 0.0f;
for (int i = 0; i < VEC_SIZE; i++) {
amax = fmaxf(amax, fabsf(static_cast<float>(k_val_ptr[i])));
}
// Warp reduction to find max across quant_block_size elements
for (int mask = 16; mask > 0; mask /= 2) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask));
}
float scale = fmaxf(amax, 1e-4f) / kFp8ScaleDivisorDS;
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] =
fp8_scaled_convert<uint8_t, scalar_t>(k_val_ptr[i], scale);
}
// First thread in warp writes the scale
if (threadIdx.x == 0) {
const int64_t dst_scale_idx =
block_idx * cache_block_size * cache_stride +
cache_block_size * head_dim +
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
}
}
/**
* Gather Indexer K from Quantized Cache Kernel
*
* Gathers and dequantizes K values from the cache
*/
template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel(
const char* __restrict__ kv_cache, // [num_blocks, block_size,
// cache_stride]
char* __restrict__ dst_k, // [num_tokens, head_dim]
char* __restrict__ dst_scale, // [num_tokens, head_dim/quant_block_size*4]
const int* __restrict__ block_table, // [batch_size, num_blocks]
const int* __restrict__ cu_seq_lens, // [batch_size + 1]
const int batch_size,
const int64_t token_stride,
const int64_t head_dim,
const int64_t block_stride,
const int64_t cache_token_stride,
const int64_t cache_block_size,
const int num_blocks,
const int num_tokens,
const int quant_block_size) {
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
const int token_idx = blockIdx.x * blockDim.y + threadIdx.y;
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
// Find batch index within a block
__shared__ int batch_idx[BLOCK_Y_SIZE];
for (int iter = 0; iter < (batch_size + blockDim.x - 1) / blockDim.x;
iter++) {
int tid = iter * blockDim.x + threadIdx.x;
if (tid < batch_size) {
const int seq_start = cu_seq_lens[tid];
const int seq_end = cu_seq_lens[tid + 1];
if (token_idx >= seq_start && token_idx < seq_end) {
batch_idx[threadIdx.y] = tid;
}
}
}
__syncwarp();
if (head_idx >= head_dim || token_idx >= num_tokens) {
return;
}
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
const int block_id = block_table[batch_idx[threadIdx.y] * num_blocks +
inbatch_seq_idx / cache_block_size];
const int64_t src_block_offset = block_id * block_stride;
const int64_t cache_inblock_offset =
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
if (threadIdx.x == 0) {
const int64_t src_scale_offset =
src_block_offset + cache_block_size * head_dim +
cache_inblock_offset * 4 / quant_block_size;
reinterpret_cast<float*>(dst_scale)[dst_inblock_offset / quant_block_size] =
reinterpret_cast<const float*>(kv_cache)[src_scale_offset / 4];
}
}
/**
* Prefill DS MLA Write Cache Kernel
*
* Writes prefill KV data to DS MLA cache format
*/
template <typename T, int VecSize = 1>
__global__ void prefill_ds_mla_cache_kernel(
const T* __restrict__ kv_nope, // [num_tokens, kv_num_heads * nope_size]
const T* __restrict__ kv_pe, // [num_tokens, kv_num_heads * pe_size]
uint8_t* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// entry_size]
const int* __restrict__ block_tables,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens,
const int* __restrict__ seq_lens_decoder,
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size, // 512 for DS MLA
const int pe_size, // 64 for DS MLA
const int block_size,
const int entry_size, // 656 for DS MLA FP8
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const uint32_t token_idx = linear_index / hidden_size;
const uint32_t bias = linear_index % hidden_size;
const uint32_t ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
(token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
const uint32_t block_offset = ori_seq_id % block_size;
if (bias < nope_hidden_size) {
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
// For DS MLA FP8, NoPE part goes to first 512 bytes
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * entry_size +
hi * block_size * entry_size + block_offset * entry_size + h_bias;
const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
// Convert to FP8 and store
for (int i = 0; i < VecSize; i++) {
float val = static_cast<float>(src_vec.val[i]);
val = fminf(fmaxf(val, -448.0f), 448.0f);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
kv_cache[tgt_idx + i] = *reinterpret_cast<uint8_t*>(&fp8_val);
#endif
}
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
// RoPE part goes after NoPE (512 bytes) + scales (16 bytes)
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * entry_size +
hi * block_size * entry_size + block_offset * entry_size + nope_size +
16 + h_bias * 2; // *2 for bf16
const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
// Copy RoPE without quantization (as bf16/fp16)
T* tgt_ptr = reinterpret_cast<T*>(&kv_cache[tgt_idx]);
for (int i = 0; i < VecSize; i++) {
tgt_ptr[i] = src_vec.val[i];
}
}
}
}
/**
* Decode DS MLA Write Cache Kernel
*/
template <typename T, int VecSize = 1>
__global__ void decode_ds_mla_cache_kernel(
const T* __restrict__ kv_nope,
const T* __restrict__ kv_pe,
uint8_t* __restrict__ kv_cache,
const int* __restrict__ block_tables,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens,
const int* __restrict__ seq_lens_encoder,
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const int entry_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
if (bias < nope_hidden_size) {
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * entry_size +
hi * block_size * entry_size + block_offset * entry_size + h_bias;
const uint32_t ori_idx = start_token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
for (int i = 0; i < VecSize; i++) {
float val = static_cast<float>(src_vec.val[i]);
val = fminf(fmaxf(val, -448.0f), 448.0f);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
kv_cache[tgt_idx + i] = *reinterpret_cast<uint8_t*>(&fp8_val);
#endif
}
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * entry_size +
hi * block_size * entry_size + block_offset * entry_size + nope_size +
16 + h_bias * 2;
const uint32_t ori_idx = start_token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
T* tgt_ptr = reinterpret_cast<T*>(&kv_cache[tgt_idx]);
for (int i = 0; i < VecSize; i++) {
tgt_ptr[i] = src_vec.val[i];
}
}
}
}
} // namespace ds_mla
+55
View File
@@ -1133,6 +1133,47 @@ std::vector<paddle::Tensor> get_attn_mask_q(
const paddle::optional<paddle::Tensor>& attn_mask_kv,
const int kv_token_num);
void RadixTopkRaggedTransform(
paddle::Tensor& input,
paddle::Tensor& output_indices,
const paddle::Tensor& offsets,
paddle::Tensor& lengths,
paddle::optional<paddle::Tensor>& seq_len_decoder,
paddle::optional<paddle::Tensor>& batch_id_per_token,
paddle::optional<paddle::Tensor>& maybe_row_states_buffer,
int top_k,
int q_num_heads = 0);
std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& scale,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool is_prefill);
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
const paddle::Tensor& k,
const paddle::Tensor& kv_cache,
const paddle::Tensor& slot_mapping,
const int64_t quant_block_size,
const std::string& scale_fmt);
std::vector<paddle::Tensor> CpGatherIndexerKQuantCacheKernel(
const paddle::Tensor& kv_cache,
paddle::Tensor& dst_k,
paddle::Tensor& dst_scale,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seq_lens);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num",
&GetExpertTokenNum,
@@ -1736,4 +1777,18 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("custom_numpy_to_tensor",
&CustomNumpyToTensor,
"custom_numpy_to_tensor function");
m.def("radix_topk_ragged_transform",
&RadixTopkRaggedTransform,
"radix_topk_ragged_transform function");
m.def("dsk_attn_write_cache", &DSMLAWriteCacheKernel, "dsk_attn_write_cache");
m.def("indexer_k_quant_and_cache",
&IndexerKQuantAndCacheKernel,
"indexer_k_quant_and_cache");
m.def("cp_gather_indexer_k_quant_cache",
&CpGatherIndexerKQuantCacheKernel,
"cp_gather_indexer_k_quant_cache");
}
+2 -1
View File
@@ -662,7 +662,8 @@ inline const char *getEnvVar(const char *varName) {
inline bool checkAttentionBackend() {
const char *backend = getEnvVar("FD_ATTENTION_BACKEND");
if (backend && std::strcmp(backend, "MLA_ATTN") == 0) {
if (backend && (std::strcmp(backend, "MLA_ATTN") == 0 ||
std::strcmp(backend, "DSA_ATTN") == 0)) {
return true;
}
return false;
@@ -0,0 +1,115 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_EXCEPTION_H_
#define FLASHINFER_EXCEPTION_H_
#include <exception>
#include <iostream>
#include <sstream>
#define FLASHINFER_ERROR(message) \
throw flashinfer::Error(__FUNCTION__, __FILE__, __LINE__, message)
// Base case for empty arguments
inline void write_to_stream(std::ostringstream& oss) {
// No-op for empty arguments
}
template <typename T>
void write_to_stream(std::ostringstream& oss, T&& val) {
oss << std::forward<T>(val);
}
template <typename T, typename... Args>
void write_to_stream(std::ostringstream& oss, T&& val, Args&&... args) {
oss << std::forward<T>(val) << " ";
write_to_stream(oss, std::forward<Args>(args)...);
}
// Helper macro to handle empty __VA_ARGS__
#define FLASHINFER_CHECK_IMPL(condition, message) \
if (!(condition)) { \
FLASHINFER_ERROR(message); \
}
// Main macro that handles both cases
#define FLASHINFER_CHECK(condition, ...) \
do { \
if (!(condition)) { \
std::ostringstream oss; \
write_to_stream(oss, ##__VA_ARGS__); \
std::string msg = oss.str(); \
if (msg.empty()) { \
msg = "Check failed: " #condition; \
} \
FLASHINFER_ERROR(msg); \
} \
} while (0)
// Warning macro
#define FLASHINFER_WARN(...) \
do { \
std::ostringstream oss; \
write_to_stream(oss, ##__VA_ARGS__); \
std::string msg = oss.str(); \
if (msg.empty()) { \
msg = "Warning triggered"; \
} \
flashinfer::Warning(__FUNCTION__, __FILE__, __LINE__, msg).emit(); \
} while (0)
namespace flashinfer {
class Error : public std::exception {
private:
std::string message_;
public:
Error(const std::string& func,
const std::string& file,
int line,
const std::string& message) {
std::ostringstream oss;
oss << "Error in function '" << func << "' "
<< "at " << file << ":" << line << ": " << message;
message_ = oss.str();
}
virtual const char* what() const noexcept override {
return message_.c_str();
}
};
class Warning {
private:
std::string message_;
public:
Warning(const std::string& func,
const std::string& file,
int line,
const std::string& message) {
std::ostringstream oss;
oss << "Warning in function '" << func << "' "
<< "at " << file << ":" << line << ": " << message;
message_ = oss.str();
}
void emit() const { std::cerr << message_ << std::endl; }
};
} // namespace flashinfer
#endif // FLASHINFER_EXCEPTION_H_
@@ -0,0 +1,144 @@
#include "indexer_topk.cuh"
#include <cuda_bf16.h>
#include "paddle/extension.h"
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/utils/optional.h"
#include "append_attn/mem_util.cuh"
#include "append_attn/mma_tensor_op.cuh"
#include "append_attn/utils.cuh"
#include "helper.h"
// using namespace flashinfer;
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <paddle::DataType T>
cudaError_t DispatchTopK(paddle::Tensor& input,
paddle::Tensor& output_indices,
const paddle::Tensor& offsets,
paddle::Tensor& lengths,
uint32_t num_rows,
const int32_t* seq_len_decoder,
const int32_t* batch_id_per_token,
uint32_t top_k,
uint32_t q_num_heads,
uint32_t max_len,
flashinfer::sampling::RadixRowState* row_states_ptr,
cudaStream_t stream) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
cudaError_t status;
status =
flashinfer::sampling::TopKRaggedTransformDispatch<DataType_, int32_t>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
static_cast<int32_t*>(output_indices.data<int32_t>()),
static_cast<const int32_t*>(offsets.data<int32_t>()),
static_cast<int32_t*>(lengths.data<int32_t>()),
num_rows,
seq_len_decoder,
batch_id_per_token,
static_cast<uint32_t>(top_k),
static_cast<uint32_t>(q_num_heads),
max_len,
row_states_ptr,
stream);
return status;
}
void RadixTopkRaggedTransform(
paddle::Tensor& input,
paddle::Tensor& output_indices,
const paddle::Tensor& offsets,
paddle::Tensor& lengths,
paddle::optional<paddle::Tensor>& seq_len_decoder,
paddle::optional<paddle::Tensor>& batch_id_per_token,
paddle::optional<paddle::Tensor>& maybe_row_states_buffer,
int top_k,
int q_num_heads = 0) {
// CHECK_INPUT(input);
// CHECK_INPUT(output_indices);
// CHECK_INPUT(offsets);
// CHECK_INPUT(lengths);
// CHECK_DIM(2, input); // input: (num_rows, max_len)
// CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k)
// CHECK_DIM(1, offsets); // offsets: (num_rows,)
// CHECK_DIM(1, lengths); // lengths: (num_rows,)
unsigned int num_rows = input.dims()[0];
unsigned int max_len = input.dims()[1];
static cudaStream_t stream = input.stream();
cudaError_t status;
auto input_dtype = input.dtype();
// sampling::RadixRowState* row_states_ptr = nullptr;
// if (maybe_row_states_buffer.has_value()) {
// row_states_ptr =
// static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
// }
flashinfer::sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer) {
auto& tensor_ptr = maybe_row_states_buffer.get();
row_states_ptr = reinterpret_cast<flashinfer::sampling::RadixRowState*>(
tensor_ptr.data<uint8_t>());
}
const int32_t* seq_len_ptr = nullptr;
if (seq_len_decoder) {
auto& tensor_ptr = seq_len_decoder.get();
seq_len_ptr = static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
}
const int32_t* batch_id_per_token_ptr = nullptr;
if (batch_id_per_token) {
auto& tensor_ptr = batch_id_per_token.get();
batch_id_per_token_ptr =
static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
}
if (input_dtype == paddle::DataType::BFLOAT16) {
status = DispatchTopK<paddle::DataType::BFLOAT16>(input,
output_indices,
offsets,
lengths,
num_rows,
seq_len_ptr,
batch_id_per_token_ptr,
top_k,
q_num_heads,
max_len,
row_states_ptr,
stream);
} else if (input_dtype == paddle::DataType::FLOAT32) {
status = DispatchTopK<paddle::DataType::FLOAT32>(input,
output_indices,
offsets,
lengths,
num_rows,
seq_len_ptr,
batch_id_per_token_ptr,
top_k,
q_num_heads,
max_len,
row_states_ptr,
stream);
}
}
PD_BUILD_STATIC_OP(radix_topk_ragged_transform)
.Inputs({"input",
"output_indices",
"offsets",
"lengths",
paddle::Optional("seq_len_decoder"),
paddle::Optional("batch_id_per_token"),
paddle::Optional("maybe_row_states_buffer")})
.Attrs({"top_k : int", "q_num_heads : int"})
.SetKernelFn(PD_KERNEL(RadixTopkRaggedTransform));
File diff suppressed because it is too large Load Diff
+534
View File
@@ -0,0 +1,534 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_UTILS_CUH_
#define FLASHINFER_UTILS_CUH_
#include <cuda_bf16.h>
#include <cuda_device_runtime_api.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <atomic>
#include <cstdint>
#include <iostream>
#include <type_traits>
#include <vector>
#include "exception.h"
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
// macro to turn off fp16 qk reduction to reduce binary
#ifndef FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION
#define FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION 0
#endif
#ifndef NDEBUG
#define FLASHINFER_CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
<< ") " << __FILE__ << ": line " << __LINE__ \
<< " at function " << STR(func) << std::endl; \
return e; \
} \
}
#else
#define FLASHINFER_CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
return e; \
} \
}
#endif
#define DISPATCH_USE_FP16_QK_REDUCTION( \
use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, ...) \
if (use_fp16_qk_reduction) { \
FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \
} else { \
constexpr bool USE_FP16_QK_REDUCTION = false; \
__VA_ARGS__ \
}
#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \
if (num_mma_q == 1) { \
constexpr size_t NUM_MMA_Q = 1; \
__VA_ARGS__ \
} else if (num_mma_q == 2) { \
constexpr size_t NUM_MMA_Q = 2; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported num_mma_q: " << num_mma_q; \
FLASHINFER_ERROR(err_msg.str()); \
}
#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \
if (max_mma_kv >= 8) { \
constexpr size_t NUM_MMA_KV = 8; \
__VA_ARGS__ \
} else if (max_mma_kv >= 4) { \
constexpr size_t NUM_MMA_KV = 4; \
__VA_ARGS__ \
} else if (max_mma_kv >= 2) { \
constexpr size_t NUM_MMA_KV = 2; \
__VA_ARGS__ \
} else if (max_mma_kv >= 1) { \
constexpr size_t NUM_MMA_KV = 1; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \
FLASHINFER_ERROR(err_msg.str()); \
}
#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \
switch (cta_tile_q) { \
case 128: { \
constexpr uint32_t CTA_TILE_Q = 128; \
__VA_ARGS__ \
break; \
} \
case 64: { \
constexpr uint32_t CTA_TILE_Q = 64; \
__VA_ARGS__ \
break; \
} \
case 16: { \
constexpr uint32_t CTA_TILE_Q = 16; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 2) { \
constexpr size_t GROUP_SIZE = 2; \
__VA_ARGS__ \
} else if (group_size == 3) { \
constexpr size_t GROUP_SIZE = 3; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
} else if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported group_size: " << group_size; \
FLASHINFER_ERROR(err_msg.str()); \
}
#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \
switch (mask_mode) { \
case MaskMode::kNone: { \
constexpr MaskMode MASK_MODE = MaskMode::kNone; \
__VA_ARGS__ \
break; \
} \
case MaskMode::kCausal: { \
constexpr MaskMode MASK_MODE = MaskMode::kCausal; \
__VA_ARGS__ \
break; \
} \
case MaskMode::kCustom: { \
constexpr MaskMode MASK_MODE = MaskMode::kCustom; \
__VA_ARGS__ \
break; \
} \
case MaskMode::kMultiItemScoring: { \
constexpr MaskMode MASK_MODE = MaskMode::kMultiItemScoring; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported mask_mode: " << int(mask_mode); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
// convert head_dim to compile-time constant
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 256: { \
constexpr size_t HEAD_DIM = 256; \
__VA_ARGS__ \
break; \
} \
case 512: { \
constexpr size_t HEAD_DIM = 512; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported head_dim: " << head_dim; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
// convert interleave to compile-time constant
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
constexpr bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
constexpr bool INTERLEAVE = false; \
__VA_ARGS__ \
}
#define DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, ...) \
switch (rope_dim) { \
case 16: { \
constexpr uint32_t ROPE_DIM = 16; \
__VA_ARGS__ \
break; \
} \
case 32: { \
constexpr uint32_t ROPE_DIM = 32; \
__VA_ARGS__ \
break; \
} \
case 64: { \
constexpr uint32_t ROPE_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr uint32_t ROPE_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 256: { \
constexpr uint32_t ROPE_DIM = 256; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported ROPE_DIM: " << rope_dim; \
err_msg << ". Supported values: 16, 32, 64, 128, 256"; \
err_msg << " in DISPATCH_ROPE_DIM"; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \
switch (pos_encoding_mode) { \
case PosEncodingMode::kNone: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \
__VA_ARGS__ \
break; \
} \
case PosEncodingMode::kRoPELlama: { \
constexpr PosEncodingMode POS_ENCODING_MODE = \
PosEncodingMode::kRoPELlama; \
__VA_ARGS__ \
break; \
} \
case PosEncodingMode::kALiBi: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
switch (aligned_vec_size) { \
case 16: { \
constexpr size_t ALIGNED_VEC_SIZE = 16; \
__VA_ARGS__ \
break; \
} \
case 8: { \
constexpr size_t ALIGNED_VEC_SIZE = 8; \
__VA_ARGS__ \
break; \
} \
case 4: { \
constexpr size_t ALIGNED_VEC_SIZE = 4; \
__VA_ARGS__ \
break; \
} \
case 2: { \
constexpr size_t ALIGNED_VEC_SIZE = 2; \
__VA_ARGS__ \
break; \
} \
case 1: { \
constexpr size_t ALIGNED_VEC_SIZE = 1; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
FLASHINFER_ERROR(err_msg.str()); \
} \
}
#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM( \
compute_capacity, NUM_STAGES_SMEM, ...) \
if (compute_capacity.first >= 8) { \
constexpr uint32_t NUM_STAGES_SMEM = 2; \
__VA_ARGS__ \
} else { \
constexpr uint32_t NUM_STAGES_SMEM = 1; \
__VA_ARGS__ \
}
namespace flashinfer {
template <typename T1, typename T2>
__forceinline__ __device__ __host__ constexpr T1 ceil_div(const T1 x,
const T2 y) noexcept {
return (x + y - 1) / y;
}
template <typename T1, typename T2>
__forceinline__ __device__ __host__ constexpr T1 round_up(const T1 x,
const T2 y) noexcept {
return ceil_div(x, y) * y;
}
template <typename T1, typename T2>
__forceinline__ __device__ __host__ constexpr T1 round_down(
const T1 x, const T2 y) noexcept {
return (x / y) * y;
}
inline std::pair<int, int> GetCudaComputeCapability() {
int device_id = 0;
cudaGetDevice(&device_id);
int major = 0, minor = 0;
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id);
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id);
return std::make_pair(major, minor);
}
// This function is thread-safe and cached the sm_count.
// But it will only check the current CUDA device, thus assuming each process
// handles single GPU.
inline int GetCudaMultiProcessorCount() {
static std::atomic<int> sm_count{0};
int cached = sm_count.load(std::memory_order_relaxed);
if (cached == 0) {
int device_id;
cudaGetDevice(&device_id);
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
cached = device_prop.multiProcessorCount;
sm_count.store(cached, std::memory_order_relaxed);
}
return cached;
}
template <typename T>
inline void DebugPrintCUDAArray(T* device_ptr,
size_t size,
std::string prefix = "") {
std::vector<T> host_array(size);
std::cout << prefix;
cudaMemcpy(
host_array.data(), device_ptr, size * sizeof(T), cudaMemcpyDeviceToHost);
for (size_t i = 0; i < size; ++i) {
std::cout << host_array[i] << " ";
}
std::cout << std::endl;
}
inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len,
uint32_t head_dim) {
if (avg_packed_qo_len > 64 && head_dim < 256) {
return 128;
} else {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (avg_packed_qo_len > 16) {
// avg_packed_qo_len <= 64
return 64;
} else {
// avg_packed_qo_len <= 16
return 16;
}
} else {
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
return 64;
}
}
}
inline int UpPowerOfTwo(int x) {
// Returns the smallest power of two greater than or equal to x
if (x <= 0) return 1;
--x;
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
return x + 1;
}
#define LOOP_SPLIT_MASK(iter, COND1, COND2, ...) \
{ \
_Pragma("unroll 1") for (; (COND1); (iter) -= 1) { \
constexpr bool WITH_MASK = true; \
__VA_ARGS__ \
} \
_Pragma("unroll 1") for (; (COND2); (iter) -= 1) { \
constexpr bool WITH_MASK = false; \
__VA_ARGS__ \
} \
}
/*!
* \brief Return x - y if x > y, otherwise return 0.
*/
__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
uint32_t y) {
return (x > y) ? x - y : 0U;
}
// ======================= PTX Memory Utility Functions =======================
// Non-atomic global memory access with cache streaming hint (cs)
// These are useful for streaming memory access patterns where data is used once
/*!
* \brief Get the lane ID within a warp (0-31)
*/
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
return lane_id;
}
/*!
* \brief Non-atomic global load for int (4 bytes) with cache streaming hint
*/
__forceinline__ __device__ int ld_na_global_v1(const int* addr) {
int val;
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
return val;
}
/*!
* \brief Non-atomic global load for int2 (8 bytes) with cache streaming hint
*/
__forceinline__ __device__ int2 ld_na_global_v2(const int2* addr) {
int2 val;
asm volatile("ld.global.cs.v2.b32 {%0, %1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "l"(addr));
return val;
}
/*!
* \brief Non-atomic global store for int (4 bytes) with cache streaming hint
*/
__forceinline__ __device__ void st_na_global_v1(int* addr, int val) {
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
}
/*!
* \brief Non-atomic global store for int2 (8 bytes) with cache streaming hint
*/
__forceinline__ __device__ void st_na_global_v2(int2* addr, int2 val) {
asm volatile("st.global.cs.v2.b32 [%0], {%1, %2};" ::"l"(addr),
"r"(val.x),
"r"(val.y));
}
/*!
* \brief Prefetch data to L2 cache
*/
template <typename T>
__forceinline__ __device__ void prefetch_L2(const T* addr) {
asm volatile("prefetch.global.L2 [%0];" ::"l"(addr));
}
__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) {
uint32_t tmp = a;
a = b;
b = tmp;
}
__device__ __forceinline__ uint32_t dim2_offset(const uint32_t& dim_a,
const uint32_t& idx_b,
const uint32_t& idx_a) {
return idx_b * dim_a + idx_a;
}
__device__ __forceinline__ uint32_t dim3_offset(const uint32_t& dim_b,
const uint32_t& dim_a,
const uint32_t& idx_c,
const uint32_t& idx_b,
const uint32_t& idx_a) {
return (idx_c * dim_b + idx_b) * dim_a + idx_a;
}
__device__ __forceinline__ uint32_t dim4_offset(const uint32_t& dim_c,
const uint32_t& dim_b,
const uint32_t& dim_a,
const uint32_t& idx_d,
const uint32_t& idx_c,
const uint32_t& idx_b,
const uint32_t& idx_a) {
return ((idx_d * dim_c + idx_c) * dim_b + idx_b) * dim_a + idx_a;
}
#define DEFINE_HAS_MEMBER(member) \
template <typename T, typename = void> \
struct has_##member : std::false_type {}; \
template <typename T> \
struct has_##member<T, std::void_t<decltype(std::declval<T>().member)>> \
: std::true_type {}; \
template <typename T> \
inline constexpr bool has_##member##_v = has_##member<T>::value;
} // namespace flashinfer
#endif // FLASHINFER_UTILS_CUH_
File diff suppressed because it is too large Load Diff