mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
481 lines
18 KiB
Plaintext
481 lines
18 KiB
Plaintext
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "helper.h"
|
|
#include "iluvatar_context.h"
|
|
|
|
template <paddle::DataType T>
|
|
void MixedFusedPagedAttnKernel(
|
|
const paddle::Tensor& qkv,
|
|
paddle::Tensor& k_cache,
|
|
paddle::Tensor& v_cache,
|
|
const paddle::Tensor& prefill_block_table,
|
|
const paddle::Tensor& decode_block_table,
|
|
const paddle::Tensor& cu_seqlens_qkv,
|
|
const paddle::Tensor& seq_lens,
|
|
const paddle::Tensor& prefill_rope_sin,
|
|
const paddle::Tensor& prefill_rope_cos,
|
|
const paddle::optional<paddle::Tensor>& decode_rope_sin,
|
|
const paddle::optional<paddle::Tensor>& decode_rope_cos,
|
|
int prefill_num_tokens,
|
|
int num_heads,
|
|
int head_dim,
|
|
int num_kv_heads,
|
|
int block_size,
|
|
int max_seq_len,
|
|
float scale,
|
|
bool causal,
|
|
bool q_rope,
|
|
bool k_rope,
|
|
bool v_rope,
|
|
int window_left,
|
|
int window_right,
|
|
float softcap,
|
|
bool enable_cuda_graph,
|
|
bool use_sqrt_alibi,
|
|
int rope_batch_stride,
|
|
bool is_interleaved_rope_mode,
|
|
paddle::Tensor& out) {
|
|
typedef PDTraits<T> traits_;
|
|
typedef typename traits_::data_t data_t;
|
|
|
|
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
|
paddle::experimental::DeviceContextPool::Instance().Get(qkv.place()));
|
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
|
|
|
const auto& dtype = qkv.dtype();
|
|
cuinferDataType_t cuinfer_data_type;
|
|
cudaDataType_t cu_data_type;
|
|
if (dtype == paddle::DataType::FLOAT16) {
|
|
cuinfer_data_type = CUINFER_DATA_HALF;
|
|
cu_data_type = CUDA_R_16F;
|
|
} else {
|
|
cuinfer_data_type = CUINFER_DATA_BFLOAT16;
|
|
cu_data_type = CUDA_R_16BF;
|
|
}
|
|
|
|
const auto& qkv_dims = qkv.dims();
|
|
const auto& kv_cache_dims = k_cache.dims();
|
|
const auto& prefill_block_table_dims = prefill_block_table.dims();
|
|
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
|
|
|
|
int prefill_batch_size = prefill_block_table_dims[0];
|
|
int num_tokens = qkv_dims[0];
|
|
int decode_num_tokens = num_tokens - prefill_num_tokens;
|
|
int num_total_heads = num_heads + 2 * num_kv_heads;
|
|
int max_num_blocks_per_seq = prefill_block_table_dims[1];
|
|
int qkv_stride = qkv.strides()[0];
|
|
int num_blocks = kv_cache_dims[0];
|
|
|
|
int kv_block_stride = k_cache.strides()[0];
|
|
int kv_head_stride = k_cache.strides()[1];
|
|
int block_table_stride = prefill_block_table.strides()[0];
|
|
const float* prefill_rope_sin_ptr = prefill_rope_sin.data<float>();
|
|
const float* prefill_rope_cos_ptr = prefill_rope_cos.data<float>();
|
|
const auto& prefill_rope_dims = prefill_rope_sin.dims();
|
|
std::vector<int> prefill_rope_shape_vec, prefill_rope_stride_vec;
|
|
int prefill_rope_ndim;
|
|
if (prefill_rope_dims.size() == 4) {
|
|
// [prefill_batch_size, max_seq_len, 1, head_dim]
|
|
PADDLE_ENFORCE_EQ(
|
|
prefill_rope_dims[0],
|
|
prefill_batch_size,
|
|
common::errors::InvalidArgument(
|
|
"prefill_rope_dims[0] must be equal to prefill_batch_size"));
|
|
prefill_rope_shape_vec =
|
|
std::vector<int>({prefill_batch_size, max_seq_len, head_dim});
|
|
prefill_rope_stride_vec =
|
|
std::vector<int>({max_seq_len * head_dim, head_dim, 1});
|
|
prefill_rope_ndim = 3;
|
|
} else if (prefill_rope_dims.size() == 3) {
|
|
// [max_seq_len, 1, head_dim]
|
|
prefill_rope_shape_vec = std::vector<int>({max_seq_len, head_dim});
|
|
prefill_rope_stride_vec = std::vector<int>({head_dim, 1});
|
|
prefill_rope_ndim = 2;
|
|
} else {
|
|
PD_THROW("Unsupported prefill_rope_ndim = %d for Paged attn",
|
|
prefill_rope_ndim);
|
|
}
|
|
|
|
const float* decode_rope_sin_ptr =
|
|
decode_rope_sin ? decode_rope_sin.get().data<float>() : nullptr;
|
|
const float* decode_rope_cos_ptr =
|
|
decode_rope_cos ? decode_rope_cos.get().data<float>() : nullptr;
|
|
cuinferAttentionRopeMode_t rope_mode =
|
|
is_interleaved_rope_mode ? CUINFER_ATTEN_NORMAL : CUINFER_ATTEN_OCRV1;
|
|
|
|
cuinferTensorDescriptor_t qkv_desc;
|
|
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
|
|
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
|
qkv_desc,
|
|
cuinfer_data_type,
|
|
3,
|
|
std::vector<int>({prefill_num_tokens, num_total_heads, head_dim}).data(),
|
|
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
|
|
|
|
cuinferTensorDescriptor_t qkv_seqlens_desc;
|
|
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
|
|
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
|
qkv_seqlens_desc,
|
|
CUINFER_DATA_INT32,
|
|
1,
|
|
std::vector<int>({prefill_batch_size + 1}).data(),
|
|
std::vector<int>({1}).data()));
|
|
|
|
cuinferTensorDescriptor_t block_table_desc;
|
|
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
|
|
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
|
block_table_desc,
|
|
CUINFER_DATA_INT32,
|
|
2,
|
|
std::vector<int>({prefill_batch_size, block_table_stride}).data(),
|
|
std::vector<int>({block_table_stride, 1}).data()));
|
|
|
|
cuinferTensorDescriptor_t o_desc;
|
|
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
|
|
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
|
o_desc,
|
|
cuinfer_data_type,
|
|
3,
|
|
std::vector<int>({prefill_num_tokens, num_heads, head_dim}).data(),
|
|
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
|
|
|
|
cuinferTensorDescriptor_t k_cache_desc;
|
|
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
|
|
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
|
k_cache_desc,
|
|
cuinfer_data_type,
|
|
4,
|
|
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
|
std::vector<int>({num_kv_heads * block_size * head_dim,
|
|
block_size * head_dim,
|
|
head_dim,
|
|
1})
|
|
.data()));
|
|
|
|
cuinferTensorDescriptor_t v_cache_desc;
|
|
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
|
|
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
|
v_cache_desc,
|
|
cuinfer_data_type,
|
|
4,
|
|
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
|
std::vector<int>({num_kv_heads * block_size * head_dim,
|
|
block_size * head_dim,
|
|
head_dim,
|
|
1})
|
|
.data()));
|
|
|
|
cuinferTensorDescriptor_t cos_desc;
|
|
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
|
|
CUINFER_CHECK(cuinferSetTensorNdDescriptor(cos_desc,
|
|
CUINFER_DATA_FLOAT,
|
|
prefill_rope_ndim,
|
|
prefill_rope_shape_vec.data(),
|
|
prefill_rope_stride_vec.data()));
|
|
|
|
cuinferTensorDescriptor_t sin_desc;
|
|
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
|
|
CUINFER_CHECK(cuinferSetTensorNdDescriptor(sin_desc,
|
|
CUINFER_DATA_FLOAT,
|
|
prefill_rope_ndim,
|
|
prefill_rope_shape_vec.data(),
|
|
prefill_rope_stride_vec.data()));
|
|
|
|
cuinferHandle_t cuinfer_handle =
|
|
iluvatar::getContextInstance()->getIxInferHandle();
|
|
CUINFER_CHECK(cuinferSetStream(cuinfer_handle, stream));
|
|
|
|
size_t prefill_workspace_size = 0;
|
|
CUINFER_CHECK(
|
|
cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
q_rope,
|
|
k_rope,
|
|
v_rope,
|
|
cuinfer_data_type,
|
|
cuinfer_data_type,
|
|
cuinfer_data_type,
|
|
&prefill_workspace_size));
|
|
|
|
auto* allocator = paddle::GetAllocator(qkv.place());
|
|
|
|
phi::Allocator::AllocationPtr prefill_tmp_workspace =
|
|
allocator->Allocate(prefill_workspace_size);
|
|
void* prefill_workspace_ptr = prefill_tmp_workspace->ptr();
|
|
|
|
CUINFER_CHECK(
|
|
cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
|
|
qkv_desc,
|
|
qkv.data(),
|
|
qkv_seqlens_desc,
|
|
cu_seqlens_qkv.data<int32_t>(),
|
|
block_table_desc,
|
|
prefill_block_table.data<int32_t>(),
|
|
o_desc,
|
|
out.data(),
|
|
k_cache_desc,
|
|
k_cache.data(),
|
|
v_cache_desc,
|
|
v_cache.data(),
|
|
prefill_workspace_ptr,
|
|
prefill_workspace_size,
|
|
cos_desc,
|
|
prefill_rope_cos_ptr,
|
|
sin_desc,
|
|
prefill_rope_sin_ptr,
|
|
prefill_batch_size,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
causal,
|
|
scale,
|
|
q_rope,
|
|
k_rope,
|
|
v_rope,
|
|
rope_mode));
|
|
|
|
size_t decode_workspace_size = 0;
|
|
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
block_size,
|
|
max_seq_len,
|
|
&decode_workspace_size));
|
|
|
|
phi::Allocator::AllocationPtr decode_tmp_workspace =
|
|
allocator->Allocate(decode_workspace_size);
|
|
void* decode_workspace_ptr = decode_tmp_workspace->ptr();
|
|
|
|
void* decode_qkv_ptr =
|
|
(void*)(qkv.data<data_t>() + prefill_num_tokens * qkv_stride);
|
|
void* decode_out_ptr =
|
|
(void*)(out.data<data_t>() + prefill_num_tokens * out.strides()[0]);
|
|
|
|
PageAttentionWithKVCacheArguments args{static_cast<float>(scale),
|
|
1.0,
|
|
1.0,
|
|
static_cast<float>(softcap),
|
|
window_left,
|
|
window_right,
|
|
causal,
|
|
use_sqrt_alibi,
|
|
enable_cuda_graph,
|
|
false,
|
|
nullptr,
|
|
decode_qkv_ptr,
|
|
decode_qkv_ptr,
|
|
decode_workspace_ptr,
|
|
true,
|
|
decode_rope_sin_ptr,
|
|
decode_rope_cos_ptr,
|
|
nullptr,
|
|
nullptr,
|
|
nullptr,
|
|
nullptr,
|
|
1,
|
|
0,
|
|
0,
|
|
nullptr,
|
|
static_cast<size_t>(rope_batch_stride),
|
|
rope_mode};
|
|
|
|
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
|
decode_out_ptr,
|
|
cu_data_type,
|
|
decode_qkv_ptr,
|
|
cu_data_type,
|
|
decode_num_tokens,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
qkv_stride,
|
|
kv_block_stride,
|
|
kv_head_stride,
|
|
k_cache.data(),
|
|
cu_data_type,
|
|
v_cache.data(),
|
|
cu_data_type,
|
|
block_size,
|
|
max_num_blocks_per_seq,
|
|
max_seq_len,
|
|
decode_block_table.data<int32_t>(),
|
|
seq_lens.data<int32_t>(),
|
|
args));
|
|
|
|
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
|
|
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
|
|
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
|
|
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
|
|
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
|
|
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
|
|
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
|
|
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
|
|
}
|
|
|
|
std::vector<paddle::Tensor> MixedFusedPagedAttn(
|
|
const paddle::Tensor& qkv,
|
|
paddle::Tensor& k_cache,
|
|
paddle::Tensor& v_cache,
|
|
const paddle::Tensor& prefill_block_table,
|
|
const paddle::Tensor& decode_block_table,
|
|
const paddle::Tensor& cu_seqlens_qkv,
|
|
const paddle::Tensor& seq_lens,
|
|
const paddle::Tensor& prefill_rope_sin,
|
|
const paddle::Tensor& prefill_rope_cos,
|
|
const paddle::optional<paddle::Tensor>& decode_rope_sin,
|
|
const paddle::optional<paddle::Tensor>& decode_rope_cos,
|
|
int prefill_num_tokens,
|
|
int num_heads,
|
|
int head_dim,
|
|
int num_kv_heads,
|
|
int block_size,
|
|
int max_seq_len,
|
|
float scale,
|
|
bool causal,
|
|
bool q_rope,
|
|
bool k_rope,
|
|
bool v_rope,
|
|
int window_left,
|
|
int window_right,
|
|
float softcap,
|
|
bool enable_cuda_graph,
|
|
bool use_sqrt_alibi,
|
|
int rope_batch_stride,
|
|
bool is_interleaved_rope_mode) {
|
|
const auto dtype = qkv.dtype();
|
|
auto out =
|
|
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
|
|
|
|
switch (dtype) {
|
|
case paddle::DataType::BFLOAT16:
|
|
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(
|
|
qkv,
|
|
k_cache,
|
|
v_cache,
|
|
prefill_block_table,
|
|
decode_block_table,
|
|
cu_seqlens_qkv,
|
|
seq_lens,
|
|
prefill_rope_sin,
|
|
prefill_rope_cos,
|
|
decode_rope_sin,
|
|
decode_rope_cos,
|
|
prefill_num_tokens,
|
|
num_heads,
|
|
head_dim,
|
|
num_kv_heads,
|
|
block_size,
|
|
max_seq_len,
|
|
scale,
|
|
causal,
|
|
q_rope,
|
|
k_rope,
|
|
v_rope,
|
|
window_left,
|
|
window_right,
|
|
softcap,
|
|
enable_cuda_graph,
|
|
use_sqrt_alibi,
|
|
rope_batch_stride,
|
|
is_interleaved_rope_mode,
|
|
out);
|
|
break;
|
|
case paddle::DataType::FLOAT16:
|
|
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(
|
|
qkv,
|
|
k_cache,
|
|
v_cache,
|
|
prefill_block_table,
|
|
decode_block_table,
|
|
cu_seqlens_qkv,
|
|
seq_lens,
|
|
prefill_rope_sin,
|
|
prefill_rope_cos,
|
|
decode_rope_sin,
|
|
decode_rope_cos,
|
|
prefill_num_tokens,
|
|
num_heads,
|
|
head_dim,
|
|
num_kv_heads,
|
|
block_size,
|
|
max_seq_len,
|
|
scale,
|
|
causal,
|
|
q_rope,
|
|
k_rope,
|
|
v_rope,
|
|
window_left,
|
|
window_right,
|
|
softcap,
|
|
enable_cuda_graph,
|
|
use_sqrt_alibi,
|
|
rope_batch_stride,
|
|
is_interleaved_rope_mode,
|
|
out);
|
|
break;
|
|
default:
|
|
PD_THROW("Unsupported data type for mixed paged attn");
|
|
}
|
|
return {out};
|
|
}
|
|
|
|
std::vector<std::vector<int64_t>> MixedFusedPagedAttnInferShape(
|
|
const std::vector<int64_t>& qkv_shape, int num_heads, int head_dim) {
|
|
return {{qkv_shape[0], num_heads * head_dim}};
|
|
}
|
|
|
|
std::vector<paddle::DataType> MixedFusedPagedAttnInferDtype(
|
|
const paddle::DataType& qkv_dtype) {
|
|
return {qkv_dtype};
|
|
}
|
|
|
|
PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
|
|
.Inputs({"qkv",
|
|
"k_cache",
|
|
"v_cache",
|
|
"prefill_block_table",
|
|
"decode_block_table",
|
|
"cu_seqlens_qkv",
|
|
"seq_lens",
|
|
"prefill_rope_sin",
|
|
"prefill_rope_cos",
|
|
paddle::Optional("decode_rope_sin"),
|
|
paddle::Optional("decode_rope_cos")})
|
|
.Outputs({"out"})
|
|
.Attrs({"prefill_num_tokens:int",
|
|
"num_heads: int",
|
|
"head_dim:int",
|
|
"num_kv_heads:int",
|
|
"block_size:int",
|
|
"max_seq_len:int",
|
|
"scale:float",
|
|
"causal:bool",
|
|
"q_rope:bool",
|
|
"k_rope:bool",
|
|
"v_rope:bool",
|
|
"window_left:int",
|
|
"window_right:int",
|
|
"softcap:float",
|
|
"enable_cuda_graph:bool",
|
|
"use_sqrt_alibi:bool",
|
|
"rope_batch_stride:int",
|
|
"is_interleaved_rope_mode:bool"})
|
|
.SetKernelFn(PD_KERNEL(MixedFusedPagedAttn))
|
|
.SetInferShapeFn(PD_INFER_SHAPE(MixedFusedPagedAttnInferShape))
|
|
.SetInferDtypeFn(PD_INFER_DTYPE(MixedFusedPagedAttnInferDtype));
|