fix eb5 mtp(mix) (#6800)

This commit is contained in:
cmcamdy
2026-03-13 17:36:57 +08:00
committed by GitHub
parent 8c1a2827d3
commit 7591e0d6bc
5 changed files with 55 additions and 5 deletions
@@ -277,6 +277,18 @@ std::vector<paddle::Tensor> TopPCandidates(
int candidates_len,
int max_seq_len);
std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
const paddle::Tensor& ids_remove_padding,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& attn_mask_offsets_full,
const paddle::Tensor& attn_mask_offsets_decoder,
const paddle::Tensor& is_block_step,
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback);
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
@@ -1140,6 +1152,20 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("mask_rollback"),
"Update speculative decoding states");
m.def("update_attn_mask_offsets",
&UpdateAttnMaskOffsets,
py::arg("ids_remove_padding"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_decoder"),
py::arg("cu_seqlens_q"),
py::arg("attn_mask_offsets_full"),
py::arg("attn_mask_offsets_decoder"),
py::arg("is_block_step"),
py::arg("decode_states"),
py::arg("mask_rollback"),
"Update attn mask offset");
m.def("speculate_verify",
&SpeculateVerify,
py::arg("sampled_token_ids"),