[XPU] Add return value checks for all XPU kernel launches (#6666)

* [XPU] Add return value checks for all XPU kernel launches

- Add -fxpu-launch-return compiler flag in CMakeLists.txt to enable
  kernel launch return values
- Add KERNEL_ASSERT_SUCCESS(ctx, ret_xre) checks after every XPU
  kernel launch across 45 wrapper files (55 launch sites total)
- Covers both main wrapper/ and mtp_wrapper/ directories
- Properly handles multiple kernel launches in the same function
  scope by reusing the ret_xre variable

* [XPU] code style fix
This commit is contained in:
mayang002
2026-03-10 10:45:18 +08:00
committed by GitHub
parent 28f7727a3d
commit ecc5032176
46 changed files with 498 additions and 401 deletions
+1 -1
View File
@@ -360,7 +360,7 @@ foreach(xpu_wrapper IN LISTS xpu_wrappers)
-I${CMAKE_CURRENT_SOURCE_DIR}/src/wrapper -D_GNU_SOURCE
-D__STDC_LIMIT_MACROS -DNDEBUG ${TOOLCHAIN_ARGS} --target=${TARGET_ARCH}
-fPIC -Wunused-variable -Werror -Wreorder -fvisibility=hidden
--xpu-host-only ${HOST_XPU_FLAGS}
-fxpu-launch-return --xpu-host-only ${HOST_XPU_FLAGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS wrapper_build/${wrapper_name}.wrapper.d
COMMENT wrapper_build/${wrapper_name}.wrapper.o
@@ -98,16 +98,18 @@ static int xpu3_wrapper(api::Context *ctx,
auto eb_adjust_batch_kernel =
xpu3::plugin::eb_adjust_batch<XPU_INDEX_TYPE_TX, XPU_INDEX_TYPE_TY>;
// NOTE: Don't change 16 to 64, because kernel use gsm
eb_adjust_batch_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INDEX_TYPE_TX *>(const_cast<TX *>(x)),
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y),
encoder_seqs_lods.xpu,
decoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
int32_t ret_xre =
eb_adjust_batch_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INDEX_TYPE_TX *>(const_cast<TX *>(x)),
reinterpret_cast<XPU_INDEX_TYPE_TY *>(y),
encoder_seqs_lods.xpu,
decoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -76,15 +76,17 @@ static int xpu3_wrapper(api::Context *ctx,
int64_t hidden_dim) {
auto eb_gather_next_token_kernel = xpu3::plugin::eb_gather_next_token<TX, TY>;
// NOTE: Don't change 16 to 64, because kernel use gsm
eb_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
const_cast<TX *>(x),
y,
encoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
int32_t ret_xre =
eb_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
const_cast<TX *>(x),
y,
encoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -194,27 +194,29 @@ static int xpu3_wrapper(Context *ctx,
const int max_decoder_block_num) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto free_and_dispatch_block_kernel = xpu3::plugin::free_and_dispatch_block;
free_and_dispatch_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num);
int32_t ret_xre =
free_and_dispatch_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -118,21 +118,25 @@ static int xpu3_wrapper(Context *ctx,
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto get_padding_offset = xpu3::plugin::get_padding_offset;
auto remove_padding = xpu3::plugin::remove_padding;
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens,
max_seq_len,
bs);
remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
int32_t ret_xre =
get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
padding_offset,
cum_offsets_out,
cu_seqlens_q,
cu_seqlens_k,
cum_offsets,
seq_lens,
max_seq_len,
bs);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
ret_xre = remove_padding<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(x_remove_padding),
reinterpret_cast<const XPU_INT64 *>(input_ids),
seq_lens,
cum_offsets_out,
max_seq_len,
bs);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -93,16 +93,18 @@ static int xpu3_wrapper(Context* ctx,
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto limit_thinking_content_length_kernel_v1 =
xpu3::plugin::limit_thinking_content_length_kernel_v1;
limit_thinking_content_length_kernel_v1<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(next_tokens),
max_think_lens,
reinterpret_cast<const XPU_INT64*>(step_idx),
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
limit_think_status,
stop_flags,
think_end_id,
bs,
eos_token_id_len);
int32_t ret_xre =
limit_thinking_content_length_kernel_v1<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(next_tokens),
max_think_lens,
reinterpret_cast<const XPU_INT64*>(step_idx),
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
limit_think_status,
stop_flags,
think_end_id,
bs,
eos_token_id_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -98,15 +98,17 @@ static int xpu3_wrapper(Context* ctx,
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto limit_thinking_content_length_kernel_v2 =
xpu3::plugin::limit_thinking_content_length_kernel_v2;
limit_thinking_content_length_kernel_v2<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(next_tokens),
max_think_lens,
reinterpret_cast<const XPU_INT64*>(step_idx),
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
bs);
int32_t ret_xre =
limit_thinking_content_length_kernel_v2<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(next_tokens),
max_think_lens,
reinterpret_cast<const XPU_INT64*>(step_idx),
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
bs);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -108,7 +108,7 @@ static int xpu3_wrapper(Context* ctx,
const int bsz,
const int actual_draft_token_num,
const int input_token_num) {
xpu3::plugin::ComputeOrderKernel<<<1, 1, ctx->xpu_stream>>>(
int32_t ret_xre = xpu3::plugin::ComputeOrderKernel<<<1, 1, ctx->xpu_stream>>>(
seq_lens_this_time,
seq_lens_encoder,
base_model_seq_lens_this_time,
@@ -119,6 +119,7 @@ static int xpu3_wrapper(Context* ctx,
bsz,
actual_draft_token_num,
input_token_num);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -76,13 +76,15 @@ static int xpu3_wrapper(Context* ctx,
int* output_token_num,
int bsz) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::ComputeSelfOrderKernel<<<1, 1, ctx->xpu_stream>>>(
last_seq_lens_this_time,
seq_lens_this_time,
reinterpret_cast<const XPU_INT64*>(step_idx),
src_map,
output_token_num,
bsz);
int32_t ret_xre =
xpu3::plugin::ComputeSelfOrderKernel<<<1, 1, ctx->xpu_stream>>>(
last_seq_lens_this_time,
seq_lens_this_time,
reinterpret_cast<const XPU_INT64*>(step_idx),
src_map,
output_token_num,
bsz);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -82,7 +82,7 @@ static int xpu3_wrapper(Context* ctx,
const bool* base_model_stop_flags,
int bsz,
int base_model_draft_token_len) {
xpu3::plugin::
int32_t ret_xre = xpu3::plugin::
draft_model_postprocess<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const xpu3::int64_t*>(base_model_draft_tokens),
base_model_seq_lens_this_time,
@@ -90,6 +90,7 @@ static int xpu3_wrapper(Context* ctx,
base_model_stop_flags,
bsz,
base_model_draft_token_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -240,37 +240,39 @@ static int xpu3_wrapper(api::Context* ctx,
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
// NOTE: Don't change 16 to 64, because kernel use gsm
xpu3::plugin::draft_model_preprocess<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(draft_tokens),
reinterpret_cast<XPU_INT64*>(input_ids),
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64*>(step_idx),
not_need_stop,
is_block_step,
batch_drop,
reinterpret_cast<XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(accept_tokens),
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
reinterpret_cast<const XPU_INT64*>(base_model_step_idx),
base_model_stop_flags,
base_model_is_block_step,
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1);
int32_t ret_xre =
xpu3::plugin::draft_model_preprocess<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(draft_tokens),
reinterpret_cast<XPU_INT64*>(input_ids),
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64*>(step_idx),
not_need_stop,
is_block_step,
batch_drop,
reinterpret_cast<XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(accept_tokens),
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
reinterpret_cast<const XPU_INT64*>(base_model_step_idx),
base_model_stop_flags,
base_model_is_block_step,
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -178,28 +178,30 @@ static int xpu2or3_wrapper(Context* ctx,
const bool prefill_one_step_stop) {
ctx_guard RAII_GUARD(ctx);
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::draft_model_update<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64*>(inter_next_tokens),
reinterpret_cast<XPU_INT64*>(draft_tokens),
reinterpret_cast<XPU_INT64*>(pre_ids),
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64*>(step_idx),
output_cum_offsets,
stop_flags,
not_need_stop,
reinterpret_cast<const XPU_INT64*>(max_dec_len),
reinterpret_cast<const XPU_INT64*>(end_ids),
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len,
max_seq_len,
substep,
prefill_one_step_stop);
int32_t ret_xre =
xpu3::plugin::draft_model_update<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64*>(inter_next_tokens),
reinterpret_cast<XPU_INT64*>(draft_tokens),
reinterpret_cast<XPU_INT64*>(pre_ids),
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<XPU_INT64*>(step_idx),
output_cum_offsets,
stop_flags,
not_need_stop,
reinterpret_cast<const XPU_INT64*>(max_dec_len),
reinterpret_cast<const XPU_INT64*>(end_ids),
reinterpret_cast<XPU_INT64*>(base_model_draft_tokens),
bsz,
max_draft_token,
pre_id_length,
max_base_model_draft_token,
end_ids_len,
max_seq_len,
substep,
prefill_one_step_stop);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -104,16 +104,18 @@ static int xpu3_wrapper(api::Context *ctx,
auto eb_mtp_gather_next_token_kernel =
xpu3::plugin::eb_mtp_gather_next_token<TX, TY>;
// NOTE: Don't change 16 to 64, because kernel use gsm
eb_mtp_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
const_cast<TX *>(x),
y,
encoder_seqs_lods.xpu,
decoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
int32_t ret_xre =
eb_mtp_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
const_cast<TX *>(x),
y,
encoder_seqs_lods.xpu,
decoder_seqs_lods.xpu,
encoder_batch_map.xpu,
decoder_batch_map.xpu,
en_batch,
de_batch,
hidden_dim);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -104,7 +104,9 @@ static int xpu3_wrapper(api::Context *ctx,
auto eb_recover_batch_sequence_kernel =
xpu3::plugin::eb_recover_batch_sequence<TX, TY>;
// NOTE: Don't change 16 to 64, because kernel use gsm
eb_recover_batch_sequence_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
int32_t ret_xre = eb_recover_batch_sequence_kernel<<<ctx->ncluster(),
16,
ctx->xpu_stream>>>(
const_cast<TX *>(x),
y,
encoder_seqs_lods.xpu,
@@ -114,6 +116,7 @@ static int xpu3_wrapper(api::Context *ctx,
en_batch,
de_batch,
hidden_dim);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -154,7 +154,7 @@ static int xpu2or3_wrapper(Context *ctx,
WRAPPER_UNIMPLEMENTED(ctx);
}
auto mtp_free_and_dispatch_block = xpu3::plugin::mtp_free_and_dispatch_block;
mtp_free_and_dispatch_block<<<12, 64, ctx->xpu_stream>>>(
int32_t ret_xre = mtp_free_and_dispatch_block<<<12, 64, ctx->xpu_stream>>>(
base_model_stop_flags,
stop_flags,
batch_drop,
@@ -169,6 +169,7 @@ static int xpu2or3_wrapper(Context *ctx,
block_size,
block_num_per_seq,
max_draft_tokens);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -57,9 +57,10 @@ static int xpu3_wrapper(Context* ctx,
T* output,
int dim_embed,
int elem_cnt) {
xpu3::plugin::rebuildHiddenStatesKernel<T>
int32_t ret_xre = xpu3::plugin::rebuildHiddenStatesKernel<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
input, position_map, output, dim_embed, elem_cnt);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -52,9 +52,10 @@ static int xpu3_wrapper(Context* ctx,
T* output,
int dim_embed,
int elem_cnt) {
xpu3::plugin::rebuildSelfHiddenStatesKernel<T>
int32_t ret_xre = xpu3::plugin::rebuildSelfHiddenStatesKernel<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
input, src_map, output, dim_embed, elem_cnt);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -63,22 +63,24 @@ static int xpu3_wrapper(Context *ctx,
const int num_extra_tokens) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_spec_decode_task = xpu3::plugin::recover_spec_decode_task;
recover_spec_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
reinterpret_cast<XPU_INT64 *>(draft_tokens),
reinterpret_cast<const XPU_INT64 *>(step_draft_tokens),
step_seq_lens_this_time,
bsz,
block_num_per_seq,
block_size,
draft_tokens_len,
num_extra_tokens);
int32_t ret_xre =
recover_spec_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
reinterpret_cast<XPU_INT64 *>(draft_tokens),
reinterpret_cast<const XPU_INT64 *>(step_draft_tokens),
step_seq_lens_this_time,
bsz,
block_num_per_seq,
block_size,
draft_tokens_len,
num_extra_tokens);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -42,8 +42,10 @@ static int xpu2or3_wrapper(Context* ctx,
const int* seq_lens_decoder,
const int max_bsz) {
ctx_guard RAII_GUARD(ctx);
xpu3::plugin::speculate_clear_accept_nums<<<1, 64, ctx->xpu_stream>>>(
accept_num, seq_lens_decoder, max_bsz);
int32_t ret_xre =
xpu3::plugin::speculate_clear_accept_nums<<<1, 64, ctx->xpu_stream>>>(
accept_num, seq_lens_decoder, max_bsz);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -209,31 +209,33 @@ static int xpu3_wrapper(Context *ctx,
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto speculate_free_and_dispatch_block_kernel =
xpu3::plugin::speculate_free_and_dispatch_block;
speculate_free_and_dispatch_block_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
accept_num,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
int32_t ret_xre =
speculate_free_and_dispatch_block_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
accept_num,
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -101,28 +101,30 @@ static int xpu3_wrapper(Context *ctx,
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto speculate_free_and_reschedule =
xpu3::plugin::speculate_free_and_reschedule;
speculate_free_and_reschedule<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
int32_t ret_xre =
speculate_free_and_reschedule<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_len,
recover_block_list,
recover_len,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(first_token_ids),
bsz,
block_size,
block_num_per_seq,
max_decoder_block_num,
max_draft_tokens);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -96,18 +96,20 @@ static int xpu3_wrapper(Context* ctx,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size) {
xpu3::plugin::speculate_get_logits<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
draft_logits,
next_token_num,
batch_token_num,
cu_next_token_offset,
cu_batch_token_offset,
logits,
first_token_logits,
seq_lens_this_time,
seq_lens_encoder,
real_bsz,
vocab_size);
int32_t ret_xre = xpu3::plugin::
speculate_get_logits<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
draft_logits,
next_token_num,
batch_token_num,
cu_next_token_offset,
cu_batch_token_offset,
logits,
first_token_logits,
seq_lens_this_time,
seq_lens_encoder,
real_bsz,
vocab_size);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -61,15 +61,17 @@ static int xpu2or3_wrapper(Context* ctx,
const int bsz,
const int max_seq_len) {
ctx_guard RAII_GUARD(ctx);
xpu3::plugin::speculate_get_output_padding_offset<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
output_padding_offset,
output_cum_offsets,
output_cum_offsets_tmp,
seq_lens_output,
bsz,
max_seq_len);
int32_t ret_xre =
xpu3::plugin::speculate_get_output_padding_offset<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
output_padding_offset,
output_cum_offsets,
output_cum_offsets_tmp,
seq_lens_output,
bsz,
max_seq_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -112,7 +112,7 @@ static int xpu3_wrapper_remove_padding(Context* ctx,
int bsz,
int token_num_data) {
using XPU_T = typename XPUIndexType<T>::type;
xpu3::plugin::speculate_remove_padding<XPU_T>
int32_t ret_xre = xpu3::plugin::speculate_remove_padding<XPU_T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
static_cast<XPU_T*>(static_cast<void*>(output_data)),
static_cast<const XPU_T*>(static_cast<const void*>(input_data)),
@@ -124,6 +124,7 @@ static int xpu3_wrapper_remove_padding(Context* ctx,
max_draft_tokens,
bsz,
token_num_data);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -137,7 +138,7 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx,
const int* seq_lens,
const int max_seq_len,
int bsz) {
xpu3::plugin::
int32_t ret_xre = xpu3::plugin::
speculate_get_padding_offset<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
batch_id_per_token,
cum_offsets_out,
@@ -147,6 +148,7 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx,
seq_lens,
max_seq_len,
bsz);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -58,13 +58,14 @@ static int xpu2or3_wrapper(Context* ctx,
const int* seq_lens_decoder,
const int real_bsz) {
ctx_guard RAII_GUARD(ctx);
xpu3::plugin::
int32_t ret_xre = xpu3::plugin::
speculate_get_seq_lens_output<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
seq_lens_output,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
real_bsz);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -78,7 +78,7 @@ static int xpu3_wrapper(Context* ctx,
int dim_embed,
int elem_nums,
T* out) {
xpu3::plugin::RebuildAppendPaddingKernel<T>
int32_t ret_xre = xpu3::plugin::RebuildAppendPaddingKernel<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(full_hidden_states,
cum_offsets,
seq_len_encoder,
@@ -88,6 +88,7 @@ static int xpu3_wrapper(Context* ctx,
dim_embed,
elem_nums,
out);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -136,29 +136,31 @@ static int xpu3_wrapper(Context *ctx,
const int pre_id_length) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_block_kernel = xpu3::plugin::speculate_recover_block;
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
ori_seq_lens_decoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(step_idx),
encoder_block_lens,
used_list_len,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
bsz,
block_num_per_seq,
length,
pre_id_length);
int32_t ret_xre =
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
ori_seq_lens_decoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(step_idx),
encoder_block_lens,
used_list_len,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
bsz,
block_num_per_seq,
length,
pre_id_length);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -161,30 +161,32 @@ static int xpu3_wrapper(Context *ctx,
const bool prefill_one_step_stop) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
using XPU_TI = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_schedule_cache<<<1, 64, ctx->xpu_stream>>>(
(const XPU_TI *)draft_tokens,
block_tables,
stop_flags,
(const XPU_TI *)prompt_lens,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
reinterpret_cast<XPU_TI *>(step_draft_tokens),
step_seq_lens_this_time,
accept_num,
reinterpret_cast<XPU_TI *>(accept_tokens),
is_block_step,
not_need_stop,
(const XPU_TI *)stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens,
draft_tokens_len,
accept_tokens_len,
block_size,
block_num_per_seq,
prefill_one_step_stop);
int32_t ret_xre =
xpu3::plugin::speculate_schedule_cache<<<1, 64, ctx->xpu_stream>>>(
(const XPU_TI *)draft_tokens,
block_tables,
stop_flags,
(const XPU_TI *)prompt_lens,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
reinterpret_cast<XPU_TI *>(step_draft_tokens),
step_seq_lens_this_time,
accept_num,
reinterpret_cast<XPU_TI *>(accept_tokens),
is_block_step,
not_need_stop,
(const XPU_TI *)stop_nums,
real_bsz,
max_bsz,
max_next_step_tokens,
draft_tokens_len,
accept_tokens_len,
block_size,
block_num_per_seq,
prefill_one_step_stop);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -137,24 +137,26 @@ static int xpu2or3_wrapper(Context* ctx,
const int stop_seqs_max_len,
const int pre_ids_len) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_set_stop_value_multi_seqs<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
stop_flags,
reinterpret_cast<XPU_INT64*>(accept_tokens),
accept_nums,
reinterpret_cast<const XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(step_idx),
reinterpret_cast<const XPU_INT64*>(stop_seqs),
stop_seqs_len,
seq_lens,
reinterpret_cast<const XPU_INT64*>(end_ids),
reinterpret_cast<const XPU_INT64*>(min_tokens),
bs,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
int32_t ret_xre =
xpu3::plugin::speculate_set_stop_value_multi_seqs<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
stop_flags,
reinterpret_cast<XPU_INT64*>(accept_tokens),
accept_nums,
reinterpret_cast<const XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(step_idx),
reinterpret_cast<const XPU_INT64*>(stop_seqs),
stop_seqs_len,
seq_lens,
reinterpret_cast<const XPU_INT64*>(end_ids),
reinterpret_cast<const XPU_INT64*>(min_tokens),
bs,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
pre_ids_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -82,19 +82,21 @@ static int xpu2or3_wrapper(Context *ctx,
ctx_guard RAII_GUARD(ctx);
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_set_value_by_flag_and_id<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
reinterpret_cast<const XPU_INT64 *>(accept_tokens),
accept_num,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<const XPU_INT64 *>(step_idx),
bs,
length,
max_draft_tokens);
int32_t ret_xre =
xpu3::plugin::speculate_set_value_by_flag_and_id<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
reinterpret_cast<const XPU_INT64 *>(accept_tokens),
accept_num,
stop_flags,
seq_lens_encoder,
seq_lens_decoder,
reinterpret_cast<const XPU_INT64 *>(step_idx),
bs,
length,
max_draft_tokens);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -267,17 +267,21 @@ static int xpu3_wrapper(Context* ctx,
int ret = api::constant<int>(ctx, repeat_times, token_num * length, 0);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(cur_len),
repeat_times,
output_padding_offset,
bs,
length,
length_id,
token_num,
max_seq_len);
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
int32_t ret_xre =
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(cur_len),
repeat_times,
output_padding_offset,
bs,
length,
length_id,
token_num,
max_seq_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
ret_xre = min_length_logits_process_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
logits,
reinterpret_cast<const XPU_INT64*>(cur_len),
reinterpret_cast<const XPU_INT64*>(min_len),
@@ -290,7 +294,10 @@ static int xpu3_wrapper(Context* ctx,
end_length,
token_num,
max_seq_len);
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
ret_xre = update_value_by_repeat_times_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
repeat_times,
penalty_scores,
frequency_scores,
@@ -302,9 +309,10 @@ static int xpu3_wrapper(Context* ctx,
length,
token_num,
max_seq_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
if (bad_words && length_bad_words > 0) {
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
ret_xre = ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits,
reinterpret_cast<const XPU_INT64*>(bad_words),
output_padding_offset,
@@ -313,6 +321,7 @@ static int xpu3_wrapper(Context* ctx,
length_bad_words,
token_num,
max_seq_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
}
return api::SUCCESS;
}
@@ -135,7 +135,7 @@ static int xpu3_wrapper(Context *ctx,
const int max_draft_tokens) {
constexpr int BlockSize = 512;
using XPU_TI = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_update<BlockSize>
int32_t ret_xre = xpu3::plugin::speculate_update<BlockSize>
<<<1, 64, ctx->xpu_stream>>>(seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
@@ -150,6 +150,7 @@ static int xpu3_wrapper(Context *ctx,
real_bsz,
max_bsz,
max_draft_tokens);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -139,7 +139,7 @@ static int xpu3_wrapper(Context *ctx,
const int max_draft_tokens) {
constexpr int BlockSize = 512;
using XPU_TI = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_update_v3<BlockSize>
int32_t ret_xre = xpu3::plugin::speculate_update_v3<BlockSize>
<<<1, 64, ctx->xpu_stream>>>(seq_lens_encoder,
seq_lens_decoder,
not_need_stop,
@@ -154,6 +154,7 @@ static int xpu3_wrapper(Context *ctx,
real_bsz,
max_bsz,
max_draft_tokens);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -382,7 +382,7 @@ static int xpu3_wrapper(Context *ctx,
const bool accept_all_drafts,
const bool use_target_sampling) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
int32_t ret_xre = xpu3::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64 *>(sampled_token_ids),
reinterpret_cast<XPU_INT64 *>(accept_tokens),
@@ -413,6 +413,7 @@ static int xpu3_wrapper(Context *ctx,
benchmark_mode,
accept_all_drafts,
use_target_sampling);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
template <bool ENABLE_TOPP, bool USE_TOPK>
@@ -110,7 +110,7 @@ static int xpu3_wrapper(Context* ctx,
int candidate_len,
int max_seq_len) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
xpu3::plugin::top_p_candidates<T, MaxLength, TopPBeamTopK>
int32_t ret_xre = xpu3::plugin::top_p_candidates<T, MaxLength, TopPBeamTopK>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
src,
top_ps,
@@ -122,6 +122,7 @@ static int xpu3_wrapper(Context* ctx,
token_num,
candidate_len,
max_seq_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -116,7 +116,7 @@ static int xpu3_wrapper(Context* ctx,
int real_bsz,
int max_model_len,
int decode_states_len) {
xpu3::plugin::
int32_t ret_xre = xpu3::plugin::
update_attn_mask_offsets<<<ctx->ncluster(), 1, ctx->xpu_stream>>>(
attn_mask_offsets,
seq_lens_this_time,
@@ -131,6 +131,7 @@ static int xpu3_wrapper(Context* ctx,
real_bsz,
max_model_len,
decode_states_len);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -100,16 +100,18 @@ static int xpu3_wrapper(Context *ctx,
using XPU_TID = typename XPUIndexType<T>::type;
auto set_stop_value_multi_ends =
xpu3::plugin::set_stop_value_multi_ends<XPU_TID>;
set_stop_value_multi_ends<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
reinterpret_cast<XPU_TID *>(topk_ids),
reinterpret_cast<XPU_TID *>(next_tokens),
reinterpret_cast<const XPU_TID *>(end_ids),
seq_lens,
bs,
end_length,
beam_search,
prefill_one_step_stop);
int32_t ret_xre =
set_stop_value_multi_ends<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
reinterpret_cast<XPU_TID *>(topk_ids),
reinterpret_cast<XPU_TID *>(next_tokens),
reinterpret_cast<const XPU_TID *>(end_ids),
seq_lens,
bs,
end_length,
beam_search,
prefill_one_step_stop);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -99,7 +99,9 @@ static int xpu3_wrapper(Context *ctx,
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto set_value_by_flags_and_idx_kernel =
xpu3::plugin::set_value_by_flags_and_idx;
set_value_by_flags_and_idx_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
int32_t ret_xre = set_value_by_flags_and_idx_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
stop_flags,
reinterpret_cast<XPU_INT64 *>(pre_ids_all),
reinterpret_cast<const XPU_INT64 *>(input_ids),
@@ -109,6 +111,7 @@ static int xpu3_wrapper(Context *ctx,
bs,
length,
length_input_ids);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -203,14 +203,18 @@ static int xpu3_wrapper(Context *ctx,
int ret = api::constant<int>(ctx, repeat_times, bs * length, 0);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(cur_len),
repeat_times,
bs,
length,
length_id);
min_length_logits_process_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
int32_t ret_xre =
update_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(cur_len),
repeat_times,
bs,
length,
length_id);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
ret_xre = min_length_logits_process_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
logits,
reinterpret_cast<const XPU_INT64 *>(cur_len),
reinterpret_cast<const XPU_INT64 *>(min_len),
@@ -219,23 +223,28 @@ static int xpu3_wrapper(Context *ctx,
length,
length_id,
end_length);
update_value_by_repeat_times_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
repeat_times,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
logits,
bs,
length);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
ret_xre =
update_value_by_repeat_times_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(repeat_times,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
logits,
bs,
length);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
if (bad_words && length_bad_words > 0) {
ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
ret_xre = ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits,
reinterpret_cast<const XPU_INT64 *>(bad_words),
bs,
length,
length_bad_words);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
}
return api::SUCCESS;
}
@@ -193,7 +193,9 @@ int xpu3_wrapper_input_scale(api::Context *ctx,
int64_t m,
int64_t n) {
auto func = xpu3::plugin::quant2d_per_channel_cluster<TX, TSCALE, TY>;
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, scale, y, m, n);
int32_t ret_xre =
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, scale, y, m, n);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -230,7 +232,9 @@ int xpu3_wrapper_output_scale(api::Context *ctx,
func = xpu3::plugin::quant2d_per_channel_cached<TX, TSCALE, TY, 32>;
}
}
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
int32_t ret_xre =
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
template <typename TX,
@@ -245,7 +249,9 @@ int xpu3_wrapper_output_scale(api::Context *ctx,
int64_t m,
int64_t n) {
auto func = xpu3::plugin::quant2d_per_channel_bign<TX, TSCALE, TY>;
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
int32_t ret_xre =
func<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -129,28 +129,30 @@ static int xpu3_wrapper(Context *ctx,
const int pre_id_length) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_block_kernel = xpu3::plugin::recover_block;
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(step_idx),
encoder_block_lens,
used_list_len,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
bsz,
block_num_per_seq,
length,
pre_id_length);
int32_t ret_xre =
recover_block_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
recover_block_list, // [bsz]
recover_len,
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
free_list,
free_list_len,
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<const XPU_INT64 *>(pre_ids),
reinterpret_cast<const XPU_INT64 *>(step_idx),
encoder_block_lens,
used_list_len,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
reinterpret_cast<const XPU_INT64 *>(first_token_ids),
bsz,
block_num_per_seq,
length,
pre_id_length);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -52,17 +52,19 @@ static int xpu3_wrapper(Context *ctx,
const int block_size) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_decode_task = xpu3::plugin::recover_decode_task;
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
bsz,
block_num_per_seq,
block_size);
int32_t ret_xre =
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
bsz,
block_num_per_seq,
block_size);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -116,7 +116,7 @@ static int xpu3_wrapper(Context* ctx,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
xpu3::plugin::text_image_gather_scatter<T>
int32_t ret_xre = xpu3::plugin::text_image_gather_scatter<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(input,
text_input,
image_input,
@@ -128,6 +128,7 @@ static int xpu3_wrapper(Context* ctx,
image_token_num,
hidden_size,
is_scatter);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -55,8 +55,10 @@ static int xpu3_wrapper(Context* ctx,
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
token_type_ids, text_index, image_index, token_num);
int32_t ret_xre =
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
token_type_ids, text_index, image_index, token_num);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -94,7 +94,7 @@ static int xpu3_wrapper(Context *ctx,
const int input_ids_stride) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto update_inputs = xpu3::plugin::update_inputs;
update_inputs<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
int32_t ret_xre = update_inputs<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
@@ -107,6 +107,7 @@ static int xpu3_wrapper(Context *ctx,
bsz,
max_bsz,
input_ids_stride);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
@@ -76,7 +76,7 @@ static int xpu3_wrapper(Context* ctx,
}
}
// kernel 内要做 reduce,只能用 1 个 cluster
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
int32_t ret_xre = update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
@@ -96,6 +96,7 @@ static int xpu3_wrapper(Context* ctx,
block_num_per_seq,
block_size,
prefill_one_step_stop);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}