mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Others]get_block_shape_and_split_kv_block clean code (#5123)
This commit is contained in:
@@ -79,7 +79,7 @@ __global__ void GetMaxLenKernel(const int *seq_lens_decoder,
|
||||
max_lens[2] = total_max_len_decoder;
|
||||
max_lens[3] = total;
|
||||
max_lens[4] = total_just_dec;
|
||||
max_lens[8] = total_max_len_kv;
|
||||
max_lens[5] = total_max_len_kv;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,8 +273,7 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
const int block_size) {
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
@@ -302,10 +301,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
int max_dec_len_this_time = max_len_cpu_ptr[2];
|
||||
int max_enc_dec_len_this_time = max_len_cpu_ptr[3];
|
||||
int max_just_dec_len_this_time = max_len_cpu_ptr[4];
|
||||
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
int max_kv_len_this_time = max_len_cpu_ptr[8];
|
||||
int max_kv_len_this_time = max_len_cpu_ptr[5];
|
||||
|
||||
const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0];
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
@@ -343,25 +341,15 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where *
|
||||
// 1024 can cover 128K. const uint32_t decoder_batch_shape =
|
||||
// seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
decoder_batch_ele_num * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
decoder_batch_ele_num * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
@@ -374,22 +362,15 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value
|
||||
// should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
decoder_batch_ele_num * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
decoder_batch_ele_num * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
@@ -413,13 +394,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
@@ -486,8 +460,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
const int block_size) {
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -498,8 +471,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
const int block_size) {
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -527,8 +499,7 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Attrs({"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
"decoder_step_token_num: int"})
|
||||
"block_size: int"})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
||||
|
||||
Reference in New Issue
Block a user