mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Add return value checks for all XPU kernel launches (#6666)
* [XPU] Add return value checks for all XPU kernel launches - Add -fxpu-launch-return compiler flag in CMakeLists.txt to enable kernel launch return values - Add KERNEL_ASSERT_SUCCESS(ctx, ret_xre) checks after every XPU kernel launch across 45 wrapper files (55 launch sites total) - Covers both main wrapper/ and mtp_wrapper/ directories - Properly handles multiple kernel launches in the same function scope by reusing the ret_xre variable * [XPU] code style fix
This commit is contained in:
@@ -360,7 +360,7 @@ foreach(xpu_wrapper IN LISTS xpu_wrappers)
|
||||
-I${CMAKE_CURRENT_SOURCE_DIR}/src/wrapper -D_GNU_SOURCE
|
||||
-D__STDC_LIMIT_MACROS -DNDEBUG ${TOOLCHAIN_ARGS} --target=${TARGET_ARCH}
|
||||
-fPIC -Wunused-variable -Werror -Wreorder -fvisibility=hidden
|
||||
--xpu-host-only ${HOST_XPU_FLAGS}
|
||||
-fxpu-launch-return --xpu-host-only ${HOST_XPU_FLAGS}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS wrapper_build/${wrapper_name}.wrapper.d
|
||||
COMMENT wrapper_build/${wrapper_name}.wrapper.o
|
||||
|
||||
@@ -98,16 +98,18 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
auto eb_adjust_batch_kernel =
|
||||
xpu3::plugin::eb_adjust_batch<XPU_INDEX_TYPE_TX, XPU_INDEX_TYPE_TY>;
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
eb_adjust_batch_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TX *>(const_cast<TX *>(x)),
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y),
|
||||
encoder_seqs_lods.xpu,
|
||||
decoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
int32_t ret_xre =
|
||||
eb_adjust_batch_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TX *>(const_cast<TX *>(x)),
|
||||
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y),
|
||||
encoder_seqs_lods.xpu,
|
||||
decoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -76,15 +76,17 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
int64_t hidden_dim) {
|
||||
auto eb_gather_next_token_kernel = xpu3::plugin::eb_gather_next_token<TX, TY>;
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
eb_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
const_cast<TX *>(x),
|
||||
y,
|
||||
encoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
int32_t ret_xre =
|
||||
eb_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
const_cast<TX *>(x),
|
||||
y,
|
||||
encoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -194,27 +194,29 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int max_decoder_block_num) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto free_and_dispatch_block_kernel = xpu3::plugin::free_and_dispatch_block;
|
||||
free_and_dispatch_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num);
|
||||
int32_t ret_xre =
|
||||
free_and_dispatch_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -118,21 +118,25 @@ static int xpu3_wrapper(Context *ctx,
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto get_padding_offset = xpu3::plugin::get_padding_offset;
|
||||
auto remove_padding = xpu3::plugin::remove_padding;
|
||||
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(padding_offset,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre =
|
||||
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
padding_offset,
|
||||
cum_offsets_out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cum_offsets,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bs);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
ret_xre = remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64 *>(x_remove_padding),
|
||||
reinterpret_cast<const XPU_INT64 *>(input_ids),
|
||||
seq_lens,
|
||||
cum_offsets_out,
|
||||
max_seq_len,
|
||||
bs);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -93,16 +93,18 @@ static int xpu3_wrapper(Context* ctx,
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto limit_thinking_content_length_kernel_v1 =
|
||||
xpu3::plugin::limit_thinking_content_length_kernel_v1;
|
||||
limit_thinking_content_length_kernel_v1<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64*>(next_tokens),
|
||||
max_think_lens,
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
bs,
|
||||
eos_token_id_len);
|
||||
int32_t ret_xre =
|
||||
limit_thinking_content_length_kernel_v1<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64*>(next_tokens),
|
||||
max_think_lens,
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
bs,
|
||||
eos_token_id_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -98,15 +98,17 @@ static int xpu3_wrapper(Context* ctx,
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto limit_thinking_content_length_kernel_v2 =
|
||||
xpu3::plugin::limit_thinking_content_length_kernel_v2;
|
||||
limit_thinking_content_length_kernel_v2<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64*>(next_tokens),
|
||||
max_think_lens,
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
bs);
|
||||
int32_t ret_xre =
|
||||
limit_thinking_content_length_kernel_v2<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64*>(next_tokens),
|
||||
max_think_lens,
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
bs);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
const int bsz,
|
||||
const int actual_draft_token_num,
|
||||
const int input_token_num) {
|
||||
xpu3::plugin::ComputeOrderKernel<<<1, 1, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre = xpu3::plugin::ComputeOrderKernel<<<1, 1, ctx->xpu_stream>>>(
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -119,6 +119,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
bsz,
|
||||
actual_draft_token_num,
|
||||
input_token_num);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -76,13 +76,15 @@ static int xpu3_wrapper(Context* ctx,
|
||||
int* output_token_num,
|
||||
int bsz) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::ComputeSelfOrderKernel<<<1, 1, ctx->xpu_stream>>>(
|
||||
last_seq_lens_this_time,
|
||||
seq_lens_this_time,
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
src_map,
|
||||
output_token_num,
|
||||
bsz);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::ComputeSelfOrderKernel<<<1, 1, ctx->xpu_stream>>>(
|
||||
last_seq_lens_this_time,
|
||||
seq_lens_this_time,
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
src_map,
|
||||
output_token_num,
|
||||
bsz);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
const bool* base_model_stop_flags,
|
||||
int bsz,
|
||||
int base_model_draft_token_len) {
|
||||
xpu3::plugin::
|
||||
int32_t ret_xre = xpu3::plugin::
|
||||
draft_model_postprocess<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const xpu3::int64_t*>(base_model_draft_tokens),
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -90,6 +90,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
base_model_stop_flags,
|
||||
bsz,
|
||||
base_model_draft_token_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -240,37 +240,39 @@ static int xpu3_wrapper(api::Context* ctx,
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
xpu3::plugin::draft_model_preprocess<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64*>(draft_tokens),
|
||||
reinterpret_cast<XPU_INT64*>(input_ids),
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
reinterpret_cast<XPU_INT64*>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(accept_tokens),
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
reinterpret_cast<const XPU_INT64*>(base_model_step_idx),
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::draft_model_preprocess<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64*>(draft_tokens),
|
||||
reinterpret_cast<XPU_INT64*>(input_ids),
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
reinterpret_cast<XPU_INT64*>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(accept_tokens),
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
reinterpret_cast<const XPU_INT64*>(base_model_step_idx),
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -178,28 +178,30 @@ static int xpu2or3_wrapper(Context* ctx,
|
||||
const bool prefill_one_step_stop) {
|
||||
ctx_guard RAII_GUARD(ctx);
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::draft_model_update<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64*>(inter_next_tokens),
|
||||
reinterpret_cast<XPU_INT64*>(draft_tokens),
|
||||
reinterpret_cast<XPU_INT64*>(pre_ids),
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
||||
output_cum_offsets,
|
||||
stop_flags,
|
||||
not_need_stop,
|
||||
reinterpret_cast<const XPU_INT64*>(max_dec_len),
|
||||
reinterpret_cast<const XPU_INT64*>(end_ids),
|
||||
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
|
||||
bsz,
|
||||
max_draft_token,
|
||||
pre_id_length,
|
||||
max_base_model_draft_token,
|
||||
end_ids_len,
|
||||
max_seq_len,
|
||||
substep,
|
||||
prefill_one_step_stop);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::draft_model_update<<<1, 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64*>(inter_next_tokens),
|
||||
reinterpret_cast<XPU_INT64*>(draft_tokens),
|
||||
reinterpret_cast<XPU_INT64*>(pre_ids),
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
||||
output_cum_offsets,
|
||||
stop_flags,
|
||||
not_need_stop,
|
||||
reinterpret_cast<const XPU_INT64*>(max_dec_len),
|
||||
reinterpret_cast<const XPU_INT64*>(end_ids),
|
||||
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
|
||||
bsz,
|
||||
max_draft_token,
|
||||
pre_id_length,
|
||||
max_base_model_draft_token,
|
||||
end_ids_len,
|
||||
max_seq_len,
|
||||
substep,
|
||||
prefill_one_step_stop);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
+12
-10
@@ -104,16 +104,18 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
auto eb_mtp_gather_next_token_kernel =
|
||||
xpu3::plugin::eb_mtp_gather_next_token<TX, TY>;
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
eb_mtp_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
const_cast<TX *>(x),
|
||||
y,
|
||||
encoder_seqs_lods.xpu,
|
||||
decoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
int32_t ret_xre =
|
||||
eb_mtp_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
const_cast<TX *>(x),
|
||||
y,
|
||||
encoder_seqs_lods.xpu,
|
||||
decoder_seqs_lods.xpu,
|
||||
encoder_batch_map.xpu,
|
||||
decoder_batch_map.xpu,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -104,7 +104,9 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
auto eb_recover_batch_sequence_kernel =
|
||||
xpu3::plugin::eb_recover_batch_sequence<TX, TY>;
|
||||
// NOTE: Don't change 16 to 64, because kernel use gsm
|
||||
eb_recover_batch_sequence_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre = eb_recover_batch_sequence_kernel<<<ctx->ncluster(),
|
||||
16,
|
||||
ctx->xpu_stream>>>(
|
||||
const_cast<TX *>(x),
|
||||
y,
|
||||
encoder_seqs_lods.xpu,
|
||||
@@ -114,6 +116,7 @@ static int xpu3_wrapper(api::Context *ctx,
|
||||
en_batch,
|
||||
de_batch,
|
||||
hidden_dim);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -154,7 +154,7 @@ static int xpu2or3_wrapper(Context *ctx,
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
auto mtp_free_and_dispatch_block = xpu3::plugin::mtp_free_and_dispatch_block;
|
||||
mtp_free_and_dispatch_block<<<12, 64, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre = mtp_free_and_dispatch_block<<<12, 64, ctx->xpu_stream>>>(
|
||||
base_model_stop_flags,
|
||||
stop_flags,
|
||||
batch_drop,
|
||||
@@ -169,6 +169,7 @@ static int xpu2or3_wrapper(Context *ctx,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_draft_tokens);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -57,9 +57,10 @@ static int xpu3_wrapper(Context* ctx,
|
||||
T* output,
|
||||
int dim_embed,
|
||||
int elem_cnt) {
|
||||
xpu3::plugin::rebuildHiddenStatesKernel<T>
|
||||
int32_t ret_xre = xpu3::plugin::rebuildHiddenStatesKernel<T>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
input, position_map, output, dim_embed, elem_cnt);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -52,9 +52,10 @@ static int xpu3_wrapper(Context* ctx,
|
||||
T* output,
|
||||
int dim_embed,
|
||||
int elem_cnt) {
|
||||
xpu3::plugin::rebuildSelfHiddenStatesKernel<T>
|
||||
int32_t ret_xre = xpu3::plugin::rebuildSelfHiddenStatesKernel<T>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
input, src_map, output, dim_embed, elem_cnt);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+18
-16
@@ -63,22 +63,24 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int num_extra_tokens) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_spec_decode_task = xpu3::plugin::recover_spec_decode_task;
|
||||
recover_spec_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
block_tables,
|
||||
is_block_step,
|
||||
reinterpret_cast<XPU_INT64 *>(draft_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_draft_tokens),
|
||||
step_seq_lens_this_time,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
draft_tokens_len,
|
||||
num_extra_tokens);
|
||||
int32_t ret_xre =
|
||||
recover_spec_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
block_tables,
|
||||
is_block_step,
|
||||
reinterpret_cast<XPU_INT64 *>(draft_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_draft_tokens),
|
||||
step_seq_lens_this_time,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
draft_tokens_len,
|
||||
num_extra_tokens);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+4
-2
@@ -42,8 +42,10 @@ static int xpu2or3_wrapper(Context* ctx,
|
||||
const int* seq_lens_decoder,
|
||||
const int max_bsz) {
|
||||
ctx_guard RAII_GUARD(ctx);
|
||||
xpu3::plugin::speculate_clear_accept_nums<<<1, 64, ctx->xpu_stream>>>(
|
||||
accept_num, seq_lens_decoder, max_bsz);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::speculate_clear_accept_nums<<<1, 64, ctx->xpu_stream>>>(
|
||||
accept_num, seq_lens_decoder, max_bsz);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
+27
-25
@@ -209,31 +209,33 @@ static int xpu3_wrapper(Context *ctx,
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto speculate_free_and_dispatch_block_kernel =
|
||||
xpu3::plugin::speculate_free_and_dispatch_block;
|
||||
speculate_free_and_dispatch_block_kernel<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids),
|
||||
accept_num,
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
int32_t ret_xre =
|
||||
speculate_free_and_dispatch_block_kernel<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids),
|
||||
accept_num,
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+24
-22
@@ -101,28 +101,30 @@ static int xpu3_wrapper(Context *ctx,
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto speculate_free_and_reschedule =
|
||||
xpu3::plugin::speculate_free_and_reschedule;
|
||||
speculate_free_and_reschedule<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
int32_t ret_xre =
|
||||
speculate_free_and_reschedule<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
encoder_block_lens,
|
||||
is_block_step,
|
||||
step_block_list,
|
||||
step_len,
|
||||
recover_block_list,
|
||||
recover_len,
|
||||
need_block_list,
|
||||
need_block_len,
|
||||
used_list_len,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
max_decoder_block_num,
|
||||
max_draft_tokens);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -96,18 +96,20 @@ static int xpu3_wrapper(Context* ctx,
|
||||
const int* seq_lens_encoder,
|
||||
const int real_bsz,
|
||||
const int vocab_size) {
|
||||
xpu3::plugin::speculate_get_logits<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
draft_logits,
|
||||
next_token_num,
|
||||
batch_token_num,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
logits,
|
||||
first_token_logits,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
real_bsz,
|
||||
vocab_size);
|
||||
int32_t ret_xre = xpu3::plugin::
|
||||
speculate_get_logits<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
draft_logits,
|
||||
next_token_num,
|
||||
batch_token_num,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
logits,
|
||||
first_token_logits,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
real_bsz,
|
||||
vocab_size);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+11
-9
@@ -61,15 +61,17 @@ static int xpu2or3_wrapper(Context* ctx,
|
||||
const int bsz,
|
||||
const int max_seq_len) {
|
||||
ctx_guard RAII_GUARD(ctx);
|
||||
xpu3::plugin::speculate_get_output_padding_offset<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
output_cum_offsets_tmp,
|
||||
seq_lens_output,
|
||||
bsz,
|
||||
max_seq_len);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::speculate_get_output_padding_offset<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
output_cum_offsets_tmp,
|
||||
seq_lens_output,
|
||||
bsz,
|
||||
max_seq_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+4
-2
@@ -112,7 +112,7 @@ static int xpu3_wrapper_remove_padding(Context* ctx,
|
||||
int bsz,
|
||||
int token_num_data) {
|
||||
using XPU_T = typename XPUIndexType<T>::type;
|
||||
xpu3::plugin::speculate_remove_padding<XPU_T>
|
||||
int32_t ret_xre = xpu3::plugin::speculate_remove_padding<XPU_T>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
static_cast<XPU_T*>(static_cast<void*>(output_data)),
|
||||
static_cast<const XPU_T*>(static_cast<const void*>(input_data)),
|
||||
@@ -124,6 +124,7 @@ static int xpu3_wrapper_remove_padding(Context* ctx,
|
||||
max_draft_tokens,
|
||||
bsz,
|
||||
token_num_data);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
|
||||
return api::SUCCESS;
|
||||
}
|
||||
@@ -137,7 +138,7 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len,
|
||||
int bsz) {
|
||||
xpu3::plugin::
|
||||
int32_t ret_xre = xpu3::plugin::
|
||||
speculate_get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
batch_id_per_token,
|
||||
cum_offsets_out,
|
||||
@@ -147,6 +148,7 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
bsz);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -58,13 +58,14 @@ static int xpu2or3_wrapper(Context* ctx,
|
||||
const int* seq_lens_decoder,
|
||||
const int real_bsz) {
|
||||
ctx_guard RAII_GUARD(ctx);
|
||||
xpu3::plugin::
|
||||
int32_t ret_xre = xpu3::plugin::
|
||||
speculate_get_seq_lens_output<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
seq_lens_output,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
real_bsz);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
+2
-1
@@ -78,7 +78,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
int dim_embed,
|
||||
int elem_nums,
|
||||
T* out) {
|
||||
xpu3::plugin::RebuildAppendPaddingKernel<T>
|
||||
int32_t ret_xre = xpu3::plugin::RebuildAppendPaddingKernel<T>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(full_hidden_states,
|
||||
cum_offsets,
|
||||
seq_len_encoder,
|
||||
@@ -88,6 +88,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
dim_embed,
|
||||
elem_nums,
|
||||
out);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -136,29 +136,31 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int pre_id_length) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_block_kernel = xpu3::plugin::speculate_recover_block;
|
||||
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
ori_seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
int32_t ret_xre =
|
||||
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
ori_seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+26
-24
@@ -161,30 +161,32 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const bool prefill_one_step_stop) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
using XPU_TI = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::speculate_schedule_cache<<<1, 64, ctx->xpu_stream>>>(
|
||||
(const XPU_TI *)draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
(const XPU_TI *)prompt_lens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
reinterpret_cast<XPU_TI *>(step_draft_tokens),
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
reinterpret_cast<XPU_TI *>(accept_tokens),
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
(const XPU_TI *)stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_tokens_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::speculate_schedule_cache<<<1, 64, ctx->xpu_stream>>>(
|
||||
(const XPU_TI *)draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
(const XPU_TI *)prompt_lens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
reinterpret_cast<XPU_TI *>(step_draft_tokens),
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
reinterpret_cast<XPU_TI *>(accept_tokens),
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
(const XPU_TI *)stop_nums,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_tokens_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq,
|
||||
prefill_one_step_stop);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+20
-18
@@ -137,24 +137,26 @@ static int xpu2or3_wrapper(Context* ctx,
|
||||
const int stop_seqs_max_len,
|
||||
const int pre_ids_len) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::speculate_set_stop_value_multi_seqs<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
reinterpret_cast<XPU_INT64*>(accept_tokens),
|
||||
accept_nums,
|
||||
reinterpret_cast<const XPU_INT64*>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
reinterpret_cast<const XPU_INT64*>(stop_seqs),
|
||||
stop_seqs_len,
|
||||
seq_lens,
|
||||
reinterpret_cast<const XPU_INT64*>(end_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(min_tokens),
|
||||
bs,
|
||||
accept_tokens_len,
|
||||
stop_seqs_bs,
|
||||
stop_seqs_max_len,
|
||||
pre_ids_len);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::speculate_set_stop_value_multi_seqs<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
reinterpret_cast<XPU_INT64*>(accept_tokens),
|
||||
accept_nums,
|
||||
reinterpret_cast<const XPU_INT64*>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
reinterpret_cast<const XPU_INT64*>(stop_seqs),
|
||||
stop_seqs_len,
|
||||
seq_lens,
|
||||
reinterpret_cast<const XPU_INT64*>(end_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(min_tokens),
|
||||
bs,
|
||||
accept_tokens_len,
|
||||
stop_seqs_bs,
|
||||
stop_seqs_max_len,
|
||||
pre_ids_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+15
-13
@@ -82,19 +82,21 @@ static int xpu2or3_wrapper(Context *ctx,
|
||||
ctx_guard RAII_GUARD(ctx);
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
|
||||
xpu3::plugin::speculate_set_value_by_flag_and_id<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
|
||||
reinterpret_cast<const XPU_INT64 *>(accept_tokens),
|
||||
accept_num,
|
||||
stop_flags,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
bs,
|
||||
length,
|
||||
max_draft_tokens);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::speculate_set_value_by_flag_and_id<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
|
||||
reinterpret_cast<const XPU_INT64 *>(accept_tokens),
|
||||
accept_num,
|
||||
stop_flags,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
bs,
|
||||
length,
|
||||
max_draft_tokens);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
+22
-13
@@ -267,17 +267,21 @@ static int xpu3_wrapper(Context* ctx,
|
||||
int ret = api::constant<int>(ctx, repeat_times, token_num * length, 0);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
|
||||
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64*>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(cur_len),
|
||||
repeat_times,
|
||||
output_padding_offset,
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
token_num,
|
||||
max_seq_len);
|
||||
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre =
|
||||
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64*>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64*>(cur_len),
|
||||
repeat_times,
|
||||
output_padding_offset,
|
||||
bs,
|
||||
length,
|
||||
length_id,
|
||||
token_num,
|
||||
max_seq_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
ret_xre = min_length_logits_process_kernel<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
logits,
|
||||
reinterpret_cast<const XPU_INT64*>(cur_len),
|
||||
reinterpret_cast<const XPU_INT64*>(min_len),
|
||||
@@ -290,7 +294,10 @@ static int xpu3_wrapper(Context* ctx,
|
||||
end_length,
|
||||
token_num,
|
||||
max_seq_len);
|
||||
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
ret_xre = update_value_by_repeat_times_kernel<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
repeat_times,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
@@ -302,9 +309,10 @@ static int xpu3_wrapper(Context* ctx,
|
||||
length,
|
||||
token_num,
|
||||
max_seq_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
|
||||
if (bad_words && length_bad_words > 0) {
|
||||
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
ret_xre = ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
logits,
|
||||
reinterpret_cast<const XPU_INT64*>(bad_words),
|
||||
output_padding_offset,
|
||||
@@ -313,6 +321,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
length_bad_words,
|
||||
token_num,
|
||||
max_seq_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int max_draft_tokens) {
|
||||
constexpr int BlockSize = 512;
|
||||
using XPU_TI = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::speculate_update<BlockSize>
|
||||
int32_t ret_xre = xpu3::plugin::speculate_update<BlockSize>
|
||||
<<<1, 64, ctx->xpu_stream>>>(seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
not_need_stop,
|
||||
@@ -150,6 +150,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_draft_tokens);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int max_draft_tokens) {
|
||||
constexpr int BlockSize = 512;
|
||||
using XPU_TI = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::speculate_update_v3<BlockSize>
|
||||
int32_t ret_xre = xpu3::plugin::speculate_update_v3<BlockSize>
|
||||
<<<1, 64, ctx->xpu_stream>>>(seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
not_need_stop,
|
||||
@@ -154,6 +154,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_draft_tokens);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -382,7 +382,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const bool accept_all_drafts,
|
||||
const bool use_target_sampling) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
|
||||
int32_t ret_xre = xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64 *>(sampled_token_ids),
|
||||
reinterpret_cast<XPU_INT64 *>(accept_tokens),
|
||||
@@ -413,6 +413,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
benchmark_mode,
|
||||
accept_all_drafts,
|
||||
use_target_sampling);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||
|
||||
@@ -110,7 +110,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
int candidate_len,
|
||||
int max_seq_len) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
xpu3::plugin::top_p_candidates<T, MaxLength, TopPBeamTopK>
|
||||
int32_t ret_xre = xpu3::plugin::top_p_candidates<T, MaxLength, TopPBeamTopK>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
src,
|
||||
top_ps,
|
||||
@@ -122,6 +122,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
token_num,
|
||||
candidate_len,
|
||||
max_seq_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
int real_bsz,
|
||||
int max_model_len,
|
||||
int decode_states_len) {
|
||||
xpu3::plugin::
|
||||
int32_t ret_xre = xpu3::plugin::
|
||||
update_attn_mask_offsets<<<ctx->ncluster(), 1, ctx->xpu_stream>>>(
|
||||
attn_mask_offsets,
|
||||
seq_lens_this_time,
|
||||
@@ -131,6 +131,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
real_bsz,
|
||||
max_model_len,
|
||||
decode_states_len);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -100,16 +100,18 @@ static int xpu3_wrapper(Context *ctx,
|
||||
using XPU_TID = typename XPUIndexType<T>::type;
|
||||
auto set_stop_value_multi_ends =
|
||||
xpu3::plugin::set_stop_value_multi_ends<XPU_TID>;
|
||||
set_stop_value_multi_ends<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
reinterpret_cast<XPU_TID *>(topk_ids),
|
||||
reinterpret_cast<XPU_TID *>(next_tokens),
|
||||
reinterpret_cast<const XPU_TID *>(end_ids),
|
||||
seq_lens,
|
||||
bs,
|
||||
end_length,
|
||||
beam_search,
|
||||
prefill_one_step_stop);
|
||||
int32_t ret_xre =
|
||||
set_stop_value_multi_ends<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
reinterpret_cast<XPU_TID *>(topk_ids),
|
||||
reinterpret_cast<XPU_TID *>(next_tokens),
|
||||
reinterpret_cast<const XPU_TID *>(end_ids),
|
||||
seq_lens,
|
||||
bs,
|
||||
end_length,
|
||||
beam_search,
|
||||
prefill_one_step_stop);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -99,7 +99,9 @@ static int xpu3_wrapper(Context *ctx,
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto set_value_by_flags_and_idx_kernel =
|
||||
xpu3::plugin::set_value_by_flags_and_idx;
|
||||
set_value_by_flags_and_idx_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre = set_value_by_flags_and_idx_kernel<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
|
||||
reinterpret_cast<const XPU_INT64 *>(input_ids),
|
||||
@@ -109,6 +111,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
bs,
|
||||
length,
|
||||
length_input_ids);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -203,14 +203,18 @@ static int xpu3_wrapper(Context *ctx,
|
||||
int ret = api::constant<int>(ctx, repeat_times, bs * length, 0);
|
||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||
|
||||
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(cur_len),
|
||||
repeat_times,
|
||||
bs,
|
||||
length,
|
||||
length_id);
|
||||
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre =
|
||||
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(cur_len),
|
||||
repeat_times,
|
||||
bs,
|
||||
length,
|
||||
length_id);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
ret_xre = min_length_logits_process_kernel<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(
|
||||
logits,
|
||||
reinterpret_cast<const XPU_INT64 *>(cur_len),
|
||||
reinterpret_cast<const XPU_INT64 *>(min_len),
|
||||
@@ -219,23 +223,28 @@ static int xpu3_wrapper(Context *ctx,
|
||||
length,
|
||||
length_id,
|
||||
end_length);
|
||||
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
repeat_times,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
logits,
|
||||
bs,
|
||||
length);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
ret_xre =
|
||||
update_value_by_repeat_times_kernel<<<ctx->ncluster(),
|
||||
64,
|
||||
ctx->xpu_stream>>>(repeat_times,
|
||||
penalty_scores,
|
||||
frequency_scores,
|
||||
presence_scores,
|
||||
temperatures,
|
||||
logits,
|
||||
bs,
|
||||
length);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
|
||||
if (bad_words && length_bad_words > 0) {
|
||||
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
ret_xre = ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
logits,
|
||||
reinterpret_cast<const XPU_INT64 *>(bad_words),
|
||||
bs,
|
||||
length,
|
||||
length_bad_words);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
@@ -193,7 +193,9 @@ int xpu3_wrapper_input_scale(api::Context *ctx,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
auto func = xpu3::plugin::quant2d_per_channel_cluster<TX, TSCALE, TY>;
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, scale, y, m, n);
|
||||
int32_t ret_xre =
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, scale, y, m, n);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
@@ -230,7 +232,9 @@ int xpu3_wrapper_output_scale(api::Context *ctx,
|
||||
func = xpu3::plugin::quant2d_per_channel_cached<TX, TSCALE, TY, 32>;
|
||||
}
|
||||
}
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
|
||||
int32_t ret_xre =
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
template <typename TX,
|
||||
@@ -245,7 +249,9 @@ int xpu3_wrapper_output_scale(api::Context *ctx,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
auto func = xpu3::plugin::quant2d_per_channel_bign<TX, TSCALE, TY>;
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
|
||||
int32_t ret_xre =
|
||||
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -129,28 +129,30 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int pre_id_length) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_block_kernel = xpu3::plugin::recover_block;
|
||||
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
int32_t ret_xre =
|
||||
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
recover_block_list, // [bsz]
|
||||
recover_len,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
ori_seq_lens_encoder,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
block_tables,
|
||||
free_list,
|
||||
free_list_len,
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(pre_ids),
|
||||
reinterpret_cast<const XPU_INT64 *>(step_idx),
|
||||
encoder_block_lens,
|
||||
used_list_len,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
length,
|
||||
pre_id_length);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -52,17 +52,19 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int block_size) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto recover_decode_task = xpu3::plugin::recover_decode_task;
|
||||
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
block_tables,
|
||||
is_block_step,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
int32_t ret_xre =
|
||||
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
block_tables,
|
||||
is_block_step,
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
int64_t image_token_num,
|
||||
int64_t hidden_size,
|
||||
bool is_scatter) {
|
||||
xpu3::plugin::text_image_gather_scatter<T>
|
||||
int32_t ret_xre = xpu3::plugin::text_image_gather_scatter<T>
|
||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(input,
|
||||
text_input,
|
||||
image_input,
|
||||
@@ -128,6 +128,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
image_token_num,
|
||||
hidden_size,
|
||||
is_scatter);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -55,8 +55,10 @@ static int xpu3_wrapper(Context* ctx,
|
||||
int* text_index, // y1
|
||||
int* image_index, // y2
|
||||
const int64_t token_num) {
|
||||
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
|
||||
token_type_ids, text_index, image_index, token_num);
|
||||
int32_t ret_xre =
|
||||
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
|
||||
token_type_ids, text_index, image_index, token_num);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int input_ids_stride) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto update_inputs = xpu3::plugin::update_inputs;
|
||||
update_inputs<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre = update_inputs<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
@@ -107,6 +107,7 @@ static int xpu3_wrapper(Context *ctx,
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
}
|
||||
}
|
||||
// kernel 内要做 reduce,只能用 1 个 cluster
|
||||
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
|
||||
int32_t ret_xre = update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
|
||||
not_need_stop,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
@@ -96,6 +96,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
prefill_one_step_stop);
|
||||
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user