Files
FastDeploy/custom_ops/iluvatar_ops/mixed_fused_attn.cu
T

481 lines
18 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "iluvatar_context.h"
template <paddle::DataType T>
void MixedFusedPagedAttnKernel(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& prefill_block_table,
const paddle::Tensor& decode_block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& seq_lens,
const paddle::Tensor& prefill_rope_sin,
const paddle::Tensor& prefill_rope_cos,
const paddle::optional<paddle::Tensor>& decode_rope_sin,
const paddle::optional<paddle::Tensor>& decode_rope_cos,
int prefill_num_tokens,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
int rope_batch_stride,
bool is_interleaved_rope_mode,
paddle::Tensor& out) {
typedef PDTraits<T> traits_;
typedef typename traits_::data_t data_t;
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(qkv.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
const auto& dtype = qkv.dtype();
cuinferDataType_t cuinfer_data_type;
cudaDataType_t cu_data_type;
if (dtype == paddle::DataType::FLOAT16) {
cuinfer_data_type = CUINFER_DATA_HALF;
cu_data_type = CUDA_R_16F;
} else {
cuinfer_data_type = CUINFER_DATA_BFLOAT16;
cu_data_type = CUDA_R_16BF;
}
const auto& qkv_dims = qkv.dims();
const auto& kv_cache_dims = k_cache.dims();
const auto& prefill_block_table_dims = prefill_block_table.dims();
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
int prefill_batch_size = prefill_block_table_dims[0];
int num_tokens = qkv_dims[0];
int decode_num_tokens = num_tokens - prefill_num_tokens;
int num_total_heads = num_heads + 2 * num_kv_heads;
int max_num_blocks_per_seq = prefill_block_table_dims[1];
int qkv_stride = qkv.strides()[0];
int num_blocks = kv_cache_dims[0];
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
int block_table_stride = prefill_block_table.strides()[0];
const float* prefill_rope_sin_ptr = prefill_rope_sin.data<float>();
const float* prefill_rope_cos_ptr = prefill_rope_cos.data<float>();
const auto& prefill_rope_dims = prefill_rope_sin.dims();
std::vector<int> prefill_rope_shape_vec, prefill_rope_stride_vec;
int prefill_rope_ndim;
if (prefill_rope_dims.size() == 4) {
// [prefill_batch_size, max_seq_len, 1, head_dim]
PADDLE_ENFORCE_EQ(
prefill_rope_dims[0],
prefill_batch_size,
common::errors::InvalidArgument(
"prefill_rope_dims[0] must be equal to prefill_batch_size"));
prefill_rope_shape_vec =
std::vector<int>({prefill_batch_size, max_seq_len, head_dim});
prefill_rope_stride_vec =
std::vector<int>({max_seq_len * head_dim, head_dim, 1});
prefill_rope_ndim = 3;
} else if (prefill_rope_dims.size() == 3) {
// [max_seq_len, 1, head_dim]
prefill_rope_shape_vec = std::vector<int>({max_seq_len, head_dim});
prefill_rope_stride_vec = std::vector<int>({head_dim, 1});
prefill_rope_ndim = 2;
} else {
PD_THROW("Unsupported prefill_rope_ndim = %d for Paged attn",
prefill_rope_ndim);
}
const float* decode_rope_sin_ptr =
decode_rope_sin ? decode_rope_sin.get().data<float>() : nullptr;
const float* decode_rope_cos_ptr =
decode_rope_cos ? decode_rope_cos.get().data<float>() : nullptr;
cuinferAttentionRopeMode_t rope_mode =
is_interleaved_rope_mode ? CUINFER_ATTEN_NORMAL : CUINFER_ATTEN_OCRV1;
cuinferTensorDescriptor_t qkv_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
qkv_desc,
cuinfer_data_type,
3,
std::vector<int>({prefill_num_tokens, num_total_heads, head_dim}).data(),
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
cuinferTensorDescriptor_t qkv_seqlens_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
qkv_seqlens_desc,
CUINFER_DATA_INT32,
1,
std::vector<int>({prefill_batch_size + 1}).data(),
std::vector<int>({1}).data()));
cuinferTensorDescriptor_t block_table_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
block_table_desc,
CUINFER_DATA_INT32,
2,
std::vector<int>({prefill_batch_size, block_table_stride}).data(),
std::vector<int>({block_table_stride, 1}).data()));
cuinferTensorDescriptor_t o_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
o_desc,
cuinfer_data_type,
3,
std::vector<int>({prefill_num_tokens, num_heads, head_dim}).data(),
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
cuinferTensorDescriptor_t k_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
k_cache_desc,
cuinfer_data_type,
4,
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
std::vector<int>({num_kv_heads * block_size * head_dim,
block_size * head_dim,
head_dim,
1})
.data()));
cuinferTensorDescriptor_t v_cache_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
v_cache_desc,
cuinfer_data_type,
4,
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
std::vector<int>({num_kv_heads * block_size * head_dim,
block_size * head_dim,
head_dim,
1})
.data()));
cuinferTensorDescriptor_t cos_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(cos_desc,
CUINFER_DATA_FLOAT,
prefill_rope_ndim,
prefill_rope_shape_vec.data(),
prefill_rope_stride_vec.data()));
cuinferTensorDescriptor_t sin_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(sin_desc,
CUINFER_DATA_FLOAT,
prefill_rope_ndim,
prefill_rope_shape_vec.data(),
prefill_rope_stride_vec.data()));
cuinferHandle_t cuinfer_handle =
iluvatar::getContextInstance()->getIxInferHandle();
CUINFER_CHECK(cuinferSetStream(cuinfer_handle, stream));
size_t prefill_workspace_size = 0;
CUINFER_CHECK(
cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens,
num_heads,
num_kv_heads,
head_dim,
q_rope,
k_rope,
v_rope,
cuinfer_data_type,
cuinfer_data_type,
cuinfer_data_type,
&prefill_workspace_size));
auto* allocator = paddle::GetAllocator(qkv.place());
phi::Allocator::AllocationPtr prefill_tmp_workspace =
allocator->Allocate(prefill_workspace_size);
void* prefill_workspace_ptr = prefill_tmp_workspace->ptr();
CUINFER_CHECK(
cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
qkv_desc,
qkv.data(),
qkv_seqlens_desc,
cu_seqlens_qkv.data<int32_t>(),
block_table_desc,
prefill_block_table.data<int32_t>(),
o_desc,
out.data(),
k_cache_desc,
k_cache.data(),
v_cache_desc,
v_cache.data(),
prefill_workspace_ptr,
prefill_workspace_size,
cos_desc,
prefill_rope_cos_ptr,
sin_desc,
prefill_rope_sin_ptr,
prefill_batch_size,
num_heads,
num_kv_heads,
head_dim,
causal,
scale,
q_rope,
k_rope,
v_rope,
rope_mode));
size_t decode_workspace_size = 0;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_seq_len,
&decode_workspace_size));
phi::Allocator::AllocationPtr decode_tmp_workspace =
allocator->Allocate(decode_workspace_size);
void* decode_workspace_ptr = decode_tmp_workspace->ptr();
void* decode_qkv_ptr =
(void*)(qkv.data<data_t>() + prefill_num_tokens * qkv_stride);
void* decode_out_ptr =
(void*)(out.data<data_t>() + prefill_num_tokens * out.strides()[0]);
PageAttentionWithKVCacheArguments args{static_cast<float>(scale),
1.0,
1.0,
static_cast<float>(softcap),
window_left,
window_right,
causal,
use_sqrt_alibi,
enable_cuda_graph,
false,
nullptr,
decode_qkv_ptr,
decode_qkv_ptr,
decode_workspace_ptr,
true,
decode_rope_sin_ptr,
decode_rope_cos_ptr,
nullptr,
nullptr,
nullptr,
nullptr,
1,
0,
0,
nullptr,
static_cast<size_t>(rope_batch_stride),
rope_mode};
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
decode_out_ptr,
cu_data_type,
decode_qkv_ptr,
cu_data_type,
decode_num_tokens,
num_heads,
num_kv_heads,
head_dim,
qkv_stride,
kv_block_stride,
kv_head_stride,
k_cache.data(),
cu_data_type,
v_cache.data(),
cu_data_type,
block_size,
max_num_blocks_per_seq,
max_seq_len,
decode_block_table.data<int32_t>(),
seq_lens.data<int32_t>(),
args));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
}
std::vector<paddle::Tensor> MixedFusedPagedAttn(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& prefill_block_table,
const paddle::Tensor& decode_block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& seq_lens,
const paddle::Tensor& prefill_rope_sin,
const paddle::Tensor& prefill_rope_cos,
const paddle::optional<paddle::Tensor>& decode_rope_sin,
const paddle::optional<paddle::Tensor>& decode_rope_cos,
int prefill_num_tokens,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
int rope_batch_stride,
bool is_interleaved_rope_mode) {
const auto dtype = qkv.dtype();
auto out =
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
switch (dtype) {
case paddle::DataType::BFLOAT16:
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(
qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
prefill_rope_sin,
prefill_rope_cos,
decode_rope_sin,
decode_rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
rope_batch_stride,
is_interleaved_rope_mode,
out);
break;
case paddle::DataType::FLOAT16:
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(
qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
prefill_rope_sin,
prefill_rope_cos,
decode_rope_sin,
decode_rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
rope_batch_stride,
is_interleaved_rope_mode,
out);
break;
default:
PD_THROW("Unsupported data type for mixed paged attn");
}
return {out};
}
std::vector<std::vector<int64_t>> MixedFusedPagedAttnInferShape(
const std::vector<int64_t>& qkv_shape, int num_heads, int head_dim) {
return {{qkv_shape[0], num_heads * head_dim}};
}
std::vector<paddle::DataType> MixedFusedPagedAttnInferDtype(
const paddle::DataType& qkv_dtype) {
return {qkv_dtype};
}
PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
.Inputs({"qkv",
"k_cache",
"v_cache",
"prefill_block_table",
"decode_block_table",
"cu_seqlens_qkv",
"seq_lens",
"prefill_rope_sin",
"prefill_rope_cos",
paddle::Optional("decode_rope_sin"),
paddle::Optional("decode_rope_cos")})
.Outputs({"out"})
.Attrs({"prefill_num_tokens:int",
"num_heads: int",
"head_dim:int",
"num_kv_heads:int",
"block_size:int",
"max_seq_len:int",
"scale:float",
"causal:bool",
"q_rope:bool",
"k_rope:bool",
"v_rope:bool",
"window_left:int",
"window_right:int",
"softcap:float",
"enable_cuda_graph:bool",
"use_sqrt_alibi:bool",
"rope_batch_stride:int",
"is_interleaved_rope_mode:bool"})
.SetKernelFn(PD_KERNEL(MixedFusedPagedAttn))
.SetInferShapeFn(PD_INFER_SHAPE(MixedFusedPagedAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MixedFusedPagedAttnInferDtype));