mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] Fix Async D2H copy bug & flash mash atten cache V out of bound bug (#7221)
This commit is contained in:
@@ -296,7 +296,7 @@ void GetBlockShapeAndSplitKVBlock(
|
|||||||
if (!phi::backends::gpu::IsCUDAGraphCapturing())
|
if (!phi::backends::gpu::IsCUDAGraphCapturing())
|
||||||
#endif
|
#endif
|
||||||
max_len_tensor_cpu.copy_(
|
max_len_tensor_cpu.copy_(
|
||||||
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
|
max_len_tensor_gpu, max_len_tensor_cpu.place(), true);
|
||||||
|
|
||||||
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
|
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
|
||||||
int max_len_this_time = max_len_cpu_ptr[0];
|
int max_len_this_time = max_len_cpu_ptr[0];
|
||||||
@@ -378,7 +378,7 @@ void GetBlockShapeAndSplitKVBlock(
|
|||||||
if (!phi::backends::gpu::IsCUDAGraphCapturing())
|
if (!phi::backends::gpu::IsCUDAGraphCapturing())
|
||||||
#endif
|
#endif
|
||||||
decoder_num_blocks_cpu.copy_(
|
decoder_num_blocks_cpu.copy_(
|
||||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// mla_backend not need run the following code.
|
// mla_backend not need run the following code.
|
||||||
@@ -409,7 +409,7 @@ void GetBlockShapeAndSplitKVBlock(
|
|||||||
block_size);
|
block_size);
|
||||||
|
|
||||||
kv_num_blocks_x_cpu.copy_(
|
kv_num_blocks_x_cpu.copy_(
|
||||||
kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
kv_num_blocks_x, kv_num_blocks_x_cpu.place(), true);
|
||||||
// Clear buffer
|
// Clear buffer
|
||||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||||
@@ -433,7 +433,7 @@ void GetBlockShapeAndSplitKVBlock(
|
|||||||
encoder_block_shape_q,
|
encoder_block_shape_q,
|
||||||
group_size);
|
group_size);
|
||||||
encoder_num_blocks_x_cpu.copy_(
|
encoder_num_blocks_x_cpu.copy_(
|
||||||
encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -87,9 +87,9 @@ std::vector<paddle::Tensor> PreCacheLenConcat(
|
|||||||
bsz,
|
bsz,
|
||||||
block_size);
|
block_size);
|
||||||
paddle::Tensor pre_cache_num_blocks_cpu =
|
paddle::Tensor pre_cache_num_blocks_cpu =
|
||||||
pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false);
|
pre_cache_num_blocks.copy_to(paddle::CPUPlace(), true);
|
||||||
paddle::Tensor kv_token_num_cpu =
|
paddle::Tensor kv_token_num_cpu =
|
||||||
kv_token_num.copy_to(paddle::CPUPlace(), false);
|
kv_token_num.copy_to(paddle::CPUPlace(), true);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
|
|||||||
@@ -490,6 +490,23 @@ struct CollectiveMainloopAttn {
|
|||||||
|
|
||||||
softmax.rescale_o(tOrO, scores_scale);
|
softmax.rescale_o(tOrO, scores_scale);
|
||||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||||
|
if (seq_len_k - n_block * kBlockN < kBlockN) {
|
||||||
|
int valid_k = seq_len_k - n_block * kBlockN;
|
||||||
|
auto sVt_this = sVt(_, _, smem_pipe_read_v.index());
|
||||||
|
constexpr int kHdLo = decltype(get<0, 0>(shape(sVt_this)))::value;
|
||||||
|
constexpr int kHdHi = decltype(get<0, 1>(shape(sVt_this)))::value;
|
||||||
|
if (thread_idx >= valid_k && thread_idx < kBlockN) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int hd_hi = 0; hd_hi < kHdHi; ++hd_hi) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int hd_lo = 0; hd_lo < kHdLo; ++hd_lo) {
|
||||||
|
sVt_this(make_coord(make_coord(hd_lo, hd_hi), thread_idx)) =
|
||||||
|
Element(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cutlass::arch::fence_view_async_shared();
|
||||||
|
}
|
||||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(
|
gemm</*zero_init=*/false, /*wg_wait=*/-1>(
|
||||||
tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||||
warp_scheduler_barrier_arrive();
|
warp_scheduler_barrier_arrive();
|
||||||
|
|||||||
Reference in New Issue
Block a user