mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Models][OP][Optimization] Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM (#6689)
* Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM
This commit is contained in:
@@ -0,0 +1,616 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/**
|
||||
* DeepSeek Kv3.2 (DsKv3.2) Attention WriteCache Implementation
|
||||
*
|
||||
* This file implements writecache operations for DeepSeek MLA (Multi-head
|
||||
* Latent Attention) with FP8 quantization support, migrated from vLLM.
|
||||
*
|
||||
* Key features:
|
||||
* 1. DS MLA FP8 cache format (656 bytes per token):
|
||||
* - 512 bytes: quantized NoPE part (fp8_e4m3)
|
||||
* - 16 bytes: scale factors (4 x float32)
|
||||
* - 128 bytes: RoPE part (64 x bf16, unquantized)
|
||||
*
|
||||
* 2. Standard MLA cache format (kv_lora_rank + pe_dim elements)
|
||||
*
|
||||
* 3. Indexer K quantization and cache operations
|
||||
*/
|
||||
|
||||
#include "ds_mla_cache_kernel.cuh"
|
||||
#include "helper.h"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
//==============================================================================
|
||||
// DS MLA FP8 WriteCache Implementation
|
||||
//==============================================================================
|
||||
|
||||
/**
|
||||
* Prefill stage: Write KV cache with DS MLA FP8 format
|
||||
*/
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto num_tokens = slot_mapping.dims()[0];
|
||||
auto kv_lora_rank = 512; // DS MLA uses 512
|
||||
auto pe_dim = 64; // DS MLA uses 64
|
||||
auto block_size = meta_data.block_size;
|
||||
|
||||
// Entry size for DS MLA FP8: 512 (fp8) + 16 (scales) + 128 (rope bf16) = 656
|
||||
// bytes
|
||||
const int entry_size = 656;
|
||||
|
||||
// Launch kernel with 96 threads (64 for NoPE, 32 for RoPE)
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(96);
|
||||
|
||||
const auto& kv_cache_dims = kv_cache->dims();
|
||||
int block_stride = kv_cache->strides()[0];
|
||||
int entry_stride = entry_size;
|
||||
int kv_c_stride = kv_nope.strides()[0];
|
||||
int k_pe_stride = kv_pe.strides()[0];
|
||||
|
||||
ds_mla::concat_and_cache_ds_mla_kernel<DataType_><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
|
||||
slot_mapping.data<int64_t>(),
|
||||
block_stride,
|
||||
entry_stride,
|
||||
kv_c_stride,
|
||||
k_pe_stride,
|
||||
kv_lora_rank,
|
||||
pe_dim,
|
||||
block_size);
|
||||
|
||||
// Handle PD disaggregation signal
|
||||
const char* fmt_write_cache_completed_signal_str =
|
||||
std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||
const char* FLAGS_use_pd_disaggregation_per_chunk =
|
||||
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
|
||||
|
||||
if (fmt_write_cache_completed_signal_str &&
|
||||
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
|
||||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
|
||||
if (FLAGS_use_pd_disaggregation_per_chunk &&
|
||||
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
|
||||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&(RemoteCacheKvIpc::
|
||||
save_cache_kv_complete_signal_layerwise_per_query),
|
||||
(void*)nullptr);
|
||||
} else {
|
||||
if (kv_signal_data) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
|
||||
(void*)(const_cast<int64_t*>(
|
||||
kv_signal_data.get().data<int64_t>())));
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode stage: Write KV cache with DS MLA FP8 format
|
||||
*/
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> DecodeDSMLAWriteCacheFP8(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto num_tokens = slot_mapping.dims()[0];
|
||||
auto kv_lora_rank = 512;
|
||||
auto pe_dim = 64;
|
||||
auto block_size = meta_data.block_size;
|
||||
const int entry_size = 656;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(96);
|
||||
|
||||
const auto& kv_cache_dims = kv_cache->dims();
|
||||
int block_stride = kv_cache->strides()[0];
|
||||
int entry_stride = entry_size;
|
||||
int kv_c_stride = kv_nope.strides()[0];
|
||||
int k_pe_stride = kv_pe.strides()[0];
|
||||
|
||||
ds_mla::concat_and_cache_ds_mla_kernel<DataType_><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
|
||||
slot_mapping.data<int64_t>(),
|
||||
block_stride,
|
||||
entry_stride,
|
||||
kv_c_stride,
|
||||
k_pe_stride,
|
||||
kv_lora_rank,
|
||||
pe_dim,
|
||||
block_size);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
//==============================================================================
|
||||
// Standard MLA WriteCache Implementation
|
||||
//==============================================================================
|
||||
|
||||
/**
|
||||
* Prefill stage: Write KV cache with standard MLA format
|
||||
*/
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillDSMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const float* scale,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto num_tokens = slot_mapping.dims()[0];
|
||||
auto kv_lora_rank = meta_data.head_dims_v;
|
||||
auto pe_dim = meta_data.head_dims - meta_data.head_dims_v;
|
||||
auto block_size = meta_data.block_size;
|
||||
|
||||
const auto& kv_cache_dims = kv_cache->dims();
|
||||
int block_stride = kv_cache->strides()[0];
|
||||
int entry_stride = kv_cache->strides()[1];
|
||||
int kv_c_stride = kv_nope.strides()[0];
|
||||
int k_pe_stride = kv_pe.strides()[0];
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
|
||||
ds_mla::concat_and_cache_mla_kernel<DataType_, DataType_>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(
|
||||
const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(
|
||||
const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
slot_mapping.data<int64_t>(),
|
||||
block_stride,
|
||||
entry_stride,
|
||||
kv_c_stride,
|
||||
k_pe_stride,
|
||||
kv_lora_rank,
|
||||
pe_dim,
|
||||
block_size,
|
||||
scale);
|
||||
|
||||
// Handle PD disaggregation signal
|
||||
const char* fmt_write_cache_completed_signal_str =
|
||||
std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||
const char* FLAGS_use_pd_disaggregation_per_chunk =
|
||||
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
|
||||
|
||||
if (fmt_write_cache_completed_signal_str &&
|
||||
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
|
||||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
|
||||
if (FLAGS_use_pd_disaggregation_per_chunk &&
|
||||
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
|
||||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&(RemoteCacheKvIpc::
|
||||
save_cache_kv_complete_signal_layerwise_per_query),
|
||||
(void*)nullptr);
|
||||
} else {
|
||||
if (kv_signal_data) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
|
||||
(void*)(const_cast<int64_t*>(
|
||||
kv_signal_data.get().data<int64_t>())));
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//==============================================================================
|
||||
// Indexer K Quantization and Cache Operations
|
||||
//==============================================================================
|
||||
|
||||
/**
|
||||
* Quantize K tensor to FP8 and write to cache
|
||||
*/
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> IndexerKQuantAndCache(
|
||||
const paddle::Tensor& k,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const int head_dim,
|
||||
const int quant_block_size,
|
||||
const int cache_block_size,
|
||||
const int cache_stride,
|
||||
const bool use_ue8m0,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int num_tokens = k.dims()[0];
|
||||
|
||||
constexpr int vec_size = 4;
|
||||
dim3 grid(num_tokens,
|
||||
(head_dim + quant_block_size * vec_size - 1) /
|
||||
(quant_block_size * vec_size));
|
||||
dim3 block(32, vec_size);
|
||||
|
||||
ds_mla::indexer_k_quant_and_cache_kernel<DataType_>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(k.data<data_t>())),
|
||||
reinterpret_cast<uint8_t*>(kv_cache->data<uint8_t>()),
|
||||
slot_mapping.data<int64_t>(),
|
||||
head_dim,
|
||||
quant_block_size,
|
||||
cache_block_size,
|
||||
cache_stride,
|
||||
use_ue8m0);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gather K from quantized cache
|
||||
*/
|
||||
void CpGatherIndexerKQuantCache(const paddle::Tensor& kv_cache,
|
||||
paddle::Tensor& dst_k,
|
||||
paddle::Tensor& dst_scale,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seq_lens,
|
||||
cudaStream_t& stream) {
|
||||
int batch_size = block_table.dims()[0];
|
||||
int num_tokens = dst_k.dims()[0];
|
||||
int head_dim = dst_k.dims()[1];
|
||||
int quant_block_size = head_dim * 4 / dst_scale.dims()[1];
|
||||
|
||||
constexpr int vec_size = 16;
|
||||
|
||||
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
|
||||
ds_mla::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
|
||||
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
|
||||
(head_dim + 8 * vec_size - 1) / (8 * vec_size)), \
|
||||
dim3(8, BLOCK_Y_SIZE), \
|
||||
0, \
|
||||
stream>>>(reinterpret_cast<const char*>(kv_cache.data<uint8_t>()), \
|
||||
reinterpret_cast<char*>(dst_k.data<uint8_t>()), \
|
||||
reinterpret_cast<char*>(dst_scale.data<float>()), \
|
||||
block_table.data<int>(), \
|
||||
cu_seq_lens.data<int>(), \
|
||||
batch_size, \
|
||||
dst_k.strides()[0], \
|
||||
dst_k.dims()[1], \
|
||||
kv_cache.strides()[0], \
|
||||
kv_cache.strides()[1], \
|
||||
kv_cache.dims()[1], \
|
||||
block_table.dims()[1], \
|
||||
num_tokens, \
|
||||
quant_block_size);
|
||||
|
||||
if (num_tokens < 32) {
|
||||
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
|
||||
} else if (num_tokens < 64) {
|
||||
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2);
|
||||
} else if (num_tokens < 128) {
|
||||
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4);
|
||||
} else if (num_tokens < 256) {
|
||||
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8);
|
||||
} else if (num_tokens < 512) {
|
||||
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16);
|
||||
} else {
|
||||
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
|
||||
}
|
||||
|
||||
#undef CALL_CP_GATHER_INDEXER_K_QUANT_CACHE
|
||||
}
|
||||
|
||||
//==============================================================================
|
||||
// Kernel Entry Points
|
||||
//==============================================================================
|
||||
|
||||
/**
|
||||
* DS MLA WriteCache entry point - supports both FP8 and standard formats
|
||||
*/
|
||||
std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& scale,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool is_prefill) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size =
|
||||
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = seq_lens_decoder.dims()[0];
|
||||
|
||||
const float* scale_ptr = scale ? scale.get().data<float>() : nullptr;
|
||||
|
||||
if (cache_quant_type_str == "fp8_ds_mla") {
|
||||
// FP8 DS MLA format
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
if (is_prefill) {
|
||||
return PrefillDSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
} else {
|
||||
return DecodeDSMLAWriteCacheFP8<paddle::DataType::BFLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
false,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
if (is_prefill) {
|
||||
return PrefillDSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
} else {
|
||||
return DecodeDSMLAWriteCacheFP8<paddle::DataType::FLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
false,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
default:
|
||||
PD_THROW("Unsupported dtype for DS MLA FP8 cache");
|
||||
}
|
||||
} else {
|
||||
// Standard MLA format (auto/bf16/fp16)
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillDSMLAWriteCache<paddle::DataType::BFLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
scale_ptr,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return PrefillDSMLAWriteCache<paddle::DataType::FLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
scale_ptr,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
default:
|
||||
PD_THROW("Unsupported dtype for DS MLA cache");
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Indexer K Quant and Cache entry point
|
||||
*/
|
||||
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
|
||||
const paddle::Tensor& k,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const int64_t quant_block_size,
|
||||
const std::string& scale_fmt) {
|
||||
cudaStream_t stream = k.stream();
|
||||
int num_tokens = k.dims()[0];
|
||||
int head_dim = k.dims()[1];
|
||||
int cache_block_size = kv_cache.dims()[1];
|
||||
int cache_stride = kv_cache.dims()[2];
|
||||
bool use_ue8m0 = scale_fmt == "ue8m0";
|
||||
|
||||
switch (k.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return IndexerKQuantAndCache<paddle::DataType::BFLOAT16>(
|
||||
k,
|
||||
slot_mapping,
|
||||
head_dim,
|
||||
quant_block_size,
|
||||
cache_block_size,
|
||||
cache_stride,
|
||||
use_ue8m0,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return IndexerKQuantAndCache<paddle::DataType::FLOAT16>(
|
||||
k,
|
||||
slot_mapping,
|
||||
head_dim,
|
||||
quant_block_size,
|
||||
cache_block_size,
|
||||
cache_stride,
|
||||
use_ue8m0,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
default:
|
||||
PD_THROW("Unsupported dtype for Indexer K Quant");
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Gather Indexer K from Quant Cache entry point
|
||||
*/
|
||||
std::vector<paddle::Tensor> CpGatherIndexerKQuantCacheKernel(
|
||||
const paddle::Tensor& kv_cache,
|
||||
paddle::Tensor& dst_k,
|
||||
paddle::Tensor& dst_scale,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seq_lens) {
|
||||
cudaStream_t stream = kv_cache.stream();
|
||||
CpGatherIndexerKQuantCache(
|
||||
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens, stream);
|
||||
return {};
|
||||
}
|
||||
|
||||
//==============================================================================
|
||||
// Paddle Custom Operator Registration
|
||||
//==============================================================================
|
||||
|
||||
PD_BUILD_STATIC_OP(ds_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"slot_mapping",
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables",
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("scale")})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int",
|
||||
"is_prefill: bool"})
|
||||
.SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_STATIC_OP(indexer_k_quant_and_cache)
|
||||
.Inputs({"k", "kv_cache", "slot_mapping"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"quant_block_size: int64_t", "scale_fmt: std::string"})
|
||||
.SetKernelFn(PD_KERNEL(IndexerKQuantAndCacheKernel));
|
||||
|
||||
PD_BUILD_STATIC_OP(cp_gather_indexer_k_quant_cache)
|
||||
.Inputs({"kv_cache", "dst_k", "dst_scale", "block_table", "cu_seq_lens"})
|
||||
.Outputs({"dst_k_out", "dst_scale_out"})
|
||||
.SetInplaceMap({{"dst_k", "dst_k_out"}, {"dst_scale", "dst_scale_out"}})
|
||||
.SetKernelFn(PD_KERNEL(CpGatherIndexerKQuantCacheKernel));
|
||||
@@ -0,0 +1,548 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cfloat>
|
||||
#include "helper.h"
|
||||
#include "mem_util.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
// FP8 scale divisor constant (for SM90+)
|
||||
#if defined(__gfx942__)
|
||||
constexpr float kFp8ScaleDivisorDS = 224.f;
|
||||
#else
|
||||
constexpr float kFp8ScaleDivisorDS = 448.f;
|
||||
#endif
|
||||
|
||||
namespace ds_mla {
|
||||
|
||||
/**
|
||||
* FP8 scaled conversion utilities
|
||||
*/
|
||||
template <typename OutT, typename InT>
|
||||
__device__ __forceinline__ OutT fp8_scaled_convert(InT src, float scale) {
|
||||
return static_cast<OutT>(static_cast<float>(src) / scale);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ uint8_t
|
||||
fp8_scaled_convert<uint8_t, __nv_bfloat16>(__nv_bfloat16 src, float scale) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
||||
float val = __bfloat162float(src) / scale;
|
||||
val = fminf(fmaxf(val, -448.0f), 448.0f);
|
||||
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
|
||||
return *reinterpret_cast<uint8_t*>(&fp8_val);
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ uint8_t
|
||||
fp8_scaled_convert<uint8_t, half>(half src, float scale) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
||||
float val = __half2float(src) / scale;
|
||||
val = fminf(fmaxf(val, -448.0f), 448.0f);
|
||||
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
|
||||
return *reinterpret_cast<uint8_t*>(&fp8_val);
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* DeepSeek MLA FP8 Cache Write Kernel
|
||||
*
|
||||
* Cache format (fp8_ds_mla - 656 bytes per token):
|
||||
* - First 512 bytes: quantized NoPE part (512 x fp8_e4m3)
|
||||
* - Next 16 bytes: scale factors (4 x float32, one per 128 fp8 values)
|
||||
* - Last 128 bytes: RoPE part (64 x bfloat16, not quantized)
|
||||
*
|
||||
* Thread organization:
|
||||
* - First 2 warps (64 threads): handle NoPE FP8 quantization
|
||||
* - Last 1 warp (32 threads): handle RoPE copy
|
||||
* - Total: 96 threads per block
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
__global__ void concat_and_cache_ds_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
uint8_t* __restrict__ kv_cache, // [num_blocks, block_size,
|
||||
// cache_entry_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, // stride per block in cache
|
||||
const int entry_stride, // stride per token entry in cache
|
||||
const int kv_c_stride, // stride for kv_c input
|
||||
const int k_pe_stride, // stride for k_pe input
|
||||
const int kv_lora_rank, // 512 for DS MLA
|
||||
const int pe_dim, // 64 for DS MLA
|
||||
const int block_size // number of tokens per cache block
|
||||
) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
const int64_t dst_idx_start =
|
||||
block_idx * block_stride + block_offset * entry_stride;
|
||||
|
||||
// Cast kv_cache to 16-bit for RoPE values
|
||||
scalar_t* kv_cache_16bit =
|
||||
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
|
||||
|
||||
// The last warp handles the RoPE part
|
||||
if (threadIdx.x >= 64) {
|
||||
// Each thread handles two elements of RoPE
|
||||
const int8_t pe_idx_start = (threadIdx.x - 64) * 2;
|
||||
const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start;
|
||||
|
||||
// Vectorized load of two 16-bit values, performed as one 32-bit load
|
||||
const int32_t vals = *reinterpret_cast<const int32_t*>(&k_pe[src_idx]);
|
||||
|
||||
// RoPE values start after the packed 8-bit NoPE values and the 32-bit
|
||||
// scales Position: kv_lora_rank/2 (256 bytes in 16-bit units) + 8 (16 bytes
|
||||
// of scales in 16-bit units)
|
||||
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start;
|
||||
|
||||
// Vectorized store of two 16-bit values
|
||||
*reinterpret_cast<int32_t*>(&kv_cache_16bit[dst_idx]) = vals;
|
||||
return;
|
||||
}
|
||||
|
||||
// The first two warps handle the NoPE part
|
||||
const int8_t warp_idx = threadIdx.x >> 5;
|
||||
const int8_t lane_idx = threadIdx.x & 31;
|
||||
const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4);
|
||||
|
||||
// Each thread handles 8 elements of NoPE
|
||||
const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8);
|
||||
|
||||
// Vectorized load of eight 16-bit values
|
||||
const int4 vals_i4 = *reinterpret_cast<const int4*>(&kv_c[src_idx_start]);
|
||||
const scalar_t* vals = reinterpret_cast<const scalar_t*>(&vals_i4);
|
||||
|
||||
// Max absolute value of this thread's elements
|
||||
float max_abs = fmaxf(fmaxf(fmaxf(fabsf(static_cast<float>(vals[0])),
|
||||
fabsf(static_cast<float>(vals[1]))),
|
||||
fmaxf(fabsf(static_cast<float>(vals[2])),
|
||||
fabsf(static_cast<float>(vals[3])))),
|
||||
fmaxf(fmaxf(fabsf(static_cast<float>(vals[4])),
|
||||
fabsf(static_cast<float>(vals[5]))),
|
||||
fmaxf(fabsf(static_cast<float>(vals[6])),
|
||||
fabsf(static_cast<float>(vals[7])))));
|
||||
|
||||
// Warp-level reduction to find the max absolute value in each half-warp
|
||||
#pragma unroll
|
||||
for (int offset = 8; offset > 0; offset /= 2) {
|
||||
max_abs = fmaxf(max_abs, __shfl_xor_sync(0xFFFF, max_abs, offset, 16));
|
||||
}
|
||||
|
||||
// Compute the scale for the tile
|
||||
float tile_scale = fmaxf(max_abs / kFp8ScaleDivisorDS, FLT_MIN);
|
||||
|
||||
// The first lane of each half-warp writes the scale to kv_cache
|
||||
if ((lane_idx == 0) || (lane_idx == 16)) {
|
||||
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
|
||||
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
|
||||
kv_cache_32bit[dst_idx] = tile_scale;
|
||||
}
|
||||
|
||||
// Now all threads in the block scale and write their elements
|
||||
const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8);
|
||||
|
||||
uint8_t result[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
result[i] = fp8_scaled_convert<uint8_t, scalar_t>(vals[i], tile_scale);
|
||||
}
|
||||
|
||||
// Store as aligned 64-bit writes
|
||||
*reinterpret_cast<uint64_t*>(&kv_cache[dst_idx_base]) =
|
||||
*reinterpret_cast<const uint64_t*>(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* Standard MLA Cache Write Kernel (non-FP8)
|
||||
*
|
||||
* For auto/bf16/fp16 cache types
|
||||
*/
|
||||
template <typename scalar_t, typename cache_t>
|
||||
__global__ void concat_and_cache_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride,
|
||||
const int entry_stride,
|
||||
const int kv_c_stride,
|
||||
const int k_pe_stride,
|
||||
const int kv_lora_rank,
|
||||
const int pe_dim,
|
||||
const int block_size,
|
||||
const float* scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
// Copy kv_c (NoPE part)
|
||||
for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) {
|
||||
const int64_t src_idx = token_idx * kv_c_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i;
|
||||
kv_cache[dst_idx] = static_cast<cache_t>(kv_c[src_idx]);
|
||||
}
|
||||
|
||||
// Copy k_pe (RoPE part)
|
||||
for (int i = threadIdx.x; i < pe_dim; i += blockDim.x) {
|
||||
const int64_t src_idx = token_idx * k_pe_stride + i;
|
||||
const int64_t dst_idx = block_idx * block_stride +
|
||||
block_offset * entry_stride + kv_lora_rank + i;
|
||||
kv_cache[dst_idx] = static_cast<cache_t>(k_pe[src_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Indexer K Quantization and Cache Kernel
|
||||
*
|
||||
* Quantizes K values to FP8 and stores them in cache with scale factors
|
||||
* Cache layout: [quantized_k (head_dim bytes)] + [scales
|
||||
* (head_dim/quant_block_size * 4 bytes)]
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
__global__ void indexer_k_quant_and_cache_kernel(
|
||||
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
|
||||
uint8_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int head_dim,
|
||||
const int quant_block_size, // typically 128
|
||||
const int cache_block_size,
|
||||
const int cache_stride,
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
|
||||
threadIdx.y * blockDim.x + threadIdx.x) *
|
||||
VEC_SIZE;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
const int64_t block_idx = slot_idx / cache_block_size;
|
||||
const int64_t block_offset = slot_idx % cache_block_size;
|
||||
|
||||
if (slot_idx < 0 || head_dim_idx >= head_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Load 4 values at once using float2 (for bf16/fp16)
|
||||
float2 k_val = reinterpret_cast<const float2*>(
|
||||
k)[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
|
||||
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
|
||||
|
||||
float amax = 0.0f;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
amax = fmaxf(amax, fabsf(static_cast<float>(k_val_ptr[i])));
|
||||
}
|
||||
|
||||
// Warp reduction to find max across quant_block_size elements
|
||||
for (int mask = 16; mask > 0; mask /= 2) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask));
|
||||
}
|
||||
|
||||
float scale = fmaxf(amax, 1e-4f) / kFp8ScaleDivisorDS;
|
||||
|
||||
if (use_ue8m0) {
|
||||
scale = exp2f(ceilf(log2f(scale)));
|
||||
}
|
||||
|
||||
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
|
||||
block_offset * head_dim + head_dim_idx;
|
||||
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
kv_cache[dst_offset + i] =
|
||||
fp8_scaled_convert<uint8_t, scalar_t>(k_val_ptr[i], scale);
|
||||
}
|
||||
|
||||
// First thread in warp writes the scale
|
||||
if (threadIdx.x == 0) {
|
||||
const int64_t dst_scale_idx =
|
||||
block_idx * cache_block_size * cache_stride +
|
||||
cache_block_size * head_dim +
|
||||
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
|
||||
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gather Indexer K from Quantized Cache Kernel
|
||||
*
|
||||
* Gathers and dequantizes K values from the cache
|
||||
*/
|
||||
template <int BLOCK_Y_SIZE>
|
||||
__global__ void cp_gather_indexer_k_quant_cache_kernel(
|
||||
const char* __restrict__ kv_cache, // [num_blocks, block_size,
|
||||
// cache_stride]
|
||||
char* __restrict__ dst_k, // [num_tokens, head_dim]
|
||||
char* __restrict__ dst_scale, // [num_tokens, head_dim/quant_block_size*4]
|
||||
const int* __restrict__ block_table, // [batch_size, num_blocks]
|
||||
const int* __restrict__ cu_seq_lens, // [batch_size + 1]
|
||||
const int batch_size,
|
||||
const int64_t token_stride,
|
||||
const int64_t head_dim,
|
||||
const int64_t block_stride,
|
||||
const int64_t cache_token_stride,
|
||||
const int64_t cache_block_size,
|
||||
const int num_blocks,
|
||||
const int num_tokens,
|
||||
const int quant_block_size) {
|
||||
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
|
||||
const int token_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
|
||||
|
||||
// Find batch index within a block
|
||||
__shared__ int batch_idx[BLOCK_Y_SIZE];
|
||||
for (int iter = 0; iter < (batch_size + blockDim.x - 1) / blockDim.x;
|
||||
iter++) {
|
||||
int tid = iter * blockDim.x + threadIdx.x;
|
||||
if (tid < batch_size) {
|
||||
const int seq_start = cu_seq_lens[tid];
|
||||
const int seq_end = cu_seq_lens[tid + 1];
|
||||
if (token_idx >= seq_start && token_idx < seq_end) {
|
||||
batch_idx[threadIdx.y] = tid;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp();
|
||||
|
||||
if (head_idx >= head_dim || token_idx >= num_tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
|
||||
const int block_id = block_table[batch_idx[threadIdx.y] * num_blocks +
|
||||
inbatch_seq_idx / cache_block_size];
|
||||
const int64_t src_block_offset = block_id * block_stride;
|
||||
const int64_t cache_inblock_offset =
|
||||
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
|
||||
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
|
||||
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
|
||||
|
||||
reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
|
||||
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int64_t src_scale_offset =
|
||||
src_block_offset + cache_block_size * head_dim +
|
||||
cache_inblock_offset * 4 / quant_block_size;
|
||||
reinterpret_cast<float*>(dst_scale)[dst_inblock_offset / quant_block_size] =
|
||||
reinterpret_cast<const float*>(kv_cache)[src_scale_offset / 4];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prefill DS MLA Write Cache Kernel
|
||||
*
|
||||
* Writes prefill KV data to DS MLA cache format
|
||||
*/
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void prefill_ds_mla_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [num_tokens, kv_num_heads * nope_size]
|
||||
const T* __restrict__ kv_pe, // [num_tokens, kv_num_heads * pe_size]
|
||||
uint8_t* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// entry_size]
|
||||
const int* __restrict__ block_tables,
|
||||
const int* __restrict__ batch_id_per_token,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ seq_lens_decoder,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size, // 512 for DS MLA
|
||||
const int pe_size, // 64 for DS MLA
|
||||
const int block_size,
|
||||
const int entry_size, // 656 for DS MLA FP8
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const uint32_t token_idx = linear_index / hidden_size;
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
|
||||
const uint32_t ori_seq_id =
|
||||
(token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
|
||||
const uint32_t block_offset = ori_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) {
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
|
||||
// For DS MLA FP8, NoPE part goes to first 512 bytes
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * entry_size +
|
||||
hi * block_size * entry_size + block_offset * entry_size + h_bias;
|
||||
const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias;
|
||||
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
|
||||
// Convert to FP8 and store
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
float val = static_cast<float>(src_vec.val[i]);
|
||||
val = fminf(fmaxf(val, -448.0f), 448.0f);
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
||||
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
|
||||
kv_cache[tgt_idx + i] = *reinterpret_cast<uint8_t*>(&fp8_val);
|
||||
#endif
|
||||
}
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
|
||||
// RoPE part goes after NoPE (512 bytes) + scales (16 bytes)
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * entry_size +
|
||||
hi * block_size * entry_size + block_offset * entry_size + nope_size +
|
||||
16 + h_bias * 2; // *2 for bf16
|
||||
const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias;
|
||||
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
|
||||
// Copy RoPE without quantization (as bf16/fp16)
|
||||
T* tgt_ptr = reinterpret_cast<T*>(&kv_cache[tgt_idx]);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
tgt_ptr[i] = src_vec.val[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode DS MLA Write Cache Kernel
|
||||
*/
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void decode_ds_mla_cache_kernel(
|
||||
const T* __restrict__ kv_nope,
|
||||
const T* __restrict__ kv_pe,
|
||||
uint8_t* __restrict__ kv_cache,
|
||||
const int* __restrict__ block_tables,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ seq_lens_encoder,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const int entry_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) {
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * entry_size +
|
||||
hi * block_size * entry_size + block_offset * entry_size + h_bias;
|
||||
const uint32_t ori_idx = start_token_idx * nope_hidden_size + inner_bias;
|
||||
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
float val = static_cast<float>(src_vec.val[i]);
|
||||
val = fminf(fmaxf(val, -448.0f), 448.0f);
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
|
||||
__nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val);
|
||||
kv_cache[tgt_idx + i] = *reinterpret_cast<uint8_t*>(&fp8_val);
|
||||
#endif
|
||||
}
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
|
||||
const uint32_t tgt_idx =
|
||||
block_idx * kv_num_heads * block_size * entry_size +
|
||||
hi * block_size * entry_size + block_offset * entry_size + nope_size +
|
||||
16 + h_bias * 2;
|
||||
const uint32_t ori_idx = start_token_idx * pe_hidden_size + inner_bias;
|
||||
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
|
||||
T* tgt_ptr = reinterpret_cast<T*>(&kv_cache[tgt_idx]);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
tgt_ptr[i] = src_vec.val[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ds_mla
|
||||
@@ -1133,6 +1133,47 @@ std::vector<paddle::Tensor> get_attn_mask_q(
|
||||
const paddle::optional<paddle::Tensor>& attn_mask_kv,
|
||||
const int kv_token_num);
|
||||
|
||||
void RadixTopkRaggedTransform(
|
||||
paddle::Tensor& input,
|
||||
paddle::Tensor& output_indices,
|
||||
const paddle::Tensor& offsets,
|
||||
paddle::Tensor& lengths,
|
||||
paddle::optional<paddle::Tensor>& seq_len_decoder,
|
||||
paddle::optional<paddle::Tensor>& batch_id_per_token,
|
||||
paddle::optional<paddle::Tensor>& maybe_row_states_buffer,
|
||||
int top_k,
|
||||
int q_num_heads = 0);
|
||||
|
||||
std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& scale,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool is_prefill);
|
||||
|
||||
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
|
||||
const paddle::Tensor& k,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const int64_t quant_block_size,
|
||||
const std::string& scale_fmt);
|
||||
|
||||
std::vector<paddle::Tensor> CpGatherIndexerKQuantCacheKernel(
|
||||
const paddle::Tensor& kv_cache,
|
||||
paddle::Tensor& dst_k,
|
||||
paddle::Tensor& dst_scale,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seq_lens);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("get_expert_token_num",
|
||||
&GetExpertTokenNum,
|
||||
@@ -1736,4 +1777,18 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("custom_numpy_to_tensor",
|
||||
&CustomNumpyToTensor,
|
||||
"custom_numpy_to_tensor function");
|
||||
|
||||
m.def("radix_topk_ragged_transform",
|
||||
&RadixTopkRaggedTransform,
|
||||
"radix_topk_ragged_transform function");
|
||||
|
||||
m.def("dsk_attn_write_cache", &DSMLAWriteCacheKernel, "dsk_attn_write_cache");
|
||||
|
||||
m.def("indexer_k_quant_and_cache",
|
||||
&IndexerKQuantAndCacheKernel,
|
||||
"indexer_k_quant_and_cache");
|
||||
|
||||
m.def("cp_gather_indexer_k_quant_cache",
|
||||
&CpGatherIndexerKQuantCacheKernel,
|
||||
"cp_gather_indexer_k_quant_cache");
|
||||
}
|
||||
|
||||
@@ -662,7 +662,8 @@ inline const char *getEnvVar(const char *varName) {
|
||||
|
||||
inline bool checkAttentionBackend() {
|
||||
const char *backend = getEnvVar("FD_ATTENTION_BACKEND");
|
||||
if (backend && std::strcmp(backend, "MLA_ATTN") == 0) {
|
||||
if (backend && (std::strcmp(backend, "MLA_ATTN") == 0 ||
|
||||
std::strcmp(backend, "DSA_ATTN") == 0)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
/*
|
||||
* Copyright (c) 2024 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef FLASHINFER_EXCEPTION_H_
|
||||
#define FLASHINFER_EXCEPTION_H_
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#define FLASHINFER_ERROR(message) \
|
||||
throw flashinfer::Error(__FUNCTION__, __FILE__, __LINE__, message)
|
||||
|
||||
// Base case for empty arguments
|
||||
inline void write_to_stream(std::ostringstream& oss) {
|
||||
// No-op for empty arguments
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void write_to_stream(std::ostringstream& oss, T&& val) {
|
||||
oss << std::forward<T>(val);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void write_to_stream(std::ostringstream& oss, T&& val, Args&&... args) {
|
||||
oss << std::forward<T>(val) << " ";
|
||||
write_to_stream(oss, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Helper macro to handle empty __VA_ARGS__
|
||||
#define FLASHINFER_CHECK_IMPL(condition, message) \
|
||||
if (!(condition)) { \
|
||||
FLASHINFER_ERROR(message); \
|
||||
}
|
||||
|
||||
// Main macro that handles both cases
|
||||
#define FLASHINFER_CHECK(condition, ...) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
std::ostringstream oss; \
|
||||
write_to_stream(oss, ##__VA_ARGS__); \
|
||||
std::string msg = oss.str(); \
|
||||
if (msg.empty()) { \
|
||||
msg = "Check failed: " #condition; \
|
||||
} \
|
||||
FLASHINFER_ERROR(msg); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Warning macro
|
||||
#define FLASHINFER_WARN(...) \
|
||||
do { \
|
||||
std::ostringstream oss; \
|
||||
write_to_stream(oss, ##__VA_ARGS__); \
|
||||
std::string msg = oss.str(); \
|
||||
if (msg.empty()) { \
|
||||
msg = "Warning triggered"; \
|
||||
} \
|
||||
flashinfer::Warning(__FUNCTION__, __FILE__, __LINE__, msg).emit(); \
|
||||
} while (0)
|
||||
|
||||
namespace flashinfer {
|
||||
class Error : public std::exception {
|
||||
private:
|
||||
std::string message_;
|
||||
|
||||
public:
|
||||
Error(const std::string& func,
|
||||
const std::string& file,
|
||||
int line,
|
||||
const std::string& message) {
|
||||
std::ostringstream oss;
|
||||
oss << "Error in function '" << func << "' "
|
||||
<< "at " << file << ":" << line << ": " << message;
|
||||
message_ = oss.str();
|
||||
}
|
||||
|
||||
virtual const char* what() const noexcept override {
|
||||
return message_.c_str();
|
||||
}
|
||||
};
|
||||
|
||||
class Warning {
|
||||
private:
|
||||
std::string message_;
|
||||
|
||||
public:
|
||||
Warning(const std::string& func,
|
||||
const std::string& file,
|
||||
int line,
|
||||
const std::string& message) {
|
||||
std::ostringstream oss;
|
||||
oss << "Warning in function '" << func << "' "
|
||||
<< "at " << file << ":" << line << ": " << message;
|
||||
message_ = oss.str();
|
||||
}
|
||||
|
||||
void emit() const { std::cerr << message_ << std::endl; }
|
||||
};
|
||||
|
||||
} // namespace flashinfer
|
||||
|
||||
#endif // FLASHINFER_EXCEPTION_H_
|
||||
@@ -0,0 +1,144 @@
|
||||
|
||||
#include "indexer_topk.cuh"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#include "paddle/phi/api/ext/op_meta_info.h"
|
||||
#include "paddle/utils/optional.h"
|
||||
|
||||
#include "append_attn/mem_util.cuh"
|
||||
#include "append_attn/mma_tensor_op.cuh"
|
||||
#include "append_attn/utils.cuh"
|
||||
#include "helper.h"
|
||||
|
||||
// using namespace flashinfer;
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
template <paddle::DataType T>
|
||||
cudaError_t DispatchTopK(paddle::Tensor& input,
|
||||
paddle::Tensor& output_indices,
|
||||
const paddle::Tensor& offsets,
|
||||
paddle::Tensor& lengths,
|
||||
uint32_t num_rows,
|
||||
const int32_t* seq_len_decoder,
|
||||
const int32_t* batch_id_per_token,
|
||||
uint32_t top_k,
|
||||
uint32_t q_num_heads,
|
||||
uint32_t max_len,
|
||||
flashinfer::sampling::RadixRowState* row_states_ptr,
|
||||
cudaStream_t stream) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
cudaError_t status;
|
||||
status =
|
||||
flashinfer::sampling::TopKRaggedTransformDispatch<DataType_, int32_t>(
|
||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||
static_cast<int32_t*>(output_indices.data<int32_t>()),
|
||||
static_cast<const int32_t*>(offsets.data<int32_t>()),
|
||||
static_cast<int32_t*>(lengths.data<int32_t>()),
|
||||
num_rows,
|
||||
seq_len_decoder,
|
||||
batch_id_per_token,
|
||||
static_cast<uint32_t>(top_k),
|
||||
static_cast<uint32_t>(q_num_heads),
|
||||
max_len,
|
||||
row_states_ptr,
|
||||
stream);
|
||||
return status;
|
||||
}
|
||||
|
||||
void RadixTopkRaggedTransform(
|
||||
paddle::Tensor& input,
|
||||
paddle::Tensor& output_indices,
|
||||
const paddle::Tensor& offsets,
|
||||
paddle::Tensor& lengths,
|
||||
paddle::optional<paddle::Tensor>& seq_len_decoder,
|
||||
paddle::optional<paddle::Tensor>& batch_id_per_token,
|
||||
paddle::optional<paddle::Tensor>& maybe_row_states_buffer,
|
||||
int top_k,
|
||||
int q_num_heads = 0) {
|
||||
// CHECK_INPUT(input);
|
||||
// CHECK_INPUT(output_indices);
|
||||
// CHECK_INPUT(offsets);
|
||||
// CHECK_INPUT(lengths);
|
||||
// CHECK_DIM(2, input); // input: (num_rows, max_len)
|
||||
// CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k)
|
||||
// CHECK_DIM(1, offsets); // offsets: (num_rows,)
|
||||
// CHECK_DIM(1, lengths); // lengths: (num_rows,)
|
||||
|
||||
unsigned int num_rows = input.dims()[0];
|
||||
unsigned int max_len = input.dims()[1];
|
||||
|
||||
static cudaStream_t stream = input.stream();
|
||||
cudaError_t status;
|
||||
auto input_dtype = input.dtype();
|
||||
|
||||
// sampling::RadixRowState* row_states_ptr = nullptr;
|
||||
// if (maybe_row_states_buffer.has_value()) {
|
||||
// row_states_ptr =
|
||||
// static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
|
||||
// }
|
||||
flashinfer::sampling::RadixRowState* row_states_ptr = nullptr;
|
||||
if (maybe_row_states_buffer) {
|
||||
auto& tensor_ptr = maybe_row_states_buffer.get();
|
||||
row_states_ptr = reinterpret_cast<flashinfer::sampling::RadixRowState*>(
|
||||
tensor_ptr.data<uint8_t>());
|
||||
}
|
||||
|
||||
const int32_t* seq_len_ptr = nullptr;
|
||||
if (seq_len_decoder) {
|
||||
auto& tensor_ptr = seq_len_decoder.get();
|
||||
seq_len_ptr = static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
|
||||
}
|
||||
const int32_t* batch_id_per_token_ptr = nullptr;
|
||||
if (batch_id_per_token) {
|
||||
auto& tensor_ptr = batch_id_per_token.get();
|
||||
batch_id_per_token_ptr =
|
||||
static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
|
||||
}
|
||||
|
||||
if (input_dtype == paddle::DataType::BFLOAT16) {
|
||||
status = DispatchTopK<paddle::DataType::BFLOAT16>(input,
|
||||
output_indices,
|
||||
offsets,
|
||||
lengths,
|
||||
num_rows,
|
||||
seq_len_ptr,
|
||||
batch_id_per_token_ptr,
|
||||
top_k,
|
||||
q_num_heads,
|
||||
max_len,
|
||||
row_states_ptr,
|
||||
stream);
|
||||
} else if (input_dtype == paddle::DataType::FLOAT32) {
|
||||
status = DispatchTopK<paddle::DataType::FLOAT32>(input,
|
||||
output_indices,
|
||||
offsets,
|
||||
lengths,
|
||||
num_rows,
|
||||
seq_len_ptr,
|
||||
batch_id_per_token_ptr,
|
||||
top_k,
|
||||
q_num_heads,
|
||||
max_len,
|
||||
row_states_ptr,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(radix_topk_ragged_transform)
|
||||
.Inputs({"input",
|
||||
"output_indices",
|
||||
"offsets",
|
||||
"lengths",
|
||||
paddle::Optional("seq_len_decoder"),
|
||||
paddle::Optional("batch_id_per_token"),
|
||||
paddle::Optional("maybe_row_states_buffer")})
|
||||
.Attrs({"top_k : int", "q_num_heads : int"})
|
||||
.SetKernelFn(PD_KERNEL(RadixTopkRaggedTransform));
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,534 @@
|
||||
/*
|
||||
* Copyright (c) 2023 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef FLASHINFER_UTILS_CUH_
|
||||
#define FLASHINFER_UTILS_CUH_
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "exception.h"
|
||||
|
||||
#define STR_HELPER(x) #x
|
||||
#define STR(x) STR_HELPER(x)
|
||||
|
||||
// macro to turn off fp16 qk reduction to reduce binary
|
||||
#ifndef FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION
|
||||
#define FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION 0
|
||||
#endif
|
||||
|
||||
#ifndef NDEBUG
|
||||
#define FLASHINFER_CUDA_CALL(func, ...) \
|
||||
{ \
|
||||
cudaError_t e = (func); \
|
||||
if (e != cudaSuccess) { \
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
|
||||
<< ") " << __FILE__ << ": line " << __LINE__ \
|
||||
<< " at function " << STR(func) << std::endl; \
|
||||
return e; \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define FLASHINFER_CUDA_CALL(func, ...) \
|
||||
{ \
|
||||
cudaError_t e = (func); \
|
||||
if (e != cudaSuccess) { \
|
||||
return e; \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define DISPATCH_USE_FP16_QK_REDUCTION( \
|
||||
use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, ...) \
|
||||
if (use_fp16_qk_reduction) { \
|
||||
FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \
|
||||
} else { \
|
||||
constexpr bool USE_FP16_QK_REDUCTION = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \
|
||||
if (num_mma_q == 1) { \
|
||||
constexpr size_t NUM_MMA_Q = 1; \
|
||||
__VA_ARGS__ \
|
||||
} else if (num_mma_q == 2) { \
|
||||
constexpr size_t NUM_MMA_Q = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported num_mma_q: " << num_mma_q; \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \
|
||||
if (max_mma_kv >= 8) { \
|
||||
constexpr size_t NUM_MMA_KV = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (max_mma_kv >= 4) { \
|
||||
constexpr size_t NUM_MMA_KV = 4; \
|
||||
__VA_ARGS__ \
|
||||
} else if (max_mma_kv >= 2) { \
|
||||
constexpr size_t NUM_MMA_KV = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else if (max_mma_kv >= 1) { \
|
||||
constexpr size_t NUM_MMA_KV = 1; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
}
|
||||
|
||||
#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \
|
||||
switch (cta_tile_q) { \
|
||||
case 128: { \
|
||||
constexpr uint32_t CTA_TILE_Q = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 64: { \
|
||||
constexpr uint32_t CTA_TILE_Q = 64; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 16: { \
|
||||
constexpr uint32_t CTA_TILE_Q = 16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 1) { \
|
||||
constexpr size_t GROUP_SIZE = 1; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 2) { \
|
||||
constexpr size_t GROUP_SIZE = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 3) { \
|
||||
constexpr size_t GROUP_SIZE = 3; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 4) { \
|
||||
constexpr size_t GROUP_SIZE = 4; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported group_size: " << group_size; \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \
|
||||
switch (mask_mode) { \
|
||||
case MaskMode::kNone: { \
|
||||
constexpr MaskMode MASK_MODE = MaskMode::kNone; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case MaskMode::kCausal: { \
|
||||
constexpr MaskMode MASK_MODE = MaskMode::kCausal; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case MaskMode::kCustom: { \
|
||||
constexpr MaskMode MASK_MODE = MaskMode::kCustom; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case MaskMode::kMultiItemScoring: { \
|
||||
constexpr MaskMode MASK_MODE = MaskMode::kMultiItemScoring; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported mask_mode: " << int(mask_mode); \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
// convert head_dim to compile-time constant
|
||||
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 64: { \
|
||||
constexpr size_t HEAD_DIM = 64; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 256: { \
|
||||
constexpr size_t HEAD_DIM = 256; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 512: { \
|
||||
constexpr size_t HEAD_DIM = 512; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported head_dim: " << head_dim; \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
// convert interleave to compile-time constant
|
||||
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
||||
if (interleave) { \
|
||||
constexpr bool INTERLEAVE = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool INTERLEAVE = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, ...) \
|
||||
switch (rope_dim) { \
|
||||
case 16: { \
|
||||
constexpr uint32_t ROPE_DIM = 16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 32: { \
|
||||
constexpr uint32_t ROPE_DIM = 32; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 64: { \
|
||||
constexpr uint32_t ROPE_DIM = 64; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 128: { \
|
||||
constexpr uint32_t ROPE_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 256: { \
|
||||
constexpr uint32_t ROPE_DIM = 256; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported ROPE_DIM: " << rope_dim; \
|
||||
err_msg << ". Supported values: 16, 32, 64, 128, 256"; \
|
||||
err_msg << " in DISPATCH_ROPE_DIM"; \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \
|
||||
switch (pos_encoding_mode) { \
|
||||
case PosEncodingMode::kNone: { \
|
||||
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case PosEncodingMode::kRoPELlama: { \
|
||||
constexpr PosEncodingMode POS_ENCODING_MODE = \
|
||||
PosEncodingMode::kRoPELlama; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case PosEncodingMode::kALiBi: { \
|
||||
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
|
||||
switch (aligned_vec_size) { \
|
||||
case 16: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 8: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 4: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 4; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 2; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 1: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 1; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
|
||||
FLASHINFER_ERROR(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM( \
|
||||
compute_capacity, NUM_STAGES_SMEM, ...) \
|
||||
if (compute_capacity.first >= 8) { \
|
||||
constexpr uint32_t NUM_STAGES_SMEM = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr uint32_t NUM_STAGES_SMEM = 1; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
namespace flashinfer {
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__forceinline__ __device__ __host__ constexpr T1 ceil_div(const T1 x,
|
||||
const T2 y) noexcept {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__forceinline__ __device__ __host__ constexpr T1 round_up(const T1 x,
|
||||
const T2 y) noexcept {
|
||||
return ceil_div(x, y) * y;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__forceinline__ __device__ __host__ constexpr T1 round_down(
|
||||
const T1 x, const T2 y) noexcept {
|
||||
return (x / y) * y;
|
||||
}
|
||||
|
||||
inline std::pair<int, int> GetCudaComputeCapability() {
|
||||
int device_id = 0;
|
||||
cudaGetDevice(&device_id);
|
||||
int major = 0, minor = 0;
|
||||
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id);
|
||||
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id);
|
||||
return std::make_pair(major, minor);
|
||||
}
|
||||
|
||||
// This function is thread-safe and cached the sm_count.
|
||||
// But it will only check the current CUDA device, thus assuming each process
|
||||
// handles single GPU.
|
||||
inline int GetCudaMultiProcessorCount() {
|
||||
static std::atomic<int> sm_count{0};
|
||||
int cached = sm_count.load(std::memory_order_relaxed);
|
||||
if (cached == 0) {
|
||||
int device_id;
|
||||
cudaGetDevice(&device_id);
|
||||
cudaDeviceProp device_prop;
|
||||
cudaGetDeviceProperties(&device_prop, device_id);
|
||||
cached = device_prop.multiProcessorCount;
|
||||
sm_count.store(cached, std::memory_order_relaxed);
|
||||
}
|
||||
return cached;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void DebugPrintCUDAArray(T* device_ptr,
|
||||
size_t size,
|
||||
std::string prefix = "") {
|
||||
std::vector<T> host_array(size);
|
||||
std::cout << prefix;
|
||||
cudaMemcpy(
|
||||
host_array.data(), device_ptr, size * sizeof(T), cudaMemcpyDeviceToHost);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
std::cout << host_array[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len,
|
||||
uint32_t head_dim) {
|
||||
if (avg_packed_qo_len > 64 && head_dim < 256) {
|
||||
return 128;
|
||||
} else {
|
||||
auto compute_capacity = GetCudaComputeCapability();
|
||||
if (compute_capacity.first >= 8) {
|
||||
// Ampere or newer
|
||||
if (avg_packed_qo_len > 16) {
|
||||
// avg_packed_qo_len <= 64
|
||||
return 64;
|
||||
} else {
|
||||
// avg_packed_qo_len <= 16
|
||||
return 16;
|
||||
}
|
||||
} else {
|
||||
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
|
||||
return 64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline int UpPowerOfTwo(int x) {
|
||||
// Returns the smallest power of two greater than or equal to x
|
||||
if (x <= 0) return 1;
|
||||
--x;
|
||||
x |= x >> 1;
|
||||
x |= x >> 2;
|
||||
x |= x >> 4;
|
||||
x |= x >> 8;
|
||||
x |= x >> 16;
|
||||
return x + 1;
|
||||
}
|
||||
|
||||
#define LOOP_SPLIT_MASK(iter, COND1, COND2, ...) \
|
||||
{ \
|
||||
_Pragma("unroll 1") for (; (COND1); (iter) -= 1) { \
|
||||
constexpr bool WITH_MASK = true; \
|
||||
__VA_ARGS__ \
|
||||
} \
|
||||
_Pragma("unroll 1") for (; (COND2); (iter) -= 1) { \
|
||||
constexpr bool WITH_MASK = false; \
|
||||
__VA_ARGS__ \
|
||||
} \
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Return x - y if x > y, otherwise return 0.
|
||||
*/
|
||||
__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
|
||||
uint32_t y) {
|
||||
return (x > y) ? x - y : 0U;
|
||||
}
|
||||
|
||||
// ======================= PTX Memory Utility Functions =======================
|
||||
// Non-atomic global memory access with cache streaming hint (cs)
|
||||
// These are useful for streaming memory access patterns where data is used once
|
||||
|
||||
/*!
|
||||
* \brief Get the lane ID within a warp (0-31)
|
||||
*/
|
||||
__forceinline__ __device__ int get_lane_id() {
|
||||
int lane_id;
|
||||
asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Non-atomic global load for int (4 bytes) with cache streaming hint
|
||||
*/
|
||||
__forceinline__ __device__ int ld_na_global_v1(const int* addr) {
|
||||
int val;
|
||||
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
|
||||
return val;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Non-atomic global load for int2 (8 bytes) with cache streaming hint
|
||||
*/
|
||||
__forceinline__ __device__ int2 ld_na_global_v2(const int2* addr) {
|
||||
int2 val;
|
||||
asm volatile("ld.global.cs.v2.b32 {%0, %1}, [%2];"
|
||||
: "=r"(val.x), "=r"(val.y)
|
||||
: "l"(addr));
|
||||
return val;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Non-atomic global store for int (4 bytes) with cache streaming hint
|
||||
*/
|
||||
__forceinline__ __device__ void st_na_global_v1(int* addr, int val) {
|
||||
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Non-atomic global store for int2 (8 bytes) with cache streaming hint
|
||||
*/
|
||||
__forceinline__ __device__ void st_na_global_v2(int2* addr, int2 val) {
|
||||
asm volatile("st.global.cs.v2.b32 [%0], {%1, %2};" ::"l"(addr),
|
||||
"r"(val.x),
|
||||
"r"(val.y));
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Prefetch data to L2 cache
|
||||
*/
|
||||
template <typename T>
|
||||
__forceinline__ __device__ void prefetch_L2(const T* addr) {
|
||||
asm volatile("prefetch.global.L2 [%0];" ::"l"(addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) {
|
||||
uint32_t tmp = a;
|
||||
a = b;
|
||||
b = tmp;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t dim2_offset(const uint32_t& dim_a,
|
||||
const uint32_t& idx_b,
|
||||
const uint32_t& idx_a) {
|
||||
return idx_b * dim_a + idx_a;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t dim3_offset(const uint32_t& dim_b,
|
||||
const uint32_t& dim_a,
|
||||
const uint32_t& idx_c,
|
||||
const uint32_t& idx_b,
|
||||
const uint32_t& idx_a) {
|
||||
return (idx_c * dim_b + idx_b) * dim_a + idx_a;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t dim4_offset(const uint32_t& dim_c,
|
||||
const uint32_t& dim_b,
|
||||
const uint32_t& dim_a,
|
||||
const uint32_t& idx_d,
|
||||
const uint32_t& idx_c,
|
||||
const uint32_t& idx_b,
|
||||
const uint32_t& idx_a) {
|
||||
return ((idx_d * dim_c + idx_c) * dim_b + idx_b) * dim_a + idx_a;
|
||||
}
|
||||
|
||||
#define DEFINE_HAS_MEMBER(member) \
|
||||
template <typename T, typename = void> \
|
||||
struct has_##member : std::false_type {}; \
|
||||
template <typename T> \
|
||||
struct has_##member<T, std::void_t<decltype(std::declval<T>().member)>> \
|
||||
: std::true_type {}; \
|
||||
template <typename T> \
|
||||
inline constexpr bool has_##member##_v = has_##member<T>::value;
|
||||
|
||||
} // namespace flashinfer
|
||||
|
||||
#endif // FLASHINFER_UTILS_CUH_
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user