mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[XPU] add more type for recover batch sequence (#6142)
This commit is contained in:
@@ -289,6 +289,17 @@ DLL_EXPORT int eb_mtp_gather_next_token(
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TY>
|
||||
DLL_EXPORT int eb_recover_batch_sequence(
|
||||
Context* ctx,
|
||||
const TX* x,
|
||||
TY* y,
|
||||
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
|
||||
VectorParam<int32_t>& encoder_batch_map, // NOLINT
|
||||
VectorParam<int32_t>& decoder_batch_map, // NOLINT
|
||||
int64_t hidden_dim);
|
||||
|
||||
template <typename TX, typename TSCALE = float, typename TY = int8_t>
|
||||
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
|
||||
const TX* x,
|
||||
|
||||
Reference in New Issue
Block a user