support dsv3 use flashmla (#6593)

This commit is contained in:
周周周
2026-03-03 11:09:43 +08:00
committed by GitHub
parent 0f718baaf2
commit 3cc09418f1
5 changed files with 266 additions and 52 deletions
@@ -16,6 +16,9 @@
from __future__ import annotations
import paddle
paddle.enable_compat(scope={"flash_mla"}) # Enable torch proxy before importing flash_mla
import math
import os
from dataclasses import dataclass, field
@@ -47,6 +50,9 @@ if current_platform.is_cuda():
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
import triton
import triton.language as tl
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
@@ -56,6 +62,139 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
@triton.jit()
def extract_kernel(
q,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
output,
cache_seqlens,
HIDDEN_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
batch_id = tl.program_id(axis=0)
cache_kv_len = tl.load(seq_lens_decoder + batch_id)
# 这个batch不是decoder,所以不需要动弹
if cache_kv_len <= 0:
return
cu_len_this_batch = tl.load(cu_seqlens_q + batch_id)
read_offsets = tl.arange(0, BLOCK_SIZE)
q += cu_len_this_batch * HIDDEN_DIM
row_data = tl.load(q + read_offsets, mask=read_offsets < HIDDEN_DIM)
output += batch_id * HIDDEN_DIM
tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM)
tl.store(cache_seqlens + batch_id, cache_kv_len + 1)
def extract_decoder_token_from_q(
q: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
):
assert len(q.shape) == 2
assert len(cu_seqlens_q.shape) == 1
assert len(seq_lens_encoder.shape) == 1
assert len(seq_lens_decoder.shape) == 1
max_bsz = seq_lens_decoder.shape[0]
hidden_dim = q.shape[-1]
out = paddle.empty([max_bsz, hidden_dim], dtype=q.dtype)
cache_seqlens = paddle.zeros_like(seq_lens_decoder)
BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
grid = (max_bsz,)
extract_kernel[grid](
q,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
out,
cache_seqlens,
hidden_dim,
BLOCK_SIZE,
)
return out, cache_seqlens
@triton.jit()
def insert_kernel(
decoder_res,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
output,
HIDDEN_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
batch_id = tl.program_id(axis=0)
cache_kv_len = tl.load(seq_lens_decoder + batch_id)
# 这个batch不是decoder,所以不需要动弹
if cache_kv_len <= 0:
return
cu_len_this_batch = tl.load(cu_seqlens_q + batch_id)
read_offsets = tl.arange(0, BLOCK_SIZE)
decoder_res += batch_id * HIDDEN_DIM
row_data = tl.load(decoder_res + read_offsets, mask=read_offsets < HIDDEN_DIM)
output += cu_len_this_batch * HIDDEN_DIM
tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM)
def insert_decoder_result_back(
decoder_result: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
mixed_token_num,
):
assert len(decoder_result.shape) == 4
assert len(cu_seqlens_q.shape) == 1
assert len(seq_lens_encoder.shape) == 1
max_bsz = seq_lens_encoder.shape[0]
hidden_dim = decoder_result.shape[-2] * decoder_result.shape[-1]
out = paddle.zeros([mixed_token_num, hidden_dim], dtype=decoder_result.dtype)
BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
grid = (max_bsz,)
insert_kernel[grid](
decoder_result,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
out,
hidden_dim,
BLOCK_SIZE,
)
return out
def yarn_get_mscale(scale=1, mscale=1):
""" """
if scale <= 1:
@@ -455,47 +594,90 @@ class MLAAttentionBackend(AttentionBackend):
speculate_decoder,
)
# 多头潜在注意力计算
fmha_out = multi_head_latent_attention(
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
metadata.max_dec_len_this_time,
metadata.max_kv_len_this_time,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # out_shifts
None, # out_smooths
metadata._fuse_kernel_compute_dtype,
"none", # cache_quant_type
self.kv_lora_rank,
self.max_seq_len,
self.attn_softmax_scale,
0.0, # quant_max_bound
0.0, # quant_min_bound
0.0, # out_linear_in_scale
speculate_max_tokens,
True, # causal
speculate_decoder,
)
if int(os.getenv("USE_FLASH_MLA", "0")) == 0:
assert self.num_heads <= 64, "paddle mla attention support failed"
# 多头潜在注意力计算
fmha_out = multi_head_latent_attention(
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
metadata.max_dec_len_this_time,
metadata.max_kv_len_this_time,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # out_shifts
None, # out_smooths
metadata._fuse_kernel_compute_dtype,
"none", # cache_quant_type
self.kv_lora_rank,
self.max_seq_len,
self.attn_softmax_scale,
0.0, # quant_max_bound
0.0, # quant_min_bound
0.0, # out_linear_in_scale
speculate_max_tokens,
True, # causal
speculate_decoder,
)
return fmha_out
return fmha_out
else:
import flash_mla
decoder_q, cache_seqlens = extract_decoder_token_from_q(
q,
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
)
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()
token_num = q.shape[0]
decoder_q.reshape_([-1, 1, self.num_heads, 576])
new_cache_shape = latent_cache.shape
assert new_cache_shape[1] == 1
new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1]
decoder_res, _ = flash_mla.flash_mla_with_kvcache(
decoder_q,
# 外面的开源仓库的kv cache存储格式和FD的不同
# 幸好这里缓存的头是1,直接view即可,否则上上下下要改很多!
latent_cache.view(new_cache_shape),
metadata.block_tables,
cache_seqlens,
512, # t.dv,
tile_scheduler_metadata,
num_splits,
softmax_scale=self.attn_softmax_scale,
causal=True,
)
final_res = insert_decoder_result_back(
decoder_res,
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
token_num,
)
return final_res