mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
fix eb5 mtp(mix) (#6800)
This commit is contained in:
@@ -277,6 +277,18 @@ std::vector<paddle::Tensor> TopPCandidates(
|
||||
int candidates_len,
|
||||
int max_seq_len);
|
||||
|
||||
std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
|
||||
const paddle::Tensor& ids_remove_padding,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& attn_mask_offsets_full,
|
||||
const paddle::Tensor& attn_mask_offsets_decoder,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& decode_states,
|
||||
const paddle::Tensor& mask_rollback);
|
||||
|
||||
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
@@ -1140,6 +1152,20 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("mask_rollback"),
|
||||
"Update speculative decoding states");
|
||||
|
||||
m.def("update_attn_mask_offsets",
|
||||
&UpdateAttnMaskOffsets,
|
||||
py::arg("ids_remove_padding"),
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("seq_lens_encoder"),
|
||||
py::arg("seq_lens_decoder"),
|
||||
py::arg("cu_seqlens_q"),
|
||||
py::arg("attn_mask_offsets_full"),
|
||||
py::arg("attn_mask_offsets_decoder"),
|
||||
py::arg("is_block_step"),
|
||||
py::arg("decode_states"),
|
||||
py::arg("mask_rollback"),
|
||||
"Update attn mask offset");
|
||||
|
||||
m.def("speculate_verify",
|
||||
&SpeculateVerify,
|
||||
py::arg("sampled_token_ids"),
|
||||
|
||||
Reference in New Issue
Block a user