[XPU] Refactor pre process (#6993)

* [XPU] support speculate_pre_process

* merge develop

* fix codestype

* fix mtp, support cu_seqlens_q_output

* fix mtp, support cu_seqlens_q_output

* fix test

---------

Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
cmcamdy
2026-04-01 20:29:55 +08:00
committed by GitHub
parent fba8a51ad1
commit 7a2e33098f
36 changed files with 2725 additions and 511 deletions
+2 -3
View File
@@ -58,9 +58,8 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const int bsz = seq_len.shape()[0]; const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1]; const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false); auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); // token num is cpu tensor
const int token_num_data = token_num.data<int64_t>()[0];
const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::empty( auto x_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place()); {token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::empty( auto padding_offset = paddle::empty(
+3 -11
View File
@@ -24,8 +24,7 @@
template <paddle::DataType T> template <paddle::DataType T>
std::vector<paddle::Tensor> AdjustBatchKernel( std::vector<paddle::Tensor> AdjustBatchKernel(
const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod, const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &encoder_batch_idx,
@@ -49,7 +48,6 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
using data_t = typename PDTraits<T>::data_t; using data_t = typename PDTraits<T>::data_t;
const int token_num = x.dims()[0]; const int token_num = x.dims()[0];
const int dim = x.dims()[1]; const int dim = x.dims()[1];
const int bsz = cum_offsets.shape()[0];
int enc_batch = len_info_cpu.data<int32_t>()[0]; int enc_batch = len_info_cpu.data<int32_t>()[0];
int dec_batch = len_info_cpu.data<int32_t>()[1]; int dec_batch = len_info_cpu.data<int32_t>()[1];
@@ -87,8 +85,7 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
} }
using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)( using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod, const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &encoder_batch_idx,
@@ -102,8 +99,7 @@ using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
int max_input_length); int max_input_length);
std::vector<paddle::Tensor> AdjustBatch( std::vector<paddle::Tensor> AdjustBatch(
const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &x, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1]
const paddle::Tensor &encoder_seq_lod, const paddle::Tensor &encoder_seq_lod,
const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &decoder_seq_lod,
const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &encoder_batch_idx,
@@ -135,7 +131,6 @@ std::vector<paddle::Tensor> AdjustBatch(
} }
return func(x, return func(x,
cum_offsets,
encoder_seq_lod, encoder_seq_lod,
decoder_seq_lod, decoder_seq_lod,
encoder_batch_idx, encoder_batch_idx,
@@ -151,7 +146,6 @@ std::vector<paddle::Tensor> AdjustBatch(
std::vector<std::vector<int64_t>> AdjustBatchInferShape( std::vector<std::vector<int64_t>> AdjustBatchInferShape(
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &encoder_seq_lod_shape, const std::vector<int64_t> &encoder_seq_lod_shape,
const std::vector<int64_t> &decoder_seq_lod_shape, const std::vector<int64_t> &decoder_seq_lod_shape,
const std::vector<int64_t> &encoder_batch_idx_shape, const std::vector<int64_t> &encoder_batch_idx_shape,
@@ -172,7 +166,6 @@ std::vector<std::vector<int64_t>> AdjustBatchInferShape(
std::vector<paddle::DataType> AdjustBatchInferDtype( std::vector<paddle::DataType> AdjustBatchInferDtype(
const paddle::DataType &x_dtype, const paddle::DataType &x_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &encoder_seq_lod_dtype, const paddle::DataType &encoder_seq_lod_dtype,
const paddle::DataType &decoder_seq_lod_dtype, const paddle::DataType &decoder_seq_lod_dtype,
const paddle::DataType &encoder_batch_idx_dtype, const paddle::DataType &encoder_batch_idx_dtype,
@@ -188,7 +181,6 @@ std::vector<paddle::DataType> AdjustBatchInferDtype(
PD_BUILD_STATIC_OP(adjust_batch) PD_BUILD_STATIC_OP(adjust_batch)
.Inputs({"x", .Inputs({"x",
"cum_offsets",
"encoder_seq_lod", "encoder_seq_lod",
"decoder_seq_lod", "decoder_seq_lod",
"encoder_batch_idx", "encoder_batch_idx",
-5
View File
@@ -66,7 +66,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::Tensor& qkv, const paddle::Tensor& qkv,
const paddle::Tensor& key_cache, const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache, const paddle::Tensor& value_cache,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& rotary_embs, const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables, const paddle::Tensor& prefix_block_tables,
@@ -122,7 +121,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
auto qkv_shape = qkv.dims(); auto qkv_shape = qkv.dims();
auto cache_shape = key_cache.dims(); auto cache_shape = key_cache.dims();
auto block_table_shape = block_tables.dims(); auto block_table_shape = block_tables.dims();
const int bsz = cum_offsets.dims()[0];
const int block_batch = block_table_shape[0]; const int block_batch = block_table_shape[0];
const int max_block_per_seq = block_table_shape[1]; const int max_block_per_seq = block_table_shape[1];
const int kv_num_heads = cache_shape[1]; const int kv_num_heads = cache_shape[1];
@@ -984,7 +982,6 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::Tensor& qkv, const paddle::Tensor& qkv,
const paddle::Tensor& key_cache, const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache, const paddle::Tensor& value_cache,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& rotary_embs, const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables, const paddle::Tensor& prefix_block_tables,
@@ -1023,7 +1020,6 @@ std::vector<paddle::Tensor> BlockAttn(
return BlockAttnKernel<TX, TC, TS>(qkv, \ return BlockAttnKernel<TX, TC, TS>(qkv, \
key_cache, \ key_cache, \
value_cache, \ value_cache, \
cum_offsets, \
rotary_embs, \ rotary_embs, \
block_tables, \ block_tables, \
prefix_block_tables, \ prefix_block_tables, \
@@ -1099,7 +1095,6 @@ PD_BUILD_STATIC_OP(block_attn)
.Inputs({"qkv", .Inputs({"qkv",
"key_cache", "key_cache",
"value_cache", "value_cache",
"cum_offsets",
"rotary_embs", "rotary_embs",
"block_tables", "block_tables",
"prefix_block_tables", "prefix_block_tables",
@@ -22,8 +22,7 @@
#endif #endif
std::vector<paddle::Tensor> GatherNextToken( std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod, const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_map, const paddle::Tensor& encoder_batch_map,
@@ -46,7 +45,7 @@ std::vector<paddle::Tensor> GatherNextToken(
typedef paddle::bfloat16 data_t; typedef paddle::bfloat16 data_t;
const int dim = x.dims()[1]; const int dim = x.dims()[1];
const int token_num = x.shape()[0]; const int token_num = x.shape()[0];
int bsz = cum_offsets.shape()[0]; int bsz = -1;
int enc_batch = len_info_cpu.data<int32_t>()[0]; int enc_batch = len_info_cpu.data<int32_t>()[0];
int dec_batch = len_info_cpu.data<int32_t>()[1]; int dec_batch = len_info_cpu.data<int32_t>()[1];
if (max_bsz > 0) { if (max_bsz > 0) {
@@ -116,7 +115,6 @@ std::vector<paddle::Tensor> GatherNextToken(
std::vector<std::vector<int64_t>> GatherNextTokenInferShape( std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
const std::vector<int64_t>& x_shape, const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& encoder_seq_lod_shape, const std::vector<int64_t>& encoder_seq_lod_shape,
const std::vector<int64_t>& decoder_seq_lod_shape, const std::vector<int64_t>& decoder_seq_lod_shape,
const std::vector<int64_t>& encoder_batch_map_shape, const std::vector<int64_t>& encoder_batch_map_shape,
@@ -130,19 +128,18 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
// if (output_padding_offset_shape) { // if (output_padding_offset_shape) {
// PD_THROW("speculative decoding is not supported in XPU."); // PD_THROW("speculative decoding is not supported in XPU.");
// } // }
int64_t bsz = cum_offsets_shape[0]; // int64_t bsz = cum_offsets_shape[0];
int64_t bsz = 0;
int64_t dim_embed = x_shape[1]; int64_t dim_embed = x_shape[1];
if (output_padding_offset_shape) { if (output_padding_offset_shape) {
return {{-1, dim_embed}}; return {{-1, dim_embed}};
} else { } else {
int64_t bsz = cum_offsets_shape[0];
return {{bsz, dim_embed}}; return {{bsz, dim_embed}};
} }
} }
std::vector<paddle::DataType> GatherNextTokenInferDtype( std::vector<paddle::DataType> GatherNextTokenInferDtype(
const paddle::DataType& x_dtype, const paddle::DataType& x_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& encoder_seq_lod_dtype, const paddle::DataType& encoder_seq_lod_dtype,
const paddle::DataType& decoder_seq_lod_dtype, const paddle::DataType& decoder_seq_lod_dtype,
const paddle::DataType& encoder_batch_map_dtype, const paddle::DataType& encoder_batch_map_dtype,
@@ -158,7 +155,6 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
PD_BUILD_STATIC_OP(gather_next_token) PD_BUILD_STATIC_OP(gather_next_token)
.Inputs({"x", .Inputs({"x",
"cum_offsets",
"encoder_seq_lod", "encoder_seq_lod",
"decoder_seq_lod", "decoder_seq_lod",
"encoder_batch_map", "encoder_batch_map",
@@ -28,7 +28,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx, const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets, const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& stop_flags, const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len, const paddle::Tensor& max_dec_len,
@@ -72,7 +72,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const_cast<int*>(seq_lens_encoder.data<int>()), const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()), const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()), const_cast<int64_t*>(step_idx.data<int64_t>()),
output_cum_offsets.data<int>(), cu_seqlens_q_output.data<int>(),
const_cast<bool*>(stop_flags.data<bool>()), const_cast<bool*>(stop_flags.data<bool>()),
const_cast<bool*>(not_need_stop_device.data<bool>()), const_cast<bool*>(not_need_stop_device.data<bool>()),
max_dec_len.data<int64_t>(), max_dec_len.data<int64_t>(),
@@ -102,7 +102,7 @@ PD_BUILD_STATIC_OP(draft_model_update)
"seq_lens_encoder", "seq_lens_encoder",
"seq_lens_decoder", "seq_lens_decoder",
"step_idx", "step_idx",
"output_cum_offsets", "cu_seqlens_q_output",
"stop_flags", "stop_flags",
"not_need_stop", "not_need_stop",
"max_dec_len", "max_dec_len",
@@ -0,0 +1,133 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
std::vector<paddle::Tensor> SpeculatePreProcess(
const int64_t cpu_token_num,
const paddle::Tensor &input_ids,
const paddle::Tensor &seq_len,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
api::Context *ctx = xpu_ctx->x_context();
// just for ut to run base line
std::unique_ptr<baidu::xpu::api::Context> cpu_ctx;
if (input_ids.place().GetType() == phi::AllocationType::CPU) {
cpu_ctx = std::make_unique<baidu::xpu::api::Context>(baidu::xpu::api::kCPU);
ctx = cpu_ctx.get();
}
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int max_seq_len = input_ids_shape[1];
const int token_num_data = cpu_token_num;
auto ids_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
const int max_draft_tokens_per_batch = draft_tokens.shape()[1];
auto seq_lens_output =
paddle::empty({bsz}, paddle::DataType::INT32, input_ids.place());
auto cu_seq_lens_q_output =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto batch_id_per_token_output =
paddle::empty({bsz * max_draft_tokens_per_batch},
paddle::DataType::INT32,
input_ids.place());
auto real_output_token_num =
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());
if (token_num_data == 0) {
return {ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num};
}
int64_t *ids_remove_padding_ptr = ids_remove_padding.data<int64_t>();
int *batch_id_per_token_ptr = batch_id_per_token.data<int>();
int *cu_seqlens_q_ptr = cu_seqlens_q.data<int>();
int *cu_seqlens_k_ptr = cu_seqlens_k.data<int>();
int *seq_lens_output_ptr = seq_lens_output.data<int>();
int *cu_seq_lens_q_output_ptr = cu_seq_lens_q_output.data<int>();
int *batch_id_per_token_output_ptr = batch_id_per_token_output.data<int>();
int *real_output_token_num_ptr = real_output_token_num.data<int>();
const int64_t *input_data_ptr = input_ids.data<int64_t>();
const int *seq_len_ptr = seq_len.data<int>();
const int64_t *draft_tokens_ptr = draft_tokens.data<int64_t>();
const int *seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
int r =
fastdeploy::plugin::speculate_preprocess(ctx,
ids_remove_padding_ptr,
batch_id_per_token_ptr,
cu_seqlens_q_ptr,
cu_seqlens_k_ptr,
seq_lens_output_ptr,
cu_seq_lens_q_output_ptr,
batch_id_per_token_output_ptr,
real_output_token_num_ptr,
input_data_ptr,
seq_len_ptr,
draft_tokens_ptr,
seq_lens_encoder_ptr,
max_seq_len,
max_draft_tokens_per_batch,
token_num_data,
bsz);
return {ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num};
}
PD_BUILD_STATIC_OP(speculate_pre_process)
.Inputs({"input_ids",
"seq_len",
"draft_tokens",
"seq_lens_encoder",
"seq_lens_decoder"})
.Outputs({"ids_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k",
"cu_seq_lens_q_output",
"batch_id_per_token_output",
"real_output_token_num"})
.Attrs({"cpu_token_num: int64_t"})
.SetKernelFn(PD_KERNEL(SpeculatePreProcess));
@@ -33,8 +33,8 @@ void SpeculateTokenPenaltyMultiScores(
const paddle::Tensor& min_len, const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id, const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset, const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& output_cum_offsets, const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len) { const int max_seq_len) {
namespace api = baidu::xpu::api; namespace api = baidu::xpu::api;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
@@ -72,8 +72,8 @@ void SpeculateTokenPenaltyMultiScores(
min_len.data<int64_t>(), min_len.data<int64_t>(),
eos_token_id.data<int64_t>(), eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(), bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(), batch_id_per_token_output.data<int>(),
output_cum_offsets.data<int>(), cu_seqlens_q_output.data<int>(),
bs, bs,
length, length,
length_id, length_id,
@@ -100,8 +100,8 @@ void SpeculateTokenPenaltyMultiScores(
min_len.data<int64_t>(), min_len.data<int64_t>(),
eos_token_id.data<int64_t>(), eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(), bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(), batch_id_per_token_output.data<int>(),
output_cum_offsets.data<int>(), cu_seqlens_q_output.data<int>(),
bs, bs,
length, length,
length_id, length_id,
@@ -125,8 +125,8 @@ void SpeculateTokenPenaltyMultiScores(
min_len.data<int64_t>(), min_len.data<int64_t>(),
eos_token_id.data<int64_t>(), eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(), bad_tokens.data<int64_t>(),
output_padding_offset.data<int>(), batch_id_per_token_output.data<int>(),
output_cum_offsets.data<int>(), cu_seqlens_q_output.data<int>(),
bs, bs,
length, length,
length_id, length_id,
@@ -157,8 +157,8 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
"min_len", "min_len",
"eos_token_id", "eos_token_id",
"seq_lens_this_time", "seq_lens_this_time",
"output_padding_offset", "batch_id_per_token_output",
"output_cum_offsets"}) "cu_seqlens_q_output"})
.Outputs({"logits_out"}) .Outputs({"logits_out"})
.Attrs({"max_seq_len: int"}) .Attrs({"max_seq_len: int"})
.SetInplaceMap({{"logits", "logits_out"}}) .SetInplaceMap({{"logits", "logits_out"}})
@@ -26,24 +26,24 @@
namespace api = baidu::xpu::api; namespace api = baidu::xpu::api;
void SpeculateVerify(const paddle::Tensor& sampled_token_ids, void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
const paddle::Tensor& accept_tokens, const paddle::Tensor &accept_tokens,
const paddle::Tensor& accept_num, const paddle::Tensor &accept_num,
const paddle::Tensor& step_idx, const paddle::Tensor &step_idx,
const paddle::Tensor& stop_flags, const paddle::Tensor &stop_flags,
const paddle::Tensor& seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor& draft_tokens, const paddle::Tensor &draft_tokens,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor& verify_tokens, const paddle::Tensor &verify_tokens,
const paddle::Tensor& verify_scores, const paddle::Tensor &verify_scores,
const paddle::Tensor& max_dec_len, const paddle::Tensor &max_dec_len,
const paddle::Tensor& end_tokens, const paddle::Tensor &end_tokens,
const paddle::Tensor& is_block_step, const paddle::Tensor &is_block_step,
const paddle::Tensor& output_cum_offsets, const paddle::Tensor &cu_seqlens_q_output,
const paddle::Tensor& actual_candidate_len, const paddle::Tensor &actual_candidate_len,
const paddle::Tensor& actual_draft_token_nums, const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor& topp, const paddle::Tensor &topp,
int max_seq_len, int max_seq_len,
int verify_window, int verify_window,
bool enable_topp, bool enable_topp,
@@ -57,7 +57,8 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context* ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context(); api::Context *ctx =
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
bool xpu_ctx_flag = true; bool xpu_ctx_flag = true;
if (draft_tokens.is_cpu()) { if (draft_tokens.is_cpu()) {
ctx = new api::Context(api::kCPU); ctx = new api::Context(api::kCPU);
@@ -65,17 +66,17 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
} }
bool use_topk = false; bool use_topk = false;
char* env_var = getenv("SPECULATE_VERIFY_USE_TOPK"); char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
if (env_var) { if (env_var) {
use_topk = static_cast<bool>(std::stoi(env_var)); use_topk = static_cast<bool>(std::stoi(env_var));
} }
bool use_target_sampling = false; bool use_target_sampling = false;
char* env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING"); char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
if (env_var_1) { if (env_var_1) {
use_target_sampling = static_cast<bool>(std::stoi(env_var_1)); use_target_sampling = static_cast<bool>(std::stoi(env_var_1));
} }
bool prefill_one_step_stop = false; bool prefill_one_step_stop = false;
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n'; // std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') { if (env_p[0] == '1') {
prefill_one_step_stop = true; prefill_one_step_stop = true;
@@ -90,7 +91,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
std::mt19937_64 engine(infer_seed[i]); std::mt19937_64 engine(infer_seed[i]);
dev_curand_states_cpu.push_back(dist(engine)); dev_curand_states_cpu.push_back(dist(engine));
} }
float* dev_curand_states = dev_curand_states_cpu.data(); float *dev_curand_states = dev_curand_states_cpu.data();
auto dev_curand_states_tensor = auto dev_curand_states_tensor =
paddle::empty({static_cast<int64_t>(dev_curand_states_cpu.size())}, paddle::empty({static_cast<int64_t>(dev_curand_states_cpu.size())},
paddle::DataType::FLOAT32, paddle::DataType::FLOAT32,
@@ -110,10 +111,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
ret = fastdeploy::plugin::speculate_verify<true, true>( ret = fastdeploy::plugin::speculate_verify<true, true>(
ctx, ctx,
sampled_token_ids.data<int64_t>(), sampled_token_ids.data<int64_t>(),
const_cast<int64_t*>(accept_tokens.data<int64_t>()), const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int*>(accept_num.data<int>()), const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()), const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()), const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), draft_tokens.data<int64_t>(),
@@ -126,7 +127,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
max_dec_len.data<int64_t>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), end_tokens.data<int64_t>(),
is_block_step.data<bool>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), cu_seqlens_q_output.data<int>(),
actual_candidate_len.data<int>(), actual_candidate_len.data<int>(),
real_bsz, real_bsz,
max_draft_tokens, max_draft_tokens,
@@ -143,10 +144,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
ret = fastdeploy::plugin::speculate_verify<false, true>( ret = fastdeploy::plugin::speculate_verify<false, true>(
ctx, ctx,
sampled_token_ids.data<int64_t>(), sampled_token_ids.data<int64_t>(),
const_cast<int64_t*>(accept_tokens.data<int64_t>()), const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int*>(accept_num.data<int>()), const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()), const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()), const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), draft_tokens.data<int64_t>(),
@@ -159,7 +160,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
max_dec_len.data<int64_t>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), end_tokens.data<int64_t>(),
is_block_step.data<bool>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), cu_seqlens_q_output.data<int>(),
actual_candidate_len.data<int>(), actual_candidate_len.data<int>(),
real_bsz, real_bsz,
max_draft_tokens, max_draft_tokens,
@@ -178,10 +179,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
ret = fastdeploy::plugin::speculate_verify<true, false>( ret = fastdeploy::plugin::speculate_verify<true, false>(
ctx, ctx,
sampled_token_ids.data<int64_t>(), sampled_token_ids.data<int64_t>(),
const_cast<int64_t*>(accept_tokens.data<int64_t>()), const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int*>(accept_num.data<int>()), const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()), const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()), const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), draft_tokens.data<int64_t>(),
@@ -194,7 +195,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
max_dec_len.data<int64_t>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), end_tokens.data<int64_t>(),
is_block_step.data<bool>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), cu_seqlens_q_output.data<int>(),
actual_candidate_len.data<int>(), actual_candidate_len.data<int>(),
real_bsz, real_bsz,
max_draft_tokens, max_draft_tokens,
@@ -211,10 +212,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
ret = fastdeploy::plugin::speculate_verify<false, false>( ret = fastdeploy::plugin::speculate_verify<false, false>(
ctx, ctx,
sampled_token_ids.data<int64_t>(), sampled_token_ids.data<int64_t>(),
const_cast<int64_t*>(accept_tokens.data<int64_t>()), const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<int*>(accept_num.data<int>()), const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()), const_cast<int64_t *>(step_idx.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()), const_cast<bool *>(stop_flags.data<bool>()),
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(), seq_lens_decoder.data<int>(),
draft_tokens.data<int64_t>(), draft_tokens.data<int64_t>(),
@@ -227,7 +228,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
max_dec_len.data<int64_t>(), max_dec_len.data<int64_t>(),
end_tokens.data<int64_t>(), end_tokens.data<int64_t>(),
is_block_step.data<bool>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), cu_seqlens_q_output.data<int>(),
actual_candidate_len.data<int>(), actual_candidate_len.data<int>(),
real_bsz, real_bsz,
max_draft_tokens, max_draft_tokens,
@@ -262,7 +263,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
"max_dec_len", "max_dec_len",
"end_tokens", "end_tokens",
"is_block_step", "is_block_step",
"output_cum_offsets", "cu_seqlens_q_output",
"actual_candidate_len", "actual_candidate_len",
"actual_draft_token_nums", "actual_draft_token_nums",
"topp"}) "topp"})
@@ -41,7 +41,7 @@ namespace api = baidu::xpu::api;
std::vector<paddle::Tensor> TopPCandidates( std::vector<paddle::Tensor> TopPCandidates(
const paddle::Tensor& probs, const paddle::Tensor& probs,
const paddle::Tensor& top_p, const paddle::Tensor& top_p,
const paddle::Tensor& output_padding_offset, const paddle::Tensor& batch_id_per_token_output,
int candidates_len, int candidates_len,
int max_seq_len) { int max_seq_len) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
@@ -77,7 +77,7 @@ std::vector<paddle::Tensor> TopPCandidates(
ctx, ctx,
reinterpret_cast<const XPUTypeBF16*>(probs.data<bf16_data_t>()), reinterpret_cast<const XPUTypeBF16*>(probs.data<bf16_data_t>()),
reinterpret_cast<const XPUTypeBF16*>(top_p.data<bf16_data_t>()), reinterpret_cast<const XPUTypeBF16*>(top_p.data<bf16_data_t>()),
output_padding_offset.data<int>(), batch_id_per_token_output.data<int>(),
verify_tokens.data<int64_t>(), verify_tokens.data<int64_t>(),
reinterpret_cast<XPUTypeBF16*>( reinterpret_cast<XPUTypeBF16*>(
verify_scores.data<bf16_data_t>()), verify_scores.data<bf16_data_t>()),
@@ -100,7 +100,7 @@ std::vector<paddle::Tensor> TopPCandidates(
ctx, ctx,
reinterpret_cast<const XPUTypeFP16*>(probs.data<fp16_data_t>()), reinterpret_cast<const XPUTypeFP16*>(probs.data<fp16_data_t>()),
reinterpret_cast<const XPUTypeFP16*>(top_p.data<fp16_data_t>()), reinterpret_cast<const XPUTypeFP16*>(top_p.data<fp16_data_t>()),
output_padding_offset.data<int>(), batch_id_per_token_output.data<int>(),
verify_tokens.data<int64_t>(), verify_tokens.data<int64_t>(),
reinterpret_cast<XPUTypeFP16*>( reinterpret_cast<XPUTypeFP16*>(
verify_scores.data<fp16_data_t>()), verify_scores.data<fp16_data_t>()),
@@ -120,7 +120,7 @@ std::vector<paddle::Tensor> TopPCandidates(
ctx, ctx,
probs.data<float>(), probs.data<float>(),
top_p.data<float>(), top_p.data<float>(),
output_padding_offset.data<int>(), batch_id_per_token_output.data<int>(),
verify_tokens.data<int64_t>(), verify_tokens.data<int64_t>(),
verify_scores.data<float>(), verify_scores.data<float>(),
actual_candidate_lens.data<int>(), actual_candidate_lens.data<int>(),
@@ -139,7 +139,7 @@ std::vector<paddle::Tensor> TopPCandidates(
std::vector<std::vector<int64_t>> TopPCandidatesInferShape( std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
const std::vector<int64_t>& probs_shape, const std::vector<int64_t>& probs_shape,
const std::vector<int64_t>& top_p_shape, const std::vector<int64_t>& top_p_shape,
const std::vector<int64_t>& output_padding_offset_shape, const std::vector<int64_t>& batch_id_per_token_output_shape,
int max_candidates_len) { int max_candidates_len) {
int token_num = probs_shape[0]; int token_num = probs_shape[0];
return {{token_num, max_candidates_len}, return {{token_num, max_candidates_len},
@@ -150,12 +150,12 @@ std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
std::vector<paddle::DataType> TopPCandidatesInferDtype( std::vector<paddle::DataType> TopPCandidatesInferDtype(
const paddle::DataType& probs_dtype, const paddle::DataType& probs_dtype,
const paddle::DataType& top_p_dtype, const paddle::DataType& top_p_dtype,
const paddle::DataType& output_padding_offset_dtype) { const paddle::DataType& batch_id_per_token_output_dtype) {
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
} }
PD_BUILD_STATIC_OP(top_p_candidates) PD_BUILD_STATIC_OP(top_p_candidates)
.Inputs({"probs", "top_p", "output_padding_offset"}) .Inputs({"probs", "top_p", "batch_id_per_token_output"})
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"}) .Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
.Attrs({"candidates_len: int", "max_seq_len: int"}) .Attrs({"candidates_len: int", "max_seq_len: int"})
.SetKernelFn(PD_KERNEL(TopPCandidates)) .SetKernelFn(PD_KERNEL(TopPCandidates))
@@ -0,0 +1,149 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <stdio.h>
#include "paddle/common/flags.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "xpu/internal/infra_op.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
namespace api = baidu::xpu::api;
void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &has_running_seqs,
const paddle::Tensor &step_input_ids,
const paddle::Tensor &adaptive_step_input_len,
const paddle::Tensor &step_output_ids,
const paddle::Tensor &step_output_len,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_paused,
const paddle::Tensor &mask_rollback,
const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &end_tokens,
const paddle::Tensor &max_dec_len,
const bool is_naive_mode,
const bool prefill_one_step_stop) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
api::Context *ctx = xpu_ctx->x_context();
// just for ut to run base line
std::unique_ptr<baidu::xpu::api::Context> cpu_ctx;
if (seq_lens_encoder.place().GetType() == phi::AllocationType::CPU) {
cpu_ctx = std::make_unique<baidu::xpu::api::Context>(baidu::xpu::api::kCPU);
ctx = cpu_ctx.get();
}
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
PADDLE_ENFORCE_LE(
max_bsz,
1024,
phi::errors::InvalidArgument(
"unified_update_model_status: max_bsz (%d) must be <= 1024 "
"(single-block launch limit).",
max_bsz));
const int max_step_tokens = step_input_ids.shape()[1];
const int max_model_len = token_ids_all.shape()[1];
const int num_end_tokens = end_tokens.shape()[0];
// has_running_seqs is CPU tensor, need to copy to GPU first
auto has_running_seqs_xpu =
has_running_seqs.copy_to(seq_lens_this_time.place(), false);
int r = fastdeploy::plugin::unified_update_model_status(
ctx,
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(has_running_seqs_xpu.data<bool>()),
const_cast<int *>(mask_rollback.data<int>()),
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
const_cast<int *>(adaptive_step_input_len.data<int>()),
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
const_cast<int *>(step_output_len.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<bool *>(is_paused.data<bool>()),
const_cast<int64_t *>(token_ids_all.data<int64_t>()),
prompt_lens.data<int64_t>(),
const_cast<int64_t *>(step_idx.data<int64_t>()),
end_tokens.data<int64_t>(),
max_dec_len.data<int64_t>(),
real_bsz,
max_bsz,
max_step_tokens,
max_model_len,
num_end_tokens,
is_naive_mode,
prefill_one_step_stop);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "unified_update_model_status");
// Copy result back to CPU
auto has_running_seqs_cpu =
has_running_seqs_xpu.copy_to(has_running_seqs.place(), false);
bool *out_data = const_cast<bool *>(has_running_seqs.data<bool>());
out_data[0] = has_running_seqs_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(unified_update_model_status)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"has_running_seqs",
"step_input_ids",
"adaptive_step_input_len",
"step_output_ids",
"step_output_len",
"stop_flags",
"seq_lens_this_time",
"is_paused",
"mask_rollback",
"token_ids_all",
"prompt_lens",
"step_idx",
"end_tokens",
"max_dec_len"})
.Attrs({"is_naive_mode: bool", "prefill_one_step_stop: bool"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"has_running_seqs_out",
"step_input_ids_out",
"adaptive_step_input_len_out",
"step_output_ids_out",
"step_output_len_out",
"stop_flags_out",
"seq_lens_this_time_out",
"mask_rollback_out",
"token_ids_all_out",
"step_idx_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"has_running_seqs", "has_running_seqs_out"},
{"step_input_ids", "step_input_ids_out"},
{"adaptive_step_input_len", "adaptive_step_input_len_out"},
{"step_output_ids", "step_output_ids_out"},
{"step_output_len", "step_output_len_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"mask_rollback", "mask_rollback_out"},
{"token_ids_all", "token_ids_all_out"},
{"step_idx", "step_idx_out"}})
.SetKernelFn(PD_KERNEL(UnifiedUpdateModelStatus));
+69 -16
View File
@@ -34,8 +34,7 @@ void prof_start();
void prof_stop(); void prof_stop();
std::vector<paddle::Tensor> AdjustBatch( std::vector<paddle::Tensor> AdjustBatch(
const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod, const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_idx, const paddle::Tensor& encoder_batch_idx,
@@ -62,7 +61,6 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::Tensor& qkv, const paddle::Tensor& qkv,
const paddle::Tensor& key_cache, const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache, const paddle::Tensor& value_cache,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& rotary_embs, const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables, const paddle::Tensor& prefix_block_tables,
@@ -210,7 +208,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx, const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets, const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& stop_flags, const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len, const paddle::Tensor& max_dec_len,
@@ -254,8 +252,8 @@ void SpeculateTokenPenaltyMultiScores(
const paddle::Tensor& min_len, const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id, const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& output_padding_offset, const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& output_cum_offsets, const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len); const int max_seq_len);
void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder, void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder,
@@ -413,8 +411,7 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
const paddle::Tensor& step_idx); const paddle::Tensor& step_idx);
std::vector<paddle::Tensor> GatherNextToken( std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& x, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& encoder_seq_lod, const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_batch_map, const paddle::Tensor& encoder_batch_map,
@@ -500,6 +497,14 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& seq_len, const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder); const paddle::Tensor& seq_lens_encoder);
std::vector<paddle::Tensor> SpeculatePreProcess(
const int64_t cpu_token_num,
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_len,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
void StepPaddle(const paddle::Tensor& stop_flags, void StepPaddle(const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& ori_seq_lens_encoder, const paddle::Tensor& ori_seq_lens_encoder,
@@ -540,6 +545,25 @@ void MTPStepPaddle(
const int block_size, const int block_size,
const int max_draft_tokens); const int max_draft_tokens);
void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& has_running_seqs,
const paddle::Tensor& step_input_ids,
const paddle::Tensor& adaptive_step_input_len,
const paddle::Tensor& step_output_ids,
const paddle::Tensor& step_output_len,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& is_paused,
const paddle::Tensor& mask_rollback,
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& end_tokens,
const paddle::Tensor& max_dec_len,
const bool is_naive_mode,
const bool prefill_one_step_stop);
void SpeculateStepPaddle( void SpeculateStepPaddle(
const paddle::Tensor& stop_flags, const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
@@ -682,7 +706,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("adjust_batch", m.def("adjust_batch",
&AdjustBatch, &AdjustBatch,
py::arg("x"), py::arg("x"),
py::arg("cum_offsets"),
py::arg("encoder_seq_lod"), py::arg("encoder_seq_lod"),
py::arg("decoder_seq_lod"), py::arg("decoder_seq_lod"),
py::arg("encoder_batch_idx"), py::arg("encoder_batch_idx"),
@@ -701,7 +724,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("qkv"), py::arg("qkv"),
py::arg("key_cache"), py::arg("key_cache"),
py::arg("value_cache"), py::arg("value_cache"),
py::arg("cum_offsets"),
py::arg("rotary_embs"), py::arg("rotary_embs"),
py::arg("block_tables"), py::arg("block_tables"),
py::arg("prefix_block_tables"), py::arg("prefix_block_tables"),
@@ -812,7 +834,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_encoder"), // 编码器序列长度张量 py::arg("seq_lens_encoder"), // 编码器序列长度张量
py::arg("seq_lens_decoder"), // 解码器序列长度张量 py::arg("seq_lens_decoder"), // 解码器序列长度张量
py::arg("step_idx"), // 步骤索引张量 py::arg("step_idx"), // 步骤索引张量
py::arg("output_cum_offsets"), // 输出累积偏移量张量 py::arg("cu_seqlens_q_output"), // 输出累积偏移量张量
py::arg("stop_flags"), // 停止标志张量 py::arg("stop_flags"), // 停止标志张量
py::arg("not_need_stop"), // 无需停止标志张量 py::arg("not_need_stop"), // 无需停止标志张量
py::arg("max_dec_len"), // 最大解码长度张量 py::arg("max_dec_len"), // 最大解码长度张量
@@ -885,7 +907,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("gather_next_token", m.def("gather_next_token",
&GatherNextToken, &GatherNextToken,
py::arg("x"), py::arg("x"),
py::arg("cum_offsets"),
py::arg("encoder_seq_lod"), py::arg("encoder_seq_lod"),
py::arg("decoder_seq_lod"), py::arg("decoder_seq_lod"),
py::arg("encoder_batch_map"), py::arg("encoder_batch_map"),
@@ -1002,6 +1023,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("redundant_ep_rank_num_plus_one"), py::arg("redundant_ep_rank_num_plus_one"),
"moe export RedundantTopKSelect function"); "moe export RedundantTopKSelect function");
m.def("unified_update_model_status",
&UnifiedUpdateModelStatus,
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("has_running_seqs"),
py::arg("step_input_ids"),
py::arg("adaptive_step_input_len"),
py::arg("step_output_ids"),
py::arg("step_output_len"),
py::arg("stop_flags"),
py::arg("seq_lens_this_time"),
py::arg("is_paused"),
py::arg("mask_rollback"),
py::arg("token_ids_all"),
py::arg("prompt_lens"),
py::arg("step_idx"),
py::arg("end_tokens"),
py::arg("max_dec_len"),
py::arg("is_naive_mode"),
py::arg("max_draft_tokens"),
"Unified update model status");
m.def("mtp_step_paddle", m.def("mtp_step_paddle",
&MTPStepPaddle, &MTPStepPaddle,
py::arg("base_model_stop_flags"), py::arg("base_model_stop_flags"),
@@ -1117,8 +1160,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("min_len"), py::arg("min_len"),
py::arg("eos_token_id"), py::arg("eos_token_id"),
py::arg("seq_lens_this_time"), py::arg("seq_lens_this_time"),
py::arg("output_padding_offset"), py::arg("batch_id_per_token_output"),
py::arg("output_cum_offsets"), py::arg("cu_seqlens_q_output"),
py::arg("max_seq_len"), py::arg("max_seq_len"),
"Applies token penalty with multiple scores"); "Applies token penalty with multiple scores");
@@ -1182,7 +1225,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_dec_len"), py::arg("max_dec_len"),
py::arg("end_tokens"), py::arg("end_tokens"),
py::arg("is_block_step"), py::arg("is_block_step"),
py::arg("output_cum_offsets"), py::arg("cu_seqlens_q_output"),
py::arg("actual_candidate_len"), py::arg("actual_candidate_len"),
py::arg("actual_draft_token_nums"), py::arg("actual_draft_token_nums"),
py::arg("topp"), py::arg("topp"),
@@ -1246,6 +1289,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_seq_len"), py::arg("max_seq_len"),
"Get output padding offset"); "Get output padding offset");
m.def("speculate_pre_process",
&SpeculatePreProcess,
py::arg("cpu_token_num"),
py::arg("input_ids"),
py::arg("seq_len"),
py::arg("draft_tokens"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
"speculate pre process to remove padding and to acquire cu_seq_len");
m.def("speculate_get_padding_offset", m.def("speculate_get_padding_offset",
&SpeculateGetPaddingOffset, &SpeculateGetPaddingOffset,
py::arg("input_ids"), py::arg("input_ids"),
@@ -1419,7 +1472,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&TopPCandidates, &TopPCandidates,
py::arg("probs"), py::arg("probs"),
py::arg("top_p"), py::arg("top_p"),
py::arg("output_padding_offset"), py::arg("batch_id_per_token_output"),
py::arg("candidates_len"), py::arg("candidates_len"),
py::arg("max_seq_len"), py::arg("max_seq_len"),
"Generate top-p candidates based on probability distributions"); "Generate top-p candidates based on probability distributions");
@@ -388,8 +388,8 @@ DLL_EXPORT int speculate_token_penalty_multi_scores(
const int64_t* min_len, const int64_t* min_len,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int64_t* bad_words, const int64_t* bad_words,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -432,7 +432,7 @@ DLL_EXPORT int speculate_verify(api::Context* ctx,
const int64_t* max_dec_len, const int64_t* max_dec_len,
const int64_t* end_tokens, const int64_t* end_tokens,
const bool* is_block_step, const bool* is_block_step,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int* actual_candidate_len, const int* actual_candidate_len,
const int real_bsz, const int real_bsz,
const int max_draft_tokens, const int max_draft_tokens,
@@ -465,7 +465,7 @@ DLL_EXPORT int draft_model_update(api::Context* ctx,
int* seq_lens_encoder, int* seq_lens_encoder,
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
bool* stop_flags, bool* stop_flags,
bool* not_need_stop, bool* not_need_stop,
const int64_t* max_dec_len, const int64_t* max_dec_len,
@@ -574,7 +574,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
DLL_EXPORT int top_p_candidates(api::Context* ctx, DLL_EXPORT int top_p_candidates(api::Context* ctx,
const T* src, const T* src,
const T* top_ps, const T* top_ps,
const int* output_padding_offset, const int* batch_id_per_token_output,
int64_t* out_id, int64_t* out_id,
T* out_val, T* out_val,
int* actual_candidates_lens, int* actual_candidates_lens,
@@ -630,6 +630,24 @@ DLL_EXPORT int speculate_schedule_cache(api::Context* ctx,
const int block_num_per_seq, const int block_num_per_seq,
const bool prefill_one_step_stop); const bool prefill_one_step_stop);
DLL_EXPORT int speculate_preprocess(api::Context* ctx,
int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cu_seqlens_q,
int* cu_seqlens_k,
int* seq_lens_output,
int* cu_seq_lens_q_output,
int* batch_id_per_token_output,
int* real_output_token_num,
const int64_t* input_data,
const int* seq_lens,
const int64_t* draft_tokens,
const int* seq_lens_encoder,
const int max_seq_len,
const int max_draft_tokens_per_batch,
const int token_num_data,
const int real_bs);
DLL_EXPORT int speculate_update_v3(api::Context* ctx, DLL_EXPORT int speculate_update_v3(api::Context* ctx,
int* seq_lens_encoder, int* seq_lens_encoder,
int* seq_lens_decoder, int* seq_lens_decoder,
@@ -662,6 +680,31 @@ DLL_EXPORT int speculate_update(api::Context* ctx,
const int max_bsz, const int max_bsz,
const int max_draft_tokens); const int max_draft_tokens);
DLL_EXPORT int unified_update_model_status(api::Context* ctx,
int* seq_lens_encoder,
int* seq_lens_decoder,
bool* has_running_seqs,
int* mask_rollback,
int64_t* step_input_ids,
int* adaptive_step_input_len,
int64_t* step_output_ids,
int* step_output_len,
bool* stop_flags,
int* seq_lens_this_time,
const bool* is_paused,
int64_t* token_ids_all,
const int64_t* prompt_lens,
int64_t* step_idx,
const int64_t* end_tokens,
const int64_t* max_dec_len,
int real_bsz,
int max_bsz,
int max_step_tokens,
int max_model_len,
int num_end_tokens,
bool is_naive_mode,
bool prefill_one_step_stop);
template <typename T> template <typename T>
DLL_EXPORT int rebuild_hidden_states(api::Context* ctx, DLL_EXPORT int rebuild_hidden_states(api::Context* ctx,
const T* input, const T* input,
@@ -22,7 +22,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens,
int* seq_lens_encoder, int* seq_lens_encoder,
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
bool* stop_flags, bool* stop_flags,
bool* not_need_stop, bool* not_need_stop,
const int64_t* max_dec_len, const int64_t* max_dec_len,
@@ -45,8 +45,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens,
auto* pre_ids_now = pre_ids + tid * pre_id_length; auto* pre_ids_now = pre_ids + tid * pre_id_length;
auto* base_model_draft_tokens_now = auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * max_base_model_draft_token; base_model_draft_tokens + tid * max_base_model_draft_token;
const int next_tokens_start_id = const int next_tokens_start_id = cu_seqlens_q_output[tid];
tid * max_seq_len - output_cum_offsets[tid];
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id; auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
auto seq_len_this_time = seq_lens_this_time[tid]; auto seq_len_this_time = seq_lens_this_time[tid];
auto seq_len_encoder = seq_lens_encoder[tid]; auto seq_len_encoder = seq_lens_encoder[tid];
@@ -72,8 +71,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens,
base_model_draft_tokens_now[substep + 1] = token_this_time; base_model_draft_tokens_now[substep + 1] = token_this_time;
} }
// multi_end // multi_end
if (is_in_end(token_this_time, end_ids, end_ids_len) || if (is_in_end(token_this_time, end_ids, end_ids_len)) {
prefill_one_step_stop) {
stop_flags[tid] = true; stop_flags[tid] = true;
stop_flag_now_int_sm[cid] += 1; stop_flag_now_int_sm[cid] += 1;
// max_dec_len // max_dec_len
@@ -22,7 +22,7 @@ inline __device__ void update_bad_words_logit<float16>(
template <typename T> template <typename T>
__global__ void speculate_ban_bad_words(T* logits, __global__ void speculate_ban_bad_words(T* logits,
const int64_t* bad_words_list, const int64_t* bad_words_list,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t bad_words_length, const int64_t bad_words_length,
@@ -32,7 +32,7 @@ __global__ void speculate_ban_bad_words(T* logits,
int nthreads = cluster_num() * core_num(); int nthreads = cluster_num() * core_num();
int start = -1; int start = -1;
int end = -1; int end = -1;
int output_padding_offset_lm; int batch_id_per_token_output_lm;
partition(tid, partition(tid,
nthreads, nthreads,
static_cast<int>(token_num * bad_words_length), static_cast<int>(token_num * bad_words_length),
@@ -41,10 +41,10 @@ __global__ void speculate_ban_bad_words(T* logits,
&end); &end);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
int token_idx = i / bad_words_length; int token_idx = i / bad_words_length;
GM2LM(output_padding_offset + token_idx, GM2LM(batch_id_per_token_output + token_idx,
&output_padding_offset_lm, &batch_id_per_token_output_lm,
sizeof(int)); sizeof(int));
int bs_idx = (token_idx + output_padding_offset_lm) / max_seq_len; int bs_idx = batch_id_per_token_output_lm;
if (bs_idx >= bs) { if (bs_idx >= bs) {
continue; continue;
} }
@@ -63,7 +63,7 @@ __global__ void speculate_ban_bad_words(T* logits,
template __global__ void speculate_ban_bad_words( \ template __global__ void speculate_ban_bad_words( \
DATA_TYPE* logits, \ DATA_TYPE* logits, \
const int64_t* bad_words_list, \ const int64_t* bad_words_list, \
const int* output_padding_offset, \ const int* batch_id_per_token_output_lm, \
const int64_t bs, \ const int64_t bs, \
const int64_t length, \ const int64_t length, \
const int64_t bad_words_length, \ const int64_t bad_words_length, \
@@ -11,8 +11,8 @@ __global__ void speculate_min_length_logits_process(
const int64_t* cur_len, const int64_t* cur_len,
const int64_t* min_len, const int64_t* min_len,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -29,26 +29,26 @@ __global__ void speculate_min_length_logits_process(
int64_t eos_token_id_now; int64_t eos_token_id_now;
int64_t bi; int64_t bi;
int64_t end_num; int64_t end_num;
int output_padding_offset_now; int batch_id_per_token_output_now;
int output_cum_offsets_now; int cu_seqlens_q_output_now;
__simd__ float float32logits_now[32]; __simd__ float float32logits_now[32];
for (int64_t i = tid; i < token_num * end_length; i += nthreads) { for (int64_t i = tid; i < token_num * end_length; i += nthreads) {
int64_t token_idx = i / end_length; int64_t token_idx = i / end_length;
GM2LM(output_padding_offset + token_idx, GM2LM(batch_id_per_token_output + token_idx,
&output_padding_offset_now, &batch_id_per_token_output_now,
sizeof(int)); sizeof(int));
bi = (token_idx + output_padding_offset_now) / max_seq_len; bi = batch_id_per_token_output[token_idx];
if (bi >= bs) { if (bi >= bs) {
continue; continue;
} }
end_num = i % end_length; end_num = i % end_length;
GM2LM_ASYNC( GM2LM_ASYNC(
output_cum_offsets + bi, (void*)&output_cum_offsets_now, sizeof(int)); cu_seqlens_q_output + bi, (void*)&cu_seqlens_q_output_now, sizeof(int));
GM2LM_ASYNC(cur_len + bi, (void*)&(cur_len_now), sizeof(int64_t)); GM2LM_ASYNC(cur_len + bi, (void*)&(cur_len_now), sizeof(int64_t));
GM2LM_ASYNC(min_len + bi, (void*)&(min_len_now), sizeof(int64_t)); GM2LM_ASYNC(min_len + bi, (void*)&(min_len_now), sizeof(int64_t));
mfence(); mfence();
int query_start_token_idx = bi * max_seq_len - output_cum_offsets_now; int query_start_token_idx = cu_seqlens_q_output_now;
if (cur_len_now >= 0 && if (cur_len_now >= 0 &&
(cur_len_now + (token_idx - query_start_token_idx) < min_len_now)) { (cur_len_now + (token_idx - query_start_token_idx) < min_len_now)) {
GM2LM( GM2LM(
@@ -74,8 +74,8 @@ __global__ void speculate_min_length_logits_process(
const int64_t* cur_len, \ const int64_t* cur_len, \
const int64_t* min_len, \ const int64_t* min_len, \
const int64_t* eos_token_id, \ const int64_t* eos_token_id, \
const int* output_padding_offset, \ const int* batch_id_per_token_output, \
const int* output_cum_offsets, \ const int* cu_seqlens_q_output, \
const int64_t bs, \ const int64_t bs, \
const int64_t length, \ const int64_t length, \
const int64_t length_id, \ const int64_t length_id, \
@@ -0,0 +1,157 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_debug.h"
namespace fd_xpu3 {
#define MAX_BATCH_SIZE 1024
static inline __device__ int v_reduce_sum_int32(int32x16_t& v0) {
auto v1 = vsrlp_int32x16(1 << 8, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 7, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 6, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 5, v0);
v0 = vvadd_int32x16(v0, v1);
return vextract_int32x16(v0, 1);
}
inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int* x,
int64_t len) {
int32x16_t x_l, x_h;
int32x16_t sum = vset_zero_int();
const auto rounddown_len = rounddown32(len);
for (int64_t i = 0; i < rounddown_len; i += 32) {
vload2_sm(x + i, x_l, x_h);
sum = vvadd_int32x16(sum, x_l);
sum = vvadd_int32x16(sum, x_h);
}
if (rounddown_len < len) {
const auto mask = ~(-1 << (len - rounddown_len));
vload2_sm_mz(x + rounddown_len, x_l, x_h, mask);
sum = vvadd_int32x16(sum, x_l);
sum = vvadd_int32x16(sum, x_h);
}
return v_reduce_sum_int32(sum);
}
__global__ void speculate_preprocess_kernel(
int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cu_seqlens_q,
int* cu_seqlens_k,
int* seq_lens_output,
int* cu_seq_lens_q_output,
int* batch_id_per_token_output,
int* real_output_token_num,
const int64_t* input_data,
const int* seq_lens,
const int64_t* draft_tokens,
const int* seq_lens_encoder,
const int max_seq_len,
const int max_draft_tokens_per_batch,
const int real_bs) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
__shared__ int sm_seq_lens[MAX_BATCH_SIZE];
__shared__ int sm_seq_lens_output[MAX_BATCH_SIZE];
__shared__ int sm_seq_lens_encoder[MAX_BATCH_SIZE];
__shared__ int sm_cum_seq_len, sm_cum_seq_len_output;
__simd__ __shared__ int buffer_cu_seqlens[64];
__simd__ __shared__ int buffer_cu_seqlens_output[64];
if (cid == 0) {
GM2SM_ASYNC(seq_lens, sm_seq_lens, sizeof(int) * real_bs);
GM2SM(seq_lens_encoder, sm_seq_lens_encoder, sizeof(int) * real_bs);
}
sync_all();
for (int bid = cid; bid < real_bs; bid += ncores) {
if (sm_seq_lens[bid] == 0) {
sm_seq_lens_output[bid] = 0;
} else if (sm_seq_lens[bid] == 1) {
sm_seq_lens_output[bid] = 1;
} else if (sm_seq_lens_encoder[bid] != 0) {
sm_seq_lens_output[bid] = 1;
} else {
sm_seq_lens_output[bid] = sm_seq_lens[bid];
}
}
mfence_sm();
sync_all();
for (int bi = clusterid; bi < real_bs; bi += nclusters) {
int cum_seq_len = 0;
int cum_seq_len_output = 0;
for (int i = cid; i < bi + 1; i += ncores) {
cum_seq_len += sm_seq_lens[i];
cum_seq_len_output += sm_seq_lens_output[i];
}
buffer_cu_seqlens[cid] = cum_seq_len;
buffer_cu_seqlens_output[cid] = cum_seq_len_output;
mfence();
sync_all();
if (cid == 0) {
cum_seq_len =
primitive_reduce_sum_sm(buffer_cu_seqlens, min(bi + 1, ncores));
cum_seq_len_output = primitive_reduce_sum_sm(buffer_cu_seqlens_output,
min(bi + 1, ncores));
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int));
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int));
LM2GM_ASYNC(
&cum_seq_len_output, cu_seq_lens_q_output + bi + 1, sizeof(int));
if (bi == real_bs - 1) {
LM2GM_ASYNC(&cum_seq_len_output, real_output_token_num, sizeof(int));
}
sm_cum_seq_len = cum_seq_len;
sm_cum_seq_len_output = cum_seq_len_output;
}
mfence();
sync_all();
const int lm_seq_lens = sm_seq_lens[bi];
const int lm_seq_lens_encoder = sm_seq_lens_encoder[bi];
for (int i = cid; i < lm_seq_lens; i += ncores) {
const int tgt_seq_id = sm_cum_seq_len - lm_seq_lens + i;
if (max_draft_tokens_per_batch > 0 && lm_seq_lens_encoder <= 0) {
// speculative decoding
const int src_seq_id = bi * max_draft_tokens_per_batch + i;
int64_t lm_draft_tokens;
GM2LM(draft_tokens + src_seq_id, &lm_draft_tokens, sizeof(int64_t));
LM2GM(
&lm_draft_tokens, ids_remove_padding + tgt_seq_id, sizeof(int64_t));
} else {
// Non-speculative decoding
const int src_seq_id = bi * max_seq_len + i;
int64_t lm_input_data;
GM2LM(input_data + src_seq_id, &lm_input_data, sizeof(int64_t));
LM2GM(&lm_input_data, ids_remove_padding + tgt_seq_id, sizeof(int64_t));
}
LM2GM(&bi, batch_id_per_token + tgt_seq_id, sizeof(int));
}
const int lm_seq_lens_output = sm_seq_lens_output[bi];
for (int i = cid; i < lm_seq_lens_output; i += ncores) {
const int tgt_seq_id_output =
sm_cum_seq_len_output - lm_seq_lens_output + i;
LM2GM(&bi, batch_id_per_token_output + tgt_seq_id_output, sizeof(int));
}
mfence();
sync_all();
}
if (cid == 0 && clusterid == 0) {
const int lm_zero = 0;
LM2GM_ASYNC(&lm_zero, cu_seqlens_q, sizeof(int));
LM2GM_ASYNC(&lm_zero, cu_seqlens_k, sizeof(int));
LM2GM(&lm_zero, cu_seq_lens_q_output, sizeof(int));
}
}
} // namespace fd_xpu3
@@ -22,7 +22,7 @@ __device__ void speculate_update_repeat_times_normal(
__global_ptr__ const int64_t *pre_ids, __global_ptr__ const int64_t *pre_ids,
__global_ptr__ const int64_t *cur_len, __global_ptr__ const int64_t *cur_len,
__global_ptr__ int *repeat_times, __global_ptr__ int *repeat_times,
__global_ptr__ const int *output_padding_offset, __global_ptr__ const int *batch_id_per_token_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -40,15 +40,17 @@ __device__ void speculate_update_repeat_times_normal(
int n_length = (length + max_sm_len - 1) / max_sm_len; int n_length = (length + max_sm_len - 1) / max_sm_len;
int64_t *cur_len_lm = (int64_t *)lm; int64_t *cur_len_lm = (int64_t *)lm;
int output_padding_offset_now; int batch_id_per_token_output_now;
GM2LM(cur_len, cur_len_lm, bs * sizeof(int64_t)); GM2LM(cur_len, cur_len_lm, bs * sizeof(int64_t));
for (int nli = 0; nli < n_length; nli++) { for (int nli = 0; nli < n_length; nli++) {
int step = nli * max_sm_len; int step = nli * max_sm_len;
int cur_length = min(max_sm_len, length - step); int cur_length = min(max_sm_len, length - step);
for (int64_t i = clusterid; i < token_num; i += nclusters) { for (int64_t i = clusterid; i < token_num; i += nclusters) {
GM2LM(output_padding_offset + i, &output_padding_offset_now, sizeof(int)); GM2LM(batch_id_per_token_output + i,
int64_t bi = (i + output_padding_offset_now) / max_seq_len; &batch_id_per_token_output_now,
sizeof(int));
int64_t bi = batch_id_per_token_output_now;
if (bi >= bs || cur_len_lm[bi] < 0) { if (bi >= bs || cur_len_lm[bi] < 0) {
continue; continue;
} }
@@ -86,10 +88,10 @@ __device__ void speculate_update_repeat_times_normal(
__device__ void speculate_update_repeat_times_optimized( __device__ void speculate_update_repeat_times_optimized(
char *lm, char *lm,
__shared_ptr__ char *sm, __shared_ptr__ char *sm,
__global_ptr__ const int64_t *pre_ids, // {bs, length_id} __global_ptr__ const int64_t *pre_ids, // {bs, length_id}
__global_ptr__ const int64_t *cur_len, // {bs} __global_ptr__ const int64_t *cur_len, // {bs}
__global_ptr__ int *repeat_times, // {token_num, length} __global_ptr__ int *repeat_times, // {token_num, length}
__global_ptr__ const int *output_padding_offset, // {token_num} __global_ptr__ const int *batch_id_per_token_output, // {token_num}
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -108,10 +110,10 @@ __device__ void speculate_update_repeat_times_optimized(
int cur_len_sm_len = 640; int cur_len_sm_len = 640;
__shared_ptr__ int64_t *cur_len_sm = __shared_ptr__ int64_t *cur_len_sm =
(__shared_ptr__ int64_t *)(repeat_times_sm + repeat_times_sm_len); (__shared_ptr__ int64_t *)(repeat_times_sm + repeat_times_sm_len);
__shared_ptr__ int *output_padding_offset_sm = __shared_ptr__ int *batch_id_per_token_output_sm =
(__shared_ptr__ int *)(cur_len_sm + cur_len_sm_len); (__shared_ptr__ int *)(cur_len_sm + cur_len_sm_len);
DoublePtr<1, SmPtr<int>> buffer_ptr_output_padding_offset( DoublePtr<1, SmPtr<int>> buffer_ptr_batch_id_per_token_output(
(SmPtr<int>(output_padding_offset_sm))); (SmPtr<int>(batch_id_per_token_output_sm)));
int pre_ids_lm_len = 4; int pre_ids_lm_len = 4;
int64_t *pre_ids_lm = (int64_t *)lm; int64_t *pre_ids_lm = (int64_t *)lm;
DoublePtr<4, LmPtr<int64_t>> buffer_ptr_pre_ids((LmPtr<int64_t>(pre_ids_lm))); DoublePtr<4, LmPtr<int64_t>> buffer_ptr_pre_ids((LmPtr<int64_t>(pre_ids_lm)));
@@ -119,18 +121,18 @@ __device__ void speculate_update_repeat_times_optimized(
int64_t i = clusterid; int64_t i = clusterid;
if (i < token_num && cid == 0) { if (i < token_num && cid == 0) {
GM2SM_ASYNC(cur_len, cur_len_sm, bs * sizeof(int64_t)); GM2SM_ASYNC(cur_len, cur_len_sm, bs * sizeof(int64_t));
buffer_ptr_output_padding_offset.gm_load_async(output_padding_offset + i, buffer_ptr_batch_id_per_token_output.gm_load_async(
1); batch_id_per_token_output + i, 1);
mfence_sm(); mfence_sm();
} }
sync_all(); sync_all();
for (; i < token_num; i += nclusters) { for (; i < token_num; i += nclusters) {
if (cid == 0 && i + nclusters < token_num) { if (cid == 0 && i + nclusters < token_num) {
buffer_ptr_output_padding_offset.next().gm_load_async( buffer_ptr_batch_id_per_token_output.next().gm_load_async(
output_padding_offset + i + nclusters, 1); batch_id_per_token_output + i + nclusters, 1);
} }
int64_t bi = (i + (buffer_ptr_output_padding_offset.ptr[0])) / max_seq_len; int64_t bi = buffer_ptr_batch_id_per_token_output.ptr[0];
buffer_ptr_output_padding_offset.toggle(); buffer_ptr_batch_id_per_token_output.toggle();
if (bi >= bs || cur_len_sm[bi] < 0) { if (bi >= bs || cur_len_sm[bi] < 0) {
mfence_sm(); mfence_sm();
sync_all(); sync_all();
@@ -224,15 +226,16 @@ __device__ void speculate_update_repeat_times_optimized(
} }
} }
__global__ void speculate_update_repeat_times(const int64_t *pre_ids, __global__ void speculate_update_repeat_times(
const int64_t *cur_len, const int64_t *pre_ids,
int *repeat_times, const int64_t *cur_len,
const int *output_padding_offset, int *repeat_times,
const int64_t bs, const int *batch_id_per_token_output,
const int64_t length, const int64_t bs,
const int64_t length_id, const int64_t length,
const int64_t token_num, const int64_t length_id,
const int64_t max_seq_len) { const int64_t token_num,
const int64_t max_seq_len) {
char lm[6 * 1024]; char lm[6 * 1024];
__shared__ char sm[256 * 1024]; __shared__ char sm[256 * 1024];
@@ -242,7 +245,7 @@ __global__ void speculate_update_repeat_times(const int64_t *pre_ids,
pre_ids, pre_ids,
cur_len, cur_len,
repeat_times, repeat_times,
output_padding_offset, batch_id_per_token_output,
bs, bs,
length, length,
length_id, length_id,
@@ -254,7 +257,7 @@ __global__ void speculate_update_repeat_times(const int64_t *pre_ids,
pre_ids, pre_ids,
cur_len, cur_len,
repeat_times, repeat_times,
output_padding_offset, batch_id_per_token_output,
bs, bs,
length, length,
length_id, length_id,
@@ -25,7 +25,7 @@ __global__ void speculate_update_value_by_repeat_times(
const T *presence_score, const T *presence_score,
const float *temperatures, const float *temperatures,
T *logits, T *logits,
const int *output_padding_offset, const int *batch_id_per_token_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t token_num, const int64_t token_num,
@@ -46,17 +46,16 @@ __global__ void speculate_update_value_by_repeat_times(
if (token_end >= token_num) { if (token_end >= token_num) {
token_end = token_num - 1; token_end = token_num - 1;
} }
int output_padding_offset_start_lm; int batch_id_per_token_output_start_lm;
int output_padding_offset_end_lm; int batch_id_per_token_output_end_lm;
GM2LM_ASYNC(output_padding_offset + token_start, GM2LM_ASYNC(batch_id_per_token_output + token_start,
(void *)&output_padding_offset_start_lm, (void *)&batch_id_per_token_output_start_lm,
sizeof(int)); sizeof(int));
GM2LM(output_padding_offset + token_end, GM2LM(batch_id_per_token_output + token_end,
(void *)&output_padding_offset_end_lm, (void *)&batch_id_per_token_output_end_lm,
sizeof(int)); sizeof(int));
int64_t bs_start = int64_t bs_start = batch_id_per_token_output_start_lm;
(token_start + output_padding_offset_start_lm) / max_seq_len; int64_t bs_end = batch_id_per_token_output_end_lm;
int64_t bs_end = (token_end + output_padding_offset_end_lm) / max_seq_len;
const int param_len = 256; const int param_len = 256;
// ncores = 64 for xpu2 // ncores = 64 for xpu2
__shared__ __simd__ float alpha_buf[param_len * 64]; __shared__ __simd__ float alpha_buf[param_len * 64];
@@ -89,13 +88,13 @@ __global__ void speculate_update_value_by_repeat_times(
const int buffer_len = 512; const int buffer_len = 512;
__simd__ float logits_lm[buffer_len]; __simd__ float logits_lm[buffer_len];
int times_lm[buffer_len]; int times_lm[buffer_len];
int output_padding_offset_lm[buffer_len]; int batch_id_per_token_output_lm[buffer_len];
for (int64_t i = start; i < end; i += buffer_len) { for (int64_t i = start; i < end; i += buffer_len) {
int read_len = min(end - i, buffer_len); int read_len = min(end - i, buffer_len);
GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T)); GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T));
GM2LM_ASYNC(output_padding_offset + i / length, GM2LM_ASYNC(batch_id_per_token_output + i / length,
output_padding_offset_lm, batch_id_per_token_output_lm,
((read_len + length - 1) / length + 1) * sizeof(int)); ((read_len + length - 1) / length + 1) * sizeof(int));
GM2LM(repeat_times + i, times_lm, read_len * sizeof(int)); GM2LM(repeat_times + i, times_lm, read_len * sizeof(int));
primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len); primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len);
@@ -104,7 +103,7 @@ __global__ void speculate_update_value_by_repeat_times(
logit_now = logits_lm[j]; logit_now = logits_lm[j];
int token_idx = (i + j) / length; int token_idx = (i + j) / length;
int bs_idx = int bs_idx =
(token_idx + output_padding_offset_lm[token_idx - i / length]) / (token_idx + batch_id_per_token_output_lm[token_idx - i / length]) /
max_seq_len; max_seq_len;
if (bs_idx >= bs) { if (bs_idx >= bs) {
continue; continue;
@@ -134,7 +133,7 @@ __global__ void speculate_update_value_by_repeat_times(
const DATA_TYPE *presence_score, \ const DATA_TYPE *presence_score, \
const float *temperatures, \ const float *temperatures, \
DATA_TYPE *logits, \ DATA_TYPE *logits, \
const int *output_padding_offset, \ const int *batch_id_per_token_output, \
const int64_t bs, \ const int64_t bs, \
const int64_t length, \ const int64_t length, \
const int64_t token_num, \ const int64_t token_num, \
@@ -151,7 +150,7 @@ __global__ void speculate_update_value_by_repeat_times_simd(
const T *presence_score, // [bs] const T *presence_score, // [bs]
const float *temperatures, // [bs] const float *temperatures, // [bs]
T *logits, // [bs * length] T *logits, // [bs * length]
const int *output_padding_offset, const int *batch_id_per_token_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t token_num, const int64_t token_num,
@@ -198,7 +197,7 @@ __global__ void speculate_update_value_by_repeat_times_simd(
const int buffer_len = 512; const int buffer_len = 512;
__simd__ float logits_lm[buffer_len]; __simd__ float logits_lm[buffer_len];
__simd__ float times_lm[buffer_len]; __simd__ float times_lm[buffer_len];
int output_padding_offset_lm[buffer_len]; int batch_id_per_token_output_lm[buffer_len];
float32x16_t logits_; float32x16_t logits_;
float32x16_t logits_tmp_0; float32x16_t logits_tmp_0;
@@ -208,8 +207,8 @@ __global__ void speculate_update_value_by_repeat_times_simd(
for (int64_t i = start; i < end; i += buffer_len) { for (int64_t i = start; i < end; i += buffer_len) {
int read_len = min(end - i, buffer_len); int read_len = min(end - i, buffer_len);
GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T)); GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T));
GM2LM_ASYNC(output_padding_offset + i / length, GM2LM_ASYNC(batch_id_per_token_output + i / length,
output_padding_offset_lm, batch_id_per_token_output_lm,
((read_len + length - 1) / length + 1) * sizeof(int)); ((read_len + length - 1) / length + 1) * sizeof(int));
GM2LM(repeat_times + i, times_lm, read_len * sizeof(int)); GM2LM(repeat_times + i, times_lm, read_len * sizeof(int));
primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len); primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len);
@@ -220,9 +219,7 @@ __global__ void speculate_update_value_by_repeat_times_simd(
time_ = vload_lm_float32x16(times_lm + j); time_ = vload_lm_float32x16(times_lm + j);
logits_ = vload_lm_float32x16(logits_lm + j); logits_ = vload_lm_float32x16(logits_lm + j);
int token_idx = (i + j) / length; int token_idx = (i + j) / length;
int bs_idx = int bs_idx = batch_id_per_token_output_lm[token_idx - i / length];
(token_idx + output_padding_offset_lm[token_idx - i / length]) /
max_seq_len;
if (bs_idx >= bs) { if (bs_idx >= bs) {
continue; continue;
} }
@@ -269,7 +266,7 @@ __global__ void speculate_update_value_by_repeat_times_simd(
const DATA_TYPE *presence_score, \ const DATA_TYPE *presence_score, \
const float *temperatures, \ const float *temperatures, \
DATA_TYPE *logits, \ DATA_TYPE *logits, \
const int *output_padding_offset, \ const int *batch_id_per_token_output, \
const int64_t bs, \ const int64_t bs, \
const int64_t length, \ const int64_t length, \
const int64_t token_num, \ const int64_t token_num, \
@@ -95,50 +95,31 @@ topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids,
template <bool ENABLE_TOPP, bool USE_TOPK> template <bool ENABLE_TOPP, bool USE_TOPK>
__global__ void speculate_verify( __global__ void speculate_verify(
const int64_t *sampled_token_ids, const int64_t *sampled_token_ids,
int64_t *accept_tokens, // out [real_bsz, max_draft_tokens], 输出最终接收的 int64_t *accept_tokens, // out [real_bsz, max_draft_tokens]
// token(通过验证或采样) int *accept_num, // out [real_bsz],
int *accept_num, // out [real_bsz], 每个序列最终接受的 token int64_t *step_idx, // out [real_bsz],
// 数量(只统计通过验证的) bool *stop_flags, // out [real_bsz],
int64_t const int *seq_lens_encoder, // [real_bsz]
*step_idx, // out [real_bsz], 记录每个bid序列已经生成或接受的token数 const int *seq_lens_decoder, // [real_bsz]
bool *stop_flags, // out [real_bsz], 每个序列的停止标志,遇到 <eos> const int64_t *draft_tokens, // [real_bsz, max_draft_tokens],
// 或长度超限时置 true const int *actual_draft_token_nums, // [real_bsz], 实际有效的 token 数量
const int *seq_lens_encoder, // [real_bsz], 每个样本 encoder
// 输入长度,用于判断 prefill 阶段
const int *seq_lens_decoder, // [real_bsz], 每个样本 decoder 输出的 token
// 数(即 draft token 数)
const int64_t *
draft_tokens, // [real_bsz, max_draft_tokens], draft model 输出的 token
const int *actual_draft_token_nums, // [real_bsz], draft_tokens
// 中实际有效的 token 数量
const float *dev_curand_states, // used for random const float *dev_curand_states, // used for random
const float *topp, // [real_bsz]TopP 阈值(如 const float *topp, // [real_bsz]
// 0.9),用于控制核采样截断概率和候选数 const int *seq_lens_this_time, // [real_bsz],
const int *seq_lens_this_time, // [real_bsz], 本轮 verify
// 阶段每个样本实际参与验证的 token 数
const int64_t const int64_t
*verify_tokens, // [sum(seq_lens_this_time), max_candidate_len], verify *verify_tokens, // [sum(seq_lens_this_time), max_candidate_len]
// decoder 输出的候选 token const float *verify_scores,
const float
*verify_scores, // 同上, 每个 verify token 对应的概率分布,用于采样
const int64_t *max_dec_len, // [real_bsz], const int64_t *max_dec_len, // [real_bsz],
// 每个样本允许生成的最大长度(超过则触发终止) const int64_t *end_tokens, // [end_length]
const int64_t const bool *is_block_step, // [real_bsz],
*end_tokens, // [end_length], 终止 token 列表(如 <eos>),命中即终止 const int *cu_seqlens_q_output,
const bool *is_block_step, // [real_bsz], 指示是否当前为 block step(为 const int *actual_candidate_len, // [sum(seq_lens_this_time)],
// true 时跳过 verify const int real_bsz, // batch size
const int const int max_draft_tokens,
*output_cum_offsets, // [real_bsz], verify_tokens 的起始偏移,用于定位
// token 所在 verify 索引
const int *actual_candidate_len, // [sum(seq_lens_this_time)], 每个 verify
// token 实际可用候选数(用于 TopP 截断)
const int real_bsz, // batch size
const int max_draft_tokens, // scalar, 每个样本最多允许的 draft token 数
const int end_length, const int end_length,
const int max_seq_len, // scalar, 每个序列的最大 token 数(用于偏移计算) const int max_seq_len,
const int max_candidate_len, // scalar, 每个 verify token const int max_candidate_len,
// 的最大候选数(用于验证或采样) const int verify_window,
const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数)
const bool prefill_one_step_stop, const bool prefill_one_step_stop,
const bool benchmark_mode, const bool benchmark_mode,
const bool accept_all_drafts, const bool accept_all_drafts,
@@ -151,7 +132,7 @@ __global__ void speculate_verify(
if (is_block_step[bid]) { if (is_block_step[bid]) {
continue; continue;
} }
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; const int start_token_id = cu_seqlens_q_output[bid];
if (stop_flags[bid]) { if (stop_flags[bid]) {
stop_flag_now_int = 1; stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入,但是因为draft } else { // 这里prefill阶段也会进入,但是因为draft
@@ -10,7 +10,7 @@ __device__ void top_p_candidates_big_n(
char* lm, char* lm,
__global_ptr__ const T* src, __global_ptr__ const T* src,
__global_ptr__ const T* top_ps, __global_ptr__ const T* top_ps,
__global_ptr__ const int* output_padding_offset, __global_ptr__ const int* batch_id_per_token_output,
__global_ptr__ int64_t* out_id, __global_ptr__ int64_t* out_id,
__global_ptr__ T* out_val, __global_ptr__ T* out_val,
__global_ptr__ int* actual_candidates_lens, __global_ptr__ int* actual_candidates_lens,
@@ -32,11 +32,13 @@ __device__ void top_p_candidates_big_n(
__shared__ T sm_out_val[64 * TopPBeamTopK]; __shared__ T sm_out_val[64 * TopPBeamTopK];
// only used in core 0 // only used in core 0
int lm_output_padding_offset; int lm_batch_id_per_token_output;
for (int64_t i = cluster_id(); i < token_num; i += cluster_num()) { for (int64_t i = cluster_id(); i < token_num; i += cluster_num()) {
if (cid == 0) { if (cid == 0) {
GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int)); GM2LM(batch_id_per_token_output + i,
&lm_batch_id_per_token_output,
sizeof(int));
} }
for (int64_t j = 0; j < TopPBeamTopK; j++) { for (int64_t j = 0; j < TopPBeamTopK; j++) {
lm_out_id[j] = -1; lm_out_id[j] = -1;
@@ -142,8 +144,7 @@ __device__ void top_p_candidates_big_n(
} }
} }
int ori_token_id = i + lm_output_padding_offset; int bid = lm_batch_id_per_token_output;
int bid = ori_token_id / max_seq_len;
T lm_top_p; T lm_top_p;
GM2LM(top_ps + bid, &lm_top_p, sizeof(T)); GM2LM(top_ps + bid, &lm_top_p, sizeof(T));
float top_p_value = static_cast<float>(lm_top_p); float top_p_value = static_cast<float>(lm_top_p);
@@ -182,7 +183,7 @@ __device__ void top_p_candidates_normal(
char* lm, char* lm,
__global_ptr__ const T* src, __global_ptr__ const T* src,
__global_ptr__ const T* top_ps, __global_ptr__ const T* top_ps,
__global_ptr__ const int* output_padding_offset, __global_ptr__ const int* batch_id_per_token_output,
__global_ptr__ int64_t* out_id, __global_ptr__ int64_t* out_id,
__global_ptr__ T* out_val, __global_ptr__ T* out_val,
__global_ptr__ int* actual_candidates_lens, __global_ptr__ int* actual_candidates_lens,
@@ -200,7 +201,7 @@ __device__ void top_p_candidates_normal(
int64_t lm_out_id[TopPBeamTopK]; int64_t lm_out_id[TopPBeamTopK];
T lm_out_val[TopPBeamTopK]; T lm_out_val[TopPBeamTopK];
int lm_output_padding_offset; int lm_batch_id_per_token_output;
T lm_top_p; T lm_top_p;
int64_t default_id = 0; int64_t default_id = 0;
T default_val = static_cast<T>(0.f); T default_val = static_cast<T>(0.f);
@@ -236,9 +237,10 @@ __device__ void top_p_candidates_normal(
} }
mfence_lm(); mfence_lm();
} }
GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int)); GM2LM(batch_id_per_token_output + i,
int ori_token_id = i + lm_output_padding_offset; &lm_batch_id_per_token_output,
int bid = ori_token_id / max_seq_len; sizeof(int));
int bid = lm_batch_id_per_token_output;
GM2LM(top_ps + bid, &lm_top_p, sizeof(T)); GM2LM(top_ps + bid, &lm_top_p, sizeof(T));
float top_p_value = static_cast<float>(lm_top_p); float top_p_value = static_cast<float>(lm_top_p);
bool set_to_default_val = false; bool set_to_default_val = false;
@@ -272,7 +274,7 @@ __device__ void top_p_candidates_normal(
template <typename T, int MaxLength, int TopPBeamTopK> template <typename T, int MaxLength, int TopPBeamTopK>
__global__ void top_p_candidates(const T* src, __global__ void top_p_candidates(const T* src,
const T* top_ps, const T* top_ps,
const int* output_padding_offset, const int* batch_id_per_token_output,
int64_t* out_id, int64_t* out_id,
T* out_val, T* out_val,
int* actual_candidates_lens, int* actual_candidates_lens,
@@ -284,29 +286,31 @@ __global__ void top_p_candidates(const T* src,
if (token_num % (core_num() * cluster_num()) != 0 && if (token_num % (core_num() * cluster_num()) != 0 &&
vocab_size >= core_num() * (6 * 1024 / sizeof(T)) && vocab_size >= core_num() * (6 * 1024 / sizeof(T)) &&
vocab_size >= core_num() * TopPBeamTopK) { vocab_size >= core_num() * TopPBeamTopK) {
top_p_candidates_big_n<T, MaxLength, TopPBeamTopK>(lm, top_p_candidates_big_n<T, MaxLength, TopPBeamTopK>(
src, lm,
top_ps, src,
output_padding_offset, top_ps,
out_id, batch_id_per_token_output,
out_val, out_id,
actual_candidates_lens, out_val,
vocab_size, actual_candidates_lens,
token_num, vocab_size,
max_cadidate_len, token_num,
max_seq_len); max_cadidate_len,
max_seq_len);
} else { } else {
top_p_candidates_normal<T, MaxLength, TopPBeamTopK>(lm, top_p_candidates_normal<T, MaxLength, TopPBeamTopK>(
src, lm,
top_ps, src,
output_padding_offset, top_ps,
out_id, batch_id_per_token_output,
out_val, out_id,
actual_candidates_lens, out_val,
vocab_size, actual_candidates_lens,
token_num, vocab_size,
max_cadidate_len, token_num,
max_seq_len); max_cadidate_len,
max_seq_len);
} }
} }
@@ -314,7 +318,7 @@ __global__ void top_p_candidates(const T* src,
template __global__ void top_p_candidates<T, MaxLength, TopPBeamTopK>( \ template __global__ void top_p_candidates<T, MaxLength, TopPBeamTopK>( \
const T* src, \ const T* src, \
const T* top_ps, \ const T* top_ps, \
const int* output_padding_offset, \ const int* batch_id_per_token_output, \
int64_t* out_id, \ int64_t* out_id, \
T* out_val, \ T* out_val, \
int* actual_candidates_lens, \ int* actual_candidates_lens, \
@@ -0,0 +1,217 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_debug.h"
namespace fd_xpu3 {
#define MAX_BATCH_SIZE 1024
static inline __device__ int v_reduce_sum_int32(int32x16_t &v0) {
auto v1 = vsrlp_int32x16(1 << 8, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 7, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 6, v0);
v0 = vvadd_int32x16(v0, v1);
v1 = vsrlp_int32x16(1 << 5, v0);
v0 = vvadd_int32x16(v0, v1);
return vextract_int32x16(v0, 1);
}
inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int *x,
int64_t len) {
int32x16_t x_l, x_h;
int32x16_t sum = vset_zero_int();
const auto rounddown_len = rounddown32(len);
for (int64_t i = 0; i < rounddown_len; i += 32) {
vload2_sm(x + i, x_l, x_h);
sum = vvadd_int32x16(sum, x_l);
sum = vvadd_int32x16(sum, x_h);
}
if (rounddown_len < len) {
const auto mask = ~(-1 << (len - rounddown_len));
vload2_sm_mz(x + rounddown_len, x_l, x_h, mask);
sum = vvadd_int32x16(sum, x_l);
sum = vvadd_int32x16(sum, x_h);
}
return v_reduce_sum_int32(sum);
}
inline __device__ bool is_end_token(int64_t token,
__shared_ptr__ const int64_t *end_tokens,
int num_end_tokens) {
#pragma unroll 4
for (int i = 0; i < num_end_tokens; i++) {
if (token == end_tokens[i]) return true;
}
return false;
}
__global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
int *seq_lens_decoder,
bool *has_running_seqs,
int *mask_rollback,
int64_t *step_input_ids,
int *adaptive_step_input_len,
int64_t *step_output_ids,
int *step_output_len,
bool *stop_flags,
int *seq_lens_this_time,
const bool *is_paused,
int64_t *token_ids_all,
const int64_t *prompt_lens,
int64_t *step_idx,
const int64_t *end_tokens,
const int64_t *max_dec_len,
int real_bsz,
int max_bsz,
int max_step_tokens,
int max_model_len,
int num_end_tokens,
bool is_naive_mode,
bool prefill_one_step_stop) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
if (clusterid > 0) return;
__shared__ int sm_seq_lens_encoder[MAX_BATCH_SIZE];
__shared__ int sm_seq_lens_decoder[MAX_BATCH_SIZE];
__shared__ bool sm_stop_flags[MAX_BATCH_SIZE];
__shared__ int64_t sm_step_idx[MAX_BATCH_SIZE];
__shared__ bool sm_is_paused[MAX_BATCH_SIZE];
__shared__ int64_t sm_end_tokens[MAX_BATCH_SIZE];
__shared__ int sm_cum_seq_len, sm_cum_seq_len_output;
__shared__ int buffer_stop_flag_int[64];
if (cid == 0) {
GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, sizeof(int) * max_bsz);
GM2SM_ASYNC(seq_lens_decoder, sm_seq_lens_decoder, sizeof(int) * max_bsz);
GM2SM_ASYNC(stop_flags, sm_stop_flags, sizeof(bool) * max_bsz);
GM2SM_ASYNC(step_idx, sm_step_idx, sizeof(int64_t) * max_bsz);
GM2SM_ASYNC(is_paused, sm_is_paused, sizeof(bool) * max_bsz);
GM2SM_ASYNC(end_tokens, sm_end_tokens, sizeof(int64_t) * num_end_tokens);
}
buffer_stop_flag_int[cid] = 0;
mfence_sm();
sync_all();
for (int batch_id = cid; batch_id < max_bsz; batch_id += ncores) {
// Read state
int cur_seq_len_encoder = sm_seq_lens_encoder[batch_id];
int cur_seq_len_decoder = sm_seq_lens_decoder[batch_id];
bool cur_stop_flag = sm_stop_flags[batch_id];
int output_len = 0;
int64_t cur_step_idx = sm_step_idx[batch_id];
bool cur_is_paused = sm_is_paused[batch_id];
bool is_running = !cur_stop_flag && !cur_is_paused;
// Compute output length
if (is_running) {
if (is_naive_mode) {
output_len = 1;
} else {
output_len = step_output_len[batch_id];
}
}
// EOS detection
if (is_running && output_len > 0) {
bool hit_stop = false;
__global_ptr__ int64_t *output_ids =
&step_output_ids[batch_id * max_step_tokens];
for (int i = 0; i < output_len; i++) {
cur_step_idx++;
int64_t token = output_ids[i];
bool is_eos = is_end_token(token, sm_end_tokens, num_end_tokens);
bool max_len_hit = (cur_step_idx >= max_dec_len[batch_id]);
if (is_eos || max_len_hit) {
if (!is_eos) output_ids[i] = sm_end_tokens[0];
output_len = i + 1;
cur_stop_flag = true;
hit_stop = true;
break;
}
}
if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) {
cur_stop_flag = true;
}
}
// Update state and write back
if (is_running) {
if (cur_stop_flag) {
buffer_stop_flag_int[cid] += 1;
if (output_len == 0) cur_seq_len_decoder = 0;
stop_flags[batch_id] = true;
mask_rollback[batch_id] = 0;
} else if (cur_seq_len_encoder == 0) {
cur_seq_len_decoder += output_len;
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
} else {
mask_rollback[batch_id] = 0;
}
if (cur_seq_len_encoder > 0) {
cur_seq_len_decoder += cur_seq_len_encoder;
cur_seq_len_encoder = 0;
}
seq_lens_encoder[batch_id] = cur_seq_len_encoder;
seq_lens_decoder[batch_id] = cur_seq_len_decoder;
step_output_len[batch_id] = output_len;
step_idx[batch_id] = cur_step_idx;
// Write history to token_ids_all
if (cur_step_idx > 0 && output_len > 0) {
// Bounds check: highest write index is prompt_lens + cur_step_idx
if (prompt_lens[batch_id] + cur_step_idx < max_model_len) {
__global_ptr__ int64_t *token_ids_all_now =
&token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]];
__global_ptr__ int64_t *output_ids =
&step_output_ids[batch_id * max_step_tokens];
for (int i = 0; i < output_len; i++) {
token_ids_all_now[cur_step_idx - i] =
output_ids[output_len - 1 - i];
}
}
}
// Setup next input
if (output_len > 0) {
step_input_ids[batch_id * max_step_tokens] =
step_output_ids[batch_id * max_step_tokens + output_len - 1];
}
if (is_naive_mode) {
seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1;
}
} else if (batch_id >= real_bsz) {
// Padding slot: just count as stopped, don't modify state
buffer_stop_flag_int[cid] += 1;
} else {
// Stopped or paused slot (batch_id < real_bsz)
buffer_stop_flag_int[cid] += 1;
stop_flags[batch_id] = true;
seq_lens_decoder[batch_id] = 0;
seq_lens_this_time[batch_id] = 0;
step_output_len[batch_id] = 0;
}
}
mfence_sm();
sync_all();
int stop_flag_int = 0;
if (cid == 0) {
for (int i = 0; i < ncores; i++) {
stop_flag_int += buffer_stop_flag_int[i];
}
}
has_running_seqs[0] = stop_flag_int < max_bsz;
}
} // namespace fd_xpu3
@@ -24,7 +24,7 @@ __attribute__((global)) void draft_model_update(
int* seq_lens_encoder, int* seq_lens_encoder,
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
bool* stop_flags, bool* stop_flags,
bool* not_need_stop, bool* not_need_stop,
const int64_t* max_dec_len, const int64_t* max_dec_len,
@@ -60,7 +60,7 @@ static int cpu_wrapper(api::Context* ctx,
int* seq_lens_encoder, int* seq_lens_encoder,
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
bool* stop_flags, bool* stop_flags,
bool* not_need_stop, bool* not_need_stop,
const int64_t* max_dec_len, const int64_t* max_dec_len,
@@ -82,8 +82,7 @@ static int cpu_wrapper(api::Context* ctx,
auto* pre_ids_now = pre_ids + tid * pre_id_length; auto* pre_ids_now = pre_ids + tid * pre_id_length;
auto* base_model_draft_tokens_now = auto* base_model_draft_tokens_now =
base_model_draft_tokens + tid * max_base_model_draft_token; base_model_draft_tokens + tid * max_base_model_draft_token;
const int next_tokens_start_id = const int next_tokens_start_id = cu_seqlens_q_output[tid];
tid * max_seq_len - output_cum_offsets[tid];
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id; auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
auto seq_len_this_time = seq_lens_this_time[tid]; auto seq_len_this_time = seq_lens_this_time[tid];
auto seq_len_encoder = seq_lens_encoder[tid]; auto seq_len_encoder = seq_lens_encoder[tid];
@@ -158,7 +157,7 @@ static int xpu3_wrapper(api::Context* ctx,
int* seq_lens_encoder, int* seq_lens_encoder,
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
bool* stop_flags, bool* stop_flags,
bool* not_need_stop, bool* not_need_stop,
const int64_t* max_dec_len, const int64_t* max_dec_len,
@@ -182,7 +181,7 @@ static int xpu3_wrapper(api::Context* ctx,
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,
reinterpret_cast<XPU_INT64*>(step_idx), reinterpret_cast<XPU_INT64*>(step_idx),
output_cum_offsets, cu_seqlens_q_output,
stop_flags, stop_flags,
not_need_stop, not_need_stop,
reinterpret_cast<const XPU_INT64*>(max_dec_len), reinterpret_cast<const XPU_INT64*>(max_dec_len),
@@ -209,7 +208,7 @@ int draft_model_update(api::Context* ctx,
int* seq_lens_encoder, int* seq_lens_encoder,
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
bool* stop_flags, bool* stop_flags,
bool* not_need_stop, bool* not_need_stop,
const int64_t* max_dec_len, const int64_t* max_dec_len,
@@ -234,7 +233,7 @@ int draft_model_update(api::Context* ctx,
seq_lens_decoder); seq_lens_decoder);
WRAPPER_DUMP_PARAM6(ctx, WRAPPER_DUMP_PARAM6(ctx,
step_idx, step_idx,
output_cum_offsets, cu_seqlens_q_output,
stop_flags, stop_flags,
not_need_stop, not_need_stop,
max_dec_len, max_dec_len,
@@ -255,7 +254,7 @@ int draft_model_update(api::Context* ctx,
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder); WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder); WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx); WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx);
WRAPPER_CHECK_PTR(ctx, int, bsz, output_cum_offsets); WRAPPER_CHECK_PTR(ctx, int, bsz, cu_seqlens_q_output);
WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags); WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop); WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop);
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, max_dec_len); WRAPPER_CHECK_PTR(ctx, int64_t, bsz, max_dec_len);
@@ -272,7 +271,7 @@ int draft_model_update(api::Context* ctx,
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,
step_idx, step_idx,
output_cum_offsets, cu_seqlens_q_output,
stop_flags, stop_flags,
not_need_stop, not_need_stop,
max_dec_len, max_dec_len,
@@ -296,7 +295,7 @@ int draft_model_update(api::Context* ctx,
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,
step_idx, step_idx,
output_cum_offsets, cu_seqlens_q_output,
stop_flags, stop_flags,
not_need_stop, not_need_stop,
max_dec_len, max_dec_len,
@@ -0,0 +1,244 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl/xdnn_impl.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace fd_xpu3 {
__attribute__((global)) void speculate_preprocess_kernel(
int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cu_seqlens_q,
int* cu_seqlens_k,
int* seq_lens_output,
int* cu_seq_lens_q_output,
int* batch_id_per_token_output,
int* real_output_token_num,
const int64_t* input_data,
const int* seq_lens,
const int64_t* draft_tokens,
const int* seq_lens_encoder,
const int max_seq_len,
const int max_draft_tokens_per_batch,
const int real_bs);
} // namespace fd_xpu3
namespace fastdeploy {
namespace plugin {
static int cpu_wrapper(api::Context* ctx,
int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cu_seqlens_q,
int* cu_seqlens_k,
int* seq_lens_output,
int* cu_seq_lens_q_output,
int* batch_id_per_token_output,
int* real_output_token_num,
const int64_t* input_data,
const int* seq_lens,
const int64_t* draft_tokens,
const int* seq_lens_encoder,
const int max_seq_len,
const int max_draft_tokens_per_batch,
const int token_num_data,
const int real_bs) {
cu_seqlens_q[0] = 0;
cu_seqlens_k[0] = 0;
for (int i = 0; i < real_bs; ++i) {
const int seq_len = seq_lens[i];
cu_seqlens_q[i + 1] = cu_seqlens_q[i] + seq_len;
cu_seqlens_k[i + 1] = cu_seqlens_k[i] + seq_len;
}
for (int bi = 0; bi < real_bs; ++bi) {
for (int i = 0; i < seq_lens[bi]; ++i) {
const int tgt_seq_id = cu_seqlens_q[bi + 1] - seq_lens[bi] + i;
if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) {
// speculative decoding
const int src_seq_id = bi * max_draft_tokens_per_batch + i;
ids_remove_padding[tgt_seq_id] = draft_tokens[src_seq_id];
} else {
// Non-speculative decoding
const int src_seq_id = bi * max_seq_len + i;
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
}
batch_id_per_token[tgt_seq_id] = bi;
}
}
for (int bid = 0; bid < real_bs; ++bid) {
if (seq_lens[bid] == 0) {
seq_lens_output[bid] = 0;
} else if (seq_lens[bid] == 1) {
seq_lens_output[bid] = 1;
} else if (seq_lens_encoder[bid] != 0) {
seq_lens_output[bid] = 1;
} else {
seq_lens_output[bid] = seq_lens[bid];
}
}
cu_seq_lens_q_output[0] = 0;
for (int i = 0; i < real_bs; ++i) {
cu_seq_lens_q_output[i + 1] = cu_seq_lens_q_output[i] + seq_lens_output[i];
}
real_output_token_num[0] = cu_seq_lens_q_output[real_bs];
for (int bi = 0; bi < real_bs; ++bi) {
for (int i = 0; i < seq_lens_output[bi]; ++i) {
const int tgt_seq_id_output =
cu_seq_lens_q_output[bi + 1] - seq_lens_output[bi] + i;
batch_id_per_token_output[tgt_seq_id_output] = bi;
}
}
return api::SUCCESS;
}
static int xpu3_wrapper(api::Context* ctx,
int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cu_seqlens_q,
int* cu_seqlens_k,
int* seq_lens_output,
int* cu_seq_lens_q_output,
int* batch_id_per_token_output,
int* real_output_token_num,
const int64_t* input_data,
const int* seq_lens,
const int64_t* draft_tokens,
const int* seq_lens_encoder,
const int max_seq_len,
const int max_draft_tokens_per_batch,
const int token_num_data,
const int real_bs) {
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
int32_t ret_xre = fd_xpu3::
speculate_preprocess_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(ids_remove_padding),
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
seq_lens_output,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num,
reinterpret_cast<const XPU_INT64*>(input_data),
seq_lens,
reinterpret_cast<const XPU_INT64*>(draft_tokens),
seq_lens_encoder,
max_seq_len,
max_draft_tokens_per_batch,
real_bs);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
int speculate_preprocess(api::Context* ctx,
int64_t* ids_remove_padding,
int* batch_id_per_token,
int* cu_seqlens_q,
int* cu_seqlens_k,
int* seq_lens_output,
int* cu_seq_lens_q_output,
int* batch_id_per_token_output,
int* real_output_token_num,
const int64_t* input_data,
const int* seq_lens,
const int64_t* draft_tokens,
const int* seq_lens_encoder,
const int max_seq_len,
const int max_draft_tokens_per_batch,
const int token_num_data,
const int real_bs) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_preprocess", int);
WRAPPER_DUMP_PARAM6(ctx,
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
seq_lens_output,
cu_seq_lens_q_output);
WRAPPER_DUMP_PARAM6(ctx,
batch_id_per_token_output,
real_output_token_num,
input_data,
seq_lens,
draft_tokens,
seq_lens_encoder);
WRAPPER_DUMP_PARAM3(ctx, max_seq_len, max_draft_tokens_per_batch, real_bs);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int64_t, token_num_data, ids_remove_padding);
WRAPPER_CHECK_PTR(ctx, int, token_num_data, batch_id_per_token);
WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seqlens_q);
WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seqlens_k);
WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens_output);
WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seq_lens_q_output);
WRAPPER_CHECK_PTR(
ctx, int, real_bs* max_draft_tokens_per_batch, batch_id_per_token_output);
WRAPPER_CHECK_PTR(ctx, int, 1, real_output_token_num);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bs * max_seq_len, input_data);
WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens);
WRAPPER_CHECK_PTR(
ctx, int, real_bs* max_draft_tokens_per_batch, draft_tokens);
WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens_encoder);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
seq_lens_output,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num,
input_data,
seq_lens,
draft_tokens,
seq_lens_encoder,
max_seq_len,
max_draft_tokens_per_batch,
token_num_data,
real_bs);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
seq_lens_output,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num,
input_data,
seq_lens,
draft_tokens,
seq_lens_encoder,
max_seq_len,
max_draft_tokens_per_batch,
token_num_data,
real_bs);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace fastdeploy
@@ -25,8 +25,8 @@ __attribute__((global)) void speculate_min_length_logits_process(
const int64_t* cur_len, const int64_t* cur_len,
const int64_t* min_len, const int64_t* min_len,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -37,7 +37,7 @@ __attribute__((global)) void speculate_update_repeat_times(
const int64_t* pre_ids, const int64_t* pre_ids,
const int64_t* cur_len, const int64_t* cur_len,
int* repeat_times, int* repeat_times,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -146,7 +146,7 @@ static int cpu_wrapper(api::Context* ctx,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int64_t* bad_words, const int64_t* bad_words,
const int* output_padding_offset, const int* output_padding_offset,
const int* output_cum_offsets, const int* batch_id_per_token_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -172,7 +172,7 @@ static int cpu_wrapper(api::Context* ctx,
WRAPPER_ASSERT_SUCCESS(ctx, ret); WRAPPER_ASSERT_SUCCESS(ctx, ret);
for (int64_t i = 0; i < token_num; i++) { for (int64_t i = 0; i < token_num; i++) {
int64_t bi = (i + output_padding_offset[i]) / max_seq_len; int64_t bi = (i + output_padding_offset[i]) / max_seq_len;
int64_t query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi]; int64_t query_start_token_idx = batch_id_per_token_output[bi];
if (bi < bs && cur_len[bi] >= 0 && if (bi < bs && cur_len[bi] >= 0 &&
(cur_len[bi] + (i - query_start_token_idx) < min_len[bi])) { (cur_len[bi] + (i - query_start_token_idx) < min_len[bi])) {
for (int64_t j = 0; j < end_length; j++) { for (int64_t j = 0; j < end_length; j++) {
@@ -236,8 +236,8 @@ static int xpu3_wrapper(api::Context* ctx,
const int64_t* min_len, const int64_t* min_len,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int64_t* bad_words, const int64_t* bad_words,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -268,7 +268,7 @@ static int xpu3_wrapper(api::Context* ctx,
reinterpret_cast<const XPU_INT64*>(pre_ids), reinterpret_cast<const XPU_INT64*>(pre_ids),
reinterpret_cast<const XPU_INT64*>(cur_len), reinterpret_cast<const XPU_INT64*>(cur_len),
repeat_times, repeat_times,
output_padding_offset, batch_id_per_token_output,
bs, bs,
length, length,
length_id, length_id,
@@ -282,8 +282,8 @@ static int xpu3_wrapper(api::Context* ctx,
reinterpret_cast<const XPU_INT64*>(cur_len), reinterpret_cast<const XPU_INT64*>(cur_len),
reinterpret_cast<const XPU_INT64*>(min_len), reinterpret_cast<const XPU_INT64*>(min_len),
reinterpret_cast<const XPU_INT64*>(eos_token_id), reinterpret_cast<const XPU_INT64*>(eos_token_id),
output_padding_offset, batch_id_per_token_output,
output_cum_offsets, cu_seqlens_q_output,
bs, bs,
length, length,
length_id, length_id,
@@ -300,7 +300,7 @@ static int xpu3_wrapper(api::Context* ctx,
presence_scores, presence_scores,
temperatures, temperatures,
logits, logits,
output_padding_offset, batch_id_per_token_output,
bs, bs,
length, length,
token_num, token_num,
@@ -311,7 +311,7 @@ static int xpu3_wrapper(api::Context* ctx,
ret_xre = ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>( ret_xre = ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
logits, logits,
reinterpret_cast<const XPU_INT64*>(bad_words), reinterpret_cast<const XPU_INT64*>(bad_words),
output_padding_offset, batch_id_per_token_output,
bs, bs,
length, length,
length_bad_words, length_bad_words,
@@ -334,8 +334,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
const int64_t* min_len, const int64_t* min_len,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int64_t* bad_words, const int64_t* bad_words,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -357,8 +357,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
min_len, min_len,
eos_token_id, eos_token_id,
bad_words, bad_words,
output_padding_offset, cu_seqlens_q_output,
output_cum_offsets); batch_id_per_token_output);
WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length); WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length);
WRAPPER_DUMP_PARAM3(ctx, length_bad_words, token_num, max_seq_len); WRAPPER_DUMP_PARAM3(ctx, length_bad_words, token_num, max_seq_len);
WRAPPER_DUMP(ctx); WRAPPER_DUMP(ctx);
@@ -373,8 +373,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
int64_t min_len_len = -1; int64_t min_len_len = -1;
int64_t eos_token_id_len = -1; int64_t eos_token_id_len = -1;
int64_t bad_words_len = -1; int64_t bad_words_len = -1;
int64_t output_padding_offset_len = -1; // int64_t output_padding_offset_len = -1;
int64_t output_cum_offsets_len = -1; // int64_t output_cum_offsets_len = -1;
WRAPPER_ASSERT_LE(ctx, bs, 640); WRAPPER_ASSERT_LE(ctx, bs, 640);
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id}); WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id});
WRAPPER_CHECK_SHAPE(ctx, &logits_len, {token_num, length}); WRAPPER_CHECK_SHAPE(ctx, &logits_len, {token_num, length});
@@ -386,8 +386,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs}); WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs});
WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length}); WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length});
WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words}); WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words});
WRAPPER_CHECK_SHAPE(ctx, &output_padding_offset_len, {token_num}); // WRAPPER_CHECK_SHAPE(ctx, &output_padding_offset_len, {token_num});
WRAPPER_CHECK_SHAPE(ctx, &output_cum_offsets_len, {bs}); // WRAPPER_CHECK_SHAPE(ctx, &output_cum_offsets_len, {bs});
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids); WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids);
WRAPPER_CHECK_PTR(ctx, T, logits_len, logits); WRAPPER_CHECK_PTR(ctx, T, logits_len, logits);
WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores); WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores);
@@ -398,8 +398,9 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len); WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len);
WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id); WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id);
WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words); WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words);
WRAPPER_CHECK_PTR(ctx, int, output_padding_offset_len, output_padding_offset); // WRAPPER_CHECK_PTR(ctx, int, output_padding_offset_len,
WRAPPER_CHECK_PTR(ctx, int, output_cum_offsets_len, output_cum_offsets); // output_padding_offset); WRAPPER_CHECK_PTR(ctx, int, output_cum_offsets_len,
// output_cum_offsets);
if (ctx->dev().type() == api::kCPU) { if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx, return cpu_wrapper<T>(ctx,
pre_ids, pre_ids,
@@ -412,8 +413,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
min_len, min_len,
eos_token_id, eos_token_id,
bad_words, bad_words,
output_padding_offset, batch_id_per_token_output,
output_cum_offsets, cu_seqlens_q_output,
bs, bs,
length, length,
length_id, length_id,
@@ -434,8 +435,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
min_len, min_len,
eos_token_id, eos_token_id,
bad_words, bad_words,
output_padding_offset, batch_id_per_token_output,
output_cum_offsets, cu_seqlens_q_output,
bs, bs,
length, length,
length_id, length_id,
@@ -459,8 +460,8 @@ template int speculate_token_penalty_multi_scores<float>(
const int64_t* min_len, const int64_t* min_len,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int64_t* bad_words, const int64_t* bad_words,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -480,8 +481,8 @@ template int speculate_token_penalty_multi_scores<float16>(
const int64_t* min_len, const int64_t* min_len,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int64_t* bad_words, const int64_t* bad_words,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -501,8 +502,8 @@ template int speculate_token_penalty_multi_scores<bfloat16>(
const int64_t* min_len, const int64_t* min_len,
const int64_t* eos_token_id, const int64_t* eos_token_id,
const int64_t* bad_words, const int64_t* bad_words,
const int* output_padding_offset, const int* batch_id_per_token_output,
const int* output_cum_offsets, const int* cu_seqlens_q_output,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
const int64_t length_id, const int64_t length_id,
@@ -23,25 +23,25 @@ typedef uint32_t curandStatePhilox4_32_10_t;
template <bool ENABLE_TOPP, bool USE_TOPK> template <bool ENABLE_TOPP, bool USE_TOPK>
__attribute__((global)) void speculate_verify( __attribute__((global)) void speculate_verify(
const int64_t* sampled_token_ids, const int64_t *sampled_token_ids,
int64_t* accept_tokens, int64_t *accept_tokens,
int* accept_num, int *accept_num,
int64_t* step_idx, int64_t *step_idx,
bool* stop_flags, bool *stop_flags,
const int* seq_lens_encoder, const int *seq_lens_encoder,
const int* seq_lens_decoder, const int *seq_lens_decoder,
const int64_t* draft_tokens, const int64_t *draft_tokens,
const int* actual_draft_token_nums, const int *actual_draft_token_nums,
const float* dev_curand_states, const float *dev_curand_states,
const float* topp, const float *topp,
const int* seq_lens_this_time, const int *seq_lens_this_time,
const int64_t* verify_tokens, const int64_t *verify_tokens,
const float* verify_scores, const float *verify_scores,
const int64_t* max_dec_len, const int64_t *max_dec_len,
const int64_t* end_tokens, const int64_t *end_tokens,
const bool* is_block_step, const bool *is_block_step,
const int* output_cum_offsets, const int *cu_seqlens_q_output,
const int* actual_candidate_len, const int *actual_candidate_len,
const int real_bsz, const int real_bsz,
const int max_draft_tokens, const int max_draft_tokens,
const int end_length, const int end_length,
@@ -58,7 +58,7 @@ namespace fastdeploy {
namespace plugin { namespace plugin {
static inline bool is_in_end(const int64_t id, static inline bool is_in_end(const int64_t id,
const int64_t* end_ids, const int64_t *end_ids,
int length) { int length) {
bool flag = false; bool flag = false;
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
@@ -69,7 +69,7 @@ static inline bool is_in_end(const int64_t id,
return flag; return flag;
} }
static inline bool is_in(const int64_t* candidates, static inline bool is_in(const int64_t *candidates,
const int64_t draft, const int64_t draft,
const int candidate_len) { const int candidate_len) {
for (int i = 0; i < candidate_len; i++) { for (int i = 0; i < candidate_len; i++) {
@@ -80,7 +80,7 @@ static inline bool is_in(const int64_t* candidates,
return false; return false;
} }
static inline unsigned int xorwow(unsigned int& state) { // NOLINT static inline unsigned int xorwow(unsigned int &state) { // NOLINT
state ^= state >> 7; state ^= state >> 7;
state ^= state << 9; state ^= state << 9;
state ^= state >> 13; state ^= state >> 13;
@@ -89,9 +89,9 @@ static inline unsigned int xorwow(unsigned int& state) { // NOLINT
typedef uint32_t curandStatePhilox4_32_10_t; typedef uint32_t curandStatePhilox4_32_10_t;
static int64_t topp_sampling_kernel(const int64_t* candidate_ids, static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
const float* candidate_scores, const float *candidate_scores,
const float* dev_curand_states, const float *dev_curand_states,
const int candidate_len, const int candidate_len,
const float topp, const float topp,
int tid) { int tid) {
@@ -111,26 +111,26 @@ static int64_t topp_sampling_kernel(const int64_t* candidate_ids,
} }
template <bool ENABLE_TOPP, bool USE_TOPK> template <bool ENABLE_TOPP, bool USE_TOPK>
static int cpu_wrapper(api::Context* ctx, static int cpu_wrapper(api::Context *ctx,
const int64_t* sampled_token_ids, const int64_t *sampled_token_ids,
int64_t* accept_tokens, int64_t *accept_tokens,
int* accept_num, int *accept_num,
int64_t* step_idx, int64_t *step_idx,
bool* stop_flags, bool *stop_flags,
const int* seq_lens_encoder, const int *seq_lens_encoder,
const int* seq_lens_decoder, const int *seq_lens_decoder,
const int64_t* draft_tokens, const int64_t *draft_tokens,
const int* actual_draft_token_nums, const int *actual_draft_token_nums,
const float* dev_curand_states, const float *dev_curand_states,
const float* topp, const float *topp,
const int* seq_lens_this_time, const int *seq_lens_this_time,
const int64_t* verify_tokens, const int64_t *verify_tokens,
const float* verify_scores, const float *verify_scores,
const int64_t* max_dec_len, const int64_t *max_dec_len,
const int64_t* end_tokens, const int64_t *end_tokens,
const bool* is_block_step, const bool *is_block_step,
const int* output_cum_offsets, const int *cu_seqlens_q_output,
const int* actual_candidate_len, const int *actual_candidate_len,
const int real_bsz, const int real_bsz,
const int max_draft_tokens, const int max_draft_tokens,
const int end_length, const int end_length,
@@ -147,7 +147,7 @@ static int cpu_wrapper(api::Context* ctx,
int stop_flag_now_int = 0; int stop_flag_now_int = 0;
if (!(is_block_step[bid] || bid >= real_bsz)) { if (!(is_block_step[bid] || bid >= real_bsz)) {
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; const int start_token_id = cu_seqlens_q_output[bid];
// printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id); // printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id);
// printf("bid %d\n", bid); // printf("bid %d\n", bid);
@@ -155,11 +155,11 @@ static int cpu_wrapper(api::Context* ctx,
stop_flag_now_int = 1; stop_flag_now_int = 1;
} else { // 这里prefill阶段也会进入,但是因为draft } else { // 这里prefill阶段也会进入,但是因为draft
// tokens会置零,因此会直接到最后的采样阶段 // tokens会置零,因此会直接到最后的采样阶段
auto* verify_tokens_now = auto *verify_tokens_now =
verify_tokens + start_token_id * max_candidate_len; verify_tokens + start_token_id * max_candidate_len;
auto* draft_tokens_now = draft_tokens + bid * max_draft_tokens; auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
auto* actual_candidate_len_now = actual_candidate_len + start_token_id; auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
auto* sampled_token_id_now = sampled_token_ids + start_token_id; auto *sampled_token_id_now = sampled_token_ids + start_token_id;
int i = 0; int i = 0;
// printf("seq_lens_this_time[%d]-1: %d \n",bid, // printf("seq_lens_this_time[%d]-1: %d \n",bid,
@@ -306,7 +306,7 @@ static int cpu_wrapper(api::Context* ctx,
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算 // 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
if (!stop_flag_now_int) { if (!stop_flag_now_int) {
int64_t accept_token; int64_t accept_token;
const float* verify_scores_now = const float *verify_scores_now =
verify_scores + start_token_id * max_candidate_len; verify_scores + start_token_id * max_candidate_len;
step_idx[bid]++; step_idx[bid]++;
if (use_target_sampling) { if (use_target_sampling) {
@@ -347,26 +347,26 @@ static int cpu_wrapper(api::Context* ctx,
} }
template <bool ENABLE_TOPP, bool USE_TOPK> template <bool ENABLE_TOPP, bool USE_TOPK>
static int xpu3_wrapper(api::Context* ctx, static int xpu3_wrapper(api::Context *ctx,
const int64_t* sampled_token_ids, const int64_t *sampled_token_ids,
int64_t* accept_tokens, int64_t *accept_tokens,
int* accept_num, int *accept_num,
int64_t* step_idx, int64_t *step_idx,
bool* stop_flags, bool *stop_flags,
const int* seq_lens_encoder, const int *seq_lens_encoder,
const int* seq_lens_decoder, const int *seq_lens_decoder,
const int64_t* draft_tokens, const int64_t *draft_tokens,
const int* actual_draft_token_nums, const int *actual_draft_token_nums,
const float* dev_curand_states, const float *dev_curand_states,
const float* topp, const float *topp,
const int* seq_lens_this_time, const int *seq_lens_this_time,
const int64_t* verify_tokens, const int64_t *verify_tokens,
const float* verify_scores, const float *verify_scores,
const int64_t* max_dec_len, const int64_t *max_dec_len,
const int64_t* end_tokens, const int64_t *end_tokens,
const bool* is_block_step, const bool *is_block_step,
const int* output_cum_offsets, const int *cu_seqlens_q_output,
const int* actual_candidate_len, const int *actual_candidate_len,
const int real_bsz, const int real_bsz,
const int max_draft_tokens, const int max_draft_tokens,
const int end_length, const int end_length,
@@ -380,24 +380,24 @@ static int xpu3_wrapper(api::Context* ctx,
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type; using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
int32_t ret_xre = fd_xpu3::speculate_verify<ENABLE_TOPP, USE_TOPK> int32_t ret_xre = fd_xpu3::speculate_verify<ENABLE_TOPP, USE_TOPK>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>( <<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const XPU_INT64*>(sampled_token_ids), reinterpret_cast<const XPU_INT64 *>(sampled_token_ids),
reinterpret_cast<XPU_INT64*>(accept_tokens), reinterpret_cast<XPU_INT64 *>(accept_tokens),
accept_num, accept_num,
reinterpret_cast<XPU_INT64*>(step_idx), reinterpret_cast<XPU_INT64 *>(step_idx),
stop_flags, stop_flags,
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,
reinterpret_cast<const XPU_INT64*>(draft_tokens), reinterpret_cast<const XPU_INT64 *>(draft_tokens),
actual_draft_token_nums, actual_draft_token_nums,
dev_curand_states, dev_curand_states,
topp, topp,
seq_lens_this_time, seq_lens_this_time,
reinterpret_cast<const XPU_INT64*>(verify_tokens), reinterpret_cast<const XPU_INT64 *>(verify_tokens),
verify_scores, verify_scores,
reinterpret_cast<const XPU_INT64*>(max_dec_len), reinterpret_cast<const XPU_INT64 *>(max_dec_len),
reinterpret_cast<const XPU_INT64*>(end_tokens), reinterpret_cast<const XPU_INT64 *>(end_tokens),
is_block_step, is_block_step,
output_cum_offsets, cu_seqlens_q_output,
actual_candidate_len, actual_candidate_len,
real_bsz, real_bsz,
max_draft_tokens, max_draft_tokens,
@@ -413,26 +413,26 @@ static int xpu3_wrapper(api::Context* ctx,
return api::SUCCESS; return api::SUCCESS;
} }
template <bool ENABLE_TOPP, bool USE_TOPK> template <bool ENABLE_TOPP, bool USE_TOPK>
int speculate_verify(api::Context* ctx, int speculate_verify(api::Context *ctx,
const int64_t* sampled_token_ids, const int64_t *sampled_token_ids,
int64_t* accept_tokens, int64_t *accept_tokens,
int* accept_num, int *accept_num,
int64_t* step_idx, int64_t *step_idx,
bool* stop_flags, bool *stop_flags,
const int* seq_lens_encoder, const int *seq_lens_encoder,
const int* seq_lens_decoder, const int *seq_lens_decoder,
const int64_t* draft_tokens, const int64_t *draft_tokens,
const int* actual_draft_token_nums, const int *actual_draft_token_nums,
const float* dev_curand_states, const float *dev_curand_states,
const float* topp, const float *topp,
const int* seq_lens_this_time, const int *seq_lens_this_time,
const int64_t* verify_tokens, const int64_t *verify_tokens,
const float* verify_scores, const float *verify_scores,
const int64_t* max_dec_len, const int64_t *max_dec_len,
const int64_t* end_tokens, const int64_t *end_tokens,
const bool* is_block_step, const bool *is_block_step,
const int* output_cum_offsets, const int *cu_seqlens_q_output,
const int* actual_candidate_len, const int *actual_candidate_len,
const int real_bsz, const int real_bsz,
const int max_draft_tokens, const int max_draft_tokens,
const int end_length, const int end_length,
@@ -462,7 +462,7 @@ int speculate_verify(api::Context* ctx,
end_tokens); end_tokens);
WRAPPER_DUMP_PARAM5(ctx, WRAPPER_DUMP_PARAM5(ctx,
is_block_step, is_block_step,
output_cum_offsets, cu_seqlens_q_output,
actual_candidate_len, actual_candidate_len,
real_bsz, real_bsz,
max_draft_tokens); max_draft_tokens);
@@ -492,7 +492,7 @@ int speculate_verify(api::Context* ctx,
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len); WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len);
WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens); WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step); WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, output_cum_offsets); WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_seqlens_q_output);
// WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_candidate_len); // WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_candidate_len);
// param check sm size limit // param check sm size limit
@@ -525,7 +525,7 @@ int speculate_verify(api::Context* ctx,
max_dec_len, max_dec_len,
end_tokens, end_tokens,
is_block_step, is_block_step,
output_cum_offsets, cu_seqlens_q_output,
actual_candidate_len, actual_candidate_len,
real_bsz, real_bsz,
max_draft_tokens, max_draft_tokens,
@@ -557,7 +557,7 @@ int speculate_verify(api::Context* ctx,
max_dec_len, max_dec_len,
end_tokens, end_tokens,
is_block_step, is_block_step,
output_cum_offsets, cu_seqlens_q_output,
actual_candidate_len, actual_candidate_len,
real_bsz, real_bsz,
max_draft_tokens, max_draft_tokens,
@@ -575,36 +575,36 @@ int speculate_verify(api::Context* ctx,
#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \ #define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \
template int fastdeploy::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \ template int fastdeploy::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \
fastdeploy::plugin::api::Context*, /* xpu_ctx */ \ fastdeploy::plugin::api::Context *, /* xpu_ctx */ \
const int64_t*, /* sampled_token_ids */ \ const int64_t *, /* sampled_token_ids */ \
int64_t*, /* accept_tokens */ \ int64_t *, /* accept_tokens */ \
int*, /* accept_num */ \ int *, /* accept_num */ \
int64_t*, /* step_idx */ \ int64_t *, /* step_idx */ \
bool*, /* stop_flags */ \ bool *, /* stop_flags */ \
const int*, /* seq_lens_encoder */ \ const int *, /* seq_lens_encoder */ \
const int*, /* seq_lens_decoder */ \ const int *, /* seq_lens_decoder */ \
const int64_t*, /* draft_tokens */ \ const int64_t *, /* draft_tokens */ \
const int*, /* actual_draft_token_nums */ \ const int *, /* actual_draft_token_nums */ \
const float*, /* dev_curand_states or topp */ \ const float *, /* dev_curand_states or topp */ \
const float*, /* topp or nullptr */ \ const float *, /* topp or nullptr */ \
const int*, /* seq_lens_this_time */ \ const int *, /* seq_lens_this_time */ \
const int64_t*, /* verify_tokens */ \ const int64_t *, /* verify_tokens */ \
const float*, /* verify_scores */ \ const float *, /* verify_scores */ \
const int64_t*, /* max_dec_len */ \ const int64_t *, /* max_dec_len */ \
const int64_t*, /* end_tokens */ \ const int64_t *, /* end_tokens */ \
const bool*, /* is_block_step */ \ const bool *, /* is_block_step */ \
const int*, /* output_cum_offsets */ \ const int *, /* cu_seqlens_q_output */ \
const int*, /* actual_candidate_len */ \ const int *, /* actual_candidate_len */ \
int, /* real_bsz */ \ int, /* real_bsz */ \
int, /* max_draft_tokens */ \ int, /* max_draft_tokens */ \
int, /* end_length */ \ int, /* end_length */ \
int, /* max_seq_len */ \ int, /* max_seq_len */ \
int, /* max_candidate_len */ \ int, /* max_candidate_len */ \
int, /* verify_window */ \ int, /* verify_window */ \
bool, /* prefill_one_step_stop */ \ bool, /* prefill_one_step_stop */ \
bool, /* benchmark_mode */ \ bool, /* benchmark_mode */ \
bool, /* accept_all_drafts */ \ bool, /* accept_all_drafts */ \
bool /* use_target_sampling */ \ bool /* use_target_sampling */ \
); );
INSTANTIATE_SPECULATE_VERIFY(false, false) INSTANTIATE_SPECULATE_VERIFY(false, false)
@@ -17,16 +17,17 @@
namespace fd_xpu3 { namespace fd_xpu3 {
template <typename T, int MaxLength, int TopPBeamTopK> template <typename T, int MaxLength, int TopPBeamTopK>
__attribute__((global)) void top_p_candidates(const T* src, __attribute__((global)) void top_p_candidates(
const T* top_ps, const T* src,
const int* output_padding_offset, const T* top_ps,
int64_t* out_id, const int* batch_id_per_token_output,
T* out_val, int64_t* out_id,
int* actual_candidates_lens, T* out_val,
int vocab_size, int* actual_candidates_lens,
int token_num, int vocab_size,
int max_candidate_len, int token_num,
int max_seq_len); int max_candidate_len,
int max_seq_len);
} // namespace fd_xpu3 } // namespace fd_xpu3
namespace fastdeploy { namespace fastdeploy {
@@ -36,7 +37,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
static int cpu_wrapper(api::Context* ctx, static int cpu_wrapper(api::Context* ctx,
const T* src, const T* src,
const T* top_ps, const T* top_ps,
const int* output_padding_offset, const int* batch_id_per_token_output,
int64_t* out_id, int64_t* out_id,
T* out_val, T* out_val,
int* actual_candidates_lens, int* actual_candidates_lens,
@@ -70,8 +71,7 @@ static int cpu_wrapper(api::Context* ctx,
} }
} }
} }
int ori_token_id = i + output_padding_offset[i]; int bid = batch_id_per_token_output[i];
int bid = ori_token_id / max_seq_len;
float top_p_value = static_cast<float>(top_ps[bid]); float top_p_value = static_cast<float>(top_ps[bid]);
bool set_to_default_val = false; bool set_to_default_val = false;
for (int j = 0; j < TopPBeamTopK; j++) { for (int j = 0; j < TopPBeamTopK; j++) {
@@ -97,7 +97,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
static int xpu3_wrapper(api::Context* ctx, static int xpu3_wrapper(api::Context* ctx,
const T* src, const T* src,
const T* top_ps, const T* top_ps,
const int* output_padding_offset, const int* batch_id_per_token_output,
int64_t* out_id, int64_t* out_id,
T* out_val, T* out_val,
int* actual_candidates_lens, int* actual_candidates_lens,
@@ -110,7 +110,7 @@ static int xpu3_wrapper(api::Context* ctx,
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>( <<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
src, src,
top_ps, top_ps,
output_padding_offset, batch_id_per_token_output,
reinterpret_cast<XPU_INT64*>(out_id), reinterpret_cast<XPU_INT64*>(out_id),
out_val, out_val,
actual_candidates_lens, actual_candidates_lens,
@@ -126,7 +126,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
int top_p_candidates(api::Context* ctx, int top_p_candidates(api::Context* ctx,
const T* src, const T* src,
const T* top_ps, const T* top_ps,
const int* output_padding_offset, const int* batch_id_per_token_output,
int64_t* out_id, int64_t* out_id,
T* out_val, T* out_val,
int* actual_candidates_lens, int* actual_candidates_lens,
@@ -136,7 +136,8 @@ int top_p_candidates(api::Context* ctx,
int max_seq_len) { int max_seq_len) {
WRAPPER_CHECK_CTX(ctx); WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "top_p_candidates", T); WRAPPER_DUMP_FUNCTION_T1(ctx, "top_p_candidates", T);
WRAPPER_DUMP_PARAM5(ctx, src, top_ps, output_padding_offset, out_id, out_val); WRAPPER_DUMP_PARAM5(
ctx, src, top_ps, batch_id_per_token_output, out_id, out_val);
WRAPPER_DUMP_PARAM5(ctx, WRAPPER_DUMP_PARAM5(ctx,
actual_candidates_lens, actual_candidates_lens,
vocab_size, vocab_size,
@@ -146,7 +147,7 @@ int top_p_candidates(api::Context* ctx,
WRAPPER_DUMP(ctx); WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, token_num * vocab_size, src); WRAPPER_CHECK_PTR(ctx, T, token_num * vocab_size, src);
WRAPPER_CHECK_PTR(ctx, T, token_num, output_padding_offset); WRAPPER_CHECK_PTR(ctx, T, token_num, batch_id_per_token_output);
WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_id); WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_id);
WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_val); WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_val);
@@ -161,7 +162,7 @@ int top_p_candidates(api::Context* ctx,
return cpu_wrapper<T, MaxLength, TopPBeamTopK>(ctx, return cpu_wrapper<T, MaxLength, TopPBeamTopK>(ctx,
src, src,
top_ps, top_ps,
output_padding_offset, batch_id_per_token_output,
out_id, out_id,
out_val, out_val,
actual_candidates_lens, actual_candidates_lens,
@@ -173,7 +174,7 @@ int top_p_candidates(api::Context* ctx,
return xpu3_wrapper<T, MaxLength, TopPBeamTopK>(ctx, return xpu3_wrapper<T, MaxLength, TopPBeamTopK>(ctx,
src, src,
top_ps, top_ps,
output_padding_offset, batch_id_per_token_output,
out_id, out_id,
out_val, out_val,
actual_candidates_lens, actual_candidates_lens,
@@ -0,0 +1,376 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl/xdnn_impl.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace fd_xpu3 {
__attribute__((global)) void unified_update_model_status_kernel(
int *seq_lens_encoder,
int *seq_lens_decoder,
bool *has_running_seqs,
int *mask_rollback,
int64_t *step_input_ids,
int *adaptive_step_input_len,
int64_t *step_output_ids,
int *step_output_len,
bool *stop_flags,
int *seq_lens_this_time,
const bool *is_paused,
int64_t *token_ids_all,
const int64_t *prompt_lens,
int64_t *step_idx,
const int64_t *end_tokens,
const int64_t *max_dec_len,
int real_bsz,
int max_bsz,
int max_step_tokens,
int max_model_len,
int num_end_tokens,
bool is_naive_mode,
bool prefill_one_step_stop);
} // namespace fd_xpu3
namespace fastdeploy {
namespace plugin {
bool is_end_token(int64_t token,
const int64_t *end_tokens,
int num_end_tokens) {
#pragma unroll 4
for (int i = 0; i < num_end_tokens; i++) {
if (token == end_tokens[i]) return true;
}
return false;
}
static int cpu_wrapper(api::Context *ctx,
int *seq_lens_encoder,
int *seq_lens_decoder,
bool *has_running_seqs,
int *mask_rollback,
int64_t *step_input_ids,
int *adaptive_step_input_len,
int64_t *step_output_ids,
int *step_output_len,
bool *stop_flags,
int *seq_lens_this_time,
const bool *is_paused,
int64_t *token_ids_all,
const int64_t *prompt_lens,
int64_t *step_idx,
const int64_t *end_tokens,
const int64_t *max_dec_len,
int real_bsz,
int max_bsz,
int max_step_tokens,
int max_model_len,
int num_end_tokens,
bool is_naive_mode,
bool prefill_one_step_stop) {
int stop_flag_int = 0;
for (int batch_id = 0; batch_id < max_bsz; batch_id++) {
// Read state
int cur_seq_len_encoder = seq_lens_encoder[batch_id];
int cur_seq_len_decoder = seq_lens_decoder[batch_id];
bool cur_stop_flag = stop_flags[batch_id];
int output_len = 0;
int64_t cur_step_idx = step_idx[batch_id];
bool cur_is_paused = is_paused[batch_id];
bool is_running = !cur_stop_flag && !cur_is_paused;
// Compute output length
if (is_running) {
if (is_naive_mode) {
output_len = 1;
} else {
output_len = step_output_len[batch_id];
}
}
// EOS detection
if (is_running && output_len > 0) {
bool hit_stop = false;
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
for (int i = 0; i < output_len; i++) {
cur_step_idx++;
int64_t token = output_ids[i];
bool is_eos = is_end_token(token, end_tokens, num_end_tokens);
bool max_len_hit = (cur_step_idx >= max_dec_len[batch_id]);
if (is_eos || max_len_hit) {
if (!is_eos) output_ids[i] = end_tokens[0];
output_len = i + 1;
cur_stop_flag = true;
hit_stop = true;
break;
}
}
if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) {
cur_stop_flag = true;
}
}
// Update state and write back
if (is_running) {
if (cur_stop_flag) {
stop_flag_int += 1;
if (output_len == 0) cur_seq_len_decoder = 0;
stop_flags[batch_id] = true;
mask_rollback[batch_id] = 0;
} else if (cur_seq_len_encoder == 0) {
cur_seq_len_decoder += output_len;
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
} else {
mask_rollback[batch_id] = 0;
}
if (cur_seq_len_encoder > 0) {
cur_seq_len_decoder += cur_seq_len_encoder;
cur_seq_len_encoder = 0;
}
seq_lens_encoder[batch_id] = cur_seq_len_encoder;
seq_lens_decoder[batch_id] = cur_seq_len_decoder;
step_output_len[batch_id] = output_len;
step_idx[batch_id] = cur_step_idx;
// Write history to token_ids_all
if (cur_step_idx > 0 && output_len > 0) {
// Bounds check: highest write index is prompt_lens + cur_step_idx
if (prompt_lens[batch_id] + cur_step_idx < max_model_len) {
int64_t *token_ids_all_now =
&token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]];
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
for (int i = 0; i < output_len; i++) {
token_ids_all_now[cur_step_idx - i] =
output_ids[output_len - 1 - i];
}
}
}
// Setup next input
if (output_len > 0) {
step_input_ids[batch_id * max_step_tokens] =
step_output_ids[batch_id * max_step_tokens + output_len - 1];
}
if (is_naive_mode) {
seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1;
}
} else if (batch_id >= real_bsz) {
// Padding slot: just count as stopped, don't modify state
stop_flag_int += 1;
} else {
// Stopped or paused slot (batch_id < real_bsz)
stop_flag_int += 1;
stop_flags[batch_id] = true;
seq_lens_decoder[batch_id] = 0;
seq_lens_this_time[batch_id] = 0;
step_output_len[batch_id] = 0;
}
}
has_running_seqs[0] = stop_flag_int < max_bsz;
return api::SUCCESS;
}
static int xpu3_wrapper(api::Context *ctx,
int *seq_lens_encoder,
int *seq_lens_decoder,
bool *has_running_seqs,
int *mask_rollback,
int64_t *step_input_ids,
int *adaptive_step_input_len,
int64_t *step_output_ids,
int *step_output_len,
bool *stop_flags,
int *seq_lens_this_time,
const bool *is_paused,
int64_t *token_ids_all,
const int64_t *prompt_lens,
int64_t *step_idx,
const int64_t *end_tokens,
const int64_t *max_dec_len,
int real_bsz,
int max_bsz,
int max_step_tokens,
int max_model_len,
int num_end_tokens,
bool is_naive_mode,
bool prefill_one_step_stop) {
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
int32_t ret_xre =
fd_xpu3::unified_update_model_status_kernel<<<ctx->ncluster(),
64,
ctx->xpu_stream>>>(
seq_lens_encoder,
seq_lens_decoder,
has_running_seqs,
mask_rollback,
reinterpret_cast<XPU_INT64 *>(step_input_ids),
adaptive_step_input_len,
reinterpret_cast<XPU_INT64 *>(step_output_ids),
step_output_len,
stop_flags,
seq_lens_this_time,
is_paused,
reinterpret_cast<XPU_INT64 *>(token_ids_all),
reinterpret_cast<const XPU_INT64 *>(prompt_lens),
reinterpret_cast<XPU_INT64 *>(step_idx),
reinterpret_cast<const XPU_INT64 *>(end_tokens),
reinterpret_cast<const XPU_INT64 *>(max_dec_len),
real_bsz,
max_bsz,
max_step_tokens,
max_model_len,
num_end_tokens,
is_naive_mode,
prefill_one_step_stop);
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
return api::SUCCESS;
}
int unified_update_model_status(api::Context *ctx,
int *seq_lens_encoder,
int *seq_lens_decoder,
bool *has_running_seqs,
int *mask_rollback,
int64_t *step_input_ids,
int *adaptive_step_input_len,
int64_t *step_output_ids,
int *step_output_len,
bool *stop_flags,
int *seq_lens_this_time,
const bool *is_paused,
int64_t *token_ids_all,
const int64_t *prompt_lens,
int64_t *step_idx,
const int64_t *end_tokens,
const int64_t *max_dec_len,
int real_bsz,
int max_bsz,
int max_step_tokens,
int max_model_len,
int num_end_tokens,
bool is_naive_mode,
bool prefill_one_step_stop) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "unified_update_model_status", int);
WRAPPER_DUMP_PARAM6(ctx,
seq_lens_encoder,
seq_lens_decoder,
has_running_seqs,
mask_rollback,
step_input_ids,
adaptive_step_input_len);
WRAPPER_DUMP_PARAM6(ctx,
step_output_ids,
step_output_len,
stop_flags,
seq_lens_this_time,
is_paused,
token_ids_all);
WRAPPER_DUMP_PARAM6(
ctx, prompt_lens, step_idx, end_tokens, max_dec_len, real_bsz, max_bsz);
WRAPPER_DUMP_PARAM5(ctx,
max_step_tokens,
max_model_len,
num_end_tokens,
is_naive_mode,
prefill_one_step_stop);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_decoder);
WRAPPER_CHECK_PTR(ctx, bool, 1, has_running_seqs);
WRAPPER_CHECK_PTR(ctx, int, max_bsz, mask_rollback);
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_step_tokens, step_input_ids);
// WRAPPER_CHECK_PTR(ctx, int, 0, adaptive_step_input_len); // Temporarily
// unused
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_step_tokens, step_output_ids);
WRAPPER_CHECK_PTR(ctx, int, max_bsz, step_output_len);
WRAPPER_CHECK_PTR(ctx, bool, max_bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, bool, max_bsz, is_paused);
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_model_len, token_ids_all);
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, prompt_lens);
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, step_idx);
WRAPPER_CHECK_PTR(ctx, int64_t, num_end_tokens, end_tokens);
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, max_dec_len);
WRAPPER_ASSERT_GE(ctx, max_bsz, real_bsz);
WRAPPER_ASSERT_GE(ctx, 1024, num_end_tokens);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
seq_lens_encoder,
seq_lens_decoder,
has_running_seqs,
mask_rollback,
step_input_ids,
adaptive_step_input_len,
step_output_ids,
step_output_len,
stop_flags,
seq_lens_this_time,
is_paused,
token_ids_all,
prompt_lens,
step_idx,
end_tokens,
max_dec_len,
real_bsz,
max_bsz,
max_step_tokens,
max_model_len,
num_end_tokens,
is_naive_mode,
prefill_one_step_stop);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
seq_lens_encoder,
seq_lens_decoder,
has_running_seqs,
mask_rollback,
step_input_ids,
adaptive_step_input_len,
step_output_ids,
step_output_len,
stop_flags,
seq_lens_this_time,
is_paused,
token_ids_all,
prompt_lens,
step_idx,
end_tokens,
max_dec_len,
real_bsz,
max_bsz,
max_step_tokens,
max_model_len,
num_end_tokens,
is_naive_mode,
prefill_one_step_stop);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace fastdeploy
@@ -0,0 +1,328 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import speculate_pre_process
def speculate_pre_process_ref(
input_ids,
seq_lens,
draft_tokens,
seq_lens_encoder,
max_seq_len,
max_draft_tokens_per_batch,
real_bsz,
token_num,
):
"""
Python reference implementation for SpeculatePreProcessKernel.
Returns:
ids_remove_padding: int64[token_num]
batch_id_per_token: int32[token_num]
cu_seqlens_q: int32[real_bsz + 1]
cu_seqlens_k: int32[real_bsz + 1]
seq_lens_output: int32[real_bsz]
cu_seq_lens_q_output: int32[real_bsz + 1]
batch_id_per_token_output: int32[real_bsz * max_draft_tokens_per_batch]
real_output_token_num: int32[1]
"""
# --- Part 1: ids_remove_padding, batch_id_per_token, cu_seqlens_q/k ---
ids_remove_padding = np.zeros(token_num, dtype=np.int64)
batch_id_per_token = np.zeros(token_num, dtype=np.int32)
cu_seqlens_q = np.zeros(real_bsz + 1, dtype=np.int32)
cu_seqlens_k = np.zeros(real_bsz + 1, dtype=np.int32)
cum = 0
for bi in range(real_bsz):
cum += seq_lens[bi]
cu_seqlens_q[bi + 1] = cum
cu_seqlens_k[bi + 1] = cum
start = cum - seq_lens[bi]
for i in range(seq_lens[bi]):
tgt = start + i
if max_draft_tokens_per_batch > 0 and seq_lens_encoder[bi] <= 0:
src = bi * max_draft_tokens_per_batch + i
ids_remove_padding[tgt] = draft_tokens[src]
else:
src = bi * max_seq_len + i
ids_remove_padding[tgt] = input_ids[src]
batch_id_per_token[tgt] = bi
# --- Part 2: seq_lens_output ---
seq_lens_output = np.zeros(real_bsz, dtype=np.int32)
for bid in range(real_bsz):
if seq_lens[bid] == 0:
seq_lens_output[bid] = 0
elif seq_lens[bid] == 1:
seq_lens_output[bid] = 1
elif seq_lens_encoder[bid] != 0:
seq_lens_output[bid] = 1
else:
seq_lens_output[bid] = seq_lens[bid]
# --- Part 3: cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num ---
cu_seq_lens_q_output = np.zeros(real_bsz + 1, dtype=np.int32)
batch_id_per_token_output = np.zeros(real_bsz * max_draft_tokens_per_batch, dtype=np.int32)
cum_output = 0
for bi in range(real_bsz):
cum_output += seq_lens_output[bi]
cu_seq_lens_q_output[bi + 1] = cum_output
start_out = cum_output - seq_lens_output[bi]
for i in range(seq_lens_output[bi]):
batch_id_per_token_output[start_out + i] = bi
real_output_token_num = np.array([cum_output], dtype=np.int32)
return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
seq_lens_output,
cu_seq_lens_q_output,
batch_id_per_token_output,
real_output_token_num,
)
def build_inputs(
real_bsz,
max_seq_len,
max_draft_tokens,
seq_lens_list,
seq_lens_encoder_list,
draft_tokens_data=None,
input_ids_data=None,
seed=42,
):
"""
Helper to build test inputs from explicit seq_lens and seq_lens_encoder lists.
draft_tokens_data and input_ids_data are optional; if None, random data is used.
"""
rng = np.random.default_rng(seed)
seq_lens = np.array(seq_lens_list, dtype=np.int32)
seq_lens_encoder = np.array(seq_lens_encoder_list, dtype=np.int32)
seq_lens_decoder = np.zeros(real_bsz, dtype=np.int32) # not used in kernel logic
token_num = int(np.sum(seq_lens))
if input_ids_data is not None:
input_ids = np.array(input_ids_data, dtype=np.int64).reshape(real_bsz, max_seq_len)
else:
input_ids = rng.integers(1, 1000, size=(real_bsz, max_seq_len), dtype=np.int64)
if draft_tokens_data is not None:
draft_tokens = np.array(draft_tokens_data, dtype=np.int64).reshape(real_bsz, max_draft_tokens)
else:
draft_tokens = rng.integers(1, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
return {
"input_ids": input_ids,
"seq_lens": seq_lens,
"draft_tokens": draft_tokens,
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"max_seq_len": max_seq_len,
"max_draft_tokens": max_draft_tokens,
"token_num": token_num,
"real_bsz": real_bsz,
}
def run_and_compare(tc, inputs):
"""
Call GPU op and Python reference, compare all outputs.
tc: unittest.TestCase instance (for assertion messages).
"""
real_bsz = inputs["real_bsz"]
max_seq_len = inputs["max_seq_len"]
max_draft_tokens = inputs["max_draft_tokens"]
token_num = inputs["token_num"]
t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32")
t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64")
t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32")
t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32")
gpu_outs = speculate_pre_process(
token_num, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder
)
ref_outs = speculate_pre_process_ref(
input_ids=inputs["input_ids"].reshape(-1),
seq_lens=inputs["seq_lens"],
draft_tokens=inputs["draft_tokens"].reshape(-1),
seq_lens_encoder=inputs["seq_lens_encoder"],
max_seq_len=max_seq_len,
max_draft_tokens_per_batch=max_draft_tokens,
real_bsz=real_bsz,
token_num=token_num,
)
output_names = [
"ids_remove_padding",
"batch_id_per_token",
"cu_seqlens_q",
"cu_seqlens_k",
"cu_seq_lens_q_output",
"batch_id_per_token_output",
"real_output_token_num",
]
# GPU op returns 7 tensors; ref returns 8 (with seq_lens_output at index 4).
# GPU output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k,
# cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num
# Ref output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k,
# seq_lens_output, cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num
ref_indices = [0, 1, 2, 3, 5, 6, 7] # skip seq_lens_output (index 4) for direct comparison
for name, gpu_idx, ref_idx in zip(output_names, range(7), ref_indices):
gpu_val = gpu_outs[gpu_idx].numpy()
ref_val = ref_outs[ref_idx]
# Trim batch_id_per_token_output to the valid portion (real_output_token_num)
# The kernel only writes valid positions; beyond that the content is undefined.
if name == "batch_id_per_token_output":
valid_len = int(ref_outs[7][0]) # real_output_token_num
gpu_val = gpu_val[:valid_len]
ref_val = ref_val[:valid_len]
np.testing.assert_allclose(
gpu_val,
ref_val,
err_msg=f"Mismatch in output '{name}'",
)
class TestSpeculatePreProcess(unittest.TestCase):
"""Unit tests for speculate_pre_process custom operator."""
# ----------------------------------------------------------------
# Test 1: mixed batch covering all 4 seq_lens_output branches
# bid=0: seq_lens=0 => output=0 (skip)
# bid=1: seq_lens=1, encoder=0 => output=1, read draft_tokens
# bid=2: seq_lens=5, encoder=3 => output=1, read input_ids (prefill)
# bid=3: seq_lens=4, encoder=0 => output=4, read draft_tokens (decode)
# bid=4: seq_lens=1, encoder=2 => output=1, read input_ids (prefill single)
# bid=5: seq_lens=8, encoder=0 => output=8, read draft_tokens (decode saturated)
# ----------------------------------------------------------------
def test_mixed_batch_all_branches(self):
inputs = build_inputs(
real_bsz=6,
max_seq_len=16,
max_draft_tokens=8,
seq_lens_list=[0, 1, 5, 4, 1, 8],
seq_lens_encoder_list=[0, 0, 3, 0, 2, 0],
)
run_and_compare(self, inputs)
# ----------------------------------------------------------------
# Test 2: token_num=0 early return — verify no crash, 7 outputs
# ----------------------------------------------------------------
def test_all_zero_seq_lens(self):
real_bsz = 3
t_input_ids = paddle.zeros([real_bsz, 8], dtype="int64")
t_seq_lens = paddle.zeros([real_bsz], dtype="int32")
t_draft_tokens = paddle.zeros([real_bsz, 4], dtype="int64")
t_seq_lens_encoder = paddle.zeros([real_bsz], dtype="int32")
t_seq_lens_decoder = paddle.zeros([real_bsz], dtype="int32")
gpu_outs = speculate_pre_process(
0, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder
)
self.assertEqual(len(gpu_outs), 7)
self.assertIsNotNone(gpu_outs[-3])
self.assertIsNotNone(gpu_outs[-2])
self.assertIsNotNone(gpu_outs[-1])
# test copy
fake_cu_seqlens_q_output = paddle.empty([real_bsz + 1], dtype="int32")
fake_batch_id_per_token_output = paddle.empty([real_bsz], dtype="int32")
fake_cu_seqlens_q_output.copy_(gpu_outs[-3])
fake_batch_id_per_token_output.copy_(gpu_outs[-2])
# test slice
fake_batch_id_per_token_output[: gpu_outs[-1].item()]
# ----------------------------------------------------------------
# Test 3: exact token values — manually verify ids_remove_padding
# bid=0: encoder=0 (decode) => draft_tokens[0][0:3] = [10,11,12]
# bid=1: encoder=5 (prefill) => input_ids[1][0:2] = [200,201]
# ----------------------------------------------------------------
def test_exact_token_values(self):
inputs = build_inputs(
real_bsz=2,
max_seq_len=4,
max_draft_tokens=4,
seq_lens_list=[3, 2],
seq_lens_encoder_list=[0, 5],
draft_tokens_data=[[10, 11, 12, 13], [20, 21, 22, 23]],
input_ids_data=[[100, 101, 102, 103], [200, 201, 202, 203]],
)
t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32")
t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64")
t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32")
t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32")
gpu_outs = speculate_pre_process(
int(np.sum(inputs["seq_lens"])),
t_input_ids,
t_seq_lens,
t_draft_tokens,
t_seq_lens_encoder,
t_seq_lens_decoder,
)
np.testing.assert_allclose(gpu_outs[0].numpy(), [10, 11, 12, 200, 201])
np.testing.assert_allclose(gpu_outs[1].numpy(), [0, 0, 0, 1, 1])
np.testing.assert_allclose(gpu_outs[2].numpy(), [0, 3, 5])
np.testing.assert_allclose(gpu_outs[6].numpy(), [4]) # real_output_token_num = 3+1
# ----------------------------------------------------------------
# Test 4: random stress test (2 configs covering small & medium batch)
# ----------------------------------------------------------------
def test_random_configs(self):
configs = [
{"real_bsz": 7, "max_seq_len": 32, "max_draft_tokens": 8, "seed": 200},
{"real_bsz": 32, "max_seq_len": 128, "max_draft_tokens": 16, "seed": 400},
]
for cfg in configs:
with self.subTest(**cfg):
rng = np.random.default_rng(cfg["seed"])
real_bsz = cfg["real_bsz"]
max_draft = cfg["max_draft_tokens"]
seq_lens_list = rng.integers(0, max_draft + 1, size=real_bsz).tolist()
seq_lens_encoder_list = rng.integers(0, 3, size=real_bsz).tolist()
inputs = build_inputs(
real_bsz=real_bsz,
max_seq_len=cfg["max_seq_len"],
max_draft_tokens=max_draft,
seq_lens_list=seq_lens_list,
seq_lens_encoder_list=seq_lens_encoder_list,
seed=cfg["seed"],
)
if inputs["token_num"] == 0:
continue
run_and_compare(self, inputs)
if __name__ == "__main__":
unittest.main()
@@ -0,0 +1,574 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for unified_update_model_status kernel.
Kernel semantics (from unified_update_model_status.cu):
- Launched as <<<1, 1024>>>, one thread per batch slot (max_bsz <= 1024).
- real_bsz = seq_lens_this_time.shape[0], max_bsz = stop_flags.shape[0].
- has_running_seqs is a CPU tensor (copied to GPU, kernel writes, copied back).
- Padding slots (batch_id >= real_bsz): only counted as stopped, NO state modified.
- Stopped/paused real slots: set stop_flags=true, seq_lens_decoder=0,
seq_lens_this_time=0, step_output_len=0.
- Running slots: EOS detection → state update → token_ids_all write → next input setup.
"""
import unittest
from typing import Any, Dict
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import unified_update_model_status
CUDA_PLACE = paddle.XPUPlace(0)
CPU_PLACE = paddle.CPUPlace()
# ============================================================
# Layer 1: Helpers — tensor creation / kernel invocation / output extraction
# ============================================================
def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Convert numpy dict → paddle tensors. has_running_seqs goes to CPU."""
paddle_inputs = {}
for k, v in inputs.items():
if isinstance(v, (int, bool, float, str)):
paddle_inputs[k] = v
elif k == "has_running_seqs":
# Kernel host function: has_running_seqs.copy_to(GPU) → kernel → copy_to(CPU)
paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE)
elif v is not None:
paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE)
else:
paddle_inputs[k] = None
return paddle_inputs
def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]):
"""Call unified_update_model_status kernel."""
unified_update_model_status(
paddle_inputs["seq_lens_encoder"],
paddle_inputs["seq_lens_decoder"],
paddle_inputs["has_running_seqs"],
paddle_inputs["step_input_ids"],
paddle_inputs["adaptive_step_input_len"],
paddle_inputs["step_output_ids"],
paddle_inputs["step_output_len"],
paddle_inputs["stop_flags"],
paddle_inputs["seq_lens_this_time"],
paddle_inputs["is_paused"],
paddle_inputs["mask_rollback"],
paddle_inputs["token_ids_all"],
paddle_inputs["prompt_lens"],
paddle_inputs["step_idx"],
paddle_inputs["end_tokens"],
paddle_inputs["max_dec_len"],
inputs["is_naive_mode"],
inputs["prefill_one_step_stop"],
)
# All 12 in-place output keys (from SetInplaceMap in .cu)
OUTPUT_KEYS = [
"seq_lens_encoder",
"seq_lens_decoder",
"has_running_seqs",
"step_input_ids",
"step_output_ids",
"step_output_len",
"stop_flags",
"seq_lens_this_time",
"mask_rollback",
"token_ids_all",
"step_idx",
# adaptive_step_input_len is in InplaceMap but kernel never writes it
]
def get_outputs(paddle_inputs: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""Extract ALL in-place-modified tensors back to numpy."""
return {k: paddle_inputs[k].numpy() for k in OUTPUT_KEYS}
# ============================================================
# Layer 2: Input generation
# ============================================================
def gen_inputs(
real_bsz: int = 8,
max_step_tokens: int = 16,
max_model_len: int = 256,
seed: int = 42,
is_naive_mode: bool = False,
prefill_one_step_stop: bool = False,
) -> Dict[str, Any]:
"""Generate randomized test inputs for unified_update_model_status kernel.
Shapes follow the kernel contract:
- real_bsz = seq_lens_this_time.shape[0]
- max_bsz = stop_flags.shape[0] (= real_bsz + padding)
- is_paused.shape[0] = max_bsz
"""
rng = np.random.default_rng(seed)
max_bsz = real_bsz + 4 # padding slots
# Per-slot arrays (size=max_bsz)
seq_lens_encoder = rng.integers(0, 5, size=max_bsz, dtype=np.int32)
seq_lens_decoder = rng.integers(10, 100, size=max_bsz, dtype=np.int32)
step_input_ids = rng.integers(0, 1000, size=(max_bsz, max_step_tokens), dtype=np.int64)
adaptive_step_input_len = rng.integers(1, max_step_tokens + 1, size=max_bsz, dtype=np.int32)
step_output_ids = rng.integers(0, 1000, size=(max_bsz, max_step_tokens), dtype=np.int64)
step_output_len = rng.integers(1, max_step_tokens + 1, size=max_bsz, dtype=np.int32)
stop_flags = np.zeros(max_bsz, dtype=bool)
# Randomly stop a few real slots
stop_flags[rng.choice(real_bsz, size=min(2, real_bsz), replace=False)] = True
# Padding slots (batch_id >= real_bsz) must be stopped — kernel accesses
# seq_lens_this_time[batch_id] which is only sized real_bsz
stop_flags[real_bsz:] = True
is_paused = np.zeros(max_bsz, dtype=bool)
mask_rollback = np.zeros(max_bsz, dtype=np.int32)
prompt_lens = rng.integers(10, 50, size=max_bsz, dtype=np.int64)
token_ids_all = rng.integers(0, 1000, size=(max_bsz, max_model_len), dtype=np.int64)
step_idx = rng.integers(0, 50, size=max_bsz, dtype=np.int64)
max_dec_len = rng.integers(100, 200, size=max_bsz, dtype=np.int64)
# Per-real-batch arrays (size=real_bsz)
seq_lens_this_time = rng.integers(1, max_step_tokens + 1, size=real_bsz, dtype=np.int32)
# Scalar / small tensors
has_running_seqs = np.array([True], dtype=bool)
end_tokens = rng.integers(1, 1000, size=4, dtype=np.int64)
return {
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"has_running_seqs": has_running_seqs,
"step_input_ids": step_input_ids,
"adaptive_step_input_len": adaptive_step_input_len,
"step_output_ids": step_output_ids,
"step_output_len": step_output_len,
"stop_flags": stop_flags,
"seq_lens_this_time": seq_lens_this_time,
"is_paused": is_paused,
"mask_rollback": mask_rollback,
"token_ids_all": token_ids_all,
"prompt_lens": prompt_lens,
"step_idx": step_idx,
"end_tokens": end_tokens,
"max_dec_len": max_dec_len,
# Scalar configs
"real_bsz": real_bsz,
"max_bsz": max_bsz,
"max_step_tokens": max_step_tokens,
"max_model_len": max_model_len,
"is_naive_mode": is_naive_mode,
"prefill_one_step_stop": prefill_one_step_stop,
}
# ============================================================
# Layer 3: Reference implementation (1:1 with CUDA kernel)
# ============================================================
def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Python reference of unified_update_model_status_kernel.
Line references are to unified_update_model_status.cu.
"""
# Deep-copy all mutable in-place tensors
seq_lens_encoder = inputs["seq_lens_encoder"].copy()
seq_lens_decoder = inputs["seq_lens_decoder"].copy()
step_output_len = inputs["step_output_len"].copy()
stop_flags = inputs["stop_flags"].copy()
seq_lens_this_time = inputs["seq_lens_this_time"].copy()
mask_rollback = inputs["mask_rollback"].copy()
token_ids_all = inputs["token_ids_all"].copy()
step_idx = inputs["step_idx"].copy()
step_input_ids = inputs["step_input_ids"].copy()
step_output_ids = inputs["step_output_ids"].copy()
# Read-only inputs
real_bsz = inputs["real_bsz"]
max_bsz = inputs["max_bsz"]
max_model_len = inputs["max_model_len"]
is_naive_mode = inputs["is_naive_mode"]
prefill_one_step_stop = inputs["prefill_one_step_stop"]
end_tokens = inputs["end_tokens"]
num_end_tokens = len(end_tokens)
max_dec_len = inputs["max_dec_len"]
prompt_lens = inputs["prompt_lens"]
is_paused = inputs["is_paused"]
# Block-level stop count for has_running_seqs reduction (line 175)
stop_count = 0
for batch_id in range(max_bsz):
# --- line 68-75: Read state ---
cur_seq_len_encoder = int(seq_lens_encoder[batch_id])
cur_seq_len_decoder = int(seq_lens_decoder[batch_id])
cur_stop_flag = bool(stop_flags[batch_id])
output_len = 0
cur_step_idx = int(step_idx[batch_id])
cur_is_paused = bool(is_paused[batch_id])
# line 77
is_running = not cur_stop_flag and not cur_is_paused
# --- line 80-86: Compute output length ---
if is_running:
output_len = 1 if is_naive_mode else int(step_output_len[batch_id])
# --- line 89-110: EOS detection ---
if is_running and output_len > 0:
hit_stop = False
for i in range(output_len):
cur_step_idx += 1 # line 94
token = int(step_output_ids[batch_id, i]) # line 95
is_eos = any(token == end_tokens[j] for j in range(num_end_tokens)) # line 96
max_len_hit = cur_step_idx >= int(max_dec_len[batch_id]) # line 97
if is_eos or max_len_hit: # line 99
if not is_eos:
step_output_ids[batch_id, i] = end_tokens[0] # line 100
output_len = i + 1 # line 101
cur_stop_flag = True # line 102
hit_stop = True # line 103
break # line 104
# line 108-110
if not hit_stop and prefill_one_step_stop and cur_seq_len_encoder > 0:
cur_stop_flag = True
# --- line 114-166: Update state and write back ---
if is_running:
if cur_stop_flag:
# line 115-119
stop_count += 1
if output_len == 0:
cur_seq_len_decoder = 0 # line 117
stop_flags[batch_id] = True # line 118
mask_rollback[batch_id] = 0 # line 119
elif cur_seq_len_encoder == 0:
# line 120-122
cur_seq_len_decoder += output_len # line 121
mask_rollback[batch_id] = int(seq_lens_this_time[batch_id]) - output_len # line 122
else:
# line 123-124 (encoder > 0, not stopped)
mask_rollback[batch_id] = 0
# line 127-130: Fold encoder into decoder
if cur_seq_len_encoder > 0:
cur_seq_len_decoder += cur_seq_len_encoder # line 128
cur_seq_len_encoder = 0 # line 129
# line 132-135: Write back scalar state
seq_lens_encoder[batch_id] = cur_seq_len_encoder
seq_lens_decoder[batch_id] = cur_seq_len_decoder
step_output_len[batch_id] = output_len
step_idx[batch_id] = cur_step_idx
# line 138-145: Write history to token_ids_all
if cur_step_idx > 0 and output_len > 0:
base = int(prompt_lens[batch_id])
for i in range(output_len):
# token_ids_all_now[cur_step_idx - i] = output_ids[output_len - 1 - i]
write_idx = base + cur_step_idx - i
if 0 <= write_idx < max_model_len:
token_ids_all[batch_id, write_idx] = step_output_ids[batch_id, output_len - 1 - i]
# line 148-151: Setup next step_input_ids
if output_len > 0:
step_input_ids[batch_id, 0] = step_output_ids[batch_id, output_len - 1]
# line 153-155: naive_mode → seq_lens_this_time
if is_naive_mode:
seq_lens_this_time[batch_id] = 0 if cur_stop_flag else 1
elif batch_id >= real_bsz:
# line 156-158: Padding slot — only count, don't modify state
stop_count += 1
else:
# line 159-166: Stopped or paused real slot
stop_count += 1
stop_flags[batch_id] = True # line 162
seq_lens_decoder[batch_id] = 0 # line 163
seq_lens_this_time[batch_id] = 0 # line 164
step_output_len[batch_id] = 0 # line 165
# line 177-179: has_running_seqs = stop_sum < max_bsz
has_running_seqs = np.array([stop_count < max_bsz], dtype=bool)
return {
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"has_running_seqs": has_running_seqs,
"step_input_ids": step_input_ids,
"step_output_ids": step_output_ids,
"step_output_len": step_output_len,
"stop_flags": stop_flags,
"seq_lens_this_time": seq_lens_this_time,
"mask_rollback": mask_rollback,
"token_ids_all": token_ids_all,
"step_idx": step_idx,
}
# ============================================================
# Layer 4a: TEST_CONFIGS
# ============================================================
TEST_CONFIGS = [
# --- basic mode coverage ---
{
"name": "mtp_mode",
"real_bsz": 8,
"max_step_tokens": 16,
"max_model_len": 256,
"seed": 42,
"is_naive_mode": False,
},
{
"name": "naive_mode",
"real_bsz": 8,
"max_step_tokens": 16,
"max_model_len": 256,
"seed": 42,
"is_naive_mode": True,
},
# --- batch size ---
{
"name": "small_batch",
"real_bsz": 1,
"max_step_tokens": 8,
"max_model_len": 128,
"seed": 42,
"is_naive_mode": False,
},
{
"name": "large_batch",
"real_bsz": 32,
"max_step_tokens": 16,
"max_model_len": 512,
"seed": 42,
"is_naive_mode": False,
},
# --- prefill_one_step_stop ---
{
"name": "prefill_one_step_stop",
"real_bsz": 8,
"max_step_tokens": 8,
"max_model_len": 128,
"seed": 42,
"is_naive_mode": False,
"prefill_one_step_stop": True,
},
# --- different seeds for randomized coverage ---
{
"name": "seed_100",
"real_bsz": 8,
"max_step_tokens": 16,
"max_model_len": 256,
"seed": 100,
"is_naive_mode": False,
},
{
"name": "seed_200_naive",
"real_bsz": 8,
"max_step_tokens": 16,
"max_model_len": 256,
"seed": 200,
"is_naive_mode": True,
},
]
# ============================================================
# Layer 4b: Test suite
# ============================================================
class TestUnifiedUpdateModelStatus(unittest.TestCase):
def setUp(self):
if not paddle.is_compiled_with_xpu():
self.skipTest("Requires XPU")
# ------ shared helpers ------
def _run_and_get(self, inputs: Dict[str, Any]) -> Dict[str, np.ndarray]:
paddle_inputs = to_paddle_inputs(inputs)
run_kernel(paddle_inputs, inputs)
return get_outputs(paddle_inputs)
def _check_all_outputs(self, inputs: Dict[str, Any], outputs: Dict[str, np.ndarray]):
"""Compare ALL output tensors against reference + sanity checks."""
ref = reference_impl(inputs)
for key in OUTPUT_KEYS:
if not np.array_equal(outputs[key], ref[key]):
diff_mask = outputs[key] != ref[key]
diff_indices = np.argwhere(diff_mask)
for idx in diff_indices[:10]: # print first 10 mismatches
idx_tuple = tuple(idx)
print(
f" [{key}] mismatch at {idx_tuple}: "
f"gpu={outputs[key][idx_tuple]} ref={ref[key][idx_tuple]}"
)
if key == "token_ids_all":
bid = idx_tuple[0]
print(
f" batch_id={bid}, prompt_lens={inputs['prompt_lens'][bid]}, "
f"step_idx(input)={inputs['step_idx'][bid]}, "
f"step_idx(gpu)={outputs['step_idx'][bid]}, "
f"step_idx(ref)={ref['step_idx'][bid]}, "
f"step_output_len(gpu)={outputs['step_output_len'][bid]}, "
f"step_output_len(ref)={ref['step_output_len'][bid]}, "
f"stop_flags(input)={inputs['stop_flags'][bid]}, "
f"is_paused={inputs['is_paused'][bid]}, "
f"seq_lens_encoder={inputs['seq_lens_encoder'][bid]}"
)
np.testing.assert_array_equal(outputs[key], ref[key], err_msg=f"{key} mismatch")
# Sanity: running slots must have encoder zeroed
for i in range(inputs["real_bsz"]):
if not inputs["stop_flags"][i] and not inputs["is_paused"][i]:
self.assertEqual(outputs["seq_lens_encoder"][i], 0, f"Running slot {i} should have encoder=0")
self.assertTrue(np.all(outputs["seq_lens_decoder"] >= 0), "negative seq_lens_decoder")
self.assertTrue(np.all(outputs["step_output_len"] >= 0), "negative step_output_len")
self.assertTrue(np.all(outputs["step_idx"] >= 0), "negative step_idx")
def _run_full_test(self, config: Dict[str, Any]) -> Dict[str, np.ndarray]:
inputs = gen_inputs(**config)
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
return outputs
# ------ test cases ------
def test_configs(self):
"""Run all TEST_CONFIGS via subTest."""
for cfg in TEST_CONFIGS:
with self.subTest(name=cfg["name"]):
test_cfg = {k: v for k, v in cfg.items() if k != "name"}
self._run_full_test(test_cfg)
def test_eos_detection(self):
"""EOS token at position 2 should truncate output_len to 3."""
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42)
eos_token = int(inputs["end_tokens"][0])
inputs["step_output_ids"][0, 2] = eos_token
inputs["step_output_len"][:] = [5, 3, 0, 0, 0, 0]
inputs["stop_flags"][: inputs["real_bsz"]] = False
inputs["is_paused"][:] = False
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
def test_max_dec_len_stop(self):
"""step_idx near max_dec_len should trigger stop and replace with end_tokens[0]."""
# Use large max_model_len to avoid token_ids_all overflow:
# kernel doesn't bounds-check prompt_lens + step_idx < max_model_len
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=512, seed=42)
inputs["step_idx"][:] = [95, 50, 0, 0, 0, 0]
inputs["max_dec_len"][:] = 100
inputs["step_output_len"][:] = [10, 5, 0, 0, 0, 0]
inputs["stop_flags"][: inputs["real_bsz"]] = False
inputs["is_paused"][:] = False
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
def test_paused_slots(self):
"""Paused slots should be treated as stopped/paused (decoder=0, output_len=0)."""
inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42)
inputs["is_paused"][:] = [True, True, False, False, False, False, False, False]
inputs["stop_flags"][: inputs["real_bsz"]] = False
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
def test_all_stopped(self):
"""All slots stopped → has_running_seqs should be False."""
inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42)
inputs["stop_flags"][:] = True
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
def test_encoder_to_decoder(self):
"""Encoder length should fold into decoder: decoder += encoder, encoder → 0."""
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42)
inputs["seq_lens_encoder"][:] = [10, 0, 0, 0, 0, 0]
inputs["seq_lens_decoder"][:] = [20, 30, 0, 0, 0, 0]
inputs["step_output_len"][:] = [5, 3, 0, 0, 0, 0]
inputs["stop_flags"][: inputs["real_bsz"]] = False
inputs["is_paused"][:] = False
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
def test_token_ids_all_writing(self):
"""token_ids_all should be written at prompt_lens + step_idx positions."""
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42)
inputs["step_idx"][:] = [10, 20, 0, 0, 0, 0]
inputs["prompt_lens"][:] = [5, 5, 0, 0, 0, 0]
inputs["step_output_len"][:] = [3, 2, 0, 0, 0, 0]
inputs["stop_flags"][: inputs["real_bsz"]] = False
inputs["is_paused"][:] = False
inputs["seq_lens_encoder"][:] = 0
# Use end_tokens that won't collide with output_ids
inputs["end_tokens"][:] = [9990, 9991, 9992, 9993]
inputs["max_dec_len"][:] = 10000
inputs["step_output_ids"][0, :3] = [100, 200, 300]
inputs["step_output_ids"][1, :2] = [400, 500]
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
def test_zero_output_len(self):
"""Running slot with output_len=0 in MTP mode: output_len stays 0."""
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42)
inputs["step_output_len"][:] = [0, 5, 0, 0, 0, 0]
inputs["stop_flags"][: inputs["real_bsz"]] = False
inputs["is_paused"][:] = False
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
def test_prefill_one_step_stop_with_encoder(self):
"""prefill_one_step_stop + encoder>0 should stop even without EOS."""
inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42, prefill_one_step_stop=True)
inputs["seq_lens_encoder"][:] = [5, 0, 0, 0, 0, 0, 0, 0]
inputs["stop_flags"][: inputs["real_bsz"]] = False
inputs["is_paused"][:] = False
# Ensure no accidental EOS hit
inputs["end_tokens"][:] = [9990, 9991, 9992, 9993]
inputs["max_dec_len"][:] = 10000
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
def test_mask_rollback(self):
"""mask_rollback = seq_lens_this_time - output_len for running decode slots."""
inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42)
inputs["stop_flags"][: inputs["real_bsz"]] = False
inputs["is_paused"][:] = False
inputs["seq_lens_encoder"][:] = 0 # All decode slots
inputs["seq_lens_this_time"][:] = [6, 4, 8, 3]
inputs["step_output_len"][:] = [3, 2, 5, 1, 0, 0, 0, 0]
# Avoid EOS/max_dec_len hits
inputs["end_tokens"][:] = [9990, 9991, 9992, 9993]
inputs["max_dec_len"][:] = 10000
outputs = self._run_and_get(inputs)
self._check_all_outputs(inputs, outputs)
if __name__ == "__main__":
unittest.main()
@@ -273,6 +273,8 @@ class XPUForwardMeta(ForwardMeta):
hidden_states: Optional[paddle.Tensor] = None hidden_states: Optional[paddle.Tensor] = None
is_draft: bool = False is_draft: bool = False
# max bs
max_num_seqs: int = 0
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None): def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
""" """
@@ -196,7 +196,6 @@ class XPUAttentionBackend(AttentionBackend):
qkv, qkv,
forward_meta.caches[2 * layer.layer_id], forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1], forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.cum_offsets,
metadata.rotary_embs, metadata.rotary_embs,
metadata.block_tables, metadata.block_tables,
forward_meta.prefix_block_tables, forward_meta.prefix_block_tables,
@@ -1069,8 +1069,8 @@ class SpeculativeSampler(nn.Layer):
sampling_metadata.min_dec_lens, sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids, sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"], share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"], share_inputs["batch_id_per_token_output"],
share_inputs["output_cum_offsets"], share_inputs["cu_seqlens_q_output"],
max_model_len, max_model_len,
sampling_metadata.pre_token_ids, sampling_metadata.pre_token_ids,
) )
@@ -1091,7 +1091,7 @@ class SpeculativeSampler(nn.Layer):
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs, probs,
sampling_metadata.top_p, sampling_metadata.top_p,
share_inputs["output_padding_offset"], share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len, self.speculative_max_candidate_len,
max_model_len, max_model_len,
) )
@@ -1113,7 +1113,7 @@ class SpeculativeSampler(nn.Layer):
share_inputs["max_dec_len"], share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids, sampling_metadata.eos_token_ids,
share_inputs["is_block_step"], share_inputs["is_block_step"],
share_inputs["output_cum_offsets"], share_inputs["cu_seqlens_q_output"],
actual_candidate_len, actual_candidate_len,
share_inputs["actual_draft_token_num"], share_inputs["actual_draft_token_num"],
sampling_metadata.top_p, sampling_metadata.top_p,
@@ -1338,8 +1338,8 @@ class MTPSampler(nn.Layer):
sampling_metadata.min_dec_lens, sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids, sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"], share_inputs["seq_lens_this_time"],
share_inputs["output_padding_offset"], share_inputs["batch_id_per_token_output"],
share_inputs["output_cum_offsets"], share_inputs["cu_seqlens_q_output"],
max_model_len, max_model_len,
sampling_metadata.pre_token_ids, sampling_metadata.pre_token_ids,
) )
@@ -40,9 +40,7 @@ if current_platform.is_xpu():
save_output_topk, save_output_topk,
set_stop_value_multi_ends, set_stop_value_multi_ends,
speculate_clear_accept_nums, speculate_clear_accept_nums,
speculate_get_output_padding_offset, speculate_pre_process,
speculate_get_padding_offset,
speculate_get_seq_lens_output,
speculate_save_output, speculate_save_output,
speculate_set_stop_value_multi_seqs, speculate_set_stop_value_multi_seqs,
speculate_set_value_by_flags_and_idx, speculate_set_value_by_flags_and_idx,
@@ -109,51 +107,32 @@ def xpu_pre_process(
) -> XPUForwardMeta: ) -> XPUForwardMeta:
""" """ """ """
max_len = input_ids.shape[1] max_len = input_ids.shape[1]
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
token_num = paddle.sum(seq_lens_this_time)
token_num_cpu = paddle.sum(seq_lens_this_time).cpu()
if use_speculate_method: if use_speculate_method:
( (
ids_remove_padding, ids_remove_padding,
cum_offsets,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
) = speculate_get_padding_offset( cu_seqlens_q_output,
input_ids, batch_id_per_token_output,
draft_tokens, real_output_token_num,
cum_offsets_now, ) = speculate_pre_process(
token_num, token_num_cpu, input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, seq_lens_decoder
seq_lens_this_time,
seq_lens_encoder,
) )
seq_lens_output = speculate_get_seq_lens_output( share_inputs["cu_seqlens_q_output"] = cu_seqlens_q_output
seq_lens_this_time, share_inputs["batch_id_per_token_output"] = batch_id_per_token_output
seq_lens_encoder,
seq_lens_decoder,
)
if isinstance(seq_lens_output, list):
seq_lens_output = seq_lens_output[0]
output_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32")
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
output_cum_offsets_tmp,
output_token_num,
seq_lens_output,
max_len,
)
share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
else: else:
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
( (
ids_remove_padding, ids_remove_padding,
cum_offsets, cum_offsets,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) ) = get_padding_offset(input_ids, cum_offsets_now, token_num_cpu, seq_lens_this_time)
share_inputs["cum_offsets"] = cum_offsets
share_inputs["batch_id_per_token"] = batch_id_per_token share_inputs["batch_id_per_token"] = batch_id_per_token
share_inputs["cu_seqlens_q"] = cu_seqlens_q share_inputs["cu_seqlens_q"] = cu_seqlens_q
share_inputs["cu_seqlens_k"] = cu_seqlens_k share_inputs["cu_seqlens_k"] = cu_seqlens_k
@@ -165,12 +144,12 @@ def xpu_pre_process(
seq_lens_encoder=share_inputs["seq_lens_encoder"], seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"], seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"], seq_lens_this_time=share_inputs["seq_lens_this_time"],
cum_offsets=share_inputs["cum_offsets"],
batch_id_per_token=share_inputs["batch_id_per_token"], batch_id_per_token=share_inputs["batch_id_per_token"],
cu_seqlens_q=share_inputs["cu_seqlens_q"], cu_seqlens_q=share_inputs["cu_seqlens_q"],
cu_seqlens_k=share_inputs["cu_seqlens_k"], cu_seqlens_k=share_inputs["cu_seqlens_k"],
block_tables=share_inputs["block_tables"], block_tables=share_inputs["block_tables"],
caches=share_inputs["caches"], caches=share_inputs["caches"],
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
) )
( (
@@ -205,7 +184,6 @@ def xpu_pre_process(
adjusted_input = adjust_batch( adjusted_input = adjust_batch(
ids_remove_padding.reshape([-1, 1]), ids_remove_padding.reshape([-1, 1]),
cum_offsets,
xpu_forward_meta.encoder_seq_lod, xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod, xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_batch_idx, xpu_forward_meta.encoder_batch_idx,
@@ -237,7 +215,6 @@ def xpu_pre_process(
def xpu_process_output( def xpu_process_output(
forward_output, forward_output,
cum_offsets: paddle.Tensor,
xpu_forward_meta: XPUForwardMeta, xpu_forward_meta: XPUForwardMeta,
share_inputs, share_inputs,
) -> paddle.Tensor: ) -> paddle.Tensor:
@@ -250,7 +227,6 @@ def xpu_process_output(
hiddden_states = gather_next_token( hiddden_states = gather_next_token(
forward_output, forward_output,
cum_offsets,
xpu_forward_meta.encoder_seq_lod, xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod, xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_batch_map, xpu_forward_meta.encoder_batch_map,
@@ -261,7 +237,7 @@ def xpu_process_output(
xpu_forward_meta.decoder_batch_map_cpu, xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.len_info_cpu, xpu_forward_meta.len_info_cpu,
output_padding_offset, # output_padding_offset output_padding_offset, # output_padding_offset
-1, # max_input_length xpu_forward_meta.max_num_seqs,
) )
return hiddden_states return hiddden_states
+2 -8
View File
@@ -820,11 +820,7 @@ class MTPProposer(Proposer):
# Note(ZKK): # Note(ZKK):
# I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend # I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend
# like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358 # like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358
( self.model_inputs["cu_seqlens_q_output"],
self.model_inputs["cu_seqlens_q_output"]
if current_platform.is_cuda()
else self.model_inputs["output_cum_offsets"]
),
self.model_inputs["stop_flags"], self.model_inputs["stop_flags"],
( (
self.model_inputs["not_need_stop_device"] self.model_inputs["not_need_stop_device"]
@@ -1125,9 +1121,7 @@ class MTPProposer(Proposer):
previous_hidden_states=self.model_inputs["target_hidden_states"], previous_hidden_states=self.model_inputs["target_hidden_states"],
forward_meta=self.forward_meta, forward_meta=self.forward_meta,
) )
hidden_states = xpu_process_output( hidden_states = xpu_process_output(model_output, self.forward_meta, self.model_inputs)
model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs
)
# 4. Compute logits, Sample # 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
sampled_token_ids, sampler_output = self.sampler( sampled_token_ids, sampler_output = self.sampler(
+3 -3
View File
@@ -1114,6 +1114,8 @@ class XPUModelRunner(ModelRunnerBase):
self.cache_config.block_size, self.cache_config.block_size,
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0, self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
) )
# TODO(chenhuan): support cached_token_num
self.forward_meta = xpu_pre_process( self.forward_meta = xpu_pre_process(
self.share_inputs["input_ids"], self.share_inputs["input_ids"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
@@ -1577,9 +1579,7 @@ class XPUModelRunner(ModelRunnerBase):
) )
if self.use_cudagraph: if self.use_cudagraph:
model_output = model_output[: self.real_token_num] model_output = model_output[: self.real_token_num]
hidden_states = xpu_process_output( hidden_states = xpu_process_output(model_output, self.forward_meta, self.share_inputs)
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
)
# 4. Compute logits, Sample # 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states) logits = self.model.compute_logits(hidden_states)
sampler_output = None sampler_output = None
+2
View File
@@ -346,6 +346,8 @@ def test_mtp_sampler_xpu_and_compute(mock_ops, monkeypatch):
"batch_token_num": paddle.to_tensor([[1]], dtype="int64"), "batch_token_num": paddle.to_tensor([[1]], dtype="int64"),
"output_padding_offset": paddle.zeros([1, 1], dtype="int64"), "output_padding_offset": paddle.zeros([1, 1], dtype="int64"),
"output_cum_offsets": paddle.zeros([1, 1], dtype="int64"), "output_cum_offsets": paddle.zeros([1, 1], dtype="int64"),
"batch_id_per_token_output": paddle.to_tensor([0], dtype="int32"),
"cu_seqlens_q_output": paddle.to_tensor([0, 1], dtype="int32"),
} }
monkeypatch.setattr( monkeypatch.setattr(
"fastdeploy.model_executor.layers.sample.sampler.top_k_top_p_sampling", "fastdeploy.model_executor.layers.sample.sampler.top_k_top_p_sampling",