mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[XPU] Refactor pre process (#6993)
* [XPU] support speculate_pre_process * merge develop * fix codestype * fix mtp, support cu_seqlens_q_output * fix mtp, support cu_seqlens_q_output * fix test --------- Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
@@ -58,9 +58,8 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
const int bsz = seq_len.shape()[0];
|
const int bsz = seq_len.shape()[0];
|
||||||
const int seq_length = input_ids_shape[1];
|
const int seq_length = input_ids_shape[1];
|
||||||
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
|
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
|
||||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
// token num is cpu tensor
|
||||||
|
const int token_num_data = token_num.data<int64_t>()[0];
|
||||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
|
||||||
auto x_remove_padding = paddle::empty(
|
auto x_remove_padding = paddle::empty(
|
||||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||||
auto padding_offset = paddle::empty(
|
auto padding_offset = paddle::empty(
|
||||||
|
|||||||
@@ -24,8 +24,7 @@
|
|||||||
|
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
std::vector<paddle::Tensor> AdjustBatchKernel(
|
std::vector<paddle::Tensor> AdjustBatchKernel(
|
||||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
|
||||||
const paddle::Tensor &encoder_seq_lod,
|
const paddle::Tensor &encoder_seq_lod,
|
||||||
const paddle::Tensor &decoder_seq_lod,
|
const paddle::Tensor &decoder_seq_lod,
|
||||||
const paddle::Tensor &encoder_batch_idx,
|
const paddle::Tensor &encoder_batch_idx,
|
||||||
@@ -49,7 +48,6 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
|
|||||||
using data_t = typename PDTraits<T>::data_t;
|
using data_t = typename PDTraits<T>::data_t;
|
||||||
const int token_num = x.dims()[0];
|
const int token_num = x.dims()[0];
|
||||||
const int dim = x.dims()[1];
|
const int dim = x.dims()[1];
|
||||||
const int bsz = cum_offsets.shape()[0];
|
|
||||||
int enc_batch = len_info_cpu.data<int32_t>()[0];
|
int enc_batch = len_info_cpu.data<int32_t>()[0];
|
||||||
int dec_batch = len_info_cpu.data<int32_t>()[1];
|
int dec_batch = len_info_cpu.data<int32_t>()[1];
|
||||||
|
|
||||||
@@ -87,8 +85,7 @@ std::vector<paddle::Tensor> AdjustBatchKernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
|
using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
|
||||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
|
||||||
const paddle::Tensor &encoder_seq_lod,
|
const paddle::Tensor &encoder_seq_lod,
|
||||||
const paddle::Tensor &decoder_seq_lod,
|
const paddle::Tensor &decoder_seq_lod,
|
||||||
const paddle::Tensor &encoder_batch_idx,
|
const paddle::Tensor &encoder_batch_idx,
|
||||||
@@ -102,8 +99,7 @@ using AdjustBatchKernelFuncPtr = std::vector<paddle::Tensor> (*)(
|
|||||||
int max_input_length);
|
int max_input_length);
|
||||||
|
|
||||||
std::vector<paddle::Tensor> AdjustBatch(
|
std::vector<paddle::Tensor> AdjustBatch(
|
||||||
const paddle::Tensor &x, // [token_num, dim_embed]
|
const paddle::Tensor &x, // [token_num, dim_embed]
|
||||||
const paddle::Tensor &cum_offsets, // [bsz, 1]
|
|
||||||
const paddle::Tensor &encoder_seq_lod,
|
const paddle::Tensor &encoder_seq_lod,
|
||||||
const paddle::Tensor &decoder_seq_lod,
|
const paddle::Tensor &decoder_seq_lod,
|
||||||
const paddle::Tensor &encoder_batch_idx,
|
const paddle::Tensor &encoder_batch_idx,
|
||||||
@@ -135,7 +131,6 @@ std::vector<paddle::Tensor> AdjustBatch(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(x,
|
return func(x,
|
||||||
cum_offsets,
|
|
||||||
encoder_seq_lod,
|
encoder_seq_lod,
|
||||||
decoder_seq_lod,
|
decoder_seq_lod,
|
||||||
encoder_batch_idx,
|
encoder_batch_idx,
|
||||||
@@ -151,7 +146,6 @@ std::vector<paddle::Tensor> AdjustBatch(
|
|||||||
|
|
||||||
std::vector<std::vector<int64_t>> AdjustBatchInferShape(
|
std::vector<std::vector<int64_t>> AdjustBatchInferShape(
|
||||||
const std::vector<int64_t> &x_shape,
|
const std::vector<int64_t> &x_shape,
|
||||||
const std::vector<int64_t> &cum_offsets_shape,
|
|
||||||
const std::vector<int64_t> &encoder_seq_lod_shape,
|
const std::vector<int64_t> &encoder_seq_lod_shape,
|
||||||
const std::vector<int64_t> &decoder_seq_lod_shape,
|
const std::vector<int64_t> &decoder_seq_lod_shape,
|
||||||
const std::vector<int64_t> &encoder_batch_idx_shape,
|
const std::vector<int64_t> &encoder_batch_idx_shape,
|
||||||
@@ -172,7 +166,6 @@ std::vector<std::vector<int64_t>> AdjustBatchInferShape(
|
|||||||
|
|
||||||
std::vector<paddle::DataType> AdjustBatchInferDtype(
|
std::vector<paddle::DataType> AdjustBatchInferDtype(
|
||||||
const paddle::DataType &x_dtype,
|
const paddle::DataType &x_dtype,
|
||||||
const paddle::DataType &cum_offsets_dtype,
|
|
||||||
const paddle::DataType &encoder_seq_lod_dtype,
|
const paddle::DataType &encoder_seq_lod_dtype,
|
||||||
const paddle::DataType &decoder_seq_lod_dtype,
|
const paddle::DataType &decoder_seq_lod_dtype,
|
||||||
const paddle::DataType &encoder_batch_idx_dtype,
|
const paddle::DataType &encoder_batch_idx_dtype,
|
||||||
@@ -188,7 +181,6 @@ std::vector<paddle::DataType> AdjustBatchInferDtype(
|
|||||||
|
|
||||||
PD_BUILD_STATIC_OP(adjust_batch)
|
PD_BUILD_STATIC_OP(adjust_batch)
|
||||||
.Inputs({"x",
|
.Inputs({"x",
|
||||||
"cum_offsets",
|
|
||||||
"encoder_seq_lod",
|
"encoder_seq_lod",
|
||||||
"decoder_seq_lod",
|
"decoder_seq_lod",
|
||||||
"encoder_batch_idx",
|
"encoder_batch_idx",
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
const paddle::Tensor& qkv,
|
const paddle::Tensor& qkv,
|
||||||
const paddle::Tensor& key_cache,
|
const paddle::Tensor& key_cache,
|
||||||
const paddle::Tensor& value_cache,
|
const paddle::Tensor& value_cache,
|
||||||
const paddle::Tensor& cum_offsets,
|
|
||||||
const paddle::Tensor& rotary_embs,
|
const paddle::Tensor& rotary_embs,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& prefix_block_tables,
|
const paddle::Tensor& prefix_block_tables,
|
||||||
@@ -122,7 +121,6 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
auto qkv_shape = qkv.dims();
|
auto qkv_shape = qkv.dims();
|
||||||
auto cache_shape = key_cache.dims();
|
auto cache_shape = key_cache.dims();
|
||||||
auto block_table_shape = block_tables.dims();
|
auto block_table_shape = block_tables.dims();
|
||||||
const int bsz = cum_offsets.dims()[0];
|
|
||||||
const int block_batch = block_table_shape[0];
|
const int block_batch = block_table_shape[0];
|
||||||
const int max_block_per_seq = block_table_shape[1];
|
const int max_block_per_seq = block_table_shape[1];
|
||||||
const int kv_num_heads = cache_shape[1];
|
const int kv_num_heads = cache_shape[1];
|
||||||
@@ -984,7 +982,6 @@ std::vector<paddle::Tensor> BlockAttn(
|
|||||||
const paddle::Tensor& qkv,
|
const paddle::Tensor& qkv,
|
||||||
const paddle::Tensor& key_cache,
|
const paddle::Tensor& key_cache,
|
||||||
const paddle::Tensor& value_cache,
|
const paddle::Tensor& value_cache,
|
||||||
const paddle::Tensor& cum_offsets,
|
|
||||||
const paddle::Tensor& rotary_embs,
|
const paddle::Tensor& rotary_embs,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& prefix_block_tables,
|
const paddle::Tensor& prefix_block_tables,
|
||||||
@@ -1023,7 +1020,6 @@ std::vector<paddle::Tensor> BlockAttn(
|
|||||||
return BlockAttnKernel<TX, TC, TS>(qkv, \
|
return BlockAttnKernel<TX, TC, TS>(qkv, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
cum_offsets, \
|
|
||||||
rotary_embs, \
|
rotary_embs, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
prefix_block_tables, \
|
prefix_block_tables, \
|
||||||
@@ -1099,7 +1095,6 @@ PD_BUILD_STATIC_OP(block_attn)
|
|||||||
.Inputs({"qkv",
|
.Inputs({"qkv",
|
||||||
"key_cache",
|
"key_cache",
|
||||||
"value_cache",
|
"value_cache",
|
||||||
"cum_offsets",
|
|
||||||
"rotary_embs",
|
"rotary_embs",
|
||||||
"block_tables",
|
"block_tables",
|
||||||
"prefix_block_tables",
|
"prefix_block_tables",
|
||||||
|
|||||||
@@ -22,8 +22,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GatherNextToken(
|
std::vector<paddle::Tensor> GatherNextToken(
|
||||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
|
||||||
const paddle::Tensor& encoder_seq_lod,
|
const paddle::Tensor& encoder_seq_lod,
|
||||||
const paddle::Tensor& decoder_seq_lod,
|
const paddle::Tensor& decoder_seq_lod,
|
||||||
const paddle::Tensor& encoder_batch_map,
|
const paddle::Tensor& encoder_batch_map,
|
||||||
@@ -46,7 +45,7 @@ std::vector<paddle::Tensor> GatherNextToken(
|
|||||||
typedef paddle::bfloat16 data_t;
|
typedef paddle::bfloat16 data_t;
|
||||||
const int dim = x.dims()[1];
|
const int dim = x.dims()[1];
|
||||||
const int token_num = x.shape()[0];
|
const int token_num = x.shape()[0];
|
||||||
int bsz = cum_offsets.shape()[0];
|
int bsz = -1;
|
||||||
int enc_batch = len_info_cpu.data<int32_t>()[0];
|
int enc_batch = len_info_cpu.data<int32_t>()[0];
|
||||||
int dec_batch = len_info_cpu.data<int32_t>()[1];
|
int dec_batch = len_info_cpu.data<int32_t>()[1];
|
||||||
if (max_bsz > 0) {
|
if (max_bsz > 0) {
|
||||||
@@ -116,7 +115,6 @@ std::vector<paddle::Tensor> GatherNextToken(
|
|||||||
|
|
||||||
std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
|
std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
|
||||||
const std::vector<int64_t>& x_shape,
|
const std::vector<int64_t>& x_shape,
|
||||||
const std::vector<int64_t>& cum_offsets_shape,
|
|
||||||
const std::vector<int64_t>& encoder_seq_lod_shape,
|
const std::vector<int64_t>& encoder_seq_lod_shape,
|
||||||
const std::vector<int64_t>& decoder_seq_lod_shape,
|
const std::vector<int64_t>& decoder_seq_lod_shape,
|
||||||
const std::vector<int64_t>& encoder_batch_map_shape,
|
const std::vector<int64_t>& encoder_batch_map_shape,
|
||||||
@@ -130,19 +128,18 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
|
|||||||
// if (output_padding_offset_shape) {
|
// if (output_padding_offset_shape) {
|
||||||
// PD_THROW("speculative decoding is not supported in XPU.");
|
// PD_THROW("speculative decoding is not supported in XPU.");
|
||||||
// }
|
// }
|
||||||
int64_t bsz = cum_offsets_shape[0];
|
// int64_t bsz = cum_offsets_shape[0];
|
||||||
|
int64_t bsz = 0;
|
||||||
int64_t dim_embed = x_shape[1];
|
int64_t dim_embed = x_shape[1];
|
||||||
if (output_padding_offset_shape) {
|
if (output_padding_offset_shape) {
|
||||||
return {{-1, dim_embed}};
|
return {{-1, dim_embed}};
|
||||||
} else {
|
} else {
|
||||||
int64_t bsz = cum_offsets_shape[0];
|
|
||||||
return {{bsz, dim_embed}};
|
return {{bsz, dim_embed}};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> GatherNextTokenInferDtype(
|
std::vector<paddle::DataType> GatherNextTokenInferDtype(
|
||||||
const paddle::DataType& x_dtype,
|
const paddle::DataType& x_dtype,
|
||||||
const paddle::DataType& cum_offsets_dtype,
|
|
||||||
const paddle::DataType& encoder_seq_lod_dtype,
|
const paddle::DataType& encoder_seq_lod_dtype,
|
||||||
const paddle::DataType& decoder_seq_lod_dtype,
|
const paddle::DataType& decoder_seq_lod_dtype,
|
||||||
const paddle::DataType& encoder_batch_map_dtype,
|
const paddle::DataType& encoder_batch_map_dtype,
|
||||||
@@ -158,7 +155,6 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
|
|||||||
|
|
||||||
PD_BUILD_STATIC_OP(gather_next_token)
|
PD_BUILD_STATIC_OP(gather_next_token)
|
||||||
.Inputs({"x",
|
.Inputs({"x",
|
||||||
"cum_offsets",
|
|
||||||
"encoder_seq_lod",
|
"encoder_seq_lod",
|
||||||
"decoder_seq_lod",
|
"decoder_seq_lod",
|
||||||
"encoder_batch_map",
|
"encoder_batch_map",
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
|||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& step_idx,
|
const paddle::Tensor& step_idx,
|
||||||
const paddle::Tensor& output_cum_offsets,
|
const paddle::Tensor& cu_seqlens_q_output,
|
||||||
const paddle::Tensor& stop_flags,
|
const paddle::Tensor& stop_flags,
|
||||||
const paddle::Tensor& not_need_stop,
|
const paddle::Tensor& not_need_stop,
|
||||||
const paddle::Tensor& max_dec_len,
|
const paddle::Tensor& max_dec_len,
|
||||||
@@ -72,7 +72,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
|||||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||||
output_cum_offsets.data<int>(),
|
cu_seqlens_q_output.data<int>(),
|
||||||
const_cast<bool*>(stop_flags.data<bool>()),
|
const_cast<bool*>(stop_flags.data<bool>()),
|
||||||
const_cast<bool*>(not_need_stop_device.data<bool>()),
|
const_cast<bool*>(not_need_stop_device.data<bool>()),
|
||||||
max_dec_len.data<int64_t>(),
|
max_dec_len.data<int64_t>(),
|
||||||
@@ -102,7 +102,7 @@ PD_BUILD_STATIC_OP(draft_model_update)
|
|||||||
"seq_lens_encoder",
|
"seq_lens_encoder",
|
||||||
"seq_lens_decoder",
|
"seq_lens_decoder",
|
||||||
"step_idx",
|
"step_idx",
|
||||||
"output_cum_offsets",
|
"cu_seqlens_q_output",
|
||||||
"stop_flags",
|
"stop_flags",
|
||||||
"not_need_stop",
|
"not_need_stop",
|
||||||
"max_dec_len",
|
"max_dec_len",
|
||||||
|
|||||||
@@ -0,0 +1,133 @@
|
|||||||
|
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||||
|
#include "paddle/extension.h"
|
||||||
|
#include "xpu/plugin.h"
|
||||||
|
|
||||||
|
#ifndef PD_BUILD_STATIC_OP
|
||||||
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace api = baidu::xpu::api;
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> SpeculatePreProcess(
|
||||||
|
const int64_t cpu_token_num,
|
||||||
|
const paddle::Tensor &input_ids,
|
||||||
|
const paddle::Tensor &seq_len,
|
||||||
|
const paddle::Tensor &draft_tokens,
|
||||||
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder) {
|
||||||
|
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||||
|
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||||
|
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||||
|
api::Context *ctx = xpu_ctx->x_context();
|
||||||
|
|
||||||
|
// just for ut to run base line
|
||||||
|
std::unique_ptr<baidu::xpu::api::Context> cpu_ctx;
|
||||||
|
if (input_ids.place().GetType() == phi::AllocationType::CPU) {
|
||||||
|
cpu_ctx = std::make_unique<baidu::xpu::api::Context>(baidu::xpu::api::kCPU);
|
||||||
|
ctx = cpu_ctx.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||||
|
const int bsz = seq_len.shape()[0];
|
||||||
|
const int max_seq_len = input_ids_shape[1];
|
||||||
|
const int token_num_data = cpu_token_num;
|
||||||
|
auto ids_remove_padding = paddle::empty(
|
||||||
|
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||||
|
auto batch_id_per_token = paddle::empty(
|
||||||
|
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||||
|
auto cu_seqlens_q =
|
||||||
|
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||||
|
auto cu_seqlens_k =
|
||||||
|
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||||
|
const int max_draft_tokens_per_batch = draft_tokens.shape()[1];
|
||||||
|
|
||||||
|
auto seq_lens_output =
|
||||||
|
paddle::empty({bsz}, paddle::DataType::INT32, input_ids.place());
|
||||||
|
auto cu_seq_lens_q_output =
|
||||||
|
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
|
||||||
|
auto batch_id_per_token_output =
|
||||||
|
paddle::empty({bsz * max_draft_tokens_per_batch},
|
||||||
|
paddle::DataType::INT32,
|
||||||
|
input_ids.place());
|
||||||
|
auto real_output_token_num =
|
||||||
|
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());
|
||||||
|
if (token_num_data == 0) {
|
||||||
|
return {ids_remove_padding,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seq_lens_q_output,
|
||||||
|
batch_id_per_token_output,
|
||||||
|
real_output_token_num};
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t *ids_remove_padding_ptr = ids_remove_padding.data<int64_t>();
|
||||||
|
int *batch_id_per_token_ptr = batch_id_per_token.data<int>();
|
||||||
|
int *cu_seqlens_q_ptr = cu_seqlens_q.data<int>();
|
||||||
|
int *cu_seqlens_k_ptr = cu_seqlens_k.data<int>();
|
||||||
|
int *seq_lens_output_ptr = seq_lens_output.data<int>();
|
||||||
|
int *cu_seq_lens_q_output_ptr = cu_seq_lens_q_output.data<int>();
|
||||||
|
int *batch_id_per_token_output_ptr = batch_id_per_token_output.data<int>();
|
||||||
|
int *real_output_token_num_ptr = real_output_token_num.data<int>();
|
||||||
|
const int64_t *input_data_ptr = input_ids.data<int64_t>();
|
||||||
|
const int *seq_len_ptr = seq_len.data<int>();
|
||||||
|
const int64_t *draft_tokens_ptr = draft_tokens.data<int64_t>();
|
||||||
|
const int *seq_lens_encoder_ptr = seq_lens_encoder.data<int>();
|
||||||
|
|
||||||
|
int r =
|
||||||
|
fastdeploy::plugin::speculate_preprocess(ctx,
|
||||||
|
ids_remove_padding_ptr,
|
||||||
|
batch_id_per_token_ptr,
|
||||||
|
cu_seqlens_q_ptr,
|
||||||
|
cu_seqlens_k_ptr,
|
||||||
|
seq_lens_output_ptr,
|
||||||
|
cu_seq_lens_q_output_ptr,
|
||||||
|
batch_id_per_token_output_ptr,
|
||||||
|
real_output_token_num_ptr,
|
||||||
|
input_data_ptr,
|
||||||
|
seq_len_ptr,
|
||||||
|
draft_tokens_ptr,
|
||||||
|
seq_lens_encoder_ptr,
|
||||||
|
max_seq_len,
|
||||||
|
max_draft_tokens_per_batch,
|
||||||
|
token_num_data,
|
||||||
|
bsz);
|
||||||
|
|
||||||
|
return {ids_remove_padding,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seq_lens_q_output,
|
||||||
|
batch_id_per_token_output,
|
||||||
|
real_output_token_num};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(speculate_pre_process)
|
||||||
|
.Inputs({"input_ids",
|
||||||
|
"seq_len",
|
||||||
|
"draft_tokens",
|
||||||
|
"seq_lens_encoder",
|
||||||
|
"seq_lens_decoder"})
|
||||||
|
.Outputs({"ids_remove_padding",
|
||||||
|
"batch_id_per_token",
|
||||||
|
"cu_seqlens_q",
|
||||||
|
"cu_seqlens_k",
|
||||||
|
"cu_seq_lens_q_output",
|
||||||
|
"batch_id_per_token_output",
|
||||||
|
"real_output_token_num"})
|
||||||
|
.Attrs({"cpu_token_num: int64_t"})
|
||||||
|
.SetKernelFn(PD_KERNEL(SpeculatePreProcess));
|
||||||
@@ -33,8 +33,8 @@ void SpeculateTokenPenaltyMultiScores(
|
|||||||
const paddle::Tensor& min_len,
|
const paddle::Tensor& min_len,
|
||||||
const paddle::Tensor& eos_token_id,
|
const paddle::Tensor& eos_token_id,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& output_padding_offset,
|
const paddle::Tensor& batch_id_per_token_output,
|
||||||
const paddle::Tensor& output_cum_offsets,
|
const paddle::Tensor& cu_seqlens_q_output,
|
||||||
const int max_seq_len) {
|
const int max_seq_len) {
|
||||||
namespace api = baidu::xpu::api;
|
namespace api = baidu::xpu::api;
|
||||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||||
@@ -72,8 +72,8 @@ void SpeculateTokenPenaltyMultiScores(
|
|||||||
min_len.data<int64_t>(),
|
min_len.data<int64_t>(),
|
||||||
eos_token_id.data<int64_t>(),
|
eos_token_id.data<int64_t>(),
|
||||||
bad_tokens.data<int64_t>(),
|
bad_tokens.data<int64_t>(),
|
||||||
output_padding_offset.data<int>(),
|
batch_id_per_token_output.data<int>(),
|
||||||
output_cum_offsets.data<int>(),
|
cu_seqlens_q_output.data<int>(),
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
@@ -100,8 +100,8 @@ void SpeculateTokenPenaltyMultiScores(
|
|||||||
min_len.data<int64_t>(),
|
min_len.data<int64_t>(),
|
||||||
eos_token_id.data<int64_t>(),
|
eos_token_id.data<int64_t>(),
|
||||||
bad_tokens.data<int64_t>(),
|
bad_tokens.data<int64_t>(),
|
||||||
output_padding_offset.data<int>(),
|
batch_id_per_token_output.data<int>(),
|
||||||
output_cum_offsets.data<int>(),
|
cu_seqlens_q_output.data<int>(),
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
@@ -125,8 +125,8 @@ void SpeculateTokenPenaltyMultiScores(
|
|||||||
min_len.data<int64_t>(),
|
min_len.data<int64_t>(),
|
||||||
eos_token_id.data<int64_t>(),
|
eos_token_id.data<int64_t>(),
|
||||||
bad_tokens.data<int64_t>(),
|
bad_tokens.data<int64_t>(),
|
||||||
output_padding_offset.data<int>(),
|
batch_id_per_token_output.data<int>(),
|
||||||
output_cum_offsets.data<int>(),
|
cu_seqlens_q_output.data<int>(),
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
@@ -157,8 +157,8 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
|
|||||||
"min_len",
|
"min_len",
|
||||||
"eos_token_id",
|
"eos_token_id",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"output_padding_offset",
|
"batch_id_per_token_output",
|
||||||
"output_cum_offsets"})
|
"cu_seqlens_q_output"})
|
||||||
.Outputs({"logits_out"})
|
.Outputs({"logits_out"})
|
||||||
.Attrs({"max_seq_len: int"})
|
.Attrs({"max_seq_len: int"})
|
||||||
.SetInplaceMap({{"logits", "logits_out"}})
|
.SetInplaceMap({{"logits", "logits_out"}})
|
||||||
|
|||||||
@@ -26,24 +26,24 @@
|
|||||||
|
|
||||||
namespace api = baidu::xpu::api;
|
namespace api = baidu::xpu::api;
|
||||||
|
|
||||||
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
void SpeculateVerify(const paddle::Tensor &sampled_token_ids,
|
||||||
const paddle::Tensor& accept_tokens,
|
const paddle::Tensor &accept_tokens,
|
||||||
const paddle::Tensor& accept_num,
|
const paddle::Tensor &accept_num,
|
||||||
const paddle::Tensor& step_idx,
|
const paddle::Tensor &step_idx,
|
||||||
const paddle::Tensor& stop_flags,
|
const paddle::Tensor &stop_flags,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor& draft_tokens,
|
const paddle::Tensor &draft_tokens,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
const paddle::Tensor& verify_tokens,
|
const paddle::Tensor &verify_tokens,
|
||||||
const paddle::Tensor& verify_scores,
|
const paddle::Tensor &verify_scores,
|
||||||
const paddle::Tensor& max_dec_len,
|
const paddle::Tensor &max_dec_len,
|
||||||
const paddle::Tensor& end_tokens,
|
const paddle::Tensor &end_tokens,
|
||||||
const paddle::Tensor& is_block_step,
|
const paddle::Tensor &is_block_step,
|
||||||
const paddle::Tensor& output_cum_offsets,
|
const paddle::Tensor &cu_seqlens_q_output,
|
||||||
const paddle::Tensor& actual_candidate_len,
|
const paddle::Tensor &actual_candidate_len,
|
||||||
const paddle::Tensor& actual_draft_token_nums,
|
const paddle::Tensor &actual_draft_token_nums,
|
||||||
const paddle::Tensor& topp,
|
const paddle::Tensor &topp,
|
||||||
int max_seq_len,
|
int max_seq_len,
|
||||||
int verify_window,
|
int verify_window,
|
||||||
bool enable_topp,
|
bool enable_topp,
|
||||||
@@ -57,7 +57,8 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
|
|
||||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||||
api::Context* ctx = static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
|
api::Context *ctx =
|
||||||
|
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
|
||||||
bool xpu_ctx_flag = true;
|
bool xpu_ctx_flag = true;
|
||||||
if (draft_tokens.is_cpu()) {
|
if (draft_tokens.is_cpu()) {
|
||||||
ctx = new api::Context(api::kCPU);
|
ctx = new api::Context(api::kCPU);
|
||||||
@@ -65,17 +66,17 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool use_topk = false;
|
bool use_topk = false;
|
||||||
char* env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
||||||
if (env_var) {
|
if (env_var) {
|
||||||
use_topk = static_cast<bool>(std::stoi(env_var));
|
use_topk = static_cast<bool>(std::stoi(env_var));
|
||||||
}
|
}
|
||||||
bool use_target_sampling = false;
|
bool use_target_sampling = false;
|
||||||
char* env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
|
char *env_var_1 = getenv("SPECULATE_VERIFY_USE_TARGET_SAMPLING");
|
||||||
if (env_var_1) {
|
if (env_var_1) {
|
||||||
use_target_sampling = static_cast<bool>(std::stoi(env_var_1));
|
use_target_sampling = static_cast<bool>(std::stoi(env_var_1));
|
||||||
}
|
}
|
||||||
bool prefill_one_step_stop = false;
|
bool prefill_one_step_stop = false;
|
||||||
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
|
||||||
// std::cout << "Your PATH is: " << env_p << '\n';
|
// std::cout << "Your PATH is: " << env_p << '\n';
|
||||||
if (env_p[0] == '1') {
|
if (env_p[0] == '1') {
|
||||||
prefill_one_step_stop = true;
|
prefill_one_step_stop = true;
|
||||||
@@ -90,7 +91,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
std::mt19937_64 engine(infer_seed[i]);
|
std::mt19937_64 engine(infer_seed[i]);
|
||||||
dev_curand_states_cpu.push_back(dist(engine));
|
dev_curand_states_cpu.push_back(dist(engine));
|
||||||
}
|
}
|
||||||
float* dev_curand_states = dev_curand_states_cpu.data();
|
float *dev_curand_states = dev_curand_states_cpu.data();
|
||||||
auto dev_curand_states_tensor =
|
auto dev_curand_states_tensor =
|
||||||
paddle::empty({static_cast<int64_t>(dev_curand_states_cpu.size())},
|
paddle::empty({static_cast<int64_t>(dev_curand_states_cpu.size())},
|
||||||
paddle::DataType::FLOAT32,
|
paddle::DataType::FLOAT32,
|
||||||
@@ -110,10 +111,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
ret = fastdeploy::plugin::speculate_verify<true, true>(
|
ret = fastdeploy::plugin::speculate_verify<true, true>(
|
||||||
ctx,
|
ctx,
|
||||||
sampled_token_ids.data<int64_t>(),
|
sampled_token_ids.data<int64_t>(),
|
||||||
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||||
const_cast<int*>(accept_num.data<int>()),
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
const_cast<bool*>(stop_flags.data<bool>()),
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
seq_lens_decoder.data<int>(),
|
seq_lens_decoder.data<int>(),
|
||||||
draft_tokens.data<int64_t>(),
|
draft_tokens.data<int64_t>(),
|
||||||
@@ -126,7 +127,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
max_dec_len.data<int64_t>(),
|
max_dec_len.data<int64_t>(),
|
||||||
end_tokens.data<int64_t>(),
|
end_tokens.data<int64_t>(),
|
||||||
is_block_step.data<bool>(),
|
is_block_step.data<bool>(),
|
||||||
output_cum_offsets.data<int>(),
|
cu_seqlens_q_output.data<int>(),
|
||||||
actual_candidate_len.data<int>(),
|
actual_candidate_len.data<int>(),
|
||||||
real_bsz,
|
real_bsz,
|
||||||
max_draft_tokens,
|
max_draft_tokens,
|
||||||
@@ -143,10 +144,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
ret = fastdeploy::plugin::speculate_verify<false, true>(
|
ret = fastdeploy::plugin::speculate_verify<false, true>(
|
||||||
ctx,
|
ctx,
|
||||||
sampled_token_ids.data<int64_t>(),
|
sampled_token_ids.data<int64_t>(),
|
||||||
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||||
const_cast<int*>(accept_num.data<int>()),
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
const_cast<bool*>(stop_flags.data<bool>()),
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
seq_lens_decoder.data<int>(),
|
seq_lens_decoder.data<int>(),
|
||||||
draft_tokens.data<int64_t>(),
|
draft_tokens.data<int64_t>(),
|
||||||
@@ -159,7 +160,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
max_dec_len.data<int64_t>(),
|
max_dec_len.data<int64_t>(),
|
||||||
end_tokens.data<int64_t>(),
|
end_tokens.data<int64_t>(),
|
||||||
is_block_step.data<bool>(),
|
is_block_step.data<bool>(),
|
||||||
output_cum_offsets.data<int>(),
|
cu_seqlens_q_output.data<int>(),
|
||||||
actual_candidate_len.data<int>(),
|
actual_candidate_len.data<int>(),
|
||||||
real_bsz,
|
real_bsz,
|
||||||
max_draft_tokens,
|
max_draft_tokens,
|
||||||
@@ -178,10 +179,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
ret = fastdeploy::plugin::speculate_verify<true, false>(
|
ret = fastdeploy::plugin::speculate_verify<true, false>(
|
||||||
ctx,
|
ctx,
|
||||||
sampled_token_ids.data<int64_t>(),
|
sampled_token_ids.data<int64_t>(),
|
||||||
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||||
const_cast<int*>(accept_num.data<int>()),
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
const_cast<bool*>(stop_flags.data<bool>()),
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
seq_lens_decoder.data<int>(),
|
seq_lens_decoder.data<int>(),
|
||||||
draft_tokens.data<int64_t>(),
|
draft_tokens.data<int64_t>(),
|
||||||
@@ -194,7 +195,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
max_dec_len.data<int64_t>(),
|
max_dec_len.data<int64_t>(),
|
||||||
end_tokens.data<int64_t>(),
|
end_tokens.data<int64_t>(),
|
||||||
is_block_step.data<bool>(),
|
is_block_step.data<bool>(),
|
||||||
output_cum_offsets.data<int>(),
|
cu_seqlens_q_output.data<int>(),
|
||||||
actual_candidate_len.data<int>(),
|
actual_candidate_len.data<int>(),
|
||||||
real_bsz,
|
real_bsz,
|
||||||
max_draft_tokens,
|
max_draft_tokens,
|
||||||
@@ -211,10 +212,10 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
ret = fastdeploy::plugin::speculate_verify<false, false>(
|
ret = fastdeploy::plugin::speculate_verify<false, false>(
|
||||||
ctx,
|
ctx,
|
||||||
sampled_token_ids.data<int64_t>(),
|
sampled_token_ids.data<int64_t>(),
|
||||||
const_cast<int64_t*>(accept_tokens.data<int64_t>()),
|
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||||
const_cast<int*>(accept_num.data<int>()),
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
const_cast<bool*>(stop_flags.data<bool>()),
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
seq_lens_decoder.data<int>(),
|
seq_lens_decoder.data<int>(),
|
||||||
draft_tokens.data<int64_t>(),
|
draft_tokens.data<int64_t>(),
|
||||||
@@ -227,7 +228,7 @@ void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
|||||||
max_dec_len.data<int64_t>(),
|
max_dec_len.data<int64_t>(),
|
||||||
end_tokens.data<int64_t>(),
|
end_tokens.data<int64_t>(),
|
||||||
is_block_step.data<bool>(),
|
is_block_step.data<bool>(),
|
||||||
output_cum_offsets.data<int>(),
|
cu_seqlens_q_output.data<int>(),
|
||||||
actual_candidate_len.data<int>(),
|
actual_candidate_len.data<int>(),
|
||||||
real_bsz,
|
real_bsz,
|
||||||
max_draft_tokens,
|
max_draft_tokens,
|
||||||
@@ -262,7 +263,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
|
|||||||
"max_dec_len",
|
"max_dec_len",
|
||||||
"end_tokens",
|
"end_tokens",
|
||||||
"is_block_step",
|
"is_block_step",
|
||||||
"output_cum_offsets",
|
"cu_seqlens_q_output",
|
||||||
"actual_candidate_len",
|
"actual_candidate_len",
|
||||||
"actual_draft_token_nums",
|
"actual_draft_token_nums",
|
||||||
"topp"})
|
"topp"})
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ namespace api = baidu::xpu::api;
|
|||||||
std::vector<paddle::Tensor> TopPCandidates(
|
std::vector<paddle::Tensor> TopPCandidates(
|
||||||
const paddle::Tensor& probs,
|
const paddle::Tensor& probs,
|
||||||
const paddle::Tensor& top_p,
|
const paddle::Tensor& top_p,
|
||||||
const paddle::Tensor& output_padding_offset,
|
const paddle::Tensor& batch_id_per_token_output,
|
||||||
int candidates_len,
|
int candidates_len,
|
||||||
int max_seq_len) {
|
int max_seq_len) {
|
||||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||||
@@ -77,7 +77,7 @@ std::vector<paddle::Tensor> TopPCandidates(
|
|||||||
ctx,
|
ctx,
|
||||||
reinterpret_cast<const XPUTypeBF16*>(probs.data<bf16_data_t>()),
|
reinterpret_cast<const XPUTypeBF16*>(probs.data<bf16_data_t>()),
|
||||||
reinterpret_cast<const XPUTypeBF16*>(top_p.data<bf16_data_t>()),
|
reinterpret_cast<const XPUTypeBF16*>(top_p.data<bf16_data_t>()),
|
||||||
output_padding_offset.data<int>(),
|
batch_id_per_token_output.data<int>(),
|
||||||
verify_tokens.data<int64_t>(),
|
verify_tokens.data<int64_t>(),
|
||||||
reinterpret_cast<XPUTypeBF16*>(
|
reinterpret_cast<XPUTypeBF16*>(
|
||||||
verify_scores.data<bf16_data_t>()),
|
verify_scores.data<bf16_data_t>()),
|
||||||
@@ -100,7 +100,7 @@ std::vector<paddle::Tensor> TopPCandidates(
|
|||||||
ctx,
|
ctx,
|
||||||
reinterpret_cast<const XPUTypeFP16*>(probs.data<fp16_data_t>()),
|
reinterpret_cast<const XPUTypeFP16*>(probs.data<fp16_data_t>()),
|
||||||
reinterpret_cast<const XPUTypeFP16*>(top_p.data<fp16_data_t>()),
|
reinterpret_cast<const XPUTypeFP16*>(top_p.data<fp16_data_t>()),
|
||||||
output_padding_offset.data<int>(),
|
batch_id_per_token_output.data<int>(),
|
||||||
verify_tokens.data<int64_t>(),
|
verify_tokens.data<int64_t>(),
|
||||||
reinterpret_cast<XPUTypeFP16*>(
|
reinterpret_cast<XPUTypeFP16*>(
|
||||||
verify_scores.data<fp16_data_t>()),
|
verify_scores.data<fp16_data_t>()),
|
||||||
@@ -120,7 +120,7 @@ std::vector<paddle::Tensor> TopPCandidates(
|
|||||||
ctx,
|
ctx,
|
||||||
probs.data<float>(),
|
probs.data<float>(),
|
||||||
top_p.data<float>(),
|
top_p.data<float>(),
|
||||||
output_padding_offset.data<int>(),
|
batch_id_per_token_output.data<int>(),
|
||||||
verify_tokens.data<int64_t>(),
|
verify_tokens.data<int64_t>(),
|
||||||
verify_scores.data<float>(),
|
verify_scores.data<float>(),
|
||||||
actual_candidate_lens.data<int>(),
|
actual_candidate_lens.data<int>(),
|
||||||
@@ -139,7 +139,7 @@ std::vector<paddle::Tensor> TopPCandidates(
|
|||||||
std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
|
std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
|
||||||
const std::vector<int64_t>& probs_shape,
|
const std::vector<int64_t>& probs_shape,
|
||||||
const std::vector<int64_t>& top_p_shape,
|
const std::vector<int64_t>& top_p_shape,
|
||||||
const std::vector<int64_t>& output_padding_offset_shape,
|
const std::vector<int64_t>& batch_id_per_token_output_shape,
|
||||||
int max_candidates_len) {
|
int max_candidates_len) {
|
||||||
int token_num = probs_shape[0];
|
int token_num = probs_shape[0];
|
||||||
return {{token_num, max_candidates_len},
|
return {{token_num, max_candidates_len},
|
||||||
@@ -150,12 +150,12 @@ std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
|
|||||||
std::vector<paddle::DataType> TopPCandidatesInferDtype(
|
std::vector<paddle::DataType> TopPCandidatesInferDtype(
|
||||||
const paddle::DataType& probs_dtype,
|
const paddle::DataType& probs_dtype,
|
||||||
const paddle::DataType& top_p_dtype,
|
const paddle::DataType& top_p_dtype,
|
||||||
const paddle::DataType& output_padding_offset_dtype) {
|
const paddle::DataType& batch_id_per_token_output_dtype) {
|
||||||
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
|
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(top_p_candidates)
|
PD_BUILD_STATIC_OP(top_p_candidates)
|
||||||
.Inputs({"probs", "top_p", "output_padding_offset"})
|
.Inputs({"probs", "top_p", "batch_id_per_token_output"})
|
||||||
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
|
.Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"})
|
||||||
.Attrs({"candidates_len: int", "max_seq_len: int"})
|
.Attrs({"candidates_len: int", "max_seq_len: int"})
|
||||||
.SetKernelFn(PD_KERNEL(TopPCandidates))
|
.SetKernelFn(PD_KERNEL(TopPCandidates))
|
||||||
|
|||||||
@@ -0,0 +1,149 @@
|
|||||||
|
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <paddle/phi/backends/xpu/xpu_context.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include "paddle/common/flags.h"
|
||||||
|
#include "paddle/extension.h"
|
||||||
|
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||||
|
#include "xpu/internal/infra_op.h"
|
||||||
|
#include "xpu/plugin.h"
|
||||||
|
|
||||||
|
#ifndef PD_BUILD_STATIC_OP
|
||||||
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace api = baidu::xpu::api;
|
||||||
|
void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
|
||||||
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
|
const paddle::Tensor &has_running_seqs,
|
||||||
|
const paddle::Tensor &step_input_ids,
|
||||||
|
const paddle::Tensor &adaptive_step_input_len,
|
||||||
|
const paddle::Tensor &step_output_ids,
|
||||||
|
const paddle::Tensor &step_output_len,
|
||||||
|
const paddle::Tensor &stop_flags,
|
||||||
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
|
const paddle::Tensor &is_paused,
|
||||||
|
const paddle::Tensor &mask_rollback,
|
||||||
|
const paddle::Tensor &token_ids_all,
|
||||||
|
const paddle::Tensor &prompt_lens,
|
||||||
|
const paddle::Tensor &step_idx,
|
||||||
|
const paddle::Tensor &end_tokens,
|
||||||
|
const paddle::Tensor &max_dec_len,
|
||||||
|
const bool is_naive_mode,
|
||||||
|
const bool prefill_one_step_stop) {
|
||||||
|
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||||
|
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||||
|
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
|
||||||
|
api::Context *ctx = xpu_ctx->x_context();
|
||||||
|
|
||||||
|
// just for ut to run base line
|
||||||
|
std::unique_ptr<baidu::xpu::api::Context> cpu_ctx;
|
||||||
|
if (seq_lens_encoder.place().GetType() == phi::AllocationType::CPU) {
|
||||||
|
cpu_ctx = std::make_unique<baidu::xpu::api::Context>(baidu::xpu::api::kCPU);
|
||||||
|
ctx = cpu_ctx.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||||
|
const int max_bsz = stop_flags.shape()[0];
|
||||||
|
PADDLE_ENFORCE_LE(
|
||||||
|
max_bsz,
|
||||||
|
1024,
|
||||||
|
phi::errors::InvalidArgument(
|
||||||
|
"unified_update_model_status: max_bsz (%d) must be <= 1024 "
|
||||||
|
"(single-block launch limit).",
|
||||||
|
max_bsz));
|
||||||
|
const int max_step_tokens = step_input_ids.shape()[1];
|
||||||
|
const int max_model_len = token_ids_all.shape()[1];
|
||||||
|
const int num_end_tokens = end_tokens.shape()[0];
|
||||||
|
|
||||||
|
// has_running_seqs is CPU tensor, need to copy to GPU first
|
||||||
|
auto has_running_seqs_xpu =
|
||||||
|
has_running_seqs.copy_to(seq_lens_this_time.place(), false);
|
||||||
|
int r = fastdeploy::plugin::unified_update_model_status(
|
||||||
|
ctx,
|
||||||
|
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||||
|
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||||
|
const_cast<bool *>(has_running_seqs_xpu.data<bool>()),
|
||||||
|
const_cast<int *>(mask_rollback.data<int>()),
|
||||||
|
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
|
||||||
|
const_cast<int *>(adaptive_step_input_len.data<int>()),
|
||||||
|
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
|
||||||
|
const_cast<int *>(step_output_len.data<int>()),
|
||||||
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
|
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||||
|
const_cast<bool *>(is_paused.data<bool>()),
|
||||||
|
const_cast<int64_t *>(token_ids_all.data<int64_t>()),
|
||||||
|
prompt_lens.data<int64_t>(),
|
||||||
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
|
end_tokens.data<int64_t>(),
|
||||||
|
max_dec_len.data<int64_t>(),
|
||||||
|
real_bsz,
|
||||||
|
max_bsz,
|
||||||
|
max_step_tokens,
|
||||||
|
max_model_len,
|
||||||
|
num_end_tokens,
|
||||||
|
is_naive_mode,
|
||||||
|
prefill_one_step_stop);
|
||||||
|
PADDLE_ENFORCE_XDNN_SUCCESS(r, "unified_update_model_status");
|
||||||
|
// Copy result back to CPU
|
||||||
|
auto has_running_seqs_cpu =
|
||||||
|
has_running_seqs_xpu.copy_to(has_running_seqs.place(), false);
|
||||||
|
bool *out_data = const_cast<bool *>(has_running_seqs.data<bool>());
|
||||||
|
out_data[0] = has_running_seqs_cpu.data<bool>()[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(unified_update_model_status)
|
||||||
|
.Inputs({"seq_lens_encoder",
|
||||||
|
"seq_lens_decoder",
|
||||||
|
"has_running_seqs",
|
||||||
|
"step_input_ids",
|
||||||
|
"adaptive_step_input_len",
|
||||||
|
"step_output_ids",
|
||||||
|
"step_output_len",
|
||||||
|
"stop_flags",
|
||||||
|
"seq_lens_this_time",
|
||||||
|
"is_paused",
|
||||||
|
"mask_rollback",
|
||||||
|
"token_ids_all",
|
||||||
|
"prompt_lens",
|
||||||
|
"step_idx",
|
||||||
|
"end_tokens",
|
||||||
|
"max_dec_len"})
|
||||||
|
.Attrs({"is_naive_mode: bool", "prefill_one_step_stop: bool"})
|
||||||
|
.Outputs({"seq_lens_encoder_out",
|
||||||
|
"seq_lens_decoder_out",
|
||||||
|
"has_running_seqs_out",
|
||||||
|
"step_input_ids_out",
|
||||||
|
"adaptive_step_input_len_out",
|
||||||
|
"step_output_ids_out",
|
||||||
|
"step_output_len_out",
|
||||||
|
"stop_flags_out",
|
||||||
|
"seq_lens_this_time_out",
|
||||||
|
"mask_rollback_out",
|
||||||
|
"token_ids_all_out",
|
||||||
|
"step_idx_out"})
|
||||||
|
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||||
|
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||||
|
{"has_running_seqs", "has_running_seqs_out"},
|
||||||
|
{"step_input_ids", "step_input_ids_out"},
|
||||||
|
{"adaptive_step_input_len", "adaptive_step_input_len_out"},
|
||||||
|
{"step_output_ids", "step_output_ids_out"},
|
||||||
|
{"step_output_len", "step_output_len_out"},
|
||||||
|
{"stop_flags", "stop_flags_out"},
|
||||||
|
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||||
|
{"mask_rollback", "mask_rollback_out"},
|
||||||
|
{"token_ids_all", "token_ids_all_out"},
|
||||||
|
{"step_idx", "step_idx_out"}})
|
||||||
|
.SetKernelFn(PD_KERNEL(UnifiedUpdateModelStatus));
|
||||||
@@ -34,8 +34,7 @@ void prof_start();
|
|||||||
void prof_stop();
|
void prof_stop();
|
||||||
|
|
||||||
std::vector<paddle::Tensor> AdjustBatch(
|
std::vector<paddle::Tensor> AdjustBatch(
|
||||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
|
||||||
const paddle::Tensor& encoder_seq_lod,
|
const paddle::Tensor& encoder_seq_lod,
|
||||||
const paddle::Tensor& decoder_seq_lod,
|
const paddle::Tensor& decoder_seq_lod,
|
||||||
const paddle::Tensor& encoder_batch_idx,
|
const paddle::Tensor& encoder_batch_idx,
|
||||||
@@ -62,7 +61,6 @@ std::vector<paddle::Tensor> BlockAttn(
|
|||||||
const paddle::Tensor& qkv,
|
const paddle::Tensor& qkv,
|
||||||
const paddle::Tensor& key_cache,
|
const paddle::Tensor& key_cache,
|
||||||
const paddle::Tensor& value_cache,
|
const paddle::Tensor& value_cache,
|
||||||
const paddle::Tensor& cum_offsets,
|
|
||||||
const paddle::Tensor& rotary_embs,
|
const paddle::Tensor& rotary_embs,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& prefix_block_tables,
|
const paddle::Tensor& prefix_block_tables,
|
||||||
@@ -210,7 +208,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
|||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& step_idx,
|
const paddle::Tensor& step_idx,
|
||||||
const paddle::Tensor& output_cum_offsets,
|
const paddle::Tensor& cu_seqlens_q_output,
|
||||||
const paddle::Tensor& stop_flags,
|
const paddle::Tensor& stop_flags,
|
||||||
const paddle::Tensor& not_need_stop,
|
const paddle::Tensor& not_need_stop,
|
||||||
const paddle::Tensor& max_dec_len,
|
const paddle::Tensor& max_dec_len,
|
||||||
@@ -254,8 +252,8 @@ void SpeculateTokenPenaltyMultiScores(
|
|||||||
const paddle::Tensor& min_len,
|
const paddle::Tensor& min_len,
|
||||||
const paddle::Tensor& eos_token_id,
|
const paddle::Tensor& eos_token_id,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& output_padding_offset,
|
const paddle::Tensor& batch_id_per_token_output,
|
||||||
const paddle::Tensor& output_cum_offsets,
|
const paddle::Tensor& cu_seqlens_q_output,
|
||||||
const int max_seq_len);
|
const int max_seq_len);
|
||||||
|
|
||||||
void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder,
|
void SpeculateUpdateV3(const paddle::Tensor& seq_lens_encoder,
|
||||||
@@ -413,8 +411,7 @@ std::vector<paddle::Tensor> EagleGetSelfHiddenStates(
|
|||||||
const paddle::Tensor& step_idx);
|
const paddle::Tensor& step_idx);
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GatherNextToken(
|
std::vector<paddle::Tensor> GatherNextToken(
|
||||||
const paddle::Tensor& x, // [token_num, dim_embed]
|
const paddle::Tensor& x, // [token_num, dim_embed]
|
||||||
const paddle::Tensor& cum_offsets, // [bsz, 1]
|
|
||||||
const paddle::Tensor& encoder_seq_lod,
|
const paddle::Tensor& encoder_seq_lod,
|
||||||
const paddle::Tensor& decoder_seq_lod,
|
const paddle::Tensor& decoder_seq_lod,
|
||||||
const paddle::Tensor& encoder_batch_map,
|
const paddle::Tensor& encoder_batch_map,
|
||||||
@@ -500,6 +497,14 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
|||||||
const paddle::Tensor& seq_len,
|
const paddle::Tensor& seq_len,
|
||||||
const paddle::Tensor& seq_lens_encoder);
|
const paddle::Tensor& seq_lens_encoder);
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> SpeculatePreProcess(
|
||||||
|
const int64_t cpu_token_num,
|
||||||
|
const paddle::Tensor& input_ids,
|
||||||
|
const paddle::Tensor& seq_len,
|
||||||
|
const paddle::Tensor& draft_tokens,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& seq_lens_decoder);
|
||||||
|
|
||||||
void StepPaddle(const paddle::Tensor& stop_flags,
|
void StepPaddle(const paddle::Tensor& stop_flags,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& ori_seq_lens_encoder,
|
const paddle::Tensor& ori_seq_lens_encoder,
|
||||||
@@ -540,6 +545,25 @@ void MTPStepPaddle(
|
|||||||
const int block_size,
|
const int block_size,
|
||||||
const int max_draft_tokens);
|
const int max_draft_tokens);
|
||||||
|
|
||||||
|
void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& has_running_seqs,
|
||||||
|
const paddle::Tensor& step_input_ids,
|
||||||
|
const paddle::Tensor& adaptive_step_input_len,
|
||||||
|
const paddle::Tensor& step_output_ids,
|
||||||
|
const paddle::Tensor& step_output_len,
|
||||||
|
const paddle::Tensor& stop_flags,
|
||||||
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
|
const paddle::Tensor& is_paused,
|
||||||
|
const paddle::Tensor& mask_rollback,
|
||||||
|
const paddle::Tensor& token_ids_all,
|
||||||
|
const paddle::Tensor& prompt_lens,
|
||||||
|
const paddle::Tensor& step_idx,
|
||||||
|
const paddle::Tensor& end_tokens,
|
||||||
|
const paddle::Tensor& max_dec_len,
|
||||||
|
const bool is_naive_mode,
|
||||||
|
const bool prefill_one_step_stop);
|
||||||
|
|
||||||
void SpeculateStepPaddle(
|
void SpeculateStepPaddle(
|
||||||
const paddle::Tensor& stop_flags,
|
const paddle::Tensor& stop_flags,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
@@ -682,7 +706,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
m.def("adjust_batch",
|
m.def("adjust_batch",
|
||||||
&AdjustBatch,
|
&AdjustBatch,
|
||||||
py::arg("x"),
|
py::arg("x"),
|
||||||
py::arg("cum_offsets"),
|
|
||||||
py::arg("encoder_seq_lod"),
|
py::arg("encoder_seq_lod"),
|
||||||
py::arg("decoder_seq_lod"),
|
py::arg("decoder_seq_lod"),
|
||||||
py::arg("encoder_batch_idx"),
|
py::arg("encoder_batch_idx"),
|
||||||
@@ -701,7 +724,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("qkv"),
|
py::arg("qkv"),
|
||||||
py::arg("key_cache"),
|
py::arg("key_cache"),
|
||||||
py::arg("value_cache"),
|
py::arg("value_cache"),
|
||||||
py::arg("cum_offsets"),
|
|
||||||
py::arg("rotary_embs"),
|
py::arg("rotary_embs"),
|
||||||
py::arg("block_tables"),
|
py::arg("block_tables"),
|
||||||
py::arg("prefix_block_tables"),
|
py::arg("prefix_block_tables"),
|
||||||
@@ -812,7 +834,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("seq_lens_encoder"), // 编码器序列长度张量
|
py::arg("seq_lens_encoder"), // 编码器序列长度张量
|
||||||
py::arg("seq_lens_decoder"), // 解码器序列长度张量
|
py::arg("seq_lens_decoder"), // 解码器序列长度张量
|
||||||
py::arg("step_idx"), // 步骤索引张量
|
py::arg("step_idx"), // 步骤索引张量
|
||||||
py::arg("output_cum_offsets"), // 输出累积偏移量张量
|
py::arg("cu_seqlens_q_output"), // 输出累积偏移量张量
|
||||||
py::arg("stop_flags"), // 停止标志张量
|
py::arg("stop_flags"), // 停止标志张量
|
||||||
py::arg("not_need_stop"), // 无需停止标志张量
|
py::arg("not_need_stop"), // 无需停止标志张量
|
||||||
py::arg("max_dec_len"), // 最大解码长度张量
|
py::arg("max_dec_len"), // 最大解码长度张量
|
||||||
@@ -885,7 +907,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
m.def("gather_next_token",
|
m.def("gather_next_token",
|
||||||
&GatherNextToken,
|
&GatherNextToken,
|
||||||
py::arg("x"),
|
py::arg("x"),
|
||||||
py::arg("cum_offsets"),
|
|
||||||
py::arg("encoder_seq_lod"),
|
py::arg("encoder_seq_lod"),
|
||||||
py::arg("decoder_seq_lod"),
|
py::arg("decoder_seq_lod"),
|
||||||
py::arg("encoder_batch_map"),
|
py::arg("encoder_batch_map"),
|
||||||
@@ -1002,6 +1023,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("redundant_ep_rank_num_plus_one"),
|
py::arg("redundant_ep_rank_num_plus_one"),
|
||||||
"moe export RedundantTopKSelect function");
|
"moe export RedundantTopKSelect function");
|
||||||
|
|
||||||
|
m.def("unified_update_model_status",
|
||||||
|
&UnifiedUpdateModelStatus,
|
||||||
|
py::arg("seq_lens_encoder"),
|
||||||
|
py::arg("seq_lens_decoder"),
|
||||||
|
py::arg("has_running_seqs"),
|
||||||
|
py::arg("step_input_ids"),
|
||||||
|
py::arg("adaptive_step_input_len"),
|
||||||
|
py::arg("step_output_ids"),
|
||||||
|
py::arg("step_output_len"),
|
||||||
|
py::arg("stop_flags"),
|
||||||
|
py::arg("seq_lens_this_time"),
|
||||||
|
py::arg("is_paused"),
|
||||||
|
py::arg("mask_rollback"),
|
||||||
|
py::arg("token_ids_all"),
|
||||||
|
py::arg("prompt_lens"),
|
||||||
|
py::arg("step_idx"),
|
||||||
|
py::arg("end_tokens"),
|
||||||
|
py::arg("max_dec_len"),
|
||||||
|
py::arg("is_naive_mode"),
|
||||||
|
py::arg("max_draft_tokens"),
|
||||||
|
"Unified update model status");
|
||||||
|
|
||||||
m.def("mtp_step_paddle",
|
m.def("mtp_step_paddle",
|
||||||
&MTPStepPaddle,
|
&MTPStepPaddle,
|
||||||
py::arg("base_model_stop_flags"),
|
py::arg("base_model_stop_flags"),
|
||||||
@@ -1117,8 +1160,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("min_len"),
|
py::arg("min_len"),
|
||||||
py::arg("eos_token_id"),
|
py::arg("eos_token_id"),
|
||||||
py::arg("seq_lens_this_time"),
|
py::arg("seq_lens_this_time"),
|
||||||
py::arg("output_padding_offset"),
|
py::arg("batch_id_per_token_output"),
|
||||||
py::arg("output_cum_offsets"),
|
py::arg("cu_seqlens_q_output"),
|
||||||
py::arg("max_seq_len"),
|
py::arg("max_seq_len"),
|
||||||
"Applies token penalty with multiple scores");
|
"Applies token penalty with multiple scores");
|
||||||
|
|
||||||
@@ -1182,7 +1225,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("max_dec_len"),
|
py::arg("max_dec_len"),
|
||||||
py::arg("end_tokens"),
|
py::arg("end_tokens"),
|
||||||
py::arg("is_block_step"),
|
py::arg("is_block_step"),
|
||||||
py::arg("output_cum_offsets"),
|
py::arg("cu_seqlens_q_output"),
|
||||||
py::arg("actual_candidate_len"),
|
py::arg("actual_candidate_len"),
|
||||||
py::arg("actual_draft_token_nums"),
|
py::arg("actual_draft_token_nums"),
|
||||||
py::arg("topp"),
|
py::arg("topp"),
|
||||||
@@ -1246,6 +1289,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("max_seq_len"),
|
py::arg("max_seq_len"),
|
||||||
"Get output padding offset");
|
"Get output padding offset");
|
||||||
|
|
||||||
|
m.def("speculate_pre_process",
|
||||||
|
&SpeculatePreProcess,
|
||||||
|
py::arg("cpu_token_num"),
|
||||||
|
py::arg("input_ids"),
|
||||||
|
py::arg("seq_len"),
|
||||||
|
py::arg("draft_tokens"),
|
||||||
|
py::arg("seq_lens_encoder"),
|
||||||
|
py::arg("seq_lens_decoder"),
|
||||||
|
"speculate pre process to remove padding and to acquire cu_seq_len");
|
||||||
|
|
||||||
m.def("speculate_get_padding_offset",
|
m.def("speculate_get_padding_offset",
|
||||||
&SpeculateGetPaddingOffset,
|
&SpeculateGetPaddingOffset,
|
||||||
py::arg("input_ids"),
|
py::arg("input_ids"),
|
||||||
@@ -1419,7 +1472,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
&TopPCandidates,
|
&TopPCandidates,
|
||||||
py::arg("probs"),
|
py::arg("probs"),
|
||||||
py::arg("top_p"),
|
py::arg("top_p"),
|
||||||
py::arg("output_padding_offset"),
|
py::arg("batch_id_per_token_output"),
|
||||||
py::arg("candidates_len"),
|
py::arg("candidates_len"),
|
||||||
py::arg("max_seq_len"),
|
py::arg("max_seq_len"),
|
||||||
"Generate top-p candidates based on probability distributions");
|
"Generate top-p candidates based on probability distributions");
|
||||||
|
|||||||
@@ -388,8 +388,8 @@ DLL_EXPORT int speculate_token_penalty_multi_scores(
|
|||||||
const int64_t* min_len,
|
const int64_t* min_len,
|
||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int64_t* bad_words,
|
const int64_t* bad_words,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -432,7 +432,7 @@ DLL_EXPORT int speculate_verify(api::Context* ctx,
|
|||||||
const int64_t* max_dec_len,
|
const int64_t* max_dec_len,
|
||||||
const int64_t* end_tokens,
|
const int64_t* end_tokens,
|
||||||
const bool* is_block_step,
|
const bool* is_block_step,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int* actual_candidate_len,
|
const int* actual_candidate_len,
|
||||||
const int real_bsz,
|
const int real_bsz,
|
||||||
const int max_draft_tokens,
|
const int max_draft_tokens,
|
||||||
@@ -465,7 +465,7 @@ DLL_EXPORT int draft_model_update(api::Context* ctx,
|
|||||||
int* seq_lens_encoder,
|
int* seq_lens_encoder,
|
||||||
int* seq_lens_decoder,
|
int* seq_lens_decoder,
|
||||||
int64_t* step_idx,
|
int64_t* step_idx,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
bool* stop_flags,
|
bool* stop_flags,
|
||||||
bool* not_need_stop,
|
bool* not_need_stop,
|
||||||
const int64_t* max_dec_len,
|
const int64_t* max_dec_len,
|
||||||
@@ -574,7 +574,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
|
|||||||
DLL_EXPORT int top_p_candidates(api::Context* ctx,
|
DLL_EXPORT int top_p_candidates(api::Context* ctx,
|
||||||
const T* src,
|
const T* src,
|
||||||
const T* top_ps,
|
const T* top_ps,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
int64_t* out_id,
|
int64_t* out_id,
|
||||||
T* out_val,
|
T* out_val,
|
||||||
int* actual_candidates_lens,
|
int* actual_candidates_lens,
|
||||||
@@ -630,6 +630,24 @@ DLL_EXPORT int speculate_schedule_cache(api::Context* ctx,
|
|||||||
const int block_num_per_seq,
|
const int block_num_per_seq,
|
||||||
const bool prefill_one_step_stop);
|
const bool prefill_one_step_stop);
|
||||||
|
|
||||||
|
DLL_EXPORT int speculate_preprocess(api::Context* ctx,
|
||||||
|
int64_t* ids_remove_padding,
|
||||||
|
int* batch_id_per_token,
|
||||||
|
int* cu_seqlens_q,
|
||||||
|
int* cu_seqlens_k,
|
||||||
|
int* seq_lens_output,
|
||||||
|
int* cu_seq_lens_q_output,
|
||||||
|
int* batch_id_per_token_output,
|
||||||
|
int* real_output_token_num,
|
||||||
|
const int64_t* input_data,
|
||||||
|
const int* seq_lens,
|
||||||
|
const int64_t* draft_tokens,
|
||||||
|
const int* seq_lens_encoder,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_draft_tokens_per_batch,
|
||||||
|
const int token_num_data,
|
||||||
|
const int real_bs);
|
||||||
|
|
||||||
DLL_EXPORT int speculate_update_v3(api::Context* ctx,
|
DLL_EXPORT int speculate_update_v3(api::Context* ctx,
|
||||||
int* seq_lens_encoder,
|
int* seq_lens_encoder,
|
||||||
int* seq_lens_decoder,
|
int* seq_lens_decoder,
|
||||||
@@ -662,6 +680,31 @@ DLL_EXPORT int speculate_update(api::Context* ctx,
|
|||||||
const int max_bsz,
|
const int max_bsz,
|
||||||
const int max_draft_tokens);
|
const int max_draft_tokens);
|
||||||
|
|
||||||
|
DLL_EXPORT int unified_update_model_status(api::Context* ctx,
|
||||||
|
int* seq_lens_encoder,
|
||||||
|
int* seq_lens_decoder,
|
||||||
|
bool* has_running_seqs,
|
||||||
|
int* mask_rollback,
|
||||||
|
int64_t* step_input_ids,
|
||||||
|
int* adaptive_step_input_len,
|
||||||
|
int64_t* step_output_ids,
|
||||||
|
int* step_output_len,
|
||||||
|
bool* stop_flags,
|
||||||
|
int* seq_lens_this_time,
|
||||||
|
const bool* is_paused,
|
||||||
|
int64_t* token_ids_all,
|
||||||
|
const int64_t* prompt_lens,
|
||||||
|
int64_t* step_idx,
|
||||||
|
const int64_t* end_tokens,
|
||||||
|
const int64_t* max_dec_len,
|
||||||
|
int real_bsz,
|
||||||
|
int max_bsz,
|
||||||
|
int max_step_tokens,
|
||||||
|
int max_model_len,
|
||||||
|
int num_end_tokens,
|
||||||
|
bool is_naive_mode,
|
||||||
|
bool prefill_one_step_stop);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
DLL_EXPORT int rebuild_hidden_states(api::Context* ctx,
|
DLL_EXPORT int rebuild_hidden_states(api::Context* ctx,
|
||||||
const T* input,
|
const T* input,
|
||||||
|
|||||||
+3
-5
@@ -22,7 +22,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens,
|
|||||||
int* seq_lens_encoder,
|
int* seq_lens_encoder,
|
||||||
int* seq_lens_decoder,
|
int* seq_lens_decoder,
|
||||||
int64_t* step_idx,
|
int64_t* step_idx,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
bool* stop_flags,
|
bool* stop_flags,
|
||||||
bool* not_need_stop,
|
bool* not_need_stop,
|
||||||
const int64_t* max_dec_len,
|
const int64_t* max_dec_len,
|
||||||
@@ -45,8 +45,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens,
|
|||||||
auto* pre_ids_now = pre_ids + tid * pre_id_length;
|
auto* pre_ids_now = pre_ids + tid * pre_id_length;
|
||||||
auto* base_model_draft_tokens_now =
|
auto* base_model_draft_tokens_now =
|
||||||
base_model_draft_tokens + tid * max_base_model_draft_token;
|
base_model_draft_tokens + tid * max_base_model_draft_token;
|
||||||
const int next_tokens_start_id =
|
const int next_tokens_start_id = cu_seqlens_q_output[tid];
|
||||||
tid * max_seq_len - output_cum_offsets[tid];
|
|
||||||
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
||||||
auto seq_len_this_time = seq_lens_this_time[tid];
|
auto seq_len_this_time = seq_lens_this_time[tid];
|
||||||
auto seq_len_encoder = seq_lens_encoder[tid];
|
auto seq_len_encoder = seq_lens_encoder[tid];
|
||||||
@@ -72,8 +71,7 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens,
|
|||||||
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
base_model_draft_tokens_now[substep + 1] = token_this_time;
|
||||||
}
|
}
|
||||||
// multi_end
|
// multi_end
|
||||||
if (is_in_end(token_this_time, end_ids, end_ids_len) ||
|
if (is_in_end(token_this_time, end_ids, end_ids_len)) {
|
||||||
prefill_one_step_stop) {
|
|
||||||
stop_flags[tid] = true;
|
stop_flags[tid] = true;
|
||||||
stop_flag_now_int_sm[cid] += 1;
|
stop_flag_now_int_sm[cid] += 1;
|
||||||
// max_dec_len
|
// max_dec_len
|
||||||
|
|||||||
+6
-6
@@ -22,7 +22,7 @@ inline __device__ void update_bad_words_logit<float16>(
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void speculate_ban_bad_words(T* logits,
|
__global__ void speculate_ban_bad_words(T* logits,
|
||||||
const int64_t* bad_words_list,
|
const int64_t* bad_words_list,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t bad_words_length,
|
const int64_t bad_words_length,
|
||||||
@@ -32,7 +32,7 @@ __global__ void speculate_ban_bad_words(T* logits,
|
|||||||
int nthreads = cluster_num() * core_num();
|
int nthreads = cluster_num() * core_num();
|
||||||
int start = -1;
|
int start = -1;
|
||||||
int end = -1;
|
int end = -1;
|
||||||
int output_padding_offset_lm;
|
int batch_id_per_token_output_lm;
|
||||||
partition(tid,
|
partition(tid,
|
||||||
nthreads,
|
nthreads,
|
||||||
static_cast<int>(token_num * bad_words_length),
|
static_cast<int>(token_num * bad_words_length),
|
||||||
@@ -41,10 +41,10 @@ __global__ void speculate_ban_bad_words(T* logits,
|
|||||||
&end);
|
&end);
|
||||||
for (int i = start; i < end; i++) {
|
for (int i = start; i < end; i++) {
|
||||||
int token_idx = i / bad_words_length;
|
int token_idx = i / bad_words_length;
|
||||||
GM2LM(output_padding_offset + token_idx,
|
GM2LM(batch_id_per_token_output + token_idx,
|
||||||
&output_padding_offset_lm,
|
&batch_id_per_token_output_lm,
|
||||||
sizeof(int));
|
sizeof(int));
|
||||||
int bs_idx = (token_idx + output_padding_offset_lm) / max_seq_len;
|
int bs_idx = batch_id_per_token_output_lm;
|
||||||
if (bs_idx >= bs) {
|
if (bs_idx >= bs) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -63,7 +63,7 @@ __global__ void speculate_ban_bad_words(T* logits,
|
|||||||
template __global__ void speculate_ban_bad_words( \
|
template __global__ void speculate_ban_bad_words( \
|
||||||
DATA_TYPE* logits, \
|
DATA_TYPE* logits, \
|
||||||
const int64_t* bad_words_list, \
|
const int64_t* bad_words_list, \
|
||||||
const int* output_padding_offset, \
|
const int* batch_id_per_token_output_lm, \
|
||||||
const int64_t bs, \
|
const int64_t bs, \
|
||||||
const int64_t length, \
|
const int64_t length, \
|
||||||
const int64_t bad_words_length, \
|
const int64_t bad_words_length, \
|
||||||
|
|||||||
+11
-11
@@ -11,8 +11,8 @@ __global__ void speculate_min_length_logits_process(
|
|||||||
const int64_t* cur_len,
|
const int64_t* cur_len,
|
||||||
const int64_t* min_len,
|
const int64_t* min_len,
|
||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -29,26 +29,26 @@ __global__ void speculate_min_length_logits_process(
|
|||||||
int64_t eos_token_id_now;
|
int64_t eos_token_id_now;
|
||||||
int64_t bi;
|
int64_t bi;
|
||||||
int64_t end_num;
|
int64_t end_num;
|
||||||
int output_padding_offset_now;
|
int batch_id_per_token_output_now;
|
||||||
int output_cum_offsets_now;
|
int cu_seqlens_q_output_now;
|
||||||
__simd__ float float32logits_now[32];
|
__simd__ float float32logits_now[32];
|
||||||
|
|
||||||
for (int64_t i = tid; i < token_num * end_length; i += nthreads) {
|
for (int64_t i = tid; i < token_num * end_length; i += nthreads) {
|
||||||
int64_t token_idx = i / end_length;
|
int64_t token_idx = i / end_length;
|
||||||
GM2LM(output_padding_offset + token_idx,
|
GM2LM(batch_id_per_token_output + token_idx,
|
||||||
&output_padding_offset_now,
|
&batch_id_per_token_output_now,
|
||||||
sizeof(int));
|
sizeof(int));
|
||||||
bi = (token_idx + output_padding_offset_now) / max_seq_len;
|
bi = batch_id_per_token_output[token_idx];
|
||||||
if (bi >= bs) {
|
if (bi >= bs) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
end_num = i % end_length;
|
end_num = i % end_length;
|
||||||
GM2LM_ASYNC(
|
GM2LM_ASYNC(
|
||||||
output_cum_offsets + bi, (void*)&output_cum_offsets_now, sizeof(int));
|
cu_seqlens_q_output + bi, (void*)&cu_seqlens_q_output_now, sizeof(int));
|
||||||
GM2LM_ASYNC(cur_len + bi, (void*)&(cur_len_now), sizeof(int64_t));
|
GM2LM_ASYNC(cur_len + bi, (void*)&(cur_len_now), sizeof(int64_t));
|
||||||
GM2LM_ASYNC(min_len + bi, (void*)&(min_len_now), sizeof(int64_t));
|
GM2LM_ASYNC(min_len + bi, (void*)&(min_len_now), sizeof(int64_t));
|
||||||
mfence();
|
mfence();
|
||||||
int query_start_token_idx = bi * max_seq_len - output_cum_offsets_now;
|
int query_start_token_idx = cu_seqlens_q_output_now;
|
||||||
if (cur_len_now >= 0 &&
|
if (cur_len_now >= 0 &&
|
||||||
(cur_len_now + (token_idx - query_start_token_idx) < min_len_now)) {
|
(cur_len_now + (token_idx - query_start_token_idx) < min_len_now)) {
|
||||||
GM2LM(
|
GM2LM(
|
||||||
@@ -74,8 +74,8 @@ __global__ void speculate_min_length_logits_process(
|
|||||||
const int64_t* cur_len, \
|
const int64_t* cur_len, \
|
||||||
const int64_t* min_len, \
|
const int64_t* min_len, \
|
||||||
const int64_t* eos_token_id, \
|
const int64_t* eos_token_id, \
|
||||||
const int* output_padding_offset, \
|
const int* batch_id_per_token_output, \
|
||||||
const int* output_cum_offsets, \
|
const int* cu_seqlens_q_output, \
|
||||||
const int64_t bs, \
|
const int64_t bs, \
|
||||||
const int64_t length, \
|
const int64_t length, \
|
||||||
const int64_t length_id, \
|
const int64_t length_id, \
|
||||||
|
|||||||
+157
@@ -0,0 +1,157 @@
|
|||||||
|
#include "xpu/kernel/cluster.h"
|
||||||
|
#include "xpu/kernel/cluster_partition.h"
|
||||||
|
#include "xpu/kernel/cluster_primitive.h"
|
||||||
|
|
||||||
|
#include "xpu/kernel/cluster_debug.h"
|
||||||
|
|
||||||
|
namespace fd_xpu3 {
|
||||||
|
|
||||||
|
#define MAX_BATCH_SIZE 1024
|
||||||
|
|
||||||
|
static inline __device__ int v_reduce_sum_int32(int32x16_t& v0) {
|
||||||
|
auto v1 = vsrlp_int32x16(1 << 8, v0);
|
||||||
|
v0 = vvadd_int32x16(v0, v1);
|
||||||
|
v1 = vsrlp_int32x16(1 << 7, v0);
|
||||||
|
v0 = vvadd_int32x16(v0, v1);
|
||||||
|
v1 = vsrlp_int32x16(1 << 6, v0);
|
||||||
|
v0 = vvadd_int32x16(v0, v1);
|
||||||
|
v1 = vsrlp_int32x16(1 << 5, v0);
|
||||||
|
v0 = vvadd_int32x16(v0, v1);
|
||||||
|
return vextract_int32x16(v0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int* x,
|
||||||
|
int64_t len) {
|
||||||
|
int32x16_t x_l, x_h;
|
||||||
|
int32x16_t sum = vset_zero_int();
|
||||||
|
const auto rounddown_len = rounddown32(len);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < rounddown_len; i += 32) {
|
||||||
|
vload2_sm(x + i, x_l, x_h);
|
||||||
|
sum = vvadd_int32x16(sum, x_l);
|
||||||
|
sum = vvadd_int32x16(sum, x_h);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rounddown_len < len) {
|
||||||
|
const auto mask = ~(-1 << (len - rounddown_len));
|
||||||
|
vload2_sm_mz(x + rounddown_len, x_l, x_h, mask);
|
||||||
|
sum = vvadd_int32x16(sum, x_l);
|
||||||
|
sum = vvadd_int32x16(sum, x_h);
|
||||||
|
}
|
||||||
|
return v_reduce_sum_int32(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void speculate_preprocess_kernel(
|
||||||
|
int64_t* ids_remove_padding,
|
||||||
|
int* batch_id_per_token,
|
||||||
|
int* cu_seqlens_q,
|
||||||
|
int* cu_seqlens_k,
|
||||||
|
int* seq_lens_output,
|
||||||
|
int* cu_seq_lens_q_output,
|
||||||
|
int* batch_id_per_token_output,
|
||||||
|
int* real_output_token_num,
|
||||||
|
const int64_t* input_data,
|
||||||
|
const int* seq_lens,
|
||||||
|
const int64_t* draft_tokens,
|
||||||
|
const int* seq_lens_encoder,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_draft_tokens_per_batch,
|
||||||
|
const int real_bs) {
|
||||||
|
int cid = core_id();
|
||||||
|
int ncores = core_num();
|
||||||
|
int clusterid = cluster_id();
|
||||||
|
int nclusters = cluster_num();
|
||||||
|
__shared__ int sm_seq_lens[MAX_BATCH_SIZE];
|
||||||
|
__shared__ int sm_seq_lens_output[MAX_BATCH_SIZE];
|
||||||
|
__shared__ int sm_seq_lens_encoder[MAX_BATCH_SIZE];
|
||||||
|
__shared__ int sm_cum_seq_len, sm_cum_seq_len_output;
|
||||||
|
__simd__ __shared__ int buffer_cu_seqlens[64];
|
||||||
|
__simd__ __shared__ int buffer_cu_seqlens_output[64];
|
||||||
|
|
||||||
|
if (cid == 0) {
|
||||||
|
GM2SM_ASYNC(seq_lens, sm_seq_lens, sizeof(int) * real_bs);
|
||||||
|
GM2SM(seq_lens_encoder, sm_seq_lens_encoder, sizeof(int) * real_bs);
|
||||||
|
}
|
||||||
|
sync_all();
|
||||||
|
for (int bid = cid; bid < real_bs; bid += ncores) {
|
||||||
|
if (sm_seq_lens[bid] == 0) {
|
||||||
|
sm_seq_lens_output[bid] = 0;
|
||||||
|
} else if (sm_seq_lens[bid] == 1) {
|
||||||
|
sm_seq_lens_output[bid] = 1;
|
||||||
|
} else if (sm_seq_lens_encoder[bid] != 0) {
|
||||||
|
sm_seq_lens_output[bid] = 1;
|
||||||
|
} else {
|
||||||
|
sm_seq_lens_output[bid] = sm_seq_lens[bid];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mfence_sm();
|
||||||
|
sync_all();
|
||||||
|
|
||||||
|
for (int bi = clusterid; bi < real_bs; bi += nclusters) {
|
||||||
|
int cum_seq_len = 0;
|
||||||
|
int cum_seq_len_output = 0;
|
||||||
|
for (int i = cid; i < bi + 1; i += ncores) {
|
||||||
|
cum_seq_len += sm_seq_lens[i];
|
||||||
|
cum_seq_len_output += sm_seq_lens_output[i];
|
||||||
|
}
|
||||||
|
buffer_cu_seqlens[cid] = cum_seq_len;
|
||||||
|
buffer_cu_seqlens_output[cid] = cum_seq_len_output;
|
||||||
|
mfence();
|
||||||
|
sync_all();
|
||||||
|
if (cid == 0) {
|
||||||
|
cum_seq_len =
|
||||||
|
primitive_reduce_sum_sm(buffer_cu_seqlens, min(bi + 1, ncores));
|
||||||
|
cum_seq_len_output = primitive_reduce_sum_sm(buffer_cu_seqlens_output,
|
||||||
|
min(bi + 1, ncores));
|
||||||
|
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int));
|
||||||
|
LM2GM_ASYNC(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int));
|
||||||
|
LM2GM_ASYNC(
|
||||||
|
&cum_seq_len_output, cu_seq_lens_q_output + bi + 1, sizeof(int));
|
||||||
|
if (bi == real_bs - 1) {
|
||||||
|
LM2GM_ASYNC(&cum_seq_len_output, real_output_token_num, sizeof(int));
|
||||||
|
}
|
||||||
|
sm_cum_seq_len = cum_seq_len;
|
||||||
|
sm_cum_seq_len_output = cum_seq_len_output;
|
||||||
|
}
|
||||||
|
mfence();
|
||||||
|
sync_all();
|
||||||
|
|
||||||
|
const int lm_seq_lens = sm_seq_lens[bi];
|
||||||
|
const int lm_seq_lens_encoder = sm_seq_lens_encoder[bi];
|
||||||
|
for (int i = cid; i < lm_seq_lens; i += ncores) {
|
||||||
|
const int tgt_seq_id = sm_cum_seq_len - lm_seq_lens + i;
|
||||||
|
if (max_draft_tokens_per_batch > 0 && lm_seq_lens_encoder <= 0) {
|
||||||
|
// speculative decoding
|
||||||
|
const int src_seq_id = bi * max_draft_tokens_per_batch + i;
|
||||||
|
int64_t lm_draft_tokens;
|
||||||
|
GM2LM(draft_tokens + src_seq_id, &lm_draft_tokens, sizeof(int64_t));
|
||||||
|
LM2GM(
|
||||||
|
&lm_draft_tokens, ids_remove_padding + tgt_seq_id, sizeof(int64_t));
|
||||||
|
} else {
|
||||||
|
// Non-speculative decoding
|
||||||
|
const int src_seq_id = bi * max_seq_len + i;
|
||||||
|
int64_t lm_input_data;
|
||||||
|
GM2LM(input_data + src_seq_id, &lm_input_data, sizeof(int64_t));
|
||||||
|
LM2GM(&lm_input_data, ids_remove_padding + tgt_seq_id, sizeof(int64_t));
|
||||||
|
}
|
||||||
|
LM2GM(&bi, batch_id_per_token + tgt_seq_id, sizeof(int));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int lm_seq_lens_output = sm_seq_lens_output[bi];
|
||||||
|
for (int i = cid; i < lm_seq_lens_output; i += ncores) {
|
||||||
|
const int tgt_seq_id_output =
|
||||||
|
sm_cum_seq_len_output - lm_seq_lens_output + i;
|
||||||
|
LM2GM(&bi, batch_id_per_token_output + tgt_seq_id_output, sizeof(int));
|
||||||
|
}
|
||||||
|
mfence();
|
||||||
|
sync_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cid == 0 && clusterid == 0) {
|
||||||
|
const int lm_zero = 0;
|
||||||
|
LM2GM_ASYNC(&lm_zero, cu_seqlens_q, sizeof(int));
|
||||||
|
LM2GM_ASYNC(&lm_zero, cu_seqlens_k, sizeof(int));
|
||||||
|
LM2GM(&lm_zero, cu_seq_lens_q_output, sizeof(int));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace fd_xpu3
|
||||||
+31
-28
@@ -22,7 +22,7 @@ __device__ void speculate_update_repeat_times_normal(
|
|||||||
__global_ptr__ const int64_t *pre_ids,
|
__global_ptr__ const int64_t *pre_ids,
|
||||||
__global_ptr__ const int64_t *cur_len,
|
__global_ptr__ const int64_t *cur_len,
|
||||||
__global_ptr__ int *repeat_times,
|
__global_ptr__ int *repeat_times,
|
||||||
__global_ptr__ const int *output_padding_offset,
|
__global_ptr__ const int *batch_id_per_token_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -40,15 +40,17 @@ __device__ void speculate_update_repeat_times_normal(
|
|||||||
int n_length = (length + max_sm_len - 1) / max_sm_len;
|
int n_length = (length + max_sm_len - 1) / max_sm_len;
|
||||||
|
|
||||||
int64_t *cur_len_lm = (int64_t *)lm;
|
int64_t *cur_len_lm = (int64_t *)lm;
|
||||||
int output_padding_offset_now;
|
int batch_id_per_token_output_now;
|
||||||
GM2LM(cur_len, cur_len_lm, bs * sizeof(int64_t));
|
GM2LM(cur_len, cur_len_lm, bs * sizeof(int64_t));
|
||||||
|
|
||||||
for (int nli = 0; nli < n_length; nli++) {
|
for (int nli = 0; nli < n_length; nli++) {
|
||||||
int step = nli * max_sm_len;
|
int step = nli * max_sm_len;
|
||||||
int cur_length = min(max_sm_len, length - step);
|
int cur_length = min(max_sm_len, length - step);
|
||||||
for (int64_t i = clusterid; i < token_num; i += nclusters) {
|
for (int64_t i = clusterid; i < token_num; i += nclusters) {
|
||||||
GM2LM(output_padding_offset + i, &output_padding_offset_now, sizeof(int));
|
GM2LM(batch_id_per_token_output + i,
|
||||||
int64_t bi = (i + output_padding_offset_now) / max_seq_len;
|
&batch_id_per_token_output_now,
|
||||||
|
sizeof(int));
|
||||||
|
int64_t bi = batch_id_per_token_output_now;
|
||||||
if (bi >= bs || cur_len_lm[bi] < 0) {
|
if (bi >= bs || cur_len_lm[bi] < 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -86,10 +88,10 @@ __device__ void speculate_update_repeat_times_normal(
|
|||||||
__device__ void speculate_update_repeat_times_optimized(
|
__device__ void speculate_update_repeat_times_optimized(
|
||||||
char *lm,
|
char *lm,
|
||||||
__shared_ptr__ char *sm,
|
__shared_ptr__ char *sm,
|
||||||
__global_ptr__ const int64_t *pre_ids, // {bs, length_id}
|
__global_ptr__ const int64_t *pre_ids, // {bs, length_id}
|
||||||
__global_ptr__ const int64_t *cur_len, // {bs}
|
__global_ptr__ const int64_t *cur_len, // {bs}
|
||||||
__global_ptr__ int *repeat_times, // {token_num, length}
|
__global_ptr__ int *repeat_times, // {token_num, length}
|
||||||
__global_ptr__ const int *output_padding_offset, // {token_num}
|
__global_ptr__ const int *batch_id_per_token_output, // {token_num}
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -108,10 +110,10 @@ __device__ void speculate_update_repeat_times_optimized(
|
|||||||
int cur_len_sm_len = 640;
|
int cur_len_sm_len = 640;
|
||||||
__shared_ptr__ int64_t *cur_len_sm =
|
__shared_ptr__ int64_t *cur_len_sm =
|
||||||
(__shared_ptr__ int64_t *)(repeat_times_sm + repeat_times_sm_len);
|
(__shared_ptr__ int64_t *)(repeat_times_sm + repeat_times_sm_len);
|
||||||
__shared_ptr__ int *output_padding_offset_sm =
|
__shared_ptr__ int *batch_id_per_token_output_sm =
|
||||||
(__shared_ptr__ int *)(cur_len_sm + cur_len_sm_len);
|
(__shared_ptr__ int *)(cur_len_sm + cur_len_sm_len);
|
||||||
DoublePtr<1, SmPtr<int>> buffer_ptr_output_padding_offset(
|
DoublePtr<1, SmPtr<int>> buffer_ptr_batch_id_per_token_output(
|
||||||
(SmPtr<int>(output_padding_offset_sm)));
|
(SmPtr<int>(batch_id_per_token_output_sm)));
|
||||||
int pre_ids_lm_len = 4;
|
int pre_ids_lm_len = 4;
|
||||||
int64_t *pre_ids_lm = (int64_t *)lm;
|
int64_t *pre_ids_lm = (int64_t *)lm;
|
||||||
DoublePtr<4, LmPtr<int64_t>> buffer_ptr_pre_ids((LmPtr<int64_t>(pre_ids_lm)));
|
DoublePtr<4, LmPtr<int64_t>> buffer_ptr_pre_ids((LmPtr<int64_t>(pre_ids_lm)));
|
||||||
@@ -119,18 +121,18 @@ __device__ void speculate_update_repeat_times_optimized(
|
|||||||
int64_t i = clusterid;
|
int64_t i = clusterid;
|
||||||
if (i < token_num && cid == 0) {
|
if (i < token_num && cid == 0) {
|
||||||
GM2SM_ASYNC(cur_len, cur_len_sm, bs * sizeof(int64_t));
|
GM2SM_ASYNC(cur_len, cur_len_sm, bs * sizeof(int64_t));
|
||||||
buffer_ptr_output_padding_offset.gm_load_async(output_padding_offset + i,
|
buffer_ptr_batch_id_per_token_output.gm_load_async(
|
||||||
1);
|
batch_id_per_token_output + i, 1);
|
||||||
mfence_sm();
|
mfence_sm();
|
||||||
}
|
}
|
||||||
sync_all();
|
sync_all();
|
||||||
for (; i < token_num; i += nclusters) {
|
for (; i < token_num; i += nclusters) {
|
||||||
if (cid == 0 && i + nclusters < token_num) {
|
if (cid == 0 && i + nclusters < token_num) {
|
||||||
buffer_ptr_output_padding_offset.next().gm_load_async(
|
buffer_ptr_batch_id_per_token_output.next().gm_load_async(
|
||||||
output_padding_offset + i + nclusters, 1);
|
batch_id_per_token_output + i + nclusters, 1);
|
||||||
}
|
}
|
||||||
int64_t bi = (i + (buffer_ptr_output_padding_offset.ptr[0])) / max_seq_len;
|
int64_t bi = buffer_ptr_batch_id_per_token_output.ptr[0];
|
||||||
buffer_ptr_output_padding_offset.toggle();
|
buffer_ptr_batch_id_per_token_output.toggle();
|
||||||
if (bi >= bs || cur_len_sm[bi] < 0) {
|
if (bi >= bs || cur_len_sm[bi] < 0) {
|
||||||
mfence_sm();
|
mfence_sm();
|
||||||
sync_all();
|
sync_all();
|
||||||
@@ -224,15 +226,16 @@ __device__ void speculate_update_repeat_times_optimized(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void speculate_update_repeat_times(const int64_t *pre_ids,
|
__global__ void speculate_update_repeat_times(
|
||||||
const int64_t *cur_len,
|
const int64_t *pre_ids,
|
||||||
int *repeat_times,
|
const int64_t *cur_len,
|
||||||
const int *output_padding_offset,
|
int *repeat_times,
|
||||||
const int64_t bs,
|
const int *batch_id_per_token_output,
|
||||||
const int64_t length,
|
const int64_t bs,
|
||||||
const int64_t length_id,
|
const int64_t length,
|
||||||
const int64_t token_num,
|
const int64_t length_id,
|
||||||
const int64_t max_seq_len) {
|
const int64_t token_num,
|
||||||
|
const int64_t max_seq_len) {
|
||||||
char lm[6 * 1024];
|
char lm[6 * 1024];
|
||||||
__shared__ char sm[256 * 1024];
|
__shared__ char sm[256 * 1024];
|
||||||
|
|
||||||
@@ -242,7 +245,7 @@ __global__ void speculate_update_repeat_times(const int64_t *pre_ids,
|
|||||||
pre_ids,
|
pre_ids,
|
||||||
cur_len,
|
cur_len,
|
||||||
repeat_times,
|
repeat_times,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
@@ -254,7 +257,7 @@ __global__ void speculate_update_repeat_times(const int64_t *pre_ids,
|
|||||||
pre_ids,
|
pre_ids,
|
||||||
cur_len,
|
cur_len,
|
||||||
repeat_times,
|
repeat_times,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
|
|||||||
+20
-23
@@ -25,7 +25,7 @@ __global__ void speculate_update_value_by_repeat_times(
|
|||||||
const T *presence_score,
|
const T *presence_score,
|
||||||
const float *temperatures,
|
const float *temperatures,
|
||||||
T *logits,
|
T *logits,
|
||||||
const int *output_padding_offset,
|
const int *batch_id_per_token_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t token_num,
|
const int64_t token_num,
|
||||||
@@ -46,17 +46,16 @@ __global__ void speculate_update_value_by_repeat_times(
|
|||||||
if (token_end >= token_num) {
|
if (token_end >= token_num) {
|
||||||
token_end = token_num - 1;
|
token_end = token_num - 1;
|
||||||
}
|
}
|
||||||
int output_padding_offset_start_lm;
|
int batch_id_per_token_output_start_lm;
|
||||||
int output_padding_offset_end_lm;
|
int batch_id_per_token_output_end_lm;
|
||||||
GM2LM_ASYNC(output_padding_offset + token_start,
|
GM2LM_ASYNC(batch_id_per_token_output + token_start,
|
||||||
(void *)&output_padding_offset_start_lm,
|
(void *)&batch_id_per_token_output_start_lm,
|
||||||
sizeof(int));
|
sizeof(int));
|
||||||
GM2LM(output_padding_offset + token_end,
|
GM2LM(batch_id_per_token_output + token_end,
|
||||||
(void *)&output_padding_offset_end_lm,
|
(void *)&batch_id_per_token_output_end_lm,
|
||||||
sizeof(int));
|
sizeof(int));
|
||||||
int64_t bs_start =
|
int64_t bs_start = batch_id_per_token_output_start_lm;
|
||||||
(token_start + output_padding_offset_start_lm) / max_seq_len;
|
int64_t bs_end = batch_id_per_token_output_end_lm;
|
||||||
int64_t bs_end = (token_end + output_padding_offset_end_lm) / max_seq_len;
|
|
||||||
const int param_len = 256;
|
const int param_len = 256;
|
||||||
// ncores = 64 for xpu2
|
// ncores = 64 for xpu2
|
||||||
__shared__ __simd__ float alpha_buf[param_len * 64];
|
__shared__ __simd__ float alpha_buf[param_len * 64];
|
||||||
@@ -89,13 +88,13 @@ __global__ void speculate_update_value_by_repeat_times(
|
|||||||
const int buffer_len = 512;
|
const int buffer_len = 512;
|
||||||
__simd__ float logits_lm[buffer_len];
|
__simd__ float logits_lm[buffer_len];
|
||||||
int times_lm[buffer_len];
|
int times_lm[buffer_len];
|
||||||
int output_padding_offset_lm[buffer_len];
|
int batch_id_per_token_output_lm[buffer_len];
|
||||||
|
|
||||||
for (int64_t i = start; i < end; i += buffer_len) {
|
for (int64_t i = start; i < end; i += buffer_len) {
|
||||||
int read_len = min(end - i, buffer_len);
|
int read_len = min(end - i, buffer_len);
|
||||||
GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T));
|
GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T));
|
||||||
GM2LM_ASYNC(output_padding_offset + i / length,
|
GM2LM_ASYNC(batch_id_per_token_output + i / length,
|
||||||
output_padding_offset_lm,
|
batch_id_per_token_output_lm,
|
||||||
((read_len + length - 1) / length + 1) * sizeof(int));
|
((read_len + length - 1) / length + 1) * sizeof(int));
|
||||||
GM2LM(repeat_times + i, times_lm, read_len * sizeof(int));
|
GM2LM(repeat_times + i, times_lm, read_len * sizeof(int));
|
||||||
primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len);
|
primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len);
|
||||||
@@ -104,7 +103,7 @@ __global__ void speculate_update_value_by_repeat_times(
|
|||||||
logit_now = logits_lm[j];
|
logit_now = logits_lm[j];
|
||||||
int token_idx = (i + j) / length;
|
int token_idx = (i + j) / length;
|
||||||
int bs_idx =
|
int bs_idx =
|
||||||
(token_idx + output_padding_offset_lm[token_idx - i / length]) /
|
(token_idx + batch_id_per_token_output_lm[token_idx - i / length]) /
|
||||||
max_seq_len;
|
max_seq_len;
|
||||||
if (bs_idx >= bs) {
|
if (bs_idx >= bs) {
|
||||||
continue;
|
continue;
|
||||||
@@ -134,7 +133,7 @@ __global__ void speculate_update_value_by_repeat_times(
|
|||||||
const DATA_TYPE *presence_score, \
|
const DATA_TYPE *presence_score, \
|
||||||
const float *temperatures, \
|
const float *temperatures, \
|
||||||
DATA_TYPE *logits, \
|
DATA_TYPE *logits, \
|
||||||
const int *output_padding_offset, \
|
const int *batch_id_per_token_output, \
|
||||||
const int64_t bs, \
|
const int64_t bs, \
|
||||||
const int64_t length, \
|
const int64_t length, \
|
||||||
const int64_t token_num, \
|
const int64_t token_num, \
|
||||||
@@ -151,7 +150,7 @@ __global__ void speculate_update_value_by_repeat_times_simd(
|
|||||||
const T *presence_score, // [bs]
|
const T *presence_score, // [bs]
|
||||||
const float *temperatures, // [bs]
|
const float *temperatures, // [bs]
|
||||||
T *logits, // [bs * length]
|
T *logits, // [bs * length]
|
||||||
const int *output_padding_offset,
|
const int *batch_id_per_token_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t token_num,
|
const int64_t token_num,
|
||||||
@@ -198,7 +197,7 @@ __global__ void speculate_update_value_by_repeat_times_simd(
|
|||||||
const int buffer_len = 512;
|
const int buffer_len = 512;
|
||||||
__simd__ float logits_lm[buffer_len];
|
__simd__ float logits_lm[buffer_len];
|
||||||
__simd__ float times_lm[buffer_len];
|
__simd__ float times_lm[buffer_len];
|
||||||
int output_padding_offset_lm[buffer_len];
|
int batch_id_per_token_output_lm[buffer_len];
|
||||||
|
|
||||||
float32x16_t logits_;
|
float32x16_t logits_;
|
||||||
float32x16_t logits_tmp_0;
|
float32x16_t logits_tmp_0;
|
||||||
@@ -208,8 +207,8 @@ __global__ void speculate_update_value_by_repeat_times_simd(
|
|||||||
for (int64_t i = start; i < end; i += buffer_len) {
|
for (int64_t i = start; i < end; i += buffer_len) {
|
||||||
int read_len = min(end - i, buffer_len);
|
int read_len = min(end - i, buffer_len);
|
||||||
GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T));
|
GM2LM_ASYNC(logits + i, logits_lm, read_len * sizeof(T));
|
||||||
GM2LM_ASYNC(output_padding_offset + i / length,
|
GM2LM_ASYNC(batch_id_per_token_output + i / length,
|
||||||
output_padding_offset_lm,
|
batch_id_per_token_output_lm,
|
||||||
((read_len + length - 1) / length + 1) * sizeof(int));
|
((read_len + length - 1) / length + 1) * sizeof(int));
|
||||||
GM2LM(repeat_times + i, times_lm, read_len * sizeof(int));
|
GM2LM(repeat_times + i, times_lm, read_len * sizeof(int));
|
||||||
primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len);
|
primitive_cast<T, float>((const T *)(logits_lm), logits_lm, read_len);
|
||||||
@@ -220,9 +219,7 @@ __global__ void speculate_update_value_by_repeat_times_simd(
|
|||||||
time_ = vload_lm_float32x16(times_lm + j);
|
time_ = vload_lm_float32x16(times_lm + j);
|
||||||
logits_ = vload_lm_float32x16(logits_lm + j);
|
logits_ = vload_lm_float32x16(logits_lm + j);
|
||||||
int token_idx = (i + j) / length;
|
int token_idx = (i + j) / length;
|
||||||
int bs_idx =
|
int bs_idx = batch_id_per_token_output_lm[token_idx - i / length];
|
||||||
(token_idx + output_padding_offset_lm[token_idx - i / length]) /
|
|
||||||
max_seq_len;
|
|
||||||
if (bs_idx >= bs) {
|
if (bs_idx >= bs) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -269,7 +266,7 @@ __global__ void speculate_update_value_by_repeat_times_simd(
|
|||||||
const DATA_TYPE *presence_score, \
|
const DATA_TYPE *presence_score, \
|
||||||
const float *temperatures, \
|
const float *temperatures, \
|
||||||
DATA_TYPE *logits, \
|
DATA_TYPE *logits, \
|
||||||
const int *output_padding_offset, \
|
const int *batch_id_per_token_output, \
|
||||||
const int64_t bs, \
|
const int64_t bs, \
|
||||||
const int64_t length, \
|
const int64_t length, \
|
||||||
const int64_t token_num, \
|
const int64_t token_num, \
|
||||||
|
|||||||
+22
-41
@@ -95,50 +95,31 @@ topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids,
|
|||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
__global__ void speculate_verify(
|
__global__ void speculate_verify(
|
||||||
const int64_t *sampled_token_ids,
|
const int64_t *sampled_token_ids,
|
||||||
int64_t *accept_tokens, // out [real_bsz, max_draft_tokens], 输出最终接收的
|
int64_t *accept_tokens, // out [real_bsz, max_draft_tokens]
|
||||||
// token(通过验证或采样)
|
int *accept_num, // out [real_bsz],
|
||||||
int *accept_num, // out [real_bsz], 每个序列最终接受的 token
|
int64_t *step_idx, // out [real_bsz],
|
||||||
// 数量(只统计通过验证的)
|
bool *stop_flags, // out [real_bsz],
|
||||||
int64_t
|
const int *seq_lens_encoder, // [real_bsz]
|
||||||
*step_idx, // out [real_bsz], 记录每个bid序列已经生成或接受的token数
|
const int *seq_lens_decoder, // [real_bsz]
|
||||||
bool *stop_flags, // out [real_bsz], 每个序列的停止标志,遇到 <eos>
|
const int64_t *draft_tokens, // [real_bsz, max_draft_tokens],
|
||||||
// 或长度超限时置 true
|
const int *actual_draft_token_nums, // [real_bsz], 实际有效的 token 数量
|
||||||
const int *seq_lens_encoder, // [real_bsz], 每个样本 encoder
|
|
||||||
// 输入长度,用于判断 prefill 阶段
|
|
||||||
const int *seq_lens_decoder, // [real_bsz], 每个样本 decoder 输出的 token
|
|
||||||
// 数(即 draft token 数)
|
|
||||||
const int64_t *
|
|
||||||
draft_tokens, // [real_bsz, max_draft_tokens], draft model 输出的 token
|
|
||||||
const int *actual_draft_token_nums, // [real_bsz], draft_tokens
|
|
||||||
// 中实际有效的 token 数量
|
|
||||||
const float *dev_curand_states, // used for random
|
const float *dev_curand_states, // used for random
|
||||||
const float *topp, // [real_bsz],TopP 阈值(如
|
const float *topp, // [real_bsz],
|
||||||
// 0.9),用于控制核采样截断概率和候选数
|
const int *seq_lens_this_time, // [real_bsz],
|
||||||
const int *seq_lens_this_time, // [real_bsz], 本轮 verify
|
|
||||||
// 阶段每个样本实际参与验证的 token 数
|
|
||||||
const int64_t
|
const int64_t
|
||||||
*verify_tokens, // [sum(seq_lens_this_time), max_candidate_len], verify
|
*verify_tokens, // [sum(seq_lens_this_time), max_candidate_len]
|
||||||
// decoder 输出的候选 token
|
const float *verify_scores,
|
||||||
const float
|
|
||||||
*verify_scores, // 同上, 每个 verify token 对应的概率分布,用于采样
|
|
||||||
const int64_t *max_dec_len, // [real_bsz],
|
const int64_t *max_dec_len, // [real_bsz],
|
||||||
// 每个样本允许生成的最大长度(超过则触发终止)
|
const int64_t *end_tokens, // [end_length]
|
||||||
const int64_t
|
const bool *is_block_step, // [real_bsz],
|
||||||
*end_tokens, // [end_length], 终止 token 列表(如 <eos>),命中即终止
|
const int *cu_seqlens_q_output,
|
||||||
const bool *is_block_step, // [real_bsz], 指示是否当前为 block step(为
|
const int *actual_candidate_len, // [sum(seq_lens_this_time)],
|
||||||
// true 时跳过 verify)
|
const int real_bsz, // batch size
|
||||||
const int
|
const int max_draft_tokens,
|
||||||
*output_cum_offsets, // [real_bsz], verify_tokens 的起始偏移,用于定位
|
|
||||||
// token 所在 verify 索引
|
|
||||||
const int *actual_candidate_len, // [sum(seq_lens_this_time)], 每个 verify
|
|
||||||
// token 实际可用候选数(用于 TopP 截断)
|
|
||||||
const int real_bsz, // batch size
|
|
||||||
const int max_draft_tokens, // scalar, 每个样本最多允许的 draft token 数
|
|
||||||
const int end_length,
|
const int end_length,
|
||||||
const int max_seq_len, // scalar, 每个序列的最大 token 数(用于偏移计算)
|
const int max_seq_len,
|
||||||
const int max_candidate_len, // scalar, 每个 verify token
|
const int max_candidate_len,
|
||||||
// 的最大候选数(用于验证或采样)
|
const int verify_window,
|
||||||
const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数)
|
|
||||||
const bool prefill_one_step_stop,
|
const bool prefill_one_step_stop,
|
||||||
const bool benchmark_mode,
|
const bool benchmark_mode,
|
||||||
const bool accept_all_drafts,
|
const bool accept_all_drafts,
|
||||||
@@ -151,7 +132,7 @@ __global__ void speculate_verify(
|
|||||||
if (is_block_step[bid]) {
|
if (is_block_step[bid]) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
|
const int start_token_id = cu_seqlens_q_output[bid];
|
||||||
if (stop_flags[bid]) {
|
if (stop_flags[bid]) {
|
||||||
stop_flag_now_int = 1;
|
stop_flag_now_int = 1;
|
||||||
} else { // 这里prefill阶段也会进入,但是因为draft
|
} else { // 这里prefill阶段也会进入,但是因为draft
|
||||||
|
|||||||
+38
-34
@@ -10,7 +10,7 @@ __device__ void top_p_candidates_big_n(
|
|||||||
char* lm,
|
char* lm,
|
||||||
__global_ptr__ const T* src,
|
__global_ptr__ const T* src,
|
||||||
__global_ptr__ const T* top_ps,
|
__global_ptr__ const T* top_ps,
|
||||||
__global_ptr__ const int* output_padding_offset,
|
__global_ptr__ const int* batch_id_per_token_output,
|
||||||
__global_ptr__ int64_t* out_id,
|
__global_ptr__ int64_t* out_id,
|
||||||
__global_ptr__ T* out_val,
|
__global_ptr__ T* out_val,
|
||||||
__global_ptr__ int* actual_candidates_lens,
|
__global_ptr__ int* actual_candidates_lens,
|
||||||
@@ -32,11 +32,13 @@ __device__ void top_p_candidates_big_n(
|
|||||||
__shared__ T sm_out_val[64 * TopPBeamTopK];
|
__shared__ T sm_out_val[64 * TopPBeamTopK];
|
||||||
|
|
||||||
// only used in core 0
|
// only used in core 0
|
||||||
int lm_output_padding_offset;
|
int lm_batch_id_per_token_output;
|
||||||
|
|
||||||
for (int64_t i = cluster_id(); i < token_num; i += cluster_num()) {
|
for (int64_t i = cluster_id(); i < token_num; i += cluster_num()) {
|
||||||
if (cid == 0) {
|
if (cid == 0) {
|
||||||
GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int));
|
GM2LM(batch_id_per_token_output + i,
|
||||||
|
&lm_batch_id_per_token_output,
|
||||||
|
sizeof(int));
|
||||||
}
|
}
|
||||||
for (int64_t j = 0; j < TopPBeamTopK; j++) {
|
for (int64_t j = 0; j < TopPBeamTopK; j++) {
|
||||||
lm_out_id[j] = -1;
|
lm_out_id[j] = -1;
|
||||||
@@ -142,8 +144,7 @@ __device__ void top_p_candidates_big_n(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int ori_token_id = i + lm_output_padding_offset;
|
int bid = lm_batch_id_per_token_output;
|
||||||
int bid = ori_token_id / max_seq_len;
|
|
||||||
T lm_top_p;
|
T lm_top_p;
|
||||||
GM2LM(top_ps + bid, &lm_top_p, sizeof(T));
|
GM2LM(top_ps + bid, &lm_top_p, sizeof(T));
|
||||||
float top_p_value = static_cast<float>(lm_top_p);
|
float top_p_value = static_cast<float>(lm_top_p);
|
||||||
@@ -182,7 +183,7 @@ __device__ void top_p_candidates_normal(
|
|||||||
char* lm,
|
char* lm,
|
||||||
__global_ptr__ const T* src,
|
__global_ptr__ const T* src,
|
||||||
__global_ptr__ const T* top_ps,
|
__global_ptr__ const T* top_ps,
|
||||||
__global_ptr__ const int* output_padding_offset,
|
__global_ptr__ const int* batch_id_per_token_output,
|
||||||
__global_ptr__ int64_t* out_id,
|
__global_ptr__ int64_t* out_id,
|
||||||
__global_ptr__ T* out_val,
|
__global_ptr__ T* out_val,
|
||||||
__global_ptr__ int* actual_candidates_lens,
|
__global_ptr__ int* actual_candidates_lens,
|
||||||
@@ -200,7 +201,7 @@ __device__ void top_p_candidates_normal(
|
|||||||
int64_t lm_out_id[TopPBeamTopK];
|
int64_t lm_out_id[TopPBeamTopK];
|
||||||
T lm_out_val[TopPBeamTopK];
|
T lm_out_val[TopPBeamTopK];
|
||||||
|
|
||||||
int lm_output_padding_offset;
|
int lm_batch_id_per_token_output;
|
||||||
T lm_top_p;
|
T lm_top_p;
|
||||||
int64_t default_id = 0;
|
int64_t default_id = 0;
|
||||||
T default_val = static_cast<T>(0.f);
|
T default_val = static_cast<T>(0.f);
|
||||||
@@ -236,9 +237,10 @@ __device__ void top_p_candidates_normal(
|
|||||||
}
|
}
|
||||||
mfence_lm();
|
mfence_lm();
|
||||||
}
|
}
|
||||||
GM2LM(output_padding_offset + i, &lm_output_padding_offset, sizeof(int));
|
GM2LM(batch_id_per_token_output + i,
|
||||||
int ori_token_id = i + lm_output_padding_offset;
|
&lm_batch_id_per_token_output,
|
||||||
int bid = ori_token_id / max_seq_len;
|
sizeof(int));
|
||||||
|
int bid = lm_batch_id_per_token_output;
|
||||||
GM2LM(top_ps + bid, &lm_top_p, sizeof(T));
|
GM2LM(top_ps + bid, &lm_top_p, sizeof(T));
|
||||||
float top_p_value = static_cast<float>(lm_top_p);
|
float top_p_value = static_cast<float>(lm_top_p);
|
||||||
bool set_to_default_val = false;
|
bool set_to_default_val = false;
|
||||||
@@ -272,7 +274,7 @@ __device__ void top_p_candidates_normal(
|
|||||||
template <typename T, int MaxLength, int TopPBeamTopK>
|
template <typename T, int MaxLength, int TopPBeamTopK>
|
||||||
__global__ void top_p_candidates(const T* src,
|
__global__ void top_p_candidates(const T* src,
|
||||||
const T* top_ps,
|
const T* top_ps,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
int64_t* out_id,
|
int64_t* out_id,
|
||||||
T* out_val,
|
T* out_val,
|
||||||
int* actual_candidates_lens,
|
int* actual_candidates_lens,
|
||||||
@@ -284,29 +286,31 @@ __global__ void top_p_candidates(const T* src,
|
|||||||
if (token_num % (core_num() * cluster_num()) != 0 &&
|
if (token_num % (core_num() * cluster_num()) != 0 &&
|
||||||
vocab_size >= core_num() * (6 * 1024 / sizeof(T)) &&
|
vocab_size >= core_num() * (6 * 1024 / sizeof(T)) &&
|
||||||
vocab_size >= core_num() * TopPBeamTopK) {
|
vocab_size >= core_num() * TopPBeamTopK) {
|
||||||
top_p_candidates_big_n<T, MaxLength, TopPBeamTopK>(lm,
|
top_p_candidates_big_n<T, MaxLength, TopPBeamTopK>(
|
||||||
src,
|
lm,
|
||||||
top_ps,
|
src,
|
||||||
output_padding_offset,
|
top_ps,
|
||||||
out_id,
|
batch_id_per_token_output,
|
||||||
out_val,
|
out_id,
|
||||||
actual_candidates_lens,
|
out_val,
|
||||||
vocab_size,
|
actual_candidates_lens,
|
||||||
token_num,
|
vocab_size,
|
||||||
max_cadidate_len,
|
token_num,
|
||||||
max_seq_len);
|
max_cadidate_len,
|
||||||
|
max_seq_len);
|
||||||
} else {
|
} else {
|
||||||
top_p_candidates_normal<T, MaxLength, TopPBeamTopK>(lm,
|
top_p_candidates_normal<T, MaxLength, TopPBeamTopK>(
|
||||||
src,
|
lm,
|
||||||
top_ps,
|
src,
|
||||||
output_padding_offset,
|
top_ps,
|
||||||
out_id,
|
batch_id_per_token_output,
|
||||||
out_val,
|
out_id,
|
||||||
actual_candidates_lens,
|
out_val,
|
||||||
vocab_size,
|
actual_candidates_lens,
|
||||||
token_num,
|
vocab_size,
|
||||||
max_cadidate_len,
|
token_num,
|
||||||
max_seq_len);
|
max_cadidate_len,
|
||||||
|
max_seq_len);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,7 +318,7 @@ __global__ void top_p_candidates(const T* src,
|
|||||||
template __global__ void top_p_candidates<T, MaxLength, TopPBeamTopK>( \
|
template __global__ void top_p_candidates<T, MaxLength, TopPBeamTopK>( \
|
||||||
const T* src, \
|
const T* src, \
|
||||||
const T* top_ps, \
|
const T* top_ps, \
|
||||||
const int* output_padding_offset, \
|
const int* batch_id_per_token_output, \
|
||||||
int64_t* out_id, \
|
int64_t* out_id, \
|
||||||
T* out_val, \
|
T* out_val, \
|
||||||
int* actual_candidates_lens, \
|
int* actual_candidates_lens, \
|
||||||
|
|||||||
+217
@@ -0,0 +1,217 @@
|
|||||||
|
#include "xpu/kernel/cluster.h"
|
||||||
|
#include "xpu/kernel/cluster_partition.h"
|
||||||
|
#include "xpu/kernel/cluster_primitive.h"
|
||||||
|
|
||||||
|
#include "xpu/kernel/cluster_debug.h"
|
||||||
|
|
||||||
|
namespace fd_xpu3 {
|
||||||
|
|
||||||
|
#define MAX_BATCH_SIZE 1024
|
||||||
|
|
||||||
|
static inline __device__ int v_reduce_sum_int32(int32x16_t &v0) {
|
||||||
|
auto v1 = vsrlp_int32x16(1 << 8, v0);
|
||||||
|
v0 = vvadd_int32x16(v0, v1);
|
||||||
|
v1 = vsrlp_int32x16(1 << 7, v0);
|
||||||
|
v0 = vvadd_int32x16(v0, v1);
|
||||||
|
v1 = vsrlp_int32x16(1 << 6, v0);
|
||||||
|
v0 = vvadd_int32x16(v0, v1);
|
||||||
|
v1 = vsrlp_int32x16(1 << 5, v0);
|
||||||
|
v0 = vvadd_int32x16(v0, v1);
|
||||||
|
return vextract_int32x16(v0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ int primitive_reduce_sum_sm(__shared_ptr__ const int *x,
|
||||||
|
int64_t len) {
|
||||||
|
int32x16_t x_l, x_h;
|
||||||
|
int32x16_t sum = vset_zero_int();
|
||||||
|
const auto rounddown_len = rounddown32(len);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < rounddown_len; i += 32) {
|
||||||
|
vload2_sm(x + i, x_l, x_h);
|
||||||
|
sum = vvadd_int32x16(sum, x_l);
|
||||||
|
sum = vvadd_int32x16(sum, x_h);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rounddown_len < len) {
|
||||||
|
const auto mask = ~(-1 << (len - rounddown_len));
|
||||||
|
vload2_sm_mz(x + rounddown_len, x_l, x_h, mask);
|
||||||
|
sum = vvadd_int32x16(sum, x_l);
|
||||||
|
sum = vvadd_int32x16(sum, x_h);
|
||||||
|
}
|
||||||
|
return v_reduce_sum_int32(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ bool is_end_token(int64_t token,
|
||||||
|
__shared_ptr__ const int64_t *end_tokens,
|
||||||
|
int num_end_tokens) {
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int i = 0; i < num_end_tokens; i++) {
|
||||||
|
if (token == end_tokens[i]) return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void unified_update_model_status_kernel(int *seq_lens_encoder,
|
||||||
|
int *seq_lens_decoder,
|
||||||
|
bool *has_running_seqs,
|
||||||
|
int *mask_rollback,
|
||||||
|
int64_t *step_input_ids,
|
||||||
|
int *adaptive_step_input_len,
|
||||||
|
int64_t *step_output_ids,
|
||||||
|
int *step_output_len,
|
||||||
|
bool *stop_flags,
|
||||||
|
int *seq_lens_this_time,
|
||||||
|
const bool *is_paused,
|
||||||
|
int64_t *token_ids_all,
|
||||||
|
const int64_t *prompt_lens,
|
||||||
|
int64_t *step_idx,
|
||||||
|
const int64_t *end_tokens,
|
||||||
|
const int64_t *max_dec_len,
|
||||||
|
int real_bsz,
|
||||||
|
int max_bsz,
|
||||||
|
int max_step_tokens,
|
||||||
|
int max_model_len,
|
||||||
|
int num_end_tokens,
|
||||||
|
bool is_naive_mode,
|
||||||
|
bool prefill_one_step_stop) {
|
||||||
|
int cid = core_id();
|
||||||
|
int ncores = core_num();
|
||||||
|
int clusterid = cluster_id();
|
||||||
|
if (clusterid > 0) return;
|
||||||
|
__shared__ int sm_seq_lens_encoder[MAX_BATCH_SIZE];
|
||||||
|
__shared__ int sm_seq_lens_decoder[MAX_BATCH_SIZE];
|
||||||
|
__shared__ bool sm_stop_flags[MAX_BATCH_SIZE];
|
||||||
|
__shared__ int64_t sm_step_idx[MAX_BATCH_SIZE];
|
||||||
|
__shared__ bool sm_is_paused[MAX_BATCH_SIZE];
|
||||||
|
__shared__ int64_t sm_end_tokens[MAX_BATCH_SIZE];
|
||||||
|
|
||||||
|
__shared__ int sm_cum_seq_len, sm_cum_seq_len_output;
|
||||||
|
__shared__ int buffer_stop_flag_int[64];
|
||||||
|
if (cid == 0) {
|
||||||
|
GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, sizeof(int) * max_bsz);
|
||||||
|
GM2SM_ASYNC(seq_lens_decoder, sm_seq_lens_decoder, sizeof(int) * max_bsz);
|
||||||
|
GM2SM_ASYNC(stop_flags, sm_stop_flags, sizeof(bool) * max_bsz);
|
||||||
|
GM2SM_ASYNC(step_idx, sm_step_idx, sizeof(int64_t) * max_bsz);
|
||||||
|
GM2SM_ASYNC(is_paused, sm_is_paused, sizeof(bool) * max_bsz);
|
||||||
|
GM2SM_ASYNC(end_tokens, sm_end_tokens, sizeof(int64_t) * num_end_tokens);
|
||||||
|
}
|
||||||
|
buffer_stop_flag_int[cid] = 0;
|
||||||
|
mfence_sm();
|
||||||
|
sync_all();
|
||||||
|
for (int batch_id = cid; batch_id < max_bsz; batch_id += ncores) {
|
||||||
|
// Read state
|
||||||
|
int cur_seq_len_encoder = sm_seq_lens_encoder[batch_id];
|
||||||
|
int cur_seq_len_decoder = sm_seq_lens_decoder[batch_id];
|
||||||
|
bool cur_stop_flag = sm_stop_flags[batch_id];
|
||||||
|
int output_len = 0;
|
||||||
|
int64_t cur_step_idx = sm_step_idx[batch_id];
|
||||||
|
bool cur_is_paused = sm_is_paused[batch_id];
|
||||||
|
|
||||||
|
bool is_running = !cur_stop_flag && !cur_is_paused;
|
||||||
|
|
||||||
|
// Compute output length
|
||||||
|
if (is_running) {
|
||||||
|
if (is_naive_mode) {
|
||||||
|
output_len = 1;
|
||||||
|
} else {
|
||||||
|
output_len = step_output_len[batch_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EOS detection
|
||||||
|
if (is_running && output_len > 0) {
|
||||||
|
bool hit_stop = false;
|
||||||
|
__global_ptr__ int64_t *output_ids =
|
||||||
|
&step_output_ids[batch_id * max_step_tokens];
|
||||||
|
|
||||||
|
for (int i = 0; i < output_len; i++) {
|
||||||
|
cur_step_idx++;
|
||||||
|
int64_t token = output_ids[i];
|
||||||
|
bool is_eos = is_end_token(token, sm_end_tokens, num_end_tokens);
|
||||||
|
bool max_len_hit = (cur_step_idx >= max_dec_len[batch_id]);
|
||||||
|
|
||||||
|
if (is_eos || max_len_hit) {
|
||||||
|
if (!is_eos) output_ids[i] = sm_end_tokens[0];
|
||||||
|
output_len = i + 1;
|
||||||
|
cur_stop_flag = true;
|
||||||
|
hit_stop = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) {
|
||||||
|
cur_stop_flag = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update state and write back
|
||||||
|
if (is_running) {
|
||||||
|
if (cur_stop_flag) {
|
||||||
|
buffer_stop_flag_int[cid] += 1;
|
||||||
|
if (output_len == 0) cur_seq_len_decoder = 0;
|
||||||
|
stop_flags[batch_id] = true;
|
||||||
|
mask_rollback[batch_id] = 0;
|
||||||
|
} else if (cur_seq_len_encoder == 0) {
|
||||||
|
cur_seq_len_decoder += output_len;
|
||||||
|
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
|
||||||
|
} else {
|
||||||
|
mask_rollback[batch_id] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cur_seq_len_encoder > 0) {
|
||||||
|
cur_seq_len_decoder += cur_seq_len_encoder;
|
||||||
|
cur_seq_len_encoder = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
seq_lens_encoder[batch_id] = cur_seq_len_encoder;
|
||||||
|
seq_lens_decoder[batch_id] = cur_seq_len_decoder;
|
||||||
|
step_output_len[batch_id] = output_len;
|
||||||
|
step_idx[batch_id] = cur_step_idx;
|
||||||
|
|
||||||
|
// Write history to token_ids_all
|
||||||
|
if (cur_step_idx > 0 && output_len > 0) {
|
||||||
|
// Bounds check: highest write index is prompt_lens + cur_step_idx
|
||||||
|
if (prompt_lens[batch_id] + cur_step_idx < max_model_len) {
|
||||||
|
__global_ptr__ int64_t *token_ids_all_now =
|
||||||
|
&token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]];
|
||||||
|
__global_ptr__ int64_t *output_ids =
|
||||||
|
&step_output_ids[batch_id * max_step_tokens];
|
||||||
|
for (int i = 0; i < output_len; i++) {
|
||||||
|
token_ids_all_now[cur_step_idx - i] =
|
||||||
|
output_ids[output_len - 1 - i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup next input
|
||||||
|
if (output_len > 0) {
|
||||||
|
step_input_ids[batch_id * max_step_tokens] =
|
||||||
|
step_output_ids[batch_id * max_step_tokens + output_len - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_naive_mode) {
|
||||||
|
seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1;
|
||||||
|
}
|
||||||
|
} else if (batch_id >= real_bsz) {
|
||||||
|
// Padding slot: just count as stopped, don't modify state
|
||||||
|
buffer_stop_flag_int[cid] += 1;
|
||||||
|
} else {
|
||||||
|
// Stopped or paused slot (batch_id < real_bsz)
|
||||||
|
buffer_stop_flag_int[cid] += 1;
|
||||||
|
stop_flags[batch_id] = true;
|
||||||
|
seq_lens_decoder[batch_id] = 0;
|
||||||
|
seq_lens_this_time[batch_id] = 0;
|
||||||
|
step_output_len[batch_id] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mfence_sm();
|
||||||
|
sync_all();
|
||||||
|
int stop_flag_int = 0;
|
||||||
|
if (cid == 0) {
|
||||||
|
for (int i = 0; i < ncores; i++) {
|
||||||
|
stop_flag_int += buffer_stop_flag_int[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
has_running_seqs[0] = stop_flag_int < max_bsz;
|
||||||
|
}
|
||||||
|
} // namespace fd_xpu3
|
||||||
@@ -24,7 +24,7 @@ __attribute__((global)) void draft_model_update(
|
|||||||
int* seq_lens_encoder,
|
int* seq_lens_encoder,
|
||||||
int* seq_lens_decoder,
|
int* seq_lens_decoder,
|
||||||
int64_t* step_idx,
|
int64_t* step_idx,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
bool* stop_flags,
|
bool* stop_flags,
|
||||||
bool* not_need_stop,
|
bool* not_need_stop,
|
||||||
const int64_t* max_dec_len,
|
const int64_t* max_dec_len,
|
||||||
@@ -60,7 +60,7 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
int* seq_lens_encoder,
|
int* seq_lens_encoder,
|
||||||
int* seq_lens_decoder,
|
int* seq_lens_decoder,
|
||||||
int64_t* step_idx,
|
int64_t* step_idx,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
bool* stop_flags,
|
bool* stop_flags,
|
||||||
bool* not_need_stop,
|
bool* not_need_stop,
|
||||||
const int64_t* max_dec_len,
|
const int64_t* max_dec_len,
|
||||||
@@ -82,8 +82,7 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
auto* pre_ids_now = pre_ids + tid * pre_id_length;
|
auto* pre_ids_now = pre_ids + tid * pre_id_length;
|
||||||
auto* base_model_draft_tokens_now =
|
auto* base_model_draft_tokens_now =
|
||||||
base_model_draft_tokens + tid * max_base_model_draft_token;
|
base_model_draft_tokens + tid * max_base_model_draft_token;
|
||||||
const int next_tokens_start_id =
|
const int next_tokens_start_id = cu_seqlens_q_output[tid];
|
||||||
tid * max_seq_len - output_cum_offsets[tid];
|
|
||||||
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
auto* next_tokens_start = inter_next_tokens + next_tokens_start_id;
|
||||||
auto seq_len_this_time = seq_lens_this_time[tid];
|
auto seq_len_this_time = seq_lens_this_time[tid];
|
||||||
auto seq_len_encoder = seq_lens_encoder[tid];
|
auto seq_len_encoder = seq_lens_encoder[tid];
|
||||||
@@ -158,7 +157,7 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
int* seq_lens_encoder,
|
int* seq_lens_encoder,
|
||||||
int* seq_lens_decoder,
|
int* seq_lens_decoder,
|
||||||
int64_t* step_idx,
|
int64_t* step_idx,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
bool* stop_flags,
|
bool* stop_flags,
|
||||||
bool* not_need_stop,
|
bool* not_need_stop,
|
||||||
const int64_t* max_dec_len,
|
const int64_t* max_dec_len,
|
||||||
@@ -182,7 +181,7 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
reinterpret_cast<XPU_INT64*>(step_idx),
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
stop_flags,
|
stop_flags,
|
||||||
not_need_stop,
|
not_need_stop,
|
||||||
reinterpret_cast<const XPU_INT64*>(max_dec_len),
|
reinterpret_cast<const XPU_INT64*>(max_dec_len),
|
||||||
@@ -209,7 +208,7 @@ int draft_model_update(api::Context* ctx,
|
|||||||
int* seq_lens_encoder,
|
int* seq_lens_encoder,
|
||||||
int* seq_lens_decoder,
|
int* seq_lens_decoder,
|
||||||
int64_t* step_idx,
|
int64_t* step_idx,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
bool* stop_flags,
|
bool* stop_flags,
|
||||||
bool* not_need_stop,
|
bool* not_need_stop,
|
||||||
const int64_t* max_dec_len,
|
const int64_t* max_dec_len,
|
||||||
@@ -234,7 +233,7 @@ int draft_model_update(api::Context* ctx,
|
|||||||
seq_lens_decoder);
|
seq_lens_decoder);
|
||||||
WRAPPER_DUMP_PARAM6(ctx,
|
WRAPPER_DUMP_PARAM6(ctx,
|
||||||
step_idx,
|
step_idx,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
stop_flags,
|
stop_flags,
|
||||||
not_need_stop,
|
not_need_stop,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
@@ -255,7 +254,7 @@ int draft_model_update(api::Context* ctx,
|
|||||||
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder);
|
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_encoder);
|
||||||
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder);
|
WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_decoder);
|
||||||
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx);
|
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, step_idx);
|
||||||
WRAPPER_CHECK_PTR(ctx, int, bsz, output_cum_offsets);
|
WRAPPER_CHECK_PTR(ctx, int, bsz, cu_seqlens_q_output);
|
||||||
WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags);
|
WRAPPER_CHECK_PTR(ctx, bool, bsz, stop_flags);
|
||||||
WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop);
|
WRAPPER_CHECK_PTR(ctx, bool, 1, not_need_stop);
|
||||||
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, max_dec_len);
|
WRAPPER_CHECK_PTR(ctx, int64_t, bsz, max_dec_len);
|
||||||
@@ -272,7 +271,7 @@ int draft_model_update(api::Context* ctx,
|
|||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
step_idx,
|
step_idx,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
stop_flags,
|
stop_flags,
|
||||||
not_need_stop,
|
not_need_stop,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
@@ -296,7 +295,7 @@ int draft_model_update(api::Context* ctx,
|
|||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
step_idx,
|
step_idx,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
stop_flags,
|
stop_flags,
|
||||||
not_need_stop,
|
not_need_stop,
|
||||||
max_dec_len,
|
max_dec_len,
|
||||||
|
|||||||
@@ -0,0 +1,244 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "xpu/plugin.h"
|
||||||
|
#include "xpu/refactor/impl/xdnn_impl.h"
|
||||||
|
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||||
|
|
||||||
|
namespace fd_xpu3 {
|
||||||
|
|
||||||
|
__attribute__((global)) void speculate_preprocess_kernel(
|
||||||
|
int64_t* ids_remove_padding,
|
||||||
|
int* batch_id_per_token,
|
||||||
|
int* cu_seqlens_q,
|
||||||
|
int* cu_seqlens_k,
|
||||||
|
int* seq_lens_output,
|
||||||
|
int* cu_seq_lens_q_output,
|
||||||
|
int* batch_id_per_token_output,
|
||||||
|
int* real_output_token_num,
|
||||||
|
const int64_t* input_data,
|
||||||
|
const int* seq_lens,
|
||||||
|
const int64_t* draft_tokens,
|
||||||
|
const int* seq_lens_encoder,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_draft_tokens_per_batch,
|
||||||
|
const int real_bs);
|
||||||
|
} // namespace fd_xpu3
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
static int cpu_wrapper(api::Context* ctx,
|
||||||
|
int64_t* ids_remove_padding,
|
||||||
|
int* batch_id_per_token,
|
||||||
|
int* cu_seqlens_q,
|
||||||
|
int* cu_seqlens_k,
|
||||||
|
int* seq_lens_output,
|
||||||
|
int* cu_seq_lens_q_output,
|
||||||
|
int* batch_id_per_token_output,
|
||||||
|
int* real_output_token_num,
|
||||||
|
const int64_t* input_data,
|
||||||
|
const int* seq_lens,
|
||||||
|
const int64_t* draft_tokens,
|
||||||
|
const int* seq_lens_encoder,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_draft_tokens_per_batch,
|
||||||
|
const int token_num_data,
|
||||||
|
const int real_bs) {
|
||||||
|
cu_seqlens_q[0] = 0;
|
||||||
|
cu_seqlens_k[0] = 0;
|
||||||
|
for (int i = 0; i < real_bs; ++i) {
|
||||||
|
const int seq_len = seq_lens[i];
|
||||||
|
cu_seqlens_q[i + 1] = cu_seqlens_q[i] + seq_len;
|
||||||
|
cu_seqlens_k[i + 1] = cu_seqlens_k[i] + seq_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int bi = 0; bi < real_bs; ++bi) {
|
||||||
|
for (int i = 0; i < seq_lens[bi]; ++i) {
|
||||||
|
const int tgt_seq_id = cu_seqlens_q[bi + 1] - seq_lens[bi] + i;
|
||||||
|
if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) {
|
||||||
|
// speculative decoding
|
||||||
|
const int src_seq_id = bi * max_draft_tokens_per_batch + i;
|
||||||
|
ids_remove_padding[tgt_seq_id] = draft_tokens[src_seq_id];
|
||||||
|
} else {
|
||||||
|
// Non-speculative decoding
|
||||||
|
const int src_seq_id = bi * max_seq_len + i;
|
||||||
|
ids_remove_padding[tgt_seq_id] = input_data[src_seq_id];
|
||||||
|
}
|
||||||
|
batch_id_per_token[tgt_seq_id] = bi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int bid = 0; bid < real_bs; ++bid) {
|
||||||
|
if (seq_lens[bid] == 0) {
|
||||||
|
seq_lens_output[bid] = 0;
|
||||||
|
} else if (seq_lens[bid] == 1) {
|
||||||
|
seq_lens_output[bid] = 1;
|
||||||
|
} else if (seq_lens_encoder[bid] != 0) {
|
||||||
|
seq_lens_output[bid] = 1;
|
||||||
|
} else {
|
||||||
|
seq_lens_output[bid] = seq_lens[bid];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cu_seq_lens_q_output[0] = 0;
|
||||||
|
for (int i = 0; i < real_bs; ++i) {
|
||||||
|
cu_seq_lens_q_output[i + 1] = cu_seq_lens_q_output[i] + seq_lens_output[i];
|
||||||
|
}
|
||||||
|
real_output_token_num[0] = cu_seq_lens_q_output[real_bs];
|
||||||
|
|
||||||
|
for (int bi = 0; bi < real_bs; ++bi) {
|
||||||
|
for (int i = 0; i < seq_lens_output[bi]; ++i) {
|
||||||
|
const int tgt_seq_id_output =
|
||||||
|
cu_seq_lens_q_output[bi + 1] - seq_lens_output[bi] + i;
|
||||||
|
batch_id_per_token_output[tgt_seq_id_output] = bi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return api::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int xpu3_wrapper(api::Context* ctx,
|
||||||
|
int64_t* ids_remove_padding,
|
||||||
|
int* batch_id_per_token,
|
||||||
|
int* cu_seqlens_q,
|
||||||
|
int* cu_seqlens_k,
|
||||||
|
int* seq_lens_output,
|
||||||
|
int* cu_seq_lens_q_output,
|
||||||
|
int* batch_id_per_token_output,
|
||||||
|
int* real_output_token_num,
|
||||||
|
const int64_t* input_data,
|
||||||
|
const int* seq_lens,
|
||||||
|
const int64_t* draft_tokens,
|
||||||
|
const int* seq_lens_encoder,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_draft_tokens_per_batch,
|
||||||
|
const int token_num_data,
|
||||||
|
const int real_bs) {
|
||||||
|
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
|
||||||
|
int32_t ret_xre = fd_xpu3::
|
||||||
|
speculate_preprocess_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||||
|
reinterpret_cast<XPU_INT64*>(ids_remove_padding),
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seq_lens_output,
|
||||||
|
cu_seq_lens_q_output,
|
||||||
|
batch_id_per_token_output,
|
||||||
|
real_output_token_num,
|
||||||
|
reinterpret_cast<const XPU_INT64*>(input_data),
|
||||||
|
seq_lens,
|
||||||
|
reinterpret_cast<const XPU_INT64*>(draft_tokens),
|
||||||
|
seq_lens_encoder,
|
||||||
|
max_seq_len,
|
||||||
|
max_draft_tokens_per_batch,
|
||||||
|
real_bs);
|
||||||
|
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||||
|
return api::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
int speculate_preprocess(api::Context* ctx,
|
||||||
|
int64_t* ids_remove_padding,
|
||||||
|
int* batch_id_per_token,
|
||||||
|
int* cu_seqlens_q,
|
||||||
|
int* cu_seqlens_k,
|
||||||
|
int* seq_lens_output,
|
||||||
|
int* cu_seq_lens_q_output,
|
||||||
|
int* batch_id_per_token_output,
|
||||||
|
int* real_output_token_num,
|
||||||
|
const int64_t* input_data,
|
||||||
|
const int* seq_lens,
|
||||||
|
const int64_t* draft_tokens,
|
||||||
|
const int* seq_lens_encoder,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_draft_tokens_per_batch,
|
||||||
|
const int token_num_data,
|
||||||
|
const int real_bs) {
|
||||||
|
WRAPPER_CHECK_CTX(ctx);
|
||||||
|
WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_preprocess", int);
|
||||||
|
WRAPPER_DUMP_PARAM6(ctx,
|
||||||
|
ids_remove_padding,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seq_lens_output,
|
||||||
|
cu_seq_lens_q_output);
|
||||||
|
WRAPPER_DUMP_PARAM6(ctx,
|
||||||
|
batch_id_per_token_output,
|
||||||
|
real_output_token_num,
|
||||||
|
input_data,
|
||||||
|
seq_lens,
|
||||||
|
draft_tokens,
|
||||||
|
seq_lens_encoder);
|
||||||
|
WRAPPER_DUMP_PARAM3(ctx, max_seq_len, max_draft_tokens_per_batch, real_bs);
|
||||||
|
WRAPPER_DUMP(ctx);
|
||||||
|
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, token_num_data, ids_remove_padding);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, token_num_data, batch_id_per_token);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seqlens_q);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seqlens_k);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens_output);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, real_bs + 1, cu_seq_lens_q_output);
|
||||||
|
WRAPPER_CHECK_PTR(
|
||||||
|
ctx, int, real_bs* max_draft_tokens_per_batch, batch_id_per_token_output);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, 1, real_output_token_num);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, real_bs * max_seq_len, input_data);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens);
|
||||||
|
WRAPPER_CHECK_PTR(
|
||||||
|
ctx, int, real_bs* max_draft_tokens_per_batch, draft_tokens);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, real_bs, seq_lens_encoder);
|
||||||
|
|
||||||
|
if (ctx->dev().type() == api::kCPU) {
|
||||||
|
return cpu_wrapper(ctx,
|
||||||
|
ids_remove_padding,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seq_lens_output,
|
||||||
|
cu_seq_lens_q_output,
|
||||||
|
batch_id_per_token_output,
|
||||||
|
real_output_token_num,
|
||||||
|
input_data,
|
||||||
|
seq_lens,
|
||||||
|
draft_tokens,
|
||||||
|
seq_lens_encoder,
|
||||||
|
max_seq_len,
|
||||||
|
max_draft_tokens_per_batch,
|
||||||
|
token_num_data,
|
||||||
|
real_bs);
|
||||||
|
}
|
||||||
|
if (ctx->dev().type() == api::kXPU3) {
|
||||||
|
return xpu3_wrapper(ctx,
|
||||||
|
ids_remove_padding,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seq_lens_output,
|
||||||
|
cu_seq_lens_q_output,
|
||||||
|
batch_id_per_token_output,
|
||||||
|
real_output_token_num,
|
||||||
|
input_data,
|
||||||
|
seq_lens,
|
||||||
|
draft_tokens,
|
||||||
|
seq_lens_encoder,
|
||||||
|
max_seq_len,
|
||||||
|
max_draft_tokens_per_batch,
|
||||||
|
token_num_data,
|
||||||
|
real_bs);
|
||||||
|
}
|
||||||
|
WRAPPER_UNIMPLEMENTED(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace fastdeploy
|
||||||
+33
-32
@@ -25,8 +25,8 @@ __attribute__((global)) void speculate_min_length_logits_process(
|
|||||||
const int64_t* cur_len,
|
const int64_t* cur_len,
|
||||||
const int64_t* min_len,
|
const int64_t* min_len,
|
||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -37,7 +37,7 @@ __attribute__((global)) void speculate_update_repeat_times(
|
|||||||
const int64_t* pre_ids,
|
const int64_t* pre_ids,
|
||||||
const int64_t* cur_len,
|
const int64_t* cur_len,
|
||||||
int* repeat_times,
|
int* repeat_times,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -146,7 +146,7 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int64_t* bad_words,
|
const int64_t* bad_words,
|
||||||
const int* output_padding_offset,
|
const int* output_padding_offset,
|
||||||
const int* output_cum_offsets,
|
const int* batch_id_per_token_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -172,7 +172,7 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
WRAPPER_ASSERT_SUCCESS(ctx, ret);
|
||||||
for (int64_t i = 0; i < token_num; i++) {
|
for (int64_t i = 0; i < token_num; i++) {
|
||||||
int64_t bi = (i + output_padding_offset[i]) / max_seq_len;
|
int64_t bi = (i + output_padding_offset[i]) / max_seq_len;
|
||||||
int64_t query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi];
|
int64_t query_start_token_idx = batch_id_per_token_output[bi];
|
||||||
if (bi < bs && cur_len[bi] >= 0 &&
|
if (bi < bs && cur_len[bi] >= 0 &&
|
||||||
(cur_len[bi] + (i - query_start_token_idx) < min_len[bi])) {
|
(cur_len[bi] + (i - query_start_token_idx) < min_len[bi])) {
|
||||||
for (int64_t j = 0; j < end_length; j++) {
|
for (int64_t j = 0; j < end_length; j++) {
|
||||||
@@ -236,8 +236,8 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
const int64_t* min_len,
|
const int64_t* min_len,
|
||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int64_t* bad_words,
|
const int64_t* bad_words,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -268,7 +268,7 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
reinterpret_cast<const XPU_INT64*>(pre_ids),
|
reinterpret_cast<const XPU_INT64*>(pre_ids),
|
||||||
reinterpret_cast<const XPU_INT64*>(cur_len),
|
reinterpret_cast<const XPU_INT64*>(cur_len),
|
||||||
repeat_times,
|
repeat_times,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
@@ -282,8 +282,8 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
reinterpret_cast<const XPU_INT64*>(cur_len),
|
reinterpret_cast<const XPU_INT64*>(cur_len),
|
||||||
reinterpret_cast<const XPU_INT64*>(min_len),
|
reinterpret_cast<const XPU_INT64*>(min_len),
|
||||||
reinterpret_cast<const XPU_INT64*>(eos_token_id),
|
reinterpret_cast<const XPU_INT64*>(eos_token_id),
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
@@ -300,7 +300,7 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
presence_scores,
|
presence_scores,
|
||||||
temperatures,
|
temperatures,
|
||||||
logits,
|
logits,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
token_num,
|
token_num,
|
||||||
@@ -311,7 +311,7 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
ret_xre = ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
ret_xre = ban_bad_words_kernel<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||||
logits,
|
logits,
|
||||||
reinterpret_cast<const XPU_INT64*>(bad_words),
|
reinterpret_cast<const XPU_INT64*>(bad_words),
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_bad_words,
|
length_bad_words,
|
||||||
@@ -334,8 +334,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
|
|||||||
const int64_t* min_len,
|
const int64_t* min_len,
|
||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int64_t* bad_words,
|
const int64_t* bad_words,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -357,8 +357,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
|
|||||||
min_len,
|
min_len,
|
||||||
eos_token_id,
|
eos_token_id,
|
||||||
bad_words,
|
bad_words,
|
||||||
output_padding_offset,
|
cu_seqlens_q_output,
|
||||||
output_cum_offsets);
|
batch_id_per_token_output);
|
||||||
WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length);
|
WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length);
|
||||||
WRAPPER_DUMP_PARAM3(ctx, length_bad_words, token_num, max_seq_len);
|
WRAPPER_DUMP_PARAM3(ctx, length_bad_words, token_num, max_seq_len);
|
||||||
WRAPPER_DUMP(ctx);
|
WRAPPER_DUMP(ctx);
|
||||||
@@ -373,8 +373,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
|
|||||||
int64_t min_len_len = -1;
|
int64_t min_len_len = -1;
|
||||||
int64_t eos_token_id_len = -1;
|
int64_t eos_token_id_len = -1;
|
||||||
int64_t bad_words_len = -1;
|
int64_t bad_words_len = -1;
|
||||||
int64_t output_padding_offset_len = -1;
|
// int64_t output_padding_offset_len = -1;
|
||||||
int64_t output_cum_offsets_len = -1;
|
// int64_t output_cum_offsets_len = -1;
|
||||||
WRAPPER_ASSERT_LE(ctx, bs, 640);
|
WRAPPER_ASSERT_LE(ctx, bs, 640);
|
||||||
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id});
|
WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id});
|
||||||
WRAPPER_CHECK_SHAPE(ctx, &logits_len, {token_num, length});
|
WRAPPER_CHECK_SHAPE(ctx, &logits_len, {token_num, length});
|
||||||
@@ -386,8 +386,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
|
|||||||
WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs});
|
WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs});
|
||||||
WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length});
|
WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length});
|
||||||
WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words});
|
WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words});
|
||||||
WRAPPER_CHECK_SHAPE(ctx, &output_padding_offset_len, {token_num});
|
// WRAPPER_CHECK_SHAPE(ctx, &output_padding_offset_len, {token_num});
|
||||||
WRAPPER_CHECK_SHAPE(ctx, &output_cum_offsets_len, {bs});
|
// WRAPPER_CHECK_SHAPE(ctx, &output_cum_offsets_len, {bs});
|
||||||
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids);
|
WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids);
|
||||||
WRAPPER_CHECK_PTR(ctx, T, logits_len, logits);
|
WRAPPER_CHECK_PTR(ctx, T, logits_len, logits);
|
||||||
WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores);
|
WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores);
|
||||||
@@ -398,8 +398,9 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
|
|||||||
WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len);
|
WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len);
|
||||||
WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id);
|
WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id);
|
||||||
WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words);
|
WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words);
|
||||||
WRAPPER_CHECK_PTR(ctx, int, output_padding_offset_len, output_padding_offset);
|
// WRAPPER_CHECK_PTR(ctx, int, output_padding_offset_len,
|
||||||
WRAPPER_CHECK_PTR(ctx, int, output_cum_offsets_len, output_cum_offsets);
|
// output_padding_offset); WRAPPER_CHECK_PTR(ctx, int, output_cum_offsets_len,
|
||||||
|
// output_cum_offsets);
|
||||||
if (ctx->dev().type() == api::kCPU) {
|
if (ctx->dev().type() == api::kCPU) {
|
||||||
return cpu_wrapper<T>(ctx,
|
return cpu_wrapper<T>(ctx,
|
||||||
pre_ids,
|
pre_ids,
|
||||||
@@ -412,8 +413,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
|
|||||||
min_len,
|
min_len,
|
||||||
eos_token_id,
|
eos_token_id,
|
||||||
bad_words,
|
bad_words,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
@@ -434,8 +435,8 @@ int speculate_token_penalty_multi_scores(api::Context* ctx,
|
|||||||
min_len,
|
min_len,
|
||||||
eos_token_id,
|
eos_token_id,
|
||||||
bad_words,
|
bad_words,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id,
|
length_id,
|
||||||
@@ -459,8 +460,8 @@ template int speculate_token_penalty_multi_scores<float>(
|
|||||||
const int64_t* min_len,
|
const int64_t* min_len,
|
||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int64_t* bad_words,
|
const int64_t* bad_words,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -480,8 +481,8 @@ template int speculate_token_penalty_multi_scores<float16>(
|
|||||||
const int64_t* min_len,
|
const int64_t* min_len,
|
||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int64_t* bad_words,
|
const int64_t* bad_words,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
@@ -501,8 +502,8 @@ template int speculate_token_penalty_multi_scores<bfloat16>(
|
|||||||
const int64_t* min_len,
|
const int64_t* min_len,
|
||||||
const int64_t* eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int64_t* bad_words,
|
const int64_t* bad_words,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
const int* output_cum_offsets,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id,
|
const int64_t length_id,
|
||||||
|
|||||||
@@ -23,25 +23,25 @@ typedef uint32_t curandStatePhilox4_32_10_t;
|
|||||||
|
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
__attribute__((global)) void speculate_verify(
|
__attribute__((global)) void speculate_verify(
|
||||||
const int64_t* sampled_token_ids,
|
const int64_t *sampled_token_ids,
|
||||||
int64_t* accept_tokens,
|
int64_t *accept_tokens,
|
||||||
int* accept_num,
|
int *accept_num,
|
||||||
int64_t* step_idx,
|
int64_t *step_idx,
|
||||||
bool* stop_flags,
|
bool *stop_flags,
|
||||||
const int* seq_lens_encoder,
|
const int *seq_lens_encoder,
|
||||||
const int* seq_lens_decoder,
|
const int *seq_lens_decoder,
|
||||||
const int64_t* draft_tokens,
|
const int64_t *draft_tokens,
|
||||||
const int* actual_draft_token_nums,
|
const int *actual_draft_token_nums,
|
||||||
const float* dev_curand_states,
|
const float *dev_curand_states,
|
||||||
const float* topp,
|
const float *topp,
|
||||||
const int* seq_lens_this_time,
|
const int *seq_lens_this_time,
|
||||||
const int64_t* verify_tokens,
|
const int64_t *verify_tokens,
|
||||||
const float* verify_scores,
|
const float *verify_scores,
|
||||||
const int64_t* max_dec_len,
|
const int64_t *max_dec_len,
|
||||||
const int64_t* end_tokens,
|
const int64_t *end_tokens,
|
||||||
const bool* is_block_step,
|
const bool *is_block_step,
|
||||||
const int* output_cum_offsets,
|
const int *cu_seqlens_q_output,
|
||||||
const int* actual_candidate_len,
|
const int *actual_candidate_len,
|
||||||
const int real_bsz,
|
const int real_bsz,
|
||||||
const int max_draft_tokens,
|
const int max_draft_tokens,
|
||||||
const int end_length,
|
const int end_length,
|
||||||
@@ -58,7 +58,7 @@ namespace fastdeploy {
|
|||||||
namespace plugin {
|
namespace plugin {
|
||||||
|
|
||||||
static inline bool is_in_end(const int64_t id,
|
static inline bool is_in_end(const int64_t id,
|
||||||
const int64_t* end_ids,
|
const int64_t *end_ids,
|
||||||
int length) {
|
int length) {
|
||||||
bool flag = false;
|
bool flag = false;
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
@@ -69,7 +69,7 @@ static inline bool is_in_end(const int64_t id,
|
|||||||
return flag;
|
return flag;
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline bool is_in(const int64_t* candidates,
|
static inline bool is_in(const int64_t *candidates,
|
||||||
const int64_t draft,
|
const int64_t draft,
|
||||||
const int candidate_len) {
|
const int candidate_len) {
|
||||||
for (int i = 0; i < candidate_len; i++) {
|
for (int i = 0; i < candidate_len; i++) {
|
||||||
@@ -80,7 +80,7 @@ static inline bool is_in(const int64_t* candidates,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline unsigned int xorwow(unsigned int& state) { // NOLINT
|
static inline unsigned int xorwow(unsigned int &state) { // NOLINT
|
||||||
state ^= state >> 7;
|
state ^= state >> 7;
|
||||||
state ^= state << 9;
|
state ^= state << 9;
|
||||||
state ^= state >> 13;
|
state ^= state >> 13;
|
||||||
@@ -89,9 +89,9 @@ static inline unsigned int xorwow(unsigned int& state) { // NOLINT
|
|||||||
|
|
||||||
typedef uint32_t curandStatePhilox4_32_10_t;
|
typedef uint32_t curandStatePhilox4_32_10_t;
|
||||||
|
|
||||||
static int64_t topp_sampling_kernel(const int64_t* candidate_ids,
|
static int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||||
const float* candidate_scores,
|
const float *candidate_scores,
|
||||||
const float* dev_curand_states,
|
const float *dev_curand_states,
|
||||||
const int candidate_len,
|
const int candidate_len,
|
||||||
const float topp,
|
const float topp,
|
||||||
int tid) {
|
int tid) {
|
||||||
@@ -111,26 +111,26 @@ static int64_t topp_sampling_kernel(const int64_t* candidate_ids,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
static int cpu_wrapper(api::Context* ctx,
|
static int cpu_wrapper(api::Context *ctx,
|
||||||
const int64_t* sampled_token_ids,
|
const int64_t *sampled_token_ids,
|
||||||
int64_t* accept_tokens,
|
int64_t *accept_tokens,
|
||||||
int* accept_num,
|
int *accept_num,
|
||||||
int64_t* step_idx,
|
int64_t *step_idx,
|
||||||
bool* stop_flags,
|
bool *stop_flags,
|
||||||
const int* seq_lens_encoder,
|
const int *seq_lens_encoder,
|
||||||
const int* seq_lens_decoder,
|
const int *seq_lens_decoder,
|
||||||
const int64_t* draft_tokens,
|
const int64_t *draft_tokens,
|
||||||
const int* actual_draft_token_nums,
|
const int *actual_draft_token_nums,
|
||||||
const float* dev_curand_states,
|
const float *dev_curand_states,
|
||||||
const float* topp,
|
const float *topp,
|
||||||
const int* seq_lens_this_time,
|
const int *seq_lens_this_time,
|
||||||
const int64_t* verify_tokens,
|
const int64_t *verify_tokens,
|
||||||
const float* verify_scores,
|
const float *verify_scores,
|
||||||
const int64_t* max_dec_len,
|
const int64_t *max_dec_len,
|
||||||
const int64_t* end_tokens,
|
const int64_t *end_tokens,
|
||||||
const bool* is_block_step,
|
const bool *is_block_step,
|
||||||
const int* output_cum_offsets,
|
const int *cu_seqlens_q_output,
|
||||||
const int* actual_candidate_len,
|
const int *actual_candidate_len,
|
||||||
const int real_bsz,
|
const int real_bsz,
|
||||||
const int max_draft_tokens,
|
const int max_draft_tokens,
|
||||||
const int end_length,
|
const int end_length,
|
||||||
@@ -147,7 +147,7 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
int stop_flag_now_int = 0;
|
int stop_flag_now_int = 0;
|
||||||
|
|
||||||
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
if (!(is_block_step[bid] || bid >= real_bsz)) {
|
||||||
const int start_token_id = bid * max_seq_len - output_cum_offsets[bid];
|
const int start_token_id = cu_seqlens_q_output[bid];
|
||||||
// printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id);
|
// printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id);
|
||||||
// printf("bid %d\n", bid);
|
// printf("bid %d\n", bid);
|
||||||
|
|
||||||
@@ -155,11 +155,11 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
stop_flag_now_int = 1;
|
stop_flag_now_int = 1;
|
||||||
} else { // 这里prefill阶段也会进入,但是因为draft
|
} else { // 这里prefill阶段也会进入,但是因为draft
|
||||||
// tokens会置零,因此会直接到最后的采样阶段
|
// tokens会置零,因此会直接到最后的采样阶段
|
||||||
auto* verify_tokens_now =
|
auto *verify_tokens_now =
|
||||||
verify_tokens + start_token_id * max_candidate_len;
|
verify_tokens + start_token_id * max_candidate_len;
|
||||||
auto* draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens;
|
||||||
auto* actual_candidate_len_now = actual_candidate_len + start_token_id;
|
auto *actual_candidate_len_now = actual_candidate_len + start_token_id;
|
||||||
auto* sampled_token_id_now = sampled_token_ids + start_token_id;
|
auto *sampled_token_id_now = sampled_token_ids + start_token_id;
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||||
@@ -306,7 +306,7 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
|
// 也是从verify_tokens_now[i]中选一个 但是停止的情况不算
|
||||||
if (!stop_flag_now_int) {
|
if (!stop_flag_now_int) {
|
||||||
int64_t accept_token;
|
int64_t accept_token;
|
||||||
const float* verify_scores_now =
|
const float *verify_scores_now =
|
||||||
verify_scores + start_token_id * max_candidate_len;
|
verify_scores + start_token_id * max_candidate_len;
|
||||||
step_idx[bid]++;
|
step_idx[bid]++;
|
||||||
if (use_target_sampling) {
|
if (use_target_sampling) {
|
||||||
@@ -347,26 +347,26 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
static int xpu3_wrapper(api::Context* ctx,
|
static int xpu3_wrapper(api::Context *ctx,
|
||||||
const int64_t* sampled_token_ids,
|
const int64_t *sampled_token_ids,
|
||||||
int64_t* accept_tokens,
|
int64_t *accept_tokens,
|
||||||
int* accept_num,
|
int *accept_num,
|
||||||
int64_t* step_idx,
|
int64_t *step_idx,
|
||||||
bool* stop_flags,
|
bool *stop_flags,
|
||||||
const int* seq_lens_encoder,
|
const int *seq_lens_encoder,
|
||||||
const int* seq_lens_decoder,
|
const int *seq_lens_decoder,
|
||||||
const int64_t* draft_tokens,
|
const int64_t *draft_tokens,
|
||||||
const int* actual_draft_token_nums,
|
const int *actual_draft_token_nums,
|
||||||
const float* dev_curand_states,
|
const float *dev_curand_states,
|
||||||
const float* topp,
|
const float *topp,
|
||||||
const int* seq_lens_this_time,
|
const int *seq_lens_this_time,
|
||||||
const int64_t* verify_tokens,
|
const int64_t *verify_tokens,
|
||||||
const float* verify_scores,
|
const float *verify_scores,
|
||||||
const int64_t* max_dec_len,
|
const int64_t *max_dec_len,
|
||||||
const int64_t* end_tokens,
|
const int64_t *end_tokens,
|
||||||
const bool* is_block_step,
|
const bool *is_block_step,
|
||||||
const int* output_cum_offsets,
|
const int *cu_seqlens_q_output,
|
||||||
const int* actual_candidate_len,
|
const int *actual_candidate_len,
|
||||||
const int real_bsz,
|
const int real_bsz,
|
||||||
const int max_draft_tokens,
|
const int max_draft_tokens,
|
||||||
const int end_length,
|
const int end_length,
|
||||||
@@ -380,24 +380,24 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
|
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
|
||||||
int32_t ret_xre = fd_xpu3::speculate_verify<ENABLE_TOPP, USE_TOPK>
|
int32_t ret_xre = fd_xpu3::speculate_verify<ENABLE_TOPP, USE_TOPK>
|
||||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||||
reinterpret_cast<const XPU_INT64*>(sampled_token_ids),
|
reinterpret_cast<const XPU_INT64 *>(sampled_token_ids),
|
||||||
reinterpret_cast<XPU_INT64*>(accept_tokens),
|
reinterpret_cast<XPU_INT64 *>(accept_tokens),
|
||||||
accept_num,
|
accept_num,
|
||||||
reinterpret_cast<XPU_INT64*>(step_idx),
|
reinterpret_cast<XPU_INT64 *>(step_idx),
|
||||||
stop_flags,
|
stop_flags,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
reinterpret_cast<const XPU_INT64*>(draft_tokens),
|
reinterpret_cast<const XPU_INT64 *>(draft_tokens),
|
||||||
actual_draft_token_nums,
|
actual_draft_token_nums,
|
||||||
dev_curand_states,
|
dev_curand_states,
|
||||||
topp,
|
topp,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
reinterpret_cast<const XPU_INT64*>(verify_tokens),
|
reinterpret_cast<const XPU_INT64 *>(verify_tokens),
|
||||||
verify_scores,
|
verify_scores,
|
||||||
reinterpret_cast<const XPU_INT64*>(max_dec_len),
|
reinterpret_cast<const XPU_INT64 *>(max_dec_len),
|
||||||
reinterpret_cast<const XPU_INT64*>(end_tokens),
|
reinterpret_cast<const XPU_INT64 *>(end_tokens),
|
||||||
is_block_step,
|
is_block_step,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
actual_candidate_len,
|
actual_candidate_len,
|
||||||
real_bsz,
|
real_bsz,
|
||||||
max_draft_tokens,
|
max_draft_tokens,
|
||||||
@@ -413,26 +413,26 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
return api::SUCCESS;
|
return api::SUCCESS;
|
||||||
}
|
}
|
||||||
template <bool ENABLE_TOPP, bool USE_TOPK>
|
template <bool ENABLE_TOPP, bool USE_TOPK>
|
||||||
int speculate_verify(api::Context* ctx,
|
int speculate_verify(api::Context *ctx,
|
||||||
const int64_t* sampled_token_ids,
|
const int64_t *sampled_token_ids,
|
||||||
int64_t* accept_tokens,
|
int64_t *accept_tokens,
|
||||||
int* accept_num,
|
int *accept_num,
|
||||||
int64_t* step_idx,
|
int64_t *step_idx,
|
||||||
bool* stop_flags,
|
bool *stop_flags,
|
||||||
const int* seq_lens_encoder,
|
const int *seq_lens_encoder,
|
||||||
const int* seq_lens_decoder,
|
const int *seq_lens_decoder,
|
||||||
const int64_t* draft_tokens,
|
const int64_t *draft_tokens,
|
||||||
const int* actual_draft_token_nums,
|
const int *actual_draft_token_nums,
|
||||||
const float* dev_curand_states,
|
const float *dev_curand_states,
|
||||||
const float* topp,
|
const float *topp,
|
||||||
const int* seq_lens_this_time,
|
const int *seq_lens_this_time,
|
||||||
const int64_t* verify_tokens,
|
const int64_t *verify_tokens,
|
||||||
const float* verify_scores,
|
const float *verify_scores,
|
||||||
const int64_t* max_dec_len,
|
const int64_t *max_dec_len,
|
||||||
const int64_t* end_tokens,
|
const int64_t *end_tokens,
|
||||||
const bool* is_block_step,
|
const bool *is_block_step,
|
||||||
const int* output_cum_offsets,
|
const int *cu_seqlens_q_output,
|
||||||
const int* actual_candidate_len,
|
const int *actual_candidate_len,
|
||||||
const int real_bsz,
|
const int real_bsz,
|
||||||
const int max_draft_tokens,
|
const int max_draft_tokens,
|
||||||
const int end_length,
|
const int end_length,
|
||||||
@@ -462,7 +462,7 @@ int speculate_verify(api::Context* ctx,
|
|||||||
end_tokens);
|
end_tokens);
|
||||||
WRAPPER_DUMP_PARAM5(ctx,
|
WRAPPER_DUMP_PARAM5(ctx,
|
||||||
is_block_step,
|
is_block_step,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
actual_candidate_len,
|
actual_candidate_len,
|
||||||
real_bsz,
|
real_bsz,
|
||||||
max_draft_tokens);
|
max_draft_tokens);
|
||||||
@@ -492,7 +492,7 @@ int speculate_verify(api::Context* ctx,
|
|||||||
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len);
|
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len);
|
||||||
WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens);
|
WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens);
|
||||||
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step);
|
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step);
|
||||||
WRAPPER_CHECK_PTR(ctx, int, real_bsz, output_cum_offsets);
|
WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_seqlens_q_output);
|
||||||
// WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_candidate_len);
|
// WRAPPER_CHECK_PTR(ctx, int, real_bsz, actual_candidate_len);
|
||||||
|
|
||||||
// param check sm size limit
|
// param check sm size limit
|
||||||
@@ -525,7 +525,7 @@ int speculate_verify(api::Context* ctx,
|
|||||||
max_dec_len,
|
max_dec_len,
|
||||||
end_tokens,
|
end_tokens,
|
||||||
is_block_step,
|
is_block_step,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
actual_candidate_len,
|
actual_candidate_len,
|
||||||
real_bsz,
|
real_bsz,
|
||||||
max_draft_tokens,
|
max_draft_tokens,
|
||||||
@@ -557,7 +557,7 @@ int speculate_verify(api::Context* ctx,
|
|||||||
max_dec_len,
|
max_dec_len,
|
||||||
end_tokens,
|
end_tokens,
|
||||||
is_block_step,
|
is_block_step,
|
||||||
output_cum_offsets,
|
cu_seqlens_q_output,
|
||||||
actual_candidate_len,
|
actual_candidate_len,
|
||||||
real_bsz,
|
real_bsz,
|
||||||
max_draft_tokens,
|
max_draft_tokens,
|
||||||
@@ -575,36 +575,36 @@ int speculate_verify(api::Context* ctx,
|
|||||||
|
|
||||||
#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \
|
#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \
|
||||||
template int fastdeploy::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \
|
template int fastdeploy::plugin::speculate_verify<ENABLE_TOPP, USE_TOPK>( \
|
||||||
fastdeploy::plugin::api::Context*, /* xpu_ctx */ \
|
fastdeploy::plugin::api::Context *, /* xpu_ctx */ \
|
||||||
const int64_t*, /* sampled_token_ids */ \
|
const int64_t *, /* sampled_token_ids */ \
|
||||||
int64_t*, /* accept_tokens */ \
|
int64_t *, /* accept_tokens */ \
|
||||||
int*, /* accept_num */ \
|
int *, /* accept_num */ \
|
||||||
int64_t*, /* step_idx */ \
|
int64_t *, /* step_idx */ \
|
||||||
bool*, /* stop_flags */ \
|
bool *, /* stop_flags */ \
|
||||||
const int*, /* seq_lens_encoder */ \
|
const int *, /* seq_lens_encoder */ \
|
||||||
const int*, /* seq_lens_decoder */ \
|
const int *, /* seq_lens_decoder */ \
|
||||||
const int64_t*, /* draft_tokens */ \
|
const int64_t *, /* draft_tokens */ \
|
||||||
const int*, /* actual_draft_token_nums */ \
|
const int *, /* actual_draft_token_nums */ \
|
||||||
const float*, /* dev_curand_states or topp */ \
|
const float *, /* dev_curand_states or topp */ \
|
||||||
const float*, /* topp or nullptr */ \
|
const float *, /* topp or nullptr */ \
|
||||||
const int*, /* seq_lens_this_time */ \
|
const int *, /* seq_lens_this_time */ \
|
||||||
const int64_t*, /* verify_tokens */ \
|
const int64_t *, /* verify_tokens */ \
|
||||||
const float*, /* verify_scores */ \
|
const float *, /* verify_scores */ \
|
||||||
const int64_t*, /* max_dec_len */ \
|
const int64_t *, /* max_dec_len */ \
|
||||||
const int64_t*, /* end_tokens */ \
|
const int64_t *, /* end_tokens */ \
|
||||||
const bool*, /* is_block_step */ \
|
const bool *, /* is_block_step */ \
|
||||||
const int*, /* output_cum_offsets */ \
|
const int *, /* cu_seqlens_q_output */ \
|
||||||
const int*, /* actual_candidate_len */ \
|
const int *, /* actual_candidate_len */ \
|
||||||
int, /* real_bsz */ \
|
int, /* real_bsz */ \
|
||||||
int, /* max_draft_tokens */ \
|
int, /* max_draft_tokens */ \
|
||||||
int, /* end_length */ \
|
int, /* end_length */ \
|
||||||
int, /* max_seq_len */ \
|
int, /* max_seq_len */ \
|
||||||
int, /* max_candidate_len */ \
|
int, /* max_candidate_len */ \
|
||||||
int, /* verify_window */ \
|
int, /* verify_window */ \
|
||||||
bool, /* prefill_one_step_stop */ \
|
bool, /* prefill_one_step_stop */ \
|
||||||
bool, /* benchmark_mode */ \
|
bool, /* benchmark_mode */ \
|
||||||
bool, /* accept_all_drafts */ \
|
bool, /* accept_all_drafts */ \
|
||||||
bool /* use_target_sampling */ \
|
bool /* use_target_sampling */ \
|
||||||
);
|
);
|
||||||
|
|
||||||
INSTANTIATE_SPECULATE_VERIFY(false, false)
|
INSTANTIATE_SPECULATE_VERIFY(false, false)
|
||||||
|
|||||||
@@ -17,16 +17,17 @@
|
|||||||
|
|
||||||
namespace fd_xpu3 {
|
namespace fd_xpu3 {
|
||||||
template <typename T, int MaxLength, int TopPBeamTopK>
|
template <typename T, int MaxLength, int TopPBeamTopK>
|
||||||
__attribute__((global)) void top_p_candidates(const T* src,
|
__attribute__((global)) void top_p_candidates(
|
||||||
const T* top_ps,
|
const T* src,
|
||||||
const int* output_padding_offset,
|
const T* top_ps,
|
||||||
int64_t* out_id,
|
const int* batch_id_per_token_output,
|
||||||
T* out_val,
|
int64_t* out_id,
|
||||||
int* actual_candidates_lens,
|
T* out_val,
|
||||||
int vocab_size,
|
int* actual_candidates_lens,
|
||||||
int token_num,
|
int vocab_size,
|
||||||
int max_candidate_len,
|
int token_num,
|
||||||
int max_seq_len);
|
int max_candidate_len,
|
||||||
|
int max_seq_len);
|
||||||
} // namespace fd_xpu3
|
} // namespace fd_xpu3
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
@@ -36,7 +37,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
|
|||||||
static int cpu_wrapper(api::Context* ctx,
|
static int cpu_wrapper(api::Context* ctx,
|
||||||
const T* src,
|
const T* src,
|
||||||
const T* top_ps,
|
const T* top_ps,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
int64_t* out_id,
|
int64_t* out_id,
|
||||||
T* out_val,
|
T* out_val,
|
||||||
int* actual_candidates_lens,
|
int* actual_candidates_lens,
|
||||||
@@ -70,8 +71,7 @@ static int cpu_wrapper(api::Context* ctx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int ori_token_id = i + output_padding_offset[i];
|
int bid = batch_id_per_token_output[i];
|
||||||
int bid = ori_token_id / max_seq_len;
|
|
||||||
float top_p_value = static_cast<float>(top_ps[bid]);
|
float top_p_value = static_cast<float>(top_ps[bid]);
|
||||||
bool set_to_default_val = false;
|
bool set_to_default_val = false;
|
||||||
for (int j = 0; j < TopPBeamTopK; j++) {
|
for (int j = 0; j < TopPBeamTopK; j++) {
|
||||||
@@ -97,7 +97,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
|
|||||||
static int xpu3_wrapper(api::Context* ctx,
|
static int xpu3_wrapper(api::Context* ctx,
|
||||||
const T* src,
|
const T* src,
|
||||||
const T* top_ps,
|
const T* top_ps,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
int64_t* out_id,
|
int64_t* out_id,
|
||||||
T* out_val,
|
T* out_val,
|
||||||
int* actual_candidates_lens,
|
int* actual_candidates_lens,
|
||||||
@@ -110,7 +110,7 @@ static int xpu3_wrapper(api::Context* ctx,
|
|||||||
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
|
||||||
src,
|
src,
|
||||||
top_ps,
|
top_ps,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
reinterpret_cast<XPU_INT64*>(out_id),
|
reinterpret_cast<XPU_INT64*>(out_id),
|
||||||
out_val,
|
out_val,
|
||||||
actual_candidates_lens,
|
actual_candidates_lens,
|
||||||
@@ -126,7 +126,7 @@ template <typename T, int MaxLength, int TopPBeamTopK>
|
|||||||
int top_p_candidates(api::Context* ctx,
|
int top_p_candidates(api::Context* ctx,
|
||||||
const T* src,
|
const T* src,
|
||||||
const T* top_ps,
|
const T* top_ps,
|
||||||
const int* output_padding_offset,
|
const int* batch_id_per_token_output,
|
||||||
int64_t* out_id,
|
int64_t* out_id,
|
||||||
T* out_val,
|
T* out_val,
|
||||||
int* actual_candidates_lens,
|
int* actual_candidates_lens,
|
||||||
@@ -136,7 +136,8 @@ int top_p_candidates(api::Context* ctx,
|
|||||||
int max_seq_len) {
|
int max_seq_len) {
|
||||||
WRAPPER_CHECK_CTX(ctx);
|
WRAPPER_CHECK_CTX(ctx);
|
||||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "top_p_candidates", T);
|
WRAPPER_DUMP_FUNCTION_T1(ctx, "top_p_candidates", T);
|
||||||
WRAPPER_DUMP_PARAM5(ctx, src, top_ps, output_padding_offset, out_id, out_val);
|
WRAPPER_DUMP_PARAM5(
|
||||||
|
ctx, src, top_ps, batch_id_per_token_output, out_id, out_val);
|
||||||
WRAPPER_DUMP_PARAM5(ctx,
|
WRAPPER_DUMP_PARAM5(ctx,
|
||||||
actual_candidates_lens,
|
actual_candidates_lens,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
@@ -146,7 +147,7 @@ int top_p_candidates(api::Context* ctx,
|
|||||||
WRAPPER_DUMP(ctx);
|
WRAPPER_DUMP(ctx);
|
||||||
|
|
||||||
WRAPPER_CHECK_PTR(ctx, T, token_num * vocab_size, src);
|
WRAPPER_CHECK_PTR(ctx, T, token_num * vocab_size, src);
|
||||||
WRAPPER_CHECK_PTR(ctx, T, token_num, output_padding_offset);
|
WRAPPER_CHECK_PTR(ctx, T, token_num, batch_id_per_token_output);
|
||||||
WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_id);
|
WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_id);
|
||||||
WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_val);
|
WRAPPER_CHECK_PTR(ctx, T, token_num * candidate_len, out_val);
|
||||||
|
|
||||||
@@ -161,7 +162,7 @@ int top_p_candidates(api::Context* ctx,
|
|||||||
return cpu_wrapper<T, MaxLength, TopPBeamTopK>(ctx,
|
return cpu_wrapper<T, MaxLength, TopPBeamTopK>(ctx,
|
||||||
src,
|
src,
|
||||||
top_ps,
|
top_ps,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
out_id,
|
out_id,
|
||||||
out_val,
|
out_val,
|
||||||
actual_candidates_lens,
|
actual_candidates_lens,
|
||||||
@@ -173,7 +174,7 @@ int top_p_candidates(api::Context* ctx,
|
|||||||
return xpu3_wrapper<T, MaxLength, TopPBeamTopK>(ctx,
|
return xpu3_wrapper<T, MaxLength, TopPBeamTopK>(ctx,
|
||||||
src,
|
src,
|
||||||
top_ps,
|
top_ps,
|
||||||
output_padding_offset,
|
batch_id_per_token_output,
|
||||||
out_id,
|
out_id,
|
||||||
out_val,
|
out_val,
|
||||||
actual_candidates_lens,
|
actual_candidates_lens,
|
||||||
|
|||||||
+376
@@ -0,0 +1,376 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "xpu/plugin.h"
|
||||||
|
#include "xpu/refactor/impl/xdnn_impl.h"
|
||||||
|
#include "xpu/refactor/impl_public/wrapper_check.h"
|
||||||
|
|
||||||
|
namespace fd_xpu3 {
|
||||||
|
|
||||||
|
__attribute__((global)) void unified_update_model_status_kernel(
|
||||||
|
int *seq_lens_encoder,
|
||||||
|
int *seq_lens_decoder,
|
||||||
|
bool *has_running_seqs,
|
||||||
|
int *mask_rollback,
|
||||||
|
int64_t *step_input_ids,
|
||||||
|
int *adaptive_step_input_len,
|
||||||
|
int64_t *step_output_ids,
|
||||||
|
int *step_output_len,
|
||||||
|
bool *stop_flags,
|
||||||
|
int *seq_lens_this_time,
|
||||||
|
const bool *is_paused,
|
||||||
|
int64_t *token_ids_all,
|
||||||
|
const int64_t *prompt_lens,
|
||||||
|
int64_t *step_idx,
|
||||||
|
const int64_t *end_tokens,
|
||||||
|
const int64_t *max_dec_len,
|
||||||
|
int real_bsz,
|
||||||
|
int max_bsz,
|
||||||
|
int max_step_tokens,
|
||||||
|
int max_model_len,
|
||||||
|
int num_end_tokens,
|
||||||
|
bool is_naive_mode,
|
||||||
|
bool prefill_one_step_stop);
|
||||||
|
} // namespace fd_xpu3
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
namespace plugin {
|
||||||
|
|
||||||
|
bool is_end_token(int64_t token,
|
||||||
|
const int64_t *end_tokens,
|
||||||
|
int num_end_tokens) {
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int i = 0; i < num_end_tokens; i++) {
|
||||||
|
if (token == end_tokens[i]) return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int cpu_wrapper(api::Context *ctx,
|
||||||
|
int *seq_lens_encoder,
|
||||||
|
int *seq_lens_decoder,
|
||||||
|
bool *has_running_seqs,
|
||||||
|
int *mask_rollback,
|
||||||
|
int64_t *step_input_ids,
|
||||||
|
int *adaptive_step_input_len,
|
||||||
|
int64_t *step_output_ids,
|
||||||
|
int *step_output_len,
|
||||||
|
bool *stop_flags,
|
||||||
|
int *seq_lens_this_time,
|
||||||
|
const bool *is_paused,
|
||||||
|
int64_t *token_ids_all,
|
||||||
|
const int64_t *prompt_lens,
|
||||||
|
int64_t *step_idx,
|
||||||
|
const int64_t *end_tokens,
|
||||||
|
const int64_t *max_dec_len,
|
||||||
|
int real_bsz,
|
||||||
|
int max_bsz,
|
||||||
|
int max_step_tokens,
|
||||||
|
int max_model_len,
|
||||||
|
int num_end_tokens,
|
||||||
|
bool is_naive_mode,
|
||||||
|
bool prefill_one_step_stop) {
|
||||||
|
int stop_flag_int = 0;
|
||||||
|
|
||||||
|
for (int batch_id = 0; batch_id < max_bsz; batch_id++) {
|
||||||
|
// Read state
|
||||||
|
int cur_seq_len_encoder = seq_lens_encoder[batch_id];
|
||||||
|
int cur_seq_len_decoder = seq_lens_decoder[batch_id];
|
||||||
|
bool cur_stop_flag = stop_flags[batch_id];
|
||||||
|
int output_len = 0;
|
||||||
|
int64_t cur_step_idx = step_idx[batch_id];
|
||||||
|
bool cur_is_paused = is_paused[batch_id];
|
||||||
|
|
||||||
|
bool is_running = !cur_stop_flag && !cur_is_paused;
|
||||||
|
|
||||||
|
// Compute output length
|
||||||
|
if (is_running) {
|
||||||
|
if (is_naive_mode) {
|
||||||
|
output_len = 1;
|
||||||
|
} else {
|
||||||
|
output_len = step_output_len[batch_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EOS detection
|
||||||
|
if (is_running && output_len > 0) {
|
||||||
|
bool hit_stop = false;
|
||||||
|
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||||
|
|
||||||
|
for (int i = 0; i < output_len; i++) {
|
||||||
|
cur_step_idx++;
|
||||||
|
int64_t token = output_ids[i];
|
||||||
|
bool is_eos = is_end_token(token, end_tokens, num_end_tokens);
|
||||||
|
bool max_len_hit = (cur_step_idx >= max_dec_len[batch_id]);
|
||||||
|
|
||||||
|
if (is_eos || max_len_hit) {
|
||||||
|
if (!is_eos) output_ids[i] = end_tokens[0];
|
||||||
|
output_len = i + 1;
|
||||||
|
cur_stop_flag = true;
|
||||||
|
hit_stop = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hit_stop && prefill_one_step_stop && cur_seq_len_encoder > 0) {
|
||||||
|
cur_stop_flag = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update state and write back
|
||||||
|
if (is_running) {
|
||||||
|
if (cur_stop_flag) {
|
||||||
|
stop_flag_int += 1;
|
||||||
|
if (output_len == 0) cur_seq_len_decoder = 0;
|
||||||
|
stop_flags[batch_id] = true;
|
||||||
|
mask_rollback[batch_id] = 0;
|
||||||
|
} else if (cur_seq_len_encoder == 0) {
|
||||||
|
cur_seq_len_decoder += output_len;
|
||||||
|
mask_rollback[batch_id] = seq_lens_this_time[batch_id] - output_len;
|
||||||
|
} else {
|
||||||
|
mask_rollback[batch_id] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cur_seq_len_encoder > 0) {
|
||||||
|
cur_seq_len_decoder += cur_seq_len_encoder;
|
||||||
|
cur_seq_len_encoder = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
seq_lens_encoder[batch_id] = cur_seq_len_encoder;
|
||||||
|
seq_lens_decoder[batch_id] = cur_seq_len_decoder;
|
||||||
|
step_output_len[batch_id] = output_len;
|
||||||
|
step_idx[batch_id] = cur_step_idx;
|
||||||
|
|
||||||
|
// Write history to token_ids_all
|
||||||
|
if (cur_step_idx > 0 && output_len > 0) {
|
||||||
|
// Bounds check: highest write index is prompt_lens + cur_step_idx
|
||||||
|
if (prompt_lens[batch_id] + cur_step_idx < max_model_len) {
|
||||||
|
int64_t *token_ids_all_now =
|
||||||
|
&token_ids_all[batch_id * max_model_len + prompt_lens[batch_id]];
|
||||||
|
int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens];
|
||||||
|
for (int i = 0; i < output_len; i++) {
|
||||||
|
token_ids_all_now[cur_step_idx - i] =
|
||||||
|
output_ids[output_len - 1 - i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup next input
|
||||||
|
if (output_len > 0) {
|
||||||
|
step_input_ids[batch_id * max_step_tokens] =
|
||||||
|
step_output_ids[batch_id * max_step_tokens + output_len - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_naive_mode) {
|
||||||
|
seq_lens_this_time[batch_id] = cur_stop_flag ? 0 : 1;
|
||||||
|
}
|
||||||
|
} else if (batch_id >= real_bsz) {
|
||||||
|
// Padding slot: just count as stopped, don't modify state
|
||||||
|
stop_flag_int += 1;
|
||||||
|
} else {
|
||||||
|
// Stopped or paused slot (batch_id < real_bsz)
|
||||||
|
stop_flag_int += 1;
|
||||||
|
stop_flags[batch_id] = true;
|
||||||
|
seq_lens_decoder[batch_id] = 0;
|
||||||
|
seq_lens_this_time[batch_id] = 0;
|
||||||
|
step_output_len[batch_id] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
has_running_seqs[0] = stop_flag_int < max_bsz;
|
||||||
|
return api::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int xpu3_wrapper(api::Context *ctx,
|
||||||
|
int *seq_lens_encoder,
|
||||||
|
int *seq_lens_decoder,
|
||||||
|
bool *has_running_seqs,
|
||||||
|
int *mask_rollback,
|
||||||
|
int64_t *step_input_ids,
|
||||||
|
int *adaptive_step_input_len,
|
||||||
|
int64_t *step_output_ids,
|
||||||
|
int *step_output_len,
|
||||||
|
bool *stop_flags,
|
||||||
|
int *seq_lens_this_time,
|
||||||
|
const bool *is_paused,
|
||||||
|
int64_t *token_ids_all,
|
||||||
|
const int64_t *prompt_lens,
|
||||||
|
int64_t *step_idx,
|
||||||
|
const int64_t *end_tokens,
|
||||||
|
const int64_t *max_dec_len,
|
||||||
|
int real_bsz,
|
||||||
|
int max_bsz,
|
||||||
|
int max_step_tokens,
|
||||||
|
int max_model_len,
|
||||||
|
int num_end_tokens,
|
||||||
|
bool is_naive_mode,
|
||||||
|
bool prefill_one_step_stop) {
|
||||||
|
using XPU_INT64 = typename api::XPUIndexType<int64_t>::type;
|
||||||
|
int32_t ret_xre =
|
||||||
|
fd_xpu3::unified_update_model_status_kernel<<<ctx->ncluster(),
|
||||||
|
64,
|
||||||
|
ctx->xpu_stream>>>(
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
has_running_seqs,
|
||||||
|
mask_rollback,
|
||||||
|
reinterpret_cast<XPU_INT64 *>(step_input_ids),
|
||||||
|
adaptive_step_input_len,
|
||||||
|
reinterpret_cast<XPU_INT64 *>(step_output_ids),
|
||||||
|
step_output_len,
|
||||||
|
stop_flags,
|
||||||
|
seq_lens_this_time,
|
||||||
|
is_paused,
|
||||||
|
reinterpret_cast<XPU_INT64 *>(token_ids_all),
|
||||||
|
reinterpret_cast<const XPU_INT64 *>(prompt_lens),
|
||||||
|
reinterpret_cast<XPU_INT64 *>(step_idx),
|
||||||
|
reinterpret_cast<const XPU_INT64 *>(end_tokens),
|
||||||
|
reinterpret_cast<const XPU_INT64 *>(max_dec_len),
|
||||||
|
real_bsz,
|
||||||
|
max_bsz,
|
||||||
|
max_step_tokens,
|
||||||
|
max_model_len,
|
||||||
|
num_end_tokens,
|
||||||
|
is_naive_mode,
|
||||||
|
prefill_one_step_stop);
|
||||||
|
KERNEL_ASSERT_SUCCESS(ctx, ret_xre);
|
||||||
|
return api::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
int unified_update_model_status(api::Context *ctx,
|
||||||
|
int *seq_lens_encoder,
|
||||||
|
int *seq_lens_decoder,
|
||||||
|
bool *has_running_seqs,
|
||||||
|
int *mask_rollback,
|
||||||
|
int64_t *step_input_ids,
|
||||||
|
int *adaptive_step_input_len,
|
||||||
|
int64_t *step_output_ids,
|
||||||
|
int *step_output_len,
|
||||||
|
bool *stop_flags,
|
||||||
|
int *seq_lens_this_time,
|
||||||
|
const bool *is_paused,
|
||||||
|
int64_t *token_ids_all,
|
||||||
|
const int64_t *prompt_lens,
|
||||||
|
int64_t *step_idx,
|
||||||
|
const int64_t *end_tokens,
|
||||||
|
const int64_t *max_dec_len,
|
||||||
|
int real_bsz,
|
||||||
|
int max_bsz,
|
||||||
|
int max_step_tokens,
|
||||||
|
int max_model_len,
|
||||||
|
int num_end_tokens,
|
||||||
|
bool is_naive_mode,
|
||||||
|
bool prefill_one_step_stop) {
|
||||||
|
WRAPPER_CHECK_CTX(ctx);
|
||||||
|
WRAPPER_DUMP_FUNCTION_T1(ctx, "unified_update_model_status", int);
|
||||||
|
WRAPPER_DUMP_PARAM6(ctx,
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
has_running_seqs,
|
||||||
|
mask_rollback,
|
||||||
|
step_input_ids,
|
||||||
|
adaptive_step_input_len);
|
||||||
|
WRAPPER_DUMP_PARAM6(ctx,
|
||||||
|
step_output_ids,
|
||||||
|
step_output_len,
|
||||||
|
stop_flags,
|
||||||
|
seq_lens_this_time,
|
||||||
|
is_paused,
|
||||||
|
token_ids_all);
|
||||||
|
WRAPPER_DUMP_PARAM6(
|
||||||
|
ctx, prompt_lens, step_idx, end_tokens, max_dec_len, real_bsz, max_bsz);
|
||||||
|
WRAPPER_DUMP_PARAM5(ctx,
|
||||||
|
max_step_tokens,
|
||||||
|
max_model_len,
|
||||||
|
num_end_tokens,
|
||||||
|
is_naive_mode,
|
||||||
|
prefill_one_step_stop);
|
||||||
|
WRAPPER_DUMP(ctx);
|
||||||
|
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_encoder);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, max_bsz, seq_lens_decoder);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, bool, 1, has_running_seqs);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, max_bsz, mask_rollback);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_step_tokens, step_input_ids);
|
||||||
|
// WRAPPER_CHECK_PTR(ctx, int, 0, adaptive_step_input_len); // Temporarily
|
||||||
|
// unused
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_step_tokens, step_output_ids);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, max_bsz, step_output_len);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, bool, max_bsz, stop_flags);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, bool, max_bsz, is_paused);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz * max_model_len, token_ids_all);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, prompt_lens);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, step_idx);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, num_end_tokens, end_tokens);
|
||||||
|
WRAPPER_CHECK_PTR(ctx, int64_t, max_bsz, max_dec_len);
|
||||||
|
WRAPPER_ASSERT_GE(ctx, max_bsz, real_bsz);
|
||||||
|
WRAPPER_ASSERT_GE(ctx, 1024, num_end_tokens);
|
||||||
|
|
||||||
|
if (ctx->dev().type() == api::kCPU) {
|
||||||
|
return cpu_wrapper(ctx,
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
has_running_seqs,
|
||||||
|
mask_rollback,
|
||||||
|
step_input_ids,
|
||||||
|
adaptive_step_input_len,
|
||||||
|
step_output_ids,
|
||||||
|
step_output_len,
|
||||||
|
stop_flags,
|
||||||
|
seq_lens_this_time,
|
||||||
|
is_paused,
|
||||||
|
token_ids_all,
|
||||||
|
prompt_lens,
|
||||||
|
step_idx,
|
||||||
|
end_tokens,
|
||||||
|
max_dec_len,
|
||||||
|
real_bsz,
|
||||||
|
max_bsz,
|
||||||
|
max_step_tokens,
|
||||||
|
max_model_len,
|
||||||
|
num_end_tokens,
|
||||||
|
is_naive_mode,
|
||||||
|
prefill_one_step_stop);
|
||||||
|
}
|
||||||
|
if (ctx->dev().type() == api::kXPU3) {
|
||||||
|
return xpu3_wrapper(ctx,
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
has_running_seqs,
|
||||||
|
mask_rollback,
|
||||||
|
step_input_ids,
|
||||||
|
adaptive_step_input_len,
|
||||||
|
step_output_ids,
|
||||||
|
step_output_len,
|
||||||
|
stop_flags,
|
||||||
|
seq_lens_this_time,
|
||||||
|
is_paused,
|
||||||
|
token_ids_all,
|
||||||
|
prompt_lens,
|
||||||
|
step_idx,
|
||||||
|
end_tokens,
|
||||||
|
max_dec_len,
|
||||||
|
real_bsz,
|
||||||
|
max_bsz,
|
||||||
|
max_step_tokens,
|
||||||
|
max_model_len,
|
||||||
|
num_end_tokens,
|
||||||
|
is_naive_mode,
|
||||||
|
prefill_one_step_stop);
|
||||||
|
}
|
||||||
|
WRAPPER_UNIMPLEMENTED(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace plugin
|
||||||
|
} // namespace fastdeploy
|
||||||
@@ -0,0 +1,328 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.xpu import speculate_pre_process
|
||||||
|
|
||||||
|
|
||||||
|
def speculate_pre_process_ref(
|
||||||
|
input_ids,
|
||||||
|
seq_lens,
|
||||||
|
draft_tokens,
|
||||||
|
seq_lens_encoder,
|
||||||
|
max_seq_len,
|
||||||
|
max_draft_tokens_per_batch,
|
||||||
|
real_bsz,
|
||||||
|
token_num,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Python reference implementation for SpeculatePreProcessKernel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ids_remove_padding: int64[token_num]
|
||||||
|
batch_id_per_token: int32[token_num]
|
||||||
|
cu_seqlens_q: int32[real_bsz + 1]
|
||||||
|
cu_seqlens_k: int32[real_bsz + 1]
|
||||||
|
seq_lens_output: int32[real_bsz]
|
||||||
|
cu_seq_lens_q_output: int32[real_bsz + 1]
|
||||||
|
batch_id_per_token_output: int32[real_bsz * max_draft_tokens_per_batch]
|
||||||
|
real_output_token_num: int32[1]
|
||||||
|
"""
|
||||||
|
# --- Part 1: ids_remove_padding, batch_id_per_token, cu_seqlens_q/k ---
|
||||||
|
ids_remove_padding = np.zeros(token_num, dtype=np.int64)
|
||||||
|
batch_id_per_token = np.zeros(token_num, dtype=np.int32)
|
||||||
|
cu_seqlens_q = np.zeros(real_bsz + 1, dtype=np.int32)
|
||||||
|
cu_seqlens_k = np.zeros(real_bsz + 1, dtype=np.int32)
|
||||||
|
|
||||||
|
cum = 0
|
||||||
|
for bi in range(real_bsz):
|
||||||
|
cum += seq_lens[bi]
|
||||||
|
cu_seqlens_q[bi + 1] = cum
|
||||||
|
cu_seqlens_k[bi + 1] = cum
|
||||||
|
|
||||||
|
start = cum - seq_lens[bi]
|
||||||
|
for i in range(seq_lens[bi]):
|
||||||
|
tgt = start + i
|
||||||
|
if max_draft_tokens_per_batch > 0 and seq_lens_encoder[bi] <= 0:
|
||||||
|
src = bi * max_draft_tokens_per_batch + i
|
||||||
|
ids_remove_padding[tgt] = draft_tokens[src]
|
||||||
|
else:
|
||||||
|
src = bi * max_seq_len + i
|
||||||
|
ids_remove_padding[tgt] = input_ids[src]
|
||||||
|
batch_id_per_token[tgt] = bi
|
||||||
|
|
||||||
|
# --- Part 2: seq_lens_output ---
|
||||||
|
seq_lens_output = np.zeros(real_bsz, dtype=np.int32)
|
||||||
|
for bid in range(real_bsz):
|
||||||
|
if seq_lens[bid] == 0:
|
||||||
|
seq_lens_output[bid] = 0
|
||||||
|
elif seq_lens[bid] == 1:
|
||||||
|
seq_lens_output[bid] = 1
|
||||||
|
elif seq_lens_encoder[bid] != 0:
|
||||||
|
seq_lens_output[bid] = 1
|
||||||
|
else:
|
||||||
|
seq_lens_output[bid] = seq_lens[bid]
|
||||||
|
|
||||||
|
# --- Part 3: cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num ---
|
||||||
|
cu_seq_lens_q_output = np.zeros(real_bsz + 1, dtype=np.int32)
|
||||||
|
batch_id_per_token_output = np.zeros(real_bsz * max_draft_tokens_per_batch, dtype=np.int32)
|
||||||
|
|
||||||
|
cum_output = 0
|
||||||
|
for bi in range(real_bsz):
|
||||||
|
cum_output += seq_lens_output[bi]
|
||||||
|
cu_seq_lens_q_output[bi + 1] = cum_output
|
||||||
|
|
||||||
|
start_out = cum_output - seq_lens_output[bi]
|
||||||
|
for i in range(seq_lens_output[bi]):
|
||||||
|
batch_id_per_token_output[start_out + i] = bi
|
||||||
|
|
||||||
|
real_output_token_num = np.array([cum_output], dtype=np.int32)
|
||||||
|
|
||||||
|
return (
|
||||||
|
ids_remove_padding,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seq_lens_output,
|
||||||
|
cu_seq_lens_q_output,
|
||||||
|
batch_id_per_token_output,
|
||||||
|
real_output_token_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_inputs(
|
||||||
|
real_bsz,
|
||||||
|
max_seq_len,
|
||||||
|
max_draft_tokens,
|
||||||
|
seq_lens_list,
|
||||||
|
seq_lens_encoder_list,
|
||||||
|
draft_tokens_data=None,
|
||||||
|
input_ids_data=None,
|
||||||
|
seed=42,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to build test inputs from explicit seq_lens and seq_lens_encoder lists.
|
||||||
|
draft_tokens_data and input_ids_data are optional; if None, random data is used.
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
seq_lens = np.array(seq_lens_list, dtype=np.int32)
|
||||||
|
seq_lens_encoder = np.array(seq_lens_encoder_list, dtype=np.int32)
|
||||||
|
seq_lens_decoder = np.zeros(real_bsz, dtype=np.int32) # not used in kernel logic
|
||||||
|
|
||||||
|
token_num = int(np.sum(seq_lens))
|
||||||
|
|
||||||
|
if input_ids_data is not None:
|
||||||
|
input_ids = np.array(input_ids_data, dtype=np.int64).reshape(real_bsz, max_seq_len)
|
||||||
|
else:
|
||||||
|
input_ids = rng.integers(1, 1000, size=(real_bsz, max_seq_len), dtype=np.int64)
|
||||||
|
|
||||||
|
if draft_tokens_data is not None:
|
||||||
|
draft_tokens = np.array(draft_tokens_data, dtype=np.int64).reshape(real_bsz, max_draft_tokens)
|
||||||
|
else:
|
||||||
|
draft_tokens = rng.integers(1, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"seq_lens": seq_lens,
|
||||||
|
"draft_tokens": draft_tokens,
|
||||||
|
"seq_lens_encoder": seq_lens_encoder,
|
||||||
|
"seq_lens_decoder": seq_lens_decoder,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
|
"max_draft_tokens": max_draft_tokens,
|
||||||
|
"token_num": token_num,
|
||||||
|
"real_bsz": real_bsz,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_and_compare(tc, inputs):
|
||||||
|
"""
|
||||||
|
Call GPU op and Python reference, compare all outputs.
|
||||||
|
tc: unittest.TestCase instance (for assertion messages).
|
||||||
|
"""
|
||||||
|
real_bsz = inputs["real_bsz"]
|
||||||
|
max_seq_len = inputs["max_seq_len"]
|
||||||
|
max_draft_tokens = inputs["max_draft_tokens"]
|
||||||
|
token_num = inputs["token_num"]
|
||||||
|
|
||||||
|
t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
|
||||||
|
t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32")
|
||||||
|
t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64")
|
||||||
|
t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32")
|
||||||
|
t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32")
|
||||||
|
|
||||||
|
gpu_outs = speculate_pre_process(
|
||||||
|
token_num, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_outs = speculate_pre_process_ref(
|
||||||
|
input_ids=inputs["input_ids"].reshape(-1),
|
||||||
|
seq_lens=inputs["seq_lens"],
|
||||||
|
draft_tokens=inputs["draft_tokens"].reshape(-1),
|
||||||
|
seq_lens_encoder=inputs["seq_lens_encoder"],
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_draft_tokens_per_batch=max_draft_tokens,
|
||||||
|
real_bsz=real_bsz,
|
||||||
|
token_num=token_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_names = [
|
||||||
|
"ids_remove_padding",
|
||||||
|
"batch_id_per_token",
|
||||||
|
"cu_seqlens_q",
|
||||||
|
"cu_seqlens_k",
|
||||||
|
"cu_seq_lens_q_output",
|
||||||
|
"batch_id_per_token_output",
|
||||||
|
"real_output_token_num",
|
||||||
|
]
|
||||||
|
# GPU op returns 7 tensors; ref returns 8 (with seq_lens_output at index 4).
|
||||||
|
# GPU output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k,
|
||||||
|
# cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num
|
||||||
|
# Ref output order: ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k,
|
||||||
|
# seq_lens_output, cu_seq_lens_q_output, batch_id_per_token_output, real_output_token_num
|
||||||
|
ref_indices = [0, 1, 2, 3, 5, 6, 7] # skip seq_lens_output (index 4) for direct comparison
|
||||||
|
for name, gpu_idx, ref_idx in zip(output_names, range(7), ref_indices):
|
||||||
|
gpu_val = gpu_outs[gpu_idx].numpy()
|
||||||
|
ref_val = ref_outs[ref_idx]
|
||||||
|
# Trim batch_id_per_token_output to the valid portion (real_output_token_num)
|
||||||
|
# The kernel only writes valid positions; beyond that the content is undefined.
|
||||||
|
if name == "batch_id_per_token_output":
|
||||||
|
valid_len = int(ref_outs[7][0]) # real_output_token_num
|
||||||
|
gpu_val = gpu_val[:valid_len]
|
||||||
|
ref_val = ref_val[:valid_len]
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
gpu_val,
|
||||||
|
ref_val,
|
||||||
|
err_msg=f"Mismatch in output '{name}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpeculatePreProcess(unittest.TestCase):
|
||||||
|
"""Unit tests for speculate_pre_process custom operator."""
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Test 1: mixed batch covering all 4 seq_lens_output branches
|
||||||
|
# bid=0: seq_lens=0 => output=0 (skip)
|
||||||
|
# bid=1: seq_lens=1, encoder=0 => output=1, read draft_tokens
|
||||||
|
# bid=2: seq_lens=5, encoder=3 => output=1, read input_ids (prefill)
|
||||||
|
# bid=3: seq_lens=4, encoder=0 => output=4, read draft_tokens (decode)
|
||||||
|
# bid=4: seq_lens=1, encoder=2 => output=1, read input_ids (prefill single)
|
||||||
|
# bid=5: seq_lens=8, encoder=0 => output=8, read draft_tokens (decode saturated)
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
def test_mixed_batch_all_branches(self):
|
||||||
|
inputs = build_inputs(
|
||||||
|
real_bsz=6,
|
||||||
|
max_seq_len=16,
|
||||||
|
max_draft_tokens=8,
|
||||||
|
seq_lens_list=[0, 1, 5, 4, 1, 8],
|
||||||
|
seq_lens_encoder_list=[0, 0, 3, 0, 2, 0],
|
||||||
|
)
|
||||||
|
run_and_compare(self, inputs)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Test 2: token_num=0 early return — verify no crash, 7 outputs
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
def test_all_zero_seq_lens(self):
|
||||||
|
real_bsz = 3
|
||||||
|
t_input_ids = paddle.zeros([real_bsz, 8], dtype="int64")
|
||||||
|
t_seq_lens = paddle.zeros([real_bsz], dtype="int32")
|
||||||
|
t_draft_tokens = paddle.zeros([real_bsz, 4], dtype="int64")
|
||||||
|
t_seq_lens_encoder = paddle.zeros([real_bsz], dtype="int32")
|
||||||
|
t_seq_lens_decoder = paddle.zeros([real_bsz], dtype="int32")
|
||||||
|
|
||||||
|
gpu_outs = speculate_pre_process(
|
||||||
|
0, t_input_ids, t_seq_lens, t_draft_tokens, t_seq_lens_encoder, t_seq_lens_decoder
|
||||||
|
)
|
||||||
|
self.assertEqual(len(gpu_outs), 7)
|
||||||
|
self.assertIsNotNone(gpu_outs[-3])
|
||||||
|
self.assertIsNotNone(gpu_outs[-2])
|
||||||
|
self.assertIsNotNone(gpu_outs[-1])
|
||||||
|
# test copy
|
||||||
|
fake_cu_seqlens_q_output = paddle.empty([real_bsz + 1], dtype="int32")
|
||||||
|
fake_batch_id_per_token_output = paddle.empty([real_bsz], dtype="int32")
|
||||||
|
fake_cu_seqlens_q_output.copy_(gpu_outs[-3])
|
||||||
|
fake_batch_id_per_token_output.copy_(gpu_outs[-2])
|
||||||
|
# test slice
|
||||||
|
fake_batch_id_per_token_output[: gpu_outs[-1].item()]
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Test 3: exact token values — manually verify ids_remove_padding
|
||||||
|
# bid=0: encoder=0 (decode) => draft_tokens[0][0:3] = [10,11,12]
|
||||||
|
# bid=1: encoder=5 (prefill) => input_ids[1][0:2] = [200,201]
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
def test_exact_token_values(self):
|
||||||
|
inputs = build_inputs(
|
||||||
|
real_bsz=2,
|
||||||
|
max_seq_len=4,
|
||||||
|
max_draft_tokens=4,
|
||||||
|
seq_lens_list=[3, 2],
|
||||||
|
seq_lens_encoder_list=[0, 5],
|
||||||
|
draft_tokens_data=[[10, 11, 12, 13], [20, 21, 22, 23]],
|
||||||
|
input_ids_data=[[100, 101, 102, 103], [200, 201, 202, 203]],
|
||||||
|
)
|
||||||
|
|
||||||
|
t_input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
|
||||||
|
t_seq_lens = paddle.to_tensor(inputs["seq_lens"], dtype="int32")
|
||||||
|
t_draft_tokens = paddle.to_tensor(inputs["draft_tokens"], dtype="int64")
|
||||||
|
t_seq_lens_encoder = paddle.to_tensor(inputs["seq_lens_encoder"], dtype="int32")
|
||||||
|
t_seq_lens_decoder = paddle.to_tensor(inputs["seq_lens_decoder"], dtype="int32")
|
||||||
|
|
||||||
|
gpu_outs = speculate_pre_process(
|
||||||
|
int(np.sum(inputs["seq_lens"])),
|
||||||
|
t_input_ids,
|
||||||
|
t_seq_lens,
|
||||||
|
t_draft_tokens,
|
||||||
|
t_seq_lens_encoder,
|
||||||
|
t_seq_lens_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(gpu_outs[0].numpy(), [10, 11, 12, 200, 201])
|
||||||
|
np.testing.assert_allclose(gpu_outs[1].numpy(), [0, 0, 0, 1, 1])
|
||||||
|
np.testing.assert_allclose(gpu_outs[2].numpy(), [0, 3, 5])
|
||||||
|
np.testing.assert_allclose(gpu_outs[6].numpy(), [4]) # real_output_token_num = 3+1
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Test 4: random stress test (2 configs covering small & medium batch)
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
def test_random_configs(self):
|
||||||
|
configs = [
|
||||||
|
{"real_bsz": 7, "max_seq_len": 32, "max_draft_tokens": 8, "seed": 200},
|
||||||
|
{"real_bsz": 32, "max_seq_len": 128, "max_draft_tokens": 16, "seed": 400},
|
||||||
|
]
|
||||||
|
for cfg in configs:
|
||||||
|
with self.subTest(**cfg):
|
||||||
|
rng = np.random.default_rng(cfg["seed"])
|
||||||
|
real_bsz = cfg["real_bsz"]
|
||||||
|
max_draft = cfg["max_draft_tokens"]
|
||||||
|
seq_lens_list = rng.integers(0, max_draft + 1, size=real_bsz).tolist()
|
||||||
|
seq_lens_encoder_list = rng.integers(0, 3, size=real_bsz).tolist()
|
||||||
|
|
||||||
|
inputs = build_inputs(
|
||||||
|
real_bsz=real_bsz,
|
||||||
|
max_seq_len=cfg["max_seq_len"],
|
||||||
|
max_draft_tokens=max_draft,
|
||||||
|
seq_lens_list=seq_lens_list,
|
||||||
|
seq_lens_encoder_list=seq_lens_encoder_list,
|
||||||
|
seed=cfg["seed"],
|
||||||
|
)
|
||||||
|
if inputs["token_num"] == 0:
|
||||||
|
continue
|
||||||
|
run_and_compare(self, inputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -0,0 +1,574 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for unified_update_model_status kernel.
|
||||||
|
|
||||||
|
Kernel semantics (from unified_update_model_status.cu):
|
||||||
|
- Launched as <<<1, 1024>>>, one thread per batch slot (max_bsz <= 1024).
|
||||||
|
- real_bsz = seq_lens_this_time.shape[0], max_bsz = stop_flags.shape[0].
|
||||||
|
- has_running_seqs is a CPU tensor (copied to GPU, kernel writes, copied back).
|
||||||
|
- Padding slots (batch_id >= real_bsz): only counted as stopped, NO state modified.
|
||||||
|
- Stopped/paused real slots: set stop_flags=true, seq_lens_decoder=0,
|
||||||
|
seq_lens_this_time=0, step_output_len=0.
|
||||||
|
- Running slots: EOS detection → state update → token_ids_all write → next input setup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.xpu import unified_update_model_status
|
||||||
|
|
||||||
|
CUDA_PLACE = paddle.XPUPlace(0)
|
||||||
|
CPU_PLACE = paddle.CPUPlace()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Layer 1: Helpers — tensor creation / kernel invocation / output extraction
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Convert numpy dict → paddle tensors. has_running_seqs goes to CPU."""
|
||||||
|
paddle_inputs = {}
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if isinstance(v, (int, bool, float, str)):
|
||||||
|
paddle_inputs[k] = v
|
||||||
|
elif k == "has_running_seqs":
|
||||||
|
# Kernel host function: has_running_seqs.copy_to(GPU) → kernel → copy_to(CPU)
|
||||||
|
paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE)
|
||||||
|
elif v is not None:
|
||||||
|
paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE)
|
||||||
|
else:
|
||||||
|
paddle_inputs[k] = None
|
||||||
|
return paddle_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]):
|
||||||
|
"""Call unified_update_model_status kernel."""
|
||||||
|
unified_update_model_status(
|
||||||
|
paddle_inputs["seq_lens_encoder"],
|
||||||
|
paddle_inputs["seq_lens_decoder"],
|
||||||
|
paddle_inputs["has_running_seqs"],
|
||||||
|
paddle_inputs["step_input_ids"],
|
||||||
|
paddle_inputs["adaptive_step_input_len"],
|
||||||
|
paddle_inputs["step_output_ids"],
|
||||||
|
paddle_inputs["step_output_len"],
|
||||||
|
paddle_inputs["stop_flags"],
|
||||||
|
paddle_inputs["seq_lens_this_time"],
|
||||||
|
paddle_inputs["is_paused"],
|
||||||
|
paddle_inputs["mask_rollback"],
|
||||||
|
paddle_inputs["token_ids_all"],
|
||||||
|
paddle_inputs["prompt_lens"],
|
||||||
|
paddle_inputs["step_idx"],
|
||||||
|
paddle_inputs["end_tokens"],
|
||||||
|
paddle_inputs["max_dec_len"],
|
||||||
|
inputs["is_naive_mode"],
|
||||||
|
inputs["prefill_one_step_stop"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# All 12 in-place output keys (from SetInplaceMap in .cu)
|
||||||
|
OUTPUT_KEYS = [
|
||||||
|
"seq_lens_encoder",
|
||||||
|
"seq_lens_decoder",
|
||||||
|
"has_running_seqs",
|
||||||
|
"step_input_ids",
|
||||||
|
"step_output_ids",
|
||||||
|
"step_output_len",
|
||||||
|
"stop_flags",
|
||||||
|
"seq_lens_this_time",
|
||||||
|
"mask_rollback",
|
||||||
|
"token_ids_all",
|
||||||
|
"step_idx",
|
||||||
|
# adaptive_step_input_len is in InplaceMap but kernel never writes it
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_outputs(paddle_inputs: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||||
|
"""Extract ALL in-place-modified tensors back to numpy."""
|
||||||
|
return {k: paddle_inputs[k].numpy() for k in OUTPUT_KEYS}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Layer 2: Input generation
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
def gen_inputs(
|
||||||
|
real_bsz: int = 8,
|
||||||
|
max_step_tokens: int = 16,
|
||||||
|
max_model_len: int = 256,
|
||||||
|
seed: int = 42,
|
||||||
|
is_naive_mode: bool = False,
|
||||||
|
prefill_one_step_stop: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Generate randomized test inputs for unified_update_model_status kernel.
|
||||||
|
|
||||||
|
Shapes follow the kernel contract:
|
||||||
|
- real_bsz = seq_lens_this_time.shape[0]
|
||||||
|
- max_bsz = stop_flags.shape[0] (= real_bsz + padding)
|
||||||
|
- is_paused.shape[0] = max_bsz
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
max_bsz = real_bsz + 4 # padding slots
|
||||||
|
|
||||||
|
# Per-slot arrays (size=max_bsz)
|
||||||
|
seq_lens_encoder = rng.integers(0, 5, size=max_bsz, dtype=np.int32)
|
||||||
|
seq_lens_decoder = rng.integers(10, 100, size=max_bsz, dtype=np.int32)
|
||||||
|
step_input_ids = rng.integers(0, 1000, size=(max_bsz, max_step_tokens), dtype=np.int64)
|
||||||
|
adaptive_step_input_len = rng.integers(1, max_step_tokens + 1, size=max_bsz, dtype=np.int32)
|
||||||
|
step_output_ids = rng.integers(0, 1000, size=(max_bsz, max_step_tokens), dtype=np.int64)
|
||||||
|
step_output_len = rng.integers(1, max_step_tokens + 1, size=max_bsz, dtype=np.int32)
|
||||||
|
stop_flags = np.zeros(max_bsz, dtype=bool)
|
||||||
|
# Randomly stop a few real slots
|
||||||
|
stop_flags[rng.choice(real_bsz, size=min(2, real_bsz), replace=False)] = True
|
||||||
|
# Padding slots (batch_id >= real_bsz) must be stopped — kernel accesses
|
||||||
|
# seq_lens_this_time[batch_id] which is only sized real_bsz
|
||||||
|
stop_flags[real_bsz:] = True
|
||||||
|
is_paused = np.zeros(max_bsz, dtype=bool)
|
||||||
|
mask_rollback = np.zeros(max_bsz, dtype=np.int32)
|
||||||
|
prompt_lens = rng.integers(10, 50, size=max_bsz, dtype=np.int64)
|
||||||
|
token_ids_all = rng.integers(0, 1000, size=(max_bsz, max_model_len), dtype=np.int64)
|
||||||
|
step_idx = rng.integers(0, 50, size=max_bsz, dtype=np.int64)
|
||||||
|
max_dec_len = rng.integers(100, 200, size=max_bsz, dtype=np.int64)
|
||||||
|
|
||||||
|
# Per-real-batch arrays (size=real_bsz)
|
||||||
|
seq_lens_this_time = rng.integers(1, max_step_tokens + 1, size=real_bsz, dtype=np.int32)
|
||||||
|
|
||||||
|
# Scalar / small tensors
|
||||||
|
has_running_seqs = np.array([True], dtype=bool)
|
||||||
|
end_tokens = rng.integers(1, 1000, size=4, dtype=np.int64)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"seq_lens_encoder": seq_lens_encoder,
|
||||||
|
"seq_lens_decoder": seq_lens_decoder,
|
||||||
|
"has_running_seqs": has_running_seqs,
|
||||||
|
"step_input_ids": step_input_ids,
|
||||||
|
"adaptive_step_input_len": adaptive_step_input_len,
|
||||||
|
"step_output_ids": step_output_ids,
|
||||||
|
"step_output_len": step_output_len,
|
||||||
|
"stop_flags": stop_flags,
|
||||||
|
"seq_lens_this_time": seq_lens_this_time,
|
||||||
|
"is_paused": is_paused,
|
||||||
|
"mask_rollback": mask_rollback,
|
||||||
|
"token_ids_all": token_ids_all,
|
||||||
|
"prompt_lens": prompt_lens,
|
||||||
|
"step_idx": step_idx,
|
||||||
|
"end_tokens": end_tokens,
|
||||||
|
"max_dec_len": max_dec_len,
|
||||||
|
# Scalar configs
|
||||||
|
"real_bsz": real_bsz,
|
||||||
|
"max_bsz": max_bsz,
|
||||||
|
"max_step_tokens": max_step_tokens,
|
||||||
|
"max_model_len": max_model_len,
|
||||||
|
"is_naive_mode": is_naive_mode,
|
||||||
|
"prefill_one_step_stop": prefill_one_step_stop,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Layer 3: Reference implementation (1:1 with CUDA kernel)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Python reference of unified_update_model_status_kernel.
|
||||||
|
|
||||||
|
Line references are to unified_update_model_status.cu.
|
||||||
|
"""
|
||||||
|
# Deep-copy all mutable in-place tensors
|
||||||
|
seq_lens_encoder = inputs["seq_lens_encoder"].copy()
|
||||||
|
seq_lens_decoder = inputs["seq_lens_decoder"].copy()
|
||||||
|
step_output_len = inputs["step_output_len"].copy()
|
||||||
|
stop_flags = inputs["stop_flags"].copy()
|
||||||
|
seq_lens_this_time = inputs["seq_lens_this_time"].copy()
|
||||||
|
mask_rollback = inputs["mask_rollback"].copy()
|
||||||
|
token_ids_all = inputs["token_ids_all"].copy()
|
||||||
|
step_idx = inputs["step_idx"].copy()
|
||||||
|
step_input_ids = inputs["step_input_ids"].copy()
|
||||||
|
step_output_ids = inputs["step_output_ids"].copy()
|
||||||
|
|
||||||
|
# Read-only inputs
|
||||||
|
real_bsz = inputs["real_bsz"]
|
||||||
|
max_bsz = inputs["max_bsz"]
|
||||||
|
max_model_len = inputs["max_model_len"]
|
||||||
|
is_naive_mode = inputs["is_naive_mode"]
|
||||||
|
prefill_one_step_stop = inputs["prefill_one_step_stop"]
|
||||||
|
end_tokens = inputs["end_tokens"]
|
||||||
|
num_end_tokens = len(end_tokens)
|
||||||
|
max_dec_len = inputs["max_dec_len"]
|
||||||
|
prompt_lens = inputs["prompt_lens"]
|
||||||
|
is_paused = inputs["is_paused"]
|
||||||
|
|
||||||
|
# Block-level stop count for has_running_seqs reduction (line 175)
|
||||||
|
stop_count = 0
|
||||||
|
|
||||||
|
for batch_id in range(max_bsz):
|
||||||
|
# --- line 68-75: Read state ---
|
||||||
|
cur_seq_len_encoder = int(seq_lens_encoder[batch_id])
|
||||||
|
cur_seq_len_decoder = int(seq_lens_decoder[batch_id])
|
||||||
|
cur_stop_flag = bool(stop_flags[batch_id])
|
||||||
|
output_len = 0
|
||||||
|
cur_step_idx = int(step_idx[batch_id])
|
||||||
|
cur_is_paused = bool(is_paused[batch_id])
|
||||||
|
|
||||||
|
# line 77
|
||||||
|
is_running = not cur_stop_flag and not cur_is_paused
|
||||||
|
|
||||||
|
# --- line 80-86: Compute output length ---
|
||||||
|
if is_running:
|
||||||
|
output_len = 1 if is_naive_mode else int(step_output_len[batch_id])
|
||||||
|
|
||||||
|
# --- line 89-110: EOS detection ---
|
||||||
|
if is_running and output_len > 0:
|
||||||
|
hit_stop = False
|
||||||
|
for i in range(output_len):
|
||||||
|
cur_step_idx += 1 # line 94
|
||||||
|
token = int(step_output_ids[batch_id, i]) # line 95
|
||||||
|
is_eos = any(token == end_tokens[j] for j in range(num_end_tokens)) # line 96
|
||||||
|
max_len_hit = cur_step_idx >= int(max_dec_len[batch_id]) # line 97
|
||||||
|
|
||||||
|
if is_eos or max_len_hit: # line 99
|
||||||
|
if not is_eos:
|
||||||
|
step_output_ids[batch_id, i] = end_tokens[0] # line 100
|
||||||
|
output_len = i + 1 # line 101
|
||||||
|
cur_stop_flag = True # line 102
|
||||||
|
hit_stop = True # line 103
|
||||||
|
break # line 104
|
||||||
|
|
||||||
|
# line 108-110
|
||||||
|
if not hit_stop and prefill_one_step_stop and cur_seq_len_encoder > 0:
|
||||||
|
cur_stop_flag = True
|
||||||
|
|
||||||
|
# --- line 114-166: Update state and write back ---
|
||||||
|
if is_running:
|
||||||
|
if cur_stop_flag:
|
||||||
|
# line 115-119
|
||||||
|
stop_count += 1
|
||||||
|
if output_len == 0:
|
||||||
|
cur_seq_len_decoder = 0 # line 117
|
||||||
|
stop_flags[batch_id] = True # line 118
|
||||||
|
mask_rollback[batch_id] = 0 # line 119
|
||||||
|
elif cur_seq_len_encoder == 0:
|
||||||
|
# line 120-122
|
||||||
|
cur_seq_len_decoder += output_len # line 121
|
||||||
|
mask_rollback[batch_id] = int(seq_lens_this_time[batch_id]) - output_len # line 122
|
||||||
|
else:
|
||||||
|
# line 123-124 (encoder > 0, not stopped)
|
||||||
|
mask_rollback[batch_id] = 0
|
||||||
|
|
||||||
|
# line 127-130: Fold encoder into decoder
|
||||||
|
if cur_seq_len_encoder > 0:
|
||||||
|
cur_seq_len_decoder += cur_seq_len_encoder # line 128
|
||||||
|
cur_seq_len_encoder = 0 # line 129
|
||||||
|
|
||||||
|
# line 132-135: Write back scalar state
|
||||||
|
seq_lens_encoder[batch_id] = cur_seq_len_encoder
|
||||||
|
seq_lens_decoder[batch_id] = cur_seq_len_decoder
|
||||||
|
step_output_len[batch_id] = output_len
|
||||||
|
step_idx[batch_id] = cur_step_idx
|
||||||
|
|
||||||
|
# line 138-145: Write history to token_ids_all
|
||||||
|
if cur_step_idx > 0 and output_len > 0:
|
||||||
|
base = int(prompt_lens[batch_id])
|
||||||
|
for i in range(output_len):
|
||||||
|
# token_ids_all_now[cur_step_idx - i] = output_ids[output_len - 1 - i]
|
||||||
|
write_idx = base + cur_step_idx - i
|
||||||
|
if 0 <= write_idx < max_model_len:
|
||||||
|
token_ids_all[batch_id, write_idx] = step_output_ids[batch_id, output_len - 1 - i]
|
||||||
|
|
||||||
|
# line 148-151: Setup next step_input_ids
|
||||||
|
if output_len > 0:
|
||||||
|
step_input_ids[batch_id, 0] = step_output_ids[batch_id, output_len - 1]
|
||||||
|
|
||||||
|
# line 153-155: naive_mode → seq_lens_this_time
|
||||||
|
if is_naive_mode:
|
||||||
|
seq_lens_this_time[batch_id] = 0 if cur_stop_flag else 1
|
||||||
|
|
||||||
|
elif batch_id >= real_bsz:
|
||||||
|
# line 156-158: Padding slot — only count, don't modify state
|
||||||
|
stop_count += 1
|
||||||
|
else:
|
||||||
|
# line 159-166: Stopped or paused real slot
|
||||||
|
stop_count += 1
|
||||||
|
stop_flags[batch_id] = True # line 162
|
||||||
|
seq_lens_decoder[batch_id] = 0 # line 163
|
||||||
|
seq_lens_this_time[batch_id] = 0 # line 164
|
||||||
|
step_output_len[batch_id] = 0 # line 165
|
||||||
|
|
||||||
|
# line 177-179: has_running_seqs = stop_sum < max_bsz
|
||||||
|
has_running_seqs = np.array([stop_count < max_bsz], dtype=bool)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"seq_lens_encoder": seq_lens_encoder,
|
||||||
|
"seq_lens_decoder": seq_lens_decoder,
|
||||||
|
"has_running_seqs": has_running_seqs,
|
||||||
|
"step_input_ids": step_input_ids,
|
||||||
|
"step_output_ids": step_output_ids,
|
||||||
|
"step_output_len": step_output_len,
|
||||||
|
"stop_flags": stop_flags,
|
||||||
|
"seq_lens_this_time": seq_lens_this_time,
|
||||||
|
"mask_rollback": mask_rollback,
|
||||||
|
"token_ids_all": token_ids_all,
|
||||||
|
"step_idx": step_idx,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Layer 4a: TEST_CONFIGS
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
TEST_CONFIGS = [
|
||||||
|
# --- basic mode coverage ---
|
||||||
|
{
|
||||||
|
"name": "mtp_mode",
|
||||||
|
"real_bsz": 8,
|
||||||
|
"max_step_tokens": 16,
|
||||||
|
"max_model_len": 256,
|
||||||
|
"seed": 42,
|
||||||
|
"is_naive_mode": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "naive_mode",
|
||||||
|
"real_bsz": 8,
|
||||||
|
"max_step_tokens": 16,
|
||||||
|
"max_model_len": 256,
|
||||||
|
"seed": 42,
|
||||||
|
"is_naive_mode": True,
|
||||||
|
},
|
||||||
|
# --- batch size ---
|
||||||
|
{
|
||||||
|
"name": "small_batch",
|
||||||
|
"real_bsz": 1,
|
||||||
|
"max_step_tokens": 8,
|
||||||
|
"max_model_len": 128,
|
||||||
|
"seed": 42,
|
||||||
|
"is_naive_mode": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "large_batch",
|
||||||
|
"real_bsz": 32,
|
||||||
|
"max_step_tokens": 16,
|
||||||
|
"max_model_len": 512,
|
||||||
|
"seed": 42,
|
||||||
|
"is_naive_mode": False,
|
||||||
|
},
|
||||||
|
# --- prefill_one_step_stop ---
|
||||||
|
{
|
||||||
|
"name": "prefill_one_step_stop",
|
||||||
|
"real_bsz": 8,
|
||||||
|
"max_step_tokens": 8,
|
||||||
|
"max_model_len": 128,
|
||||||
|
"seed": 42,
|
||||||
|
"is_naive_mode": False,
|
||||||
|
"prefill_one_step_stop": True,
|
||||||
|
},
|
||||||
|
# --- different seeds for randomized coverage ---
|
||||||
|
{
|
||||||
|
"name": "seed_100",
|
||||||
|
"real_bsz": 8,
|
||||||
|
"max_step_tokens": 16,
|
||||||
|
"max_model_len": 256,
|
||||||
|
"seed": 100,
|
||||||
|
"is_naive_mode": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "seed_200_naive",
|
||||||
|
"real_bsz": 8,
|
||||||
|
"max_step_tokens": 16,
|
||||||
|
"max_model_len": 256,
|
||||||
|
"seed": 200,
|
||||||
|
"is_naive_mode": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Layer 4b: Test suite
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnifiedUpdateModelStatus(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
if not paddle.is_compiled_with_xpu():
|
||||||
|
self.skipTest("Requires XPU")
|
||||||
|
|
||||||
|
# ------ shared helpers ------
|
||||||
|
|
||||||
|
def _run_and_get(self, inputs: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||||
|
paddle_inputs = to_paddle_inputs(inputs)
|
||||||
|
run_kernel(paddle_inputs, inputs)
|
||||||
|
return get_outputs(paddle_inputs)
|
||||||
|
|
||||||
|
def _check_all_outputs(self, inputs: Dict[str, Any], outputs: Dict[str, np.ndarray]):
|
||||||
|
"""Compare ALL output tensors against reference + sanity checks."""
|
||||||
|
ref = reference_impl(inputs)
|
||||||
|
for key in OUTPUT_KEYS:
|
||||||
|
if not np.array_equal(outputs[key], ref[key]):
|
||||||
|
diff_mask = outputs[key] != ref[key]
|
||||||
|
diff_indices = np.argwhere(diff_mask)
|
||||||
|
for idx in diff_indices[:10]: # print first 10 mismatches
|
||||||
|
idx_tuple = tuple(idx)
|
||||||
|
print(
|
||||||
|
f" [{key}] mismatch at {idx_tuple}: "
|
||||||
|
f"gpu={outputs[key][idx_tuple]} ref={ref[key][idx_tuple]}"
|
||||||
|
)
|
||||||
|
if key == "token_ids_all":
|
||||||
|
bid = idx_tuple[0]
|
||||||
|
print(
|
||||||
|
f" batch_id={bid}, prompt_lens={inputs['prompt_lens'][bid]}, "
|
||||||
|
f"step_idx(input)={inputs['step_idx'][bid]}, "
|
||||||
|
f"step_idx(gpu)={outputs['step_idx'][bid]}, "
|
||||||
|
f"step_idx(ref)={ref['step_idx'][bid]}, "
|
||||||
|
f"step_output_len(gpu)={outputs['step_output_len'][bid]}, "
|
||||||
|
f"step_output_len(ref)={ref['step_output_len'][bid]}, "
|
||||||
|
f"stop_flags(input)={inputs['stop_flags'][bid]}, "
|
||||||
|
f"is_paused={inputs['is_paused'][bid]}, "
|
||||||
|
f"seq_lens_encoder={inputs['seq_lens_encoder'][bid]}"
|
||||||
|
)
|
||||||
|
np.testing.assert_array_equal(outputs[key], ref[key], err_msg=f"{key} mismatch")
|
||||||
|
|
||||||
|
# Sanity: running slots must have encoder zeroed
|
||||||
|
for i in range(inputs["real_bsz"]):
|
||||||
|
if not inputs["stop_flags"][i] and not inputs["is_paused"][i]:
|
||||||
|
self.assertEqual(outputs["seq_lens_encoder"][i], 0, f"Running slot {i} should have encoder=0")
|
||||||
|
self.assertTrue(np.all(outputs["seq_lens_decoder"] >= 0), "negative seq_lens_decoder")
|
||||||
|
self.assertTrue(np.all(outputs["step_output_len"] >= 0), "negative step_output_len")
|
||||||
|
self.assertTrue(np.all(outputs["step_idx"] >= 0), "negative step_idx")
|
||||||
|
|
||||||
|
def _run_full_test(self, config: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||||
|
inputs = gen_inputs(**config)
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# ------ test cases ------
|
||||||
|
|
||||||
|
def test_configs(self):
|
||||||
|
"""Run all TEST_CONFIGS via subTest."""
|
||||||
|
for cfg in TEST_CONFIGS:
|
||||||
|
with self.subTest(name=cfg["name"]):
|
||||||
|
test_cfg = {k: v for k, v in cfg.items() if k != "name"}
|
||||||
|
self._run_full_test(test_cfg)
|
||||||
|
|
||||||
|
def test_eos_detection(self):
|
||||||
|
"""EOS token at position 2 should truncate output_len to 3."""
|
||||||
|
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42)
|
||||||
|
eos_token = int(inputs["end_tokens"][0])
|
||||||
|
inputs["step_output_ids"][0, 2] = eos_token
|
||||||
|
inputs["step_output_len"][:] = [5, 3, 0, 0, 0, 0]
|
||||||
|
inputs["stop_flags"][: inputs["real_bsz"]] = False
|
||||||
|
inputs["is_paused"][:] = False
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
def test_max_dec_len_stop(self):
|
||||||
|
"""step_idx near max_dec_len should trigger stop and replace with end_tokens[0]."""
|
||||||
|
# Use large max_model_len to avoid token_ids_all overflow:
|
||||||
|
# kernel doesn't bounds-check prompt_lens + step_idx < max_model_len
|
||||||
|
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=512, seed=42)
|
||||||
|
inputs["step_idx"][:] = [95, 50, 0, 0, 0, 0]
|
||||||
|
inputs["max_dec_len"][:] = 100
|
||||||
|
inputs["step_output_len"][:] = [10, 5, 0, 0, 0, 0]
|
||||||
|
inputs["stop_flags"][: inputs["real_bsz"]] = False
|
||||||
|
inputs["is_paused"][:] = False
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
def test_paused_slots(self):
|
||||||
|
"""Paused slots should be treated as stopped/paused (decoder=0, output_len=0)."""
|
||||||
|
inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42)
|
||||||
|
inputs["is_paused"][:] = [True, True, False, False, False, False, False, False]
|
||||||
|
inputs["stop_flags"][: inputs["real_bsz"]] = False
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
def test_all_stopped(self):
|
||||||
|
"""All slots stopped → has_running_seqs should be False."""
|
||||||
|
inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42)
|
||||||
|
inputs["stop_flags"][:] = True
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
def test_encoder_to_decoder(self):
|
||||||
|
"""Encoder length should fold into decoder: decoder += encoder, encoder → 0."""
|
||||||
|
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42)
|
||||||
|
inputs["seq_lens_encoder"][:] = [10, 0, 0, 0, 0, 0]
|
||||||
|
inputs["seq_lens_decoder"][:] = [20, 30, 0, 0, 0, 0]
|
||||||
|
inputs["step_output_len"][:] = [5, 3, 0, 0, 0, 0]
|
||||||
|
inputs["stop_flags"][: inputs["real_bsz"]] = False
|
||||||
|
inputs["is_paused"][:] = False
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
def test_token_ids_all_writing(self):
|
||||||
|
"""token_ids_all should be written at prompt_lens + step_idx positions."""
|
||||||
|
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42)
|
||||||
|
inputs["step_idx"][:] = [10, 20, 0, 0, 0, 0]
|
||||||
|
inputs["prompt_lens"][:] = [5, 5, 0, 0, 0, 0]
|
||||||
|
inputs["step_output_len"][:] = [3, 2, 0, 0, 0, 0]
|
||||||
|
inputs["stop_flags"][: inputs["real_bsz"]] = False
|
||||||
|
inputs["is_paused"][:] = False
|
||||||
|
inputs["seq_lens_encoder"][:] = 0
|
||||||
|
# Use end_tokens that won't collide with output_ids
|
||||||
|
inputs["end_tokens"][:] = [9990, 9991, 9992, 9993]
|
||||||
|
inputs["max_dec_len"][:] = 10000
|
||||||
|
inputs["step_output_ids"][0, :3] = [100, 200, 300]
|
||||||
|
inputs["step_output_ids"][1, :2] = [400, 500]
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
def test_zero_output_len(self):
|
||||||
|
"""Running slot with output_len=0 in MTP mode: output_len stays 0."""
|
||||||
|
inputs = gen_inputs(real_bsz=2, max_step_tokens=8, max_model_len=128, seed=42)
|
||||||
|
inputs["step_output_len"][:] = [0, 5, 0, 0, 0, 0]
|
||||||
|
inputs["stop_flags"][: inputs["real_bsz"]] = False
|
||||||
|
inputs["is_paused"][:] = False
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
def test_prefill_one_step_stop_with_encoder(self):
|
||||||
|
"""prefill_one_step_stop + encoder>0 should stop even without EOS."""
|
||||||
|
inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42, prefill_one_step_stop=True)
|
||||||
|
inputs["seq_lens_encoder"][:] = [5, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
inputs["stop_flags"][: inputs["real_bsz"]] = False
|
||||||
|
inputs["is_paused"][:] = False
|
||||||
|
# Ensure no accidental EOS hit
|
||||||
|
inputs["end_tokens"][:] = [9990, 9991, 9992, 9993]
|
||||||
|
inputs["max_dec_len"][:] = 10000
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
def test_mask_rollback(self):
|
||||||
|
"""mask_rollback = seq_lens_this_time - output_len for running decode slots."""
|
||||||
|
inputs = gen_inputs(real_bsz=4, max_step_tokens=8, max_model_len=128, seed=42)
|
||||||
|
inputs["stop_flags"][: inputs["real_bsz"]] = False
|
||||||
|
inputs["is_paused"][:] = False
|
||||||
|
inputs["seq_lens_encoder"][:] = 0 # All decode slots
|
||||||
|
inputs["seq_lens_this_time"][:] = [6, 4, 8, 3]
|
||||||
|
inputs["step_output_len"][:] = [3, 2, 5, 1, 0, 0, 0, 0]
|
||||||
|
# Avoid EOS/max_dec_len hits
|
||||||
|
inputs["end_tokens"][:] = [9990, 9991, 9992, 9993]
|
||||||
|
inputs["max_dec_len"][:] = 10000
|
||||||
|
outputs = self._run_and_get(inputs)
|
||||||
|
self._check_all_outputs(inputs, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -273,6 +273,8 @@ class XPUForwardMeta(ForwardMeta):
|
|||||||
hidden_states: Optional[paddle.Tensor] = None
|
hidden_states: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
is_draft: bool = False
|
is_draft: bool = False
|
||||||
|
# max bs
|
||||||
|
max_num_seqs: int = 0
|
||||||
|
|
||||||
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
|
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -196,7 +196,6 @@ class XPUAttentionBackend(AttentionBackend):
|
|||||||
qkv,
|
qkv,
|
||||||
forward_meta.caches[2 * layer.layer_id],
|
forward_meta.caches[2 * layer.layer_id],
|
||||||
forward_meta.caches[2 * layer.layer_id + 1],
|
forward_meta.caches[2 * layer.layer_id + 1],
|
||||||
forward_meta.cum_offsets,
|
|
||||||
metadata.rotary_embs,
|
metadata.rotary_embs,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
forward_meta.prefix_block_tables,
|
forward_meta.prefix_block_tables,
|
||||||
|
|||||||
@@ -1069,8 +1069,8 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
sampling_metadata.min_dec_lens,
|
sampling_metadata.min_dec_lens,
|
||||||
sampling_metadata.eos_token_ids,
|
sampling_metadata.eos_token_ids,
|
||||||
share_inputs["seq_lens_this_time"],
|
share_inputs["seq_lens_this_time"],
|
||||||
share_inputs["output_padding_offset"],
|
share_inputs["batch_id_per_token_output"],
|
||||||
share_inputs["output_cum_offsets"],
|
share_inputs["cu_seqlens_q_output"],
|
||||||
max_model_len,
|
max_model_len,
|
||||||
sampling_metadata.pre_token_ids,
|
sampling_metadata.pre_token_ids,
|
||||||
)
|
)
|
||||||
@@ -1091,7 +1091,7 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||||
probs,
|
probs,
|
||||||
sampling_metadata.top_p,
|
sampling_metadata.top_p,
|
||||||
share_inputs["output_padding_offset"],
|
share_inputs["batch_id_per_token_output"],
|
||||||
self.speculative_max_candidate_len,
|
self.speculative_max_candidate_len,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
)
|
)
|
||||||
@@ -1113,7 +1113,7 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
share_inputs["max_dec_len"],
|
share_inputs["max_dec_len"],
|
||||||
sampling_metadata.eos_token_ids,
|
sampling_metadata.eos_token_ids,
|
||||||
share_inputs["is_block_step"],
|
share_inputs["is_block_step"],
|
||||||
share_inputs["output_cum_offsets"],
|
share_inputs["cu_seqlens_q_output"],
|
||||||
actual_candidate_len,
|
actual_candidate_len,
|
||||||
share_inputs["actual_draft_token_num"],
|
share_inputs["actual_draft_token_num"],
|
||||||
sampling_metadata.top_p,
|
sampling_metadata.top_p,
|
||||||
@@ -1338,8 +1338,8 @@ class MTPSampler(nn.Layer):
|
|||||||
sampling_metadata.min_dec_lens,
|
sampling_metadata.min_dec_lens,
|
||||||
sampling_metadata.eos_token_ids,
|
sampling_metadata.eos_token_ids,
|
||||||
share_inputs["seq_lens_this_time"],
|
share_inputs["seq_lens_this_time"],
|
||||||
share_inputs["output_padding_offset"],
|
share_inputs["batch_id_per_token_output"],
|
||||||
share_inputs["output_cum_offsets"],
|
share_inputs["cu_seqlens_q_output"],
|
||||||
max_model_len,
|
max_model_len,
|
||||||
sampling_metadata.pre_token_ids,
|
sampling_metadata.pre_token_ids,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -40,9 +40,7 @@ if current_platform.is_xpu():
|
|||||||
save_output_topk,
|
save_output_topk,
|
||||||
set_stop_value_multi_ends,
|
set_stop_value_multi_ends,
|
||||||
speculate_clear_accept_nums,
|
speculate_clear_accept_nums,
|
||||||
speculate_get_output_padding_offset,
|
speculate_pre_process,
|
||||||
speculate_get_padding_offset,
|
|
||||||
speculate_get_seq_lens_output,
|
|
||||||
speculate_save_output,
|
speculate_save_output,
|
||||||
speculate_set_stop_value_multi_seqs,
|
speculate_set_stop_value_multi_seqs,
|
||||||
speculate_set_value_by_flags_and_idx,
|
speculate_set_value_by_flags_and_idx,
|
||||||
@@ -109,51 +107,32 @@ def xpu_pre_process(
|
|||||||
) -> XPUForwardMeta:
|
) -> XPUForwardMeta:
|
||||||
""" """
|
""" """
|
||||||
max_len = input_ids.shape[1]
|
max_len = input_ids.shape[1]
|
||||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
|
||||||
token_num = paddle.sum(seq_lens_this_time)
|
|
||||||
|
|
||||||
|
token_num_cpu = paddle.sum(seq_lens_this_time).cpu()
|
||||||
if use_speculate_method:
|
if use_speculate_method:
|
||||||
(
|
(
|
||||||
ids_remove_padding,
|
ids_remove_padding,
|
||||||
cum_offsets,
|
|
||||||
batch_id_per_token,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
) = speculate_get_padding_offset(
|
cu_seqlens_q_output,
|
||||||
input_ids,
|
batch_id_per_token_output,
|
||||||
draft_tokens,
|
real_output_token_num,
|
||||||
cum_offsets_now,
|
) = speculate_pre_process(
|
||||||
token_num,
|
token_num_cpu, input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, seq_lens_decoder
|
||||||
seq_lens_this_time,
|
|
||||||
seq_lens_encoder,
|
|
||||||
)
|
)
|
||||||
seq_lens_output = speculate_get_seq_lens_output(
|
share_inputs["cu_seqlens_q_output"] = cu_seqlens_q_output
|
||||||
seq_lens_this_time,
|
share_inputs["batch_id_per_token_output"] = batch_id_per_token_output
|
||||||
seq_lens_encoder,
|
|
||||||
seq_lens_decoder,
|
|
||||||
)
|
|
||||||
if isinstance(seq_lens_output, list):
|
|
||||||
seq_lens_output = seq_lens_output[0]
|
|
||||||
output_token_num = paddle.sum(seq_lens_output)
|
|
||||||
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32")
|
|
||||||
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
|
|
||||||
output_cum_offsets_tmp,
|
|
||||||
output_token_num,
|
|
||||||
seq_lens_output,
|
|
||||||
max_len,
|
|
||||||
)
|
|
||||||
share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
|
||||||
share_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
|
||||||
else:
|
else:
|
||||||
|
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
||||||
(
|
(
|
||||||
ids_remove_padding,
|
ids_remove_padding,
|
||||||
cum_offsets,
|
cum_offsets,
|
||||||
batch_id_per_token,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
|
) = get_padding_offset(input_ids, cum_offsets_now, token_num_cpu, seq_lens_this_time)
|
||||||
|
|
||||||
share_inputs["cum_offsets"] = cum_offsets
|
|
||||||
share_inputs["batch_id_per_token"] = batch_id_per_token
|
share_inputs["batch_id_per_token"] = batch_id_per_token
|
||||||
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
||||||
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
||||||
@@ -165,12 +144,12 @@ def xpu_pre_process(
|
|||||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||||
cum_offsets=share_inputs["cum_offsets"],
|
|
||||||
batch_id_per_token=share_inputs["batch_id_per_token"],
|
batch_id_per_token=share_inputs["batch_id_per_token"],
|
||||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||||
block_tables=share_inputs["block_tables"],
|
block_tables=share_inputs["block_tables"],
|
||||||
caches=share_inputs["caches"],
|
caches=share_inputs["caches"],
|
||||||
|
max_num_seqs=share_inputs["seq_lens_this_time"].shape[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -205,7 +184,6 @@ def xpu_pre_process(
|
|||||||
|
|
||||||
adjusted_input = adjust_batch(
|
adjusted_input = adjust_batch(
|
||||||
ids_remove_padding.reshape([-1, 1]),
|
ids_remove_padding.reshape([-1, 1]),
|
||||||
cum_offsets,
|
|
||||||
xpu_forward_meta.encoder_seq_lod,
|
xpu_forward_meta.encoder_seq_lod,
|
||||||
xpu_forward_meta.decoder_seq_lod,
|
xpu_forward_meta.decoder_seq_lod,
|
||||||
xpu_forward_meta.encoder_batch_idx,
|
xpu_forward_meta.encoder_batch_idx,
|
||||||
@@ -237,7 +215,6 @@ def xpu_pre_process(
|
|||||||
|
|
||||||
def xpu_process_output(
|
def xpu_process_output(
|
||||||
forward_output,
|
forward_output,
|
||||||
cum_offsets: paddle.Tensor,
|
|
||||||
xpu_forward_meta: XPUForwardMeta,
|
xpu_forward_meta: XPUForwardMeta,
|
||||||
share_inputs,
|
share_inputs,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
@@ -250,7 +227,6 @@ def xpu_process_output(
|
|||||||
|
|
||||||
hiddden_states = gather_next_token(
|
hiddden_states = gather_next_token(
|
||||||
forward_output,
|
forward_output,
|
||||||
cum_offsets,
|
|
||||||
xpu_forward_meta.encoder_seq_lod,
|
xpu_forward_meta.encoder_seq_lod,
|
||||||
xpu_forward_meta.decoder_seq_lod,
|
xpu_forward_meta.decoder_seq_lod,
|
||||||
xpu_forward_meta.encoder_batch_map,
|
xpu_forward_meta.encoder_batch_map,
|
||||||
@@ -261,7 +237,7 @@ def xpu_process_output(
|
|||||||
xpu_forward_meta.decoder_batch_map_cpu,
|
xpu_forward_meta.decoder_batch_map_cpu,
|
||||||
xpu_forward_meta.len_info_cpu,
|
xpu_forward_meta.len_info_cpu,
|
||||||
output_padding_offset, # output_padding_offset
|
output_padding_offset, # output_padding_offset
|
||||||
-1, # max_input_length
|
xpu_forward_meta.max_num_seqs,
|
||||||
)
|
)
|
||||||
return hiddden_states
|
return hiddden_states
|
||||||
|
|
||||||
|
|||||||
@@ -820,11 +820,7 @@ class MTPProposer(Proposer):
|
|||||||
# Note(ZKK):
|
# Note(ZKK):
|
||||||
# I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend
|
# I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend
|
||||||
# like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358
|
# like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358
|
||||||
(
|
self.model_inputs["cu_seqlens_q_output"],
|
||||||
self.model_inputs["cu_seqlens_q_output"]
|
|
||||||
if current_platform.is_cuda()
|
|
||||||
else self.model_inputs["output_cum_offsets"]
|
|
||||||
),
|
|
||||||
self.model_inputs["stop_flags"],
|
self.model_inputs["stop_flags"],
|
||||||
(
|
(
|
||||||
self.model_inputs["not_need_stop_device"]
|
self.model_inputs["not_need_stop_device"]
|
||||||
@@ -1125,9 +1121,7 @@ class MTPProposer(Proposer):
|
|||||||
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
||||||
forward_meta=self.forward_meta,
|
forward_meta=self.forward_meta,
|
||||||
)
|
)
|
||||||
hidden_states = xpu_process_output(
|
hidden_states = xpu_process_output(model_output, self.forward_meta, self.model_inputs)
|
||||||
model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs
|
|
||||||
)
|
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
|
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
|
||||||
sampled_token_ids, sampler_output = self.sampler(
|
sampled_token_ids, sampler_output = self.sampler(
|
||||||
|
|||||||
@@ -1114,6 +1114,8 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
self.cache_config.block_size,
|
self.cache_config.block_size,
|
||||||
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
|
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(chenhuan): support cached_token_num
|
||||||
self.forward_meta = xpu_pre_process(
|
self.forward_meta = xpu_pre_process(
|
||||||
self.share_inputs["input_ids"],
|
self.share_inputs["input_ids"],
|
||||||
self.share_inputs["seq_lens_this_time"],
|
self.share_inputs["seq_lens_this_time"],
|
||||||
@@ -1577,9 +1579,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
model_output = model_output[: self.real_token_num]
|
model_output = model_output[: self.real_token_num]
|
||||||
hidden_states = xpu_process_output(
|
hidden_states = xpu_process_output(model_output, self.forward_meta, self.share_inputs)
|
||||||
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
|
|
||||||
)
|
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
logits = self.model.compute_logits(hidden_states)
|
logits = self.model.compute_logits(hidden_states)
|
||||||
sampler_output = None
|
sampler_output = None
|
||||||
|
|||||||
@@ -346,6 +346,8 @@ def test_mtp_sampler_xpu_and_compute(mock_ops, monkeypatch):
|
|||||||
"batch_token_num": paddle.to_tensor([[1]], dtype="int64"),
|
"batch_token_num": paddle.to_tensor([[1]], dtype="int64"),
|
||||||
"output_padding_offset": paddle.zeros([1, 1], dtype="int64"),
|
"output_padding_offset": paddle.zeros([1, 1], dtype="int64"),
|
||||||
"output_cum_offsets": paddle.zeros([1, 1], dtype="int64"),
|
"output_cum_offsets": paddle.zeros([1, 1], dtype="int64"),
|
||||||
|
"batch_id_per_token_output": paddle.to_tensor([0], dtype="int32"),
|
||||||
|
"cu_seqlens_q_output": paddle.to_tensor([0, 1], dtype="int32"),
|
||||||
}
|
}
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"fastdeploy.model_executor.layers.sample.sampler.top_k_top_p_sampling",
|
"fastdeploy.model_executor.layers.sample.sampler.top_k_top_p_sampling",
|
||||||
|
|||||||
Reference in New Issue
Block a user