[XPU] add more type for recover batch sequence (#6142)

This commit is contained in:
lizan1999
2026-01-23 15:16:05 +08:00
committed by GitHub
parent bef6293552
commit b3a48529ab
7 changed files with 695 additions and 4 deletions
@@ -289,6 +289,17 @@ DLL_EXPORT int eb_mtp_gather_next_token(
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TY>
DLL_EXPORT int eb_recover_batch_sequence(
Context* ctx,
const TX* x,
TY* y,
VectorParam<int32_t>& encoder_seqs_lods, // NOLINT
VectorParam<int32_t>& decoder_seqs_lods, // NOLINT
VectorParam<int32_t>& encoder_batch_map, // NOLINT
VectorParam<int32_t>& decoder_batch_map, // NOLINT
int64_t hidden_dim);
template <typename TX, typename TSCALE = float, typename TY = int8_t>
DLL_EXPORT int quant2d_per_channel(api::Context* ctx,
const TX* x,