[Cherry-Pick][Others] enhance deep_ep import and support mixed mode flash_mask_attn #6238 (#6232)

* fash_mask_attn support mixed

* enhance deep_ep and fix bug

* update

* fix
This commit is contained in:
Yuanle Liu
2026-01-27 20:02:34 +08:00
committed by GitHub
parent 1d519b9a13
commit fb7ec62341
5 changed files with 197 additions and 110 deletions
+16
View File
@@ -49,6 +49,21 @@ void cuda_host_free(uintptr_t ptr) {
check_cuda_error(cudaFreeHost(reinterpret_cast<void*>(ptr)));
}
void FlashAttentionMask(const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& attn_out,
const paddle::optional<paddle::Tensor>& mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int q_token_num,
const int k_token_num);
std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
@@ -1158,6 +1173,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("append_attention_with_output",
&AppendAttentionWithOutput,
"append attention with output function");
m.def("flash_mask_attention", &FlashAttentionMask, "flash_mask_attention");
/**
* gqa_rope_write_cache.cu
* gqa_rope_write_cache