[Models][OP][Optimization] Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM (#6689)

* Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM
This commit is contained in:
AIbin
2026-03-10 15:05:14 +08:00
committed by GitHub
parent 25c479312d
commit c3aceb6bdc
22 changed files with 8022 additions and 143 deletions
@@ -0,0 +1,616 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* DeepSeek Kv3.2 (DsKv3.2) Attention WriteCache Implementation
*
* This file implements writecache operations for DeepSeek MLA (Multi-head
* Latent Attention) with FP8 quantization support, migrated from vLLM.
*
* Key features:
* 1. DS MLA FP8 cache format (656 bytes per token):
* - 512 bytes: quantized NoPE part (fp8_e4m3)
* - 16 bytes: scale factors (4 x float32)
* - 128 bytes: RoPE part (64 x bf16, unquantized)
*
* 2. Standard MLA cache format (kv_lora_rank + pe_dim elements)
*
* 3. Indexer K quantization and cache operations
*/
#include "ds_mla_cache_kernel.cuh"
#include "helper.h"
#include "remote_cache_kv_ipc.h"
//==============================================================================
// DS MLA FP8 WriteCache Implementation
//==============================================================================
/**
* Prefill stage: Write KV cache with DS MLA FP8 format
*/
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto num_tokens = slot_mapping.dims()[0];
auto kv_lora_rank = 512; // DS MLA uses 512
auto pe_dim = 64; // DS MLA uses 64
auto block_size = meta_data.block_size;
// Entry size for DS MLA FP8: 512 (fp8) + 16 (scales) + 128 (rope bf16) = 656
// bytes
const int entry_size = 656;
// Launch kernel with 96 threads (64 for NoPE, 32 for RoPE)
dim3 grid(num_tokens);
dim3 block(96);
const auto& kv_cache_dims = kv_cache->dims();
int block_stride = kv_cache->strides()[0];
int entry_stride = entry_size;
int kv_c_stride = kv_nope.strides()[0];
int k_pe_stride = kv_pe.strides()[0];
ds_mla::concat_and_cache_ds_mla_kernel<DataType_><<<grid, block, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
slot_mapping.data<int64_t>(),
block_stride,
entry_stride,
kv_c_stride,
k_pe_stride,
kv_lora_rank,
pe_dim,
block_size);
// Handle PD disaggregation signal
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
/**
* Decode stage: Write KV cache with DS MLA FP8 format
*/
template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeDSMLAWriteCacheFP8(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto num_tokens = slot_mapping.dims()[0];
auto kv_lora_rank = 512;
auto pe_dim = 64;
auto block_size = meta_data.block_size;
const int entry_size = 656;
dim3 grid(num_tokens);
dim3 block(96);
const auto& kv_cache_dims = kv_cache->dims();
int block_stride = kv_cache->strides()[0];
int entry_stride = entry_size;
int kv_c_stride = kv_nope.strides()[0];
int k_pe_stride = kv_pe.strides()[0];
ds_mla::concat_and_cache_ds_mla_kernel<DataType_><<<grid, block, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
slot_mapping.data<int64_t>(),
block_stride,
entry_stride,
kv_c_stride,
k_pe_stride,
kv_lora_rank,
pe_dim,
block_size);
return {};
}
//==============================================================================
// Standard MLA WriteCache Implementation
//==============================================================================
/**
* Prefill stage: Write KV cache with standard MLA format
*/
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillDSMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const float* scale,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto num_tokens = slot_mapping.dims()[0];
auto kv_lora_rank = meta_data.head_dims_v;
auto pe_dim = meta_data.head_dims - meta_data.head_dims_v;
auto block_size = meta_data.block_size;
const auto& kv_cache_dims = kv_cache->dims();
int block_stride = kv_cache->strides()[0];
int entry_stride = kv_cache->strides()[1];
int kv_c_stride = kv_nope.strides()[0];
int k_pe_stride = kv_pe.strides()[0];
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
ds_mla::concat_and_cache_mla_kernel<DataType_, DataType_>
<<<grid, block, 0, stream>>>(
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
slot_mapping.data<int64_t>(),
block_stride,
entry_stride,
kv_c_stride,
k_pe_stride,
kv_lora_rank,
pe_dim,
block_size,
scale);
// Handle PD disaggregation signal
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
//==============================================================================
// Indexer K Quantization and Cache Operations
//==============================================================================
/**
* Quantize K tensor to FP8 and write to cache
*/
template <paddle::DataType T>
std::vector<paddle::Tensor> IndexerKQuantAndCache(
const paddle::Tensor& k,
const paddle::Tensor& slot_mapping,
const int head_dim,
const int quant_block_size,
const int cache_block_size,
const int cache_stride,
const bool use_ue8m0,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
int num_tokens = k.dims()[0];
constexpr int vec_size = 4;
dim3 grid(num_tokens,
(head_dim + quant_block_size * vec_size - 1) /
(quant_block_size * vec_size));
dim3 block(32, vec_size);
ds_mla::indexer_k_quant_and_cache_kernel<DataType_>
<<<grid, block, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(k.data<data_t>())),
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
slot_mapping.data<int64_t>(),
head_dim,
quant_block_size,
cache_block_size,
cache_stride,
use_ue8m0);
return {};
}
/**
* Gather K from quantized cache
*/
void CpGatherIndexerKQuantCache(const paddle::Tensor& kv_cache,
paddle::Tensor& dst_k,
paddle::Tensor& dst_scale,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seq_lens,
cudaStream_t& stream) {
int batch_size = block_table.dims()[0];
int num_tokens = dst_k.dims()[0];
int head_dim = dst_k.dims()[1];
int quant_block_size = head_dim * 4 / dst_scale.dims()[1];
constexpr int vec_size = 16;
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
ds_mla::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
(head_dim + 8 * vec_size - 1) / (8 * vec_size)), \
dim3(8, BLOCK_Y_SIZE), \
0, \
stream>>>(reinterpret_cast<const char*>(kv_cache.data<uint8_t>()), \
reinterpret_cast<char*>(dst_k.data<uint8_t>()), \
reinterpret_cast<char*>(dst_scale.data<float>()), \
block_table.data<int>(), \
cu_seq_lens.data<int>(), \
batch_size, \
dst_k.strides()[0], \
dst_k.dims()[1], \
kv_cache.strides()[0], \
kv_cache.strides()[1], \
kv_cache.dims()[1], \
block_table.dims()[1], \
num_tokens, \
quant_block_size);
if (num_tokens < 32) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
} else if (num_tokens < 64) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2);
} else if (num_tokens < 128) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4);
} else if (num_tokens < 256) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8);
} else if (num_tokens < 512) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16);
} else {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
}
#undef CALL_CP_GATHER_INDEXER_K_QUANT_CACHE
}
//==============================================================================
// Kernel Entry Points
//==============================================================================
/**
* DS MLA WriteCache entry point - supports both FP8 and standard formats
*/
std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& scale,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool is_prefill) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = seq_lens_decoder.dims()[0];
const float* scale_ptr = scale ? scale.get().data<float>() : nullptr;
if (cache_quant_type_str == "fp8_ds_mla") {
// FP8 DS MLA format
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
if (is_prefill) {
return PrefillDSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
} else {
return DecodeDSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
false,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
case paddle::DataType::FLOAT16: {
if (is_prefill) {
return PrefillDSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
} else {
return DecodeDSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
false,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
default:
PD_THROW("Unsupported dtype for DS MLA FP8 cache");
}
} else {
// Standard MLA format (auto/bf16/fp16)
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillDSMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
scale_ptr,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillDSMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
scale_ptr,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
default:
PD_THROW("Unsupported dtype for DS MLA cache");
}
}
return {};
}
/**
* Indexer K Quant and Cache entry point
*/
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
const paddle::Tensor& k,
const paddle::Tensor& kv_cache,
const paddle::Tensor& slot_mapping,
const int64_t quant_block_size,
const std::string& scale_fmt) {
cudaStream_t stream = k.stream();
int num_tokens = k.dims()[0];
int head_dim = k.dims()[1];
int cache_block_size = kv_cache.dims()[1];
int cache_stride = kv_cache.dims()[2];
bool use_ue8m0 = scale_fmt == "ue8m0";
switch (k.dtype()) {
case paddle::DataType::BFLOAT16: {
return IndexerKQuantAndCache<paddle::DataType::BFLOAT16>(
k,
slot_mapping,
head_dim,
quant_block_size,
cache_block_size,
cache_stride,
use_ue8m0,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return IndexerKQuantAndCache<paddle::DataType::FLOAT16>(
k,
slot_mapping,
head_dim,
quant_block_size,
cache_block_size,
cache_stride,
use_ue8m0,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
default:
PD_THROW("Unsupported dtype for Indexer K Quant");
}
return {};
}
/**
* Gather Indexer K from Quant Cache entry point
*/
std::vector<paddle::Tensor> CpGatherIndexerKQuantCacheKernel(
const paddle::Tensor& kv_cache,
paddle::Tensor& dst_k,
paddle::Tensor& dst_scale,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seq_lens) {
cudaStream_t stream = kv_cache.stream();
CpGatherIndexerKQuantCache(
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens, stream);
return {};
}
//==============================================================================
// Paddle Custom Operator Registration
//==============================================================================
PD_BUILD_STATIC_OP(ds_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"slot_mapping",
"seq_lens",
"seq_lens_decoder",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
paddle::Optional("kv_signal_data"),
paddle::Optional("scale")})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int",
"is_prefill: bool"})
.SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel));
PD_BUILD_STATIC_OP(indexer_k_quant_and_cache)
.Inputs({"k", "kv_cache", "slot_mapping"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"quant_block_size: int64_t", "scale_fmt: std::string"})
.SetKernelFn(PD_KERNEL(IndexerKQuantAndCacheKernel));
PD_BUILD_STATIC_OP(cp_gather_indexer_k_quant_cache)
.Inputs({"kv_cache", "dst_k", "dst_scale", "block_table", "cu_seq_lens"})
.Outputs({"dst_k_out", "dst_scale_out"})
.SetInplaceMap({{"dst_k", "dst_k_out"}, {"dst_scale", "dst_scale_out"}})
.SetKernelFn(PD_KERNEL(CpGatherIndexerKQuantCacheKernel));
@@ -0,0 +1,548 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cfloat>
#include "helper.h"
#include "mem_util.cuh"
#include "utils.cuh"
// FP8 scale divisor constant (for SM90+)
#if defined(__gfx942__)
constexpr float kFp8ScaleDivisorDS = 224.f;
#else
constexpr float kFp8ScaleDivisorDS = 448.f;
#endif
namespace ds_mla {
/**
* FP8 scaled conversion utilities
*/
template <typename OutT, typename InT>
__device__ __forceinline__ OutT fp8_scaled_convert(InT src, float scale) {
return static_cast<OutT>(static_cast<float>(src) / scale);
}
template <>
__device__ __forceinline__ uint8_t
fp8_scaled_convert<uint8_t, __nv_bfloat16>(__nv_bfloat16 src, float scale) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
float val = __bfloat162float(src) / scale;
val = fminf(fmaxf(val, -448.0f), 448.0f);
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
return *reinterpret_cast<uint8_t*>(&fp8_val);
#else
return 0;
#endif
}
template <>
__device__ __forceinline__ uint8_t
fp8_scaled_convert<uint8_t, half>(half src, float scale) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
float val = __half2float(src) / scale;
val = fminf(fmaxf(val, -448.0f), 448.0f);
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
return *reinterpret_cast<uint8_t*>(&fp8_val);
#else
return 0;
#endif
}
/**
* DeepSeek MLA FP8 Cache Write Kernel
*
* Cache format (fp8_ds_mla - 656 bytes per token):
* - First 512 bytes: quantized NoPE part (512 x fp8_e4m3)
* - Next 16 bytes: scale factors (4 x float32, one per 128 fp8 values)
* - Last 128 bytes: RoPE part (64 x bfloat16, not quantized)
*
* Thread organization:
* - First 2 warps (64 threads): handle NoPE FP8 quantization
* - Last 1 warp (32 threads): handle RoPE copy
* - Total: 96 threads per block
*/
template <typename scalar_t>
__global__ void concat_and_cache_ds_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
uint8_t* __restrict__ kv_cache, // [num_blocks, block_size,
// cache_entry_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, // stride per block in cache
const int entry_stride, // stride per token entry in cache
const int kv_c_stride, // stride for kv_c input
const int k_pe_stride, // stride for k_pe input
const int kv_lora_rank, // 512 for DS MLA
const int pe_dim, // 64 for DS MLA
const int block_size // number of tokens per cache block
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int64_t dst_idx_start =
block_idx * block_stride + block_offset * entry_stride;
// Cast kv_cache to 16-bit for RoPE values
scalar_t* kv_cache_16bit =
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
// The last warp handles the RoPE part
if (threadIdx.x >= 64) {
// Each thread handles two elements of RoPE
const int8_t pe_idx_start = (threadIdx.x - 64) * 2;
const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start;
// Vectorized load of two 16-bit values, performed as one 32-bit load
const int32_t vals = *reinterpret_cast<const int32_t*>(&k_pe[src_idx]);
// RoPE values start after the packed 8-bit NoPE values and the 32-bit
// scales Position: kv_lora_rank/2 (256 bytes in 16-bit units) + 8 (16 bytes
// of scales in 16-bit units)
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start;
// Vectorized store of two 16-bit values
*reinterpret_cast<int32_t*>(&kv_cache_16bit[dst_idx]) = vals;
return;
}
// The first two warps handle the NoPE part
const int8_t warp_idx = threadIdx.x >> 5;
const int8_t lane_idx = threadIdx.x & 31;
const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4);
// Each thread handles 8 elements of NoPE
const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8);
// Vectorized load of eight 16-bit values
const int4 vals_i4 = *reinterpret_cast<const int4*>(&kv_c[src_idx_start]);
const scalar_t* vals = reinterpret_cast<const scalar_t*>(&vals_i4);
// Max absolute value of this thread's elements
float max_abs = fmaxf(fmaxf(fmaxf(fabsf(static_cast<float>(vals[0])),
fabsf(static_cast<float>(vals[1]))),
fmaxf(fabsf(static_cast<float>(vals[2])),
fabsf(static_cast<float>(vals[3])))),
fmaxf(fmaxf(fabsf(static_cast<float>(vals[4])),
fabsf(static_cast<float>(vals[5]))),
fmaxf(fabsf(static_cast<float>(vals[6])),
fabsf(static_cast<float>(vals[7])))));
// Warp-level reduction to find the max absolute value in each half-warp
#pragma unroll
for (int offset = 8; offset > 0; offset /= 2) {
max_abs = fmaxf(max_abs, __shfl_xor_sync(0xFFFF, max_abs, offset, 16));
}
// Compute the scale for the tile
float tile_scale = fmaxf(max_abs / kFp8ScaleDivisorDS, FLT_MIN);
// The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) {
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
kv_cache_32bit[dst_idx] = tile_scale;
}
// Now all threads in the block scale and write their elements
const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8);
uint8_t result[8];
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] = fp8_scaled_convert<uint8_t, scalar_t>(vals[i], tile_scale);
}
// Store as aligned 64-bit writes
*reinterpret_cast<uint64_t*>(&kv_cache[dst_idx_base]) =
*reinterpret_cast<const uint64_t*>(result);
}
/**
* Standard MLA Cache Write Kernel (non-FP8)
*
* For auto/bf16/fp16 cache types
*/
template <typename scalar_t, typename cache_t>
__global__ void concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride,
const int entry_stride,
const int kv_c_stride,
const int k_pe_stride,
const int kv_lora_rank,
const int pe_dim,
const int block_size,
const float* scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
// Copy kv_c (NoPE part)
for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) {
const int64_t src_idx = token_idx * kv_c_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i;
kv_cache[dst_idx] = static_cast<cache_t>(kv_c[src_idx]);
}
// Copy k_pe (RoPE part)
for (int i = threadIdx.x; i < pe_dim; i += blockDim.x) {
const int64_t src_idx = token_idx * k_pe_stride + i;
const int64_t dst_idx = block_idx * block_stride +
block_offset * entry_stride + kv_lora_rank + i;
kv_cache[dst_idx] = static_cast<cache_t>(k_pe[src_idx]);
}
}
/**
* Indexer K Quantization and Cache Kernel
*
* Quantizes K values to FP8 and stores them in cache with scale factors
* Cache layout: [quantized_k (head_dim bytes)] + [scales
* (head_dim/quant_block_size * 4 bytes)]
*/
template <typename scalar_t>
__global__ void indexer_k_quant_and_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
uint8_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim,
const int quant_block_size, // typically 128
const int cache_block_size,
const int cache_stride,
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
if (slot_idx < 0 || head_dim_idx >= head_dim) {
return;
}
// Load 4 values at once using float2 (for bf16/fp16)
float2 k_val = reinterpret_cast<const float2*>(
k)[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
float amax = 0.0f;
for (int i = 0; i < VEC_SIZE; i++) {
amax = fmaxf(amax, fabsf(static_cast<float>(k_val_ptr[i])));
}
// Warp reduction to find max across quant_block_size elements
for (int mask = 16; mask > 0; mask /= 2) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask));
}
float scale = fmaxf(amax, 1e-4f) / kFp8ScaleDivisorDS;
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] =
fp8_scaled_convert<uint8_t, scalar_t>(k_val_ptr[i], scale);
}
// First thread in warp writes the scale
if (threadIdx.x == 0) {
const int64_t dst_scale_idx =
block_idx * cache_block_size * cache_stride +
cache_block_size * head_dim +
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
}
}
/**
* Gather Indexer K from Quantized Cache Kernel
*
* Gathers and dequantizes K values from the cache
*/
template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel(
const char* __restrict__ kv_cache, // [num_blocks, block_size,
// cache_stride]
char* __restrict__ dst_k, // [num_tokens, head_dim]
char* __restrict__ dst_scale, // [num_tokens, head_dim/quant_block_size*4]
const int* __restrict__ block_table, // [batch_size, num_blocks]
const int* __restrict__ cu_seq_lens, // [batch_size + 1]
const int batch_size,
const int64_t token_stride,
const int64_t head_dim,
const int64_t block_stride,
const int64_t cache_token_stride,
const int64_t cache_block_size,
const int num_blocks,
const int num_tokens,
const int quant_block_size) {
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
const int token_idx = blockIdx.x * blockDim.y + threadIdx.y;
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
// Find batch index within a block
__shared__ int batch_idx[BLOCK_Y_SIZE];
for (int iter = 0; iter < (batch_size + blockDim.x - 1) / blockDim.x;
iter++) {
int tid = iter * blockDim.x + threadIdx.x;
if (tid < batch_size) {
const int seq_start = cu_seq_lens[tid];
const int seq_end = cu_seq_lens[tid + 1];
if (token_idx >= seq_start && token_idx < seq_end) {
batch_idx[threadIdx.y] = tid;
}
}
}
__syncwarp();
if (head_idx >= head_dim || token_idx >= num_tokens) {
return;
}
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
const int block_id = block_table[batch_idx[threadIdx.y] * num_blocks +
inbatch_seq_idx / cache_block_size];
const int64_t src_block_offset = block_id * block_stride;
const int64_t cache_inblock_offset =
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
if (threadIdx.x == 0) {
const int64_t src_scale_offset =
src_block_offset + cache_block_size * head_dim +
cache_inblock_offset * 4 / quant_block_size;
reinterpret_cast<float*>(dst_scale)[dst_inblock_offset / quant_block_size] =
reinterpret_cast<const float*>(kv_cache)[src_scale_offset / 4];
}
}
/**
* Prefill DS MLA Write Cache Kernel
*
* Writes prefill KV data to DS MLA cache format
*/
template <typename T, int VecSize = 1>
__global__ void prefill_ds_mla_cache_kernel(
const T* __restrict__ kv_nope, // [num_tokens, kv_num_heads * nope_size]
const T* __restrict__ kv_pe, // [num_tokens, kv_num_heads * pe_size]
uint8_t* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// entry_size]
const int* __restrict__ block_tables,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens,
const int* __restrict__ seq_lens_decoder,
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size, // 512 for DS MLA
const int pe_size, // 64 for DS MLA
const int block_size,
const int entry_size, // 656 for DS MLA FP8
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const uint32_t token_idx = linear_index / hidden_size;
const uint32_t bias = linear_index % hidden_size;
const uint32_t ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
(token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
const uint32_t block_offset = ori_seq_id % block_size;
if (bias < nope_hidden_size) {
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
// For DS MLA FP8, NoPE part goes to first 512 bytes
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * entry_size +
hi * block_size * entry_size + block_offset * entry_size + h_bias;
const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
// Convert to FP8 and store
for (int i = 0; i < VecSize; i++) {
float val = static_cast<float>(src_vec.val[i]);
val = fminf(fmaxf(val, -448.0f), 448.0f);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
kv_cache[tgt_idx + i] = *reinterpret_cast<uint8_t*>(&fp8_val);
#endif
}
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
// RoPE part goes after NoPE (512 bytes) + scales (16 bytes)
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * entry_size +
hi * block_size * entry_size + block_offset * entry_size + nope_size +
16 + h_bias * 2; // *2 for bf16
const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
// Copy RoPE without quantization (as bf16/fp16)
T* tgt_ptr = reinterpret_cast<T*>(&kv_cache[tgt_idx]);
for (int i = 0; i < VecSize; i++) {
tgt_ptr[i] = src_vec.val[i];
}
}
}
}
/**
* Decode DS MLA Write Cache Kernel
*/
template <typename T, int VecSize = 1>
__global__ void decode_ds_mla_cache_kernel(
const T* __restrict__ kv_nope,
const T* __restrict__ kv_pe,
uint8_t* __restrict__ kv_cache,
const int* __restrict__ block_tables,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens,
const int* __restrict__ seq_lens_encoder,
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const int entry_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
if (bias < nope_hidden_size) {
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * entry_size +
hi * block_size * entry_size + block_offset * entry_size + h_bias;
const uint32_t ori_idx = start_token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
for (int i = 0; i < VecSize; i++) {
float val = static_cast<float>(src_vec.val[i]);
val = fminf(fmaxf(val, -448.0f), 448.0f);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
kv_cache[tgt_idx + i] = *reinterpret_cast<uint8_t*>(&fp8_val);
#endif
}
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx =
block_idx * kv_num_heads * block_size * entry_size +
hi * block_size * entry_size + block_offset * entry_size + nope_size +
16 + h_bias * 2;
const uint32_t ori_idx = start_token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
T* tgt_ptr = reinterpret_cast<T*>(&kv_cache[tgt_idx]);
for (int i = 0; i < VecSize; i++) {
tgt_ptr[i] = src_vec.val[i];
}
}
}
}
} // namespace ds_mla