mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
support dsv3 use flashmla (#6593)
This commit is contained in:
@@ -16,6 +16,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import paddle
|
||||
|
||||
paddle.enable_compat(scope={"flash_mla"}) # Enable torch proxy before importing flash_mla
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
@@ -47,6 +50,9 @@ if current_platform.is_cuda():
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
@@ -56,6 +62,139 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
|
||||
|
||||
@triton.jit()
|
||||
def extract_kernel(
|
||||
q,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
output,
|
||||
cache_seqlens,
|
||||
HIDDEN_DIM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
|
||||
batch_id = tl.program_id(axis=0)
|
||||
cache_kv_len = tl.load(seq_lens_decoder + batch_id)
|
||||
|
||||
# 这个batch不是decoder,所以不需要动弹
|
||||
if cache_kv_len <= 0:
|
||||
return
|
||||
|
||||
cu_len_this_batch = tl.load(cu_seqlens_q + batch_id)
|
||||
|
||||
read_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
q += cu_len_this_batch * HIDDEN_DIM
|
||||
|
||||
row_data = tl.load(q + read_offsets, mask=read_offsets < HIDDEN_DIM)
|
||||
|
||||
output += batch_id * HIDDEN_DIM
|
||||
|
||||
tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM)
|
||||
|
||||
tl.store(cache_seqlens + batch_id, cache_kv_len + 1)
|
||||
|
||||
|
||||
def extract_decoder_token_from_q(
|
||||
q: paddle.Tensor,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
):
|
||||
assert len(q.shape) == 2
|
||||
assert len(cu_seqlens_q.shape) == 1
|
||||
assert len(seq_lens_encoder.shape) == 1
|
||||
assert len(seq_lens_decoder.shape) == 1
|
||||
|
||||
max_bsz = seq_lens_decoder.shape[0]
|
||||
|
||||
hidden_dim = q.shape[-1]
|
||||
out = paddle.empty([max_bsz, hidden_dim], dtype=q.dtype)
|
||||
|
||||
cache_seqlens = paddle.zeros_like(seq_lens_decoder)
|
||||
|
||||
BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
|
||||
|
||||
grid = (max_bsz,)
|
||||
|
||||
extract_kernel[grid](
|
||||
q,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
out,
|
||||
cache_seqlens,
|
||||
hidden_dim,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out, cache_seqlens
|
||||
|
||||
|
||||
@triton.jit()
|
||||
def insert_kernel(
|
||||
decoder_res,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
output,
|
||||
HIDDEN_DIM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
|
||||
batch_id = tl.program_id(axis=0)
|
||||
cache_kv_len = tl.load(seq_lens_decoder + batch_id)
|
||||
|
||||
# 这个batch不是decoder,所以不需要动弹
|
||||
if cache_kv_len <= 0:
|
||||
return
|
||||
|
||||
cu_len_this_batch = tl.load(cu_seqlens_q + batch_id)
|
||||
|
||||
read_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
decoder_res += batch_id * HIDDEN_DIM
|
||||
|
||||
row_data = tl.load(decoder_res + read_offsets, mask=read_offsets < HIDDEN_DIM)
|
||||
|
||||
output += cu_len_this_batch * HIDDEN_DIM
|
||||
|
||||
tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM)
|
||||
|
||||
|
||||
def insert_decoder_result_back(
|
||||
decoder_result: paddle.Tensor,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
mixed_token_num,
|
||||
):
|
||||
assert len(decoder_result.shape) == 4
|
||||
assert len(cu_seqlens_q.shape) == 1
|
||||
assert len(seq_lens_encoder.shape) == 1
|
||||
|
||||
max_bsz = seq_lens_encoder.shape[0]
|
||||
|
||||
hidden_dim = decoder_result.shape[-2] * decoder_result.shape[-1]
|
||||
out = paddle.zeros([mixed_token_num, hidden_dim], dtype=decoder_result.dtype)
|
||||
|
||||
BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
|
||||
|
||||
grid = (max_bsz,)
|
||||
|
||||
insert_kernel[grid](
|
||||
decoder_result,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
out,
|
||||
hidden_dim,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
""" """
|
||||
if scale <= 1:
|
||||
@@ -455,47 +594,90 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
# 多头潜在注意力计算
|
||||
fmha_out = multi_head_latent_attention(
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_kv_len_this_time,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
None, # cache_k_quant_scales
|
||||
None, # cache_v_quant_scales
|
||||
None, # cache_k_dequant_scales
|
||||
None, # cache_v_dequant_scales
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # out_shifts
|
||||
None, # out_smooths
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
"none", # cache_quant_type
|
||||
self.kv_lora_rank,
|
||||
self.max_seq_len,
|
||||
self.attn_softmax_scale,
|
||||
0.0, # quant_max_bound
|
||||
0.0, # quant_min_bound
|
||||
0.0, # out_linear_in_scale
|
||||
speculate_max_tokens,
|
||||
True, # causal
|
||||
speculate_decoder,
|
||||
)
|
||||
if int(os.getenv("USE_FLASH_MLA", "0")) == 0:
|
||||
assert self.num_heads <= 64, "paddle mla attention support failed"
|
||||
# 多头潜在注意力计算
|
||||
fmha_out = multi_head_latent_attention(
|
||||
q,
|
||||
latent_cache,
|
||||
latent_cache,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_kv_len_this_time,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
None, # cache_k_quant_scales
|
||||
None, # cache_v_quant_scales
|
||||
None, # cache_k_dequant_scales
|
||||
None, # cache_v_dequant_scales
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # out_shifts
|
||||
None, # out_smooths
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
"none", # cache_quant_type
|
||||
self.kv_lora_rank,
|
||||
self.max_seq_len,
|
||||
self.attn_softmax_scale,
|
||||
0.0, # quant_max_bound
|
||||
0.0, # quant_min_bound
|
||||
0.0, # out_linear_in_scale
|
||||
speculate_max_tokens,
|
||||
True, # causal
|
||||
speculate_decoder,
|
||||
)
|
||||
|
||||
return fmha_out
|
||||
return fmha_out
|
||||
else:
|
||||
import flash_mla
|
||||
|
||||
decoder_q, cache_seqlens = extract_decoder_token_from_q(
|
||||
q,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
)
|
||||
|
||||
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()
|
||||
token_num = q.shape[0]
|
||||
decoder_q.reshape_([-1, 1, self.num_heads, 576])
|
||||
|
||||
new_cache_shape = latent_cache.shape
|
||||
assert new_cache_shape[1] == 1
|
||||
new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1]
|
||||
|
||||
decoder_res, _ = flash_mla.flash_mla_with_kvcache(
|
||||
decoder_q,
|
||||
# 外面的开源仓库的kv cache存储格式和FD的不同
|
||||
# 幸好这里缓存的头是1,直接view即可,否则上上下下要改很多!
|
||||
latent_cache.view(new_cache_shape),
|
||||
metadata.block_tables,
|
||||
cache_seqlens,
|
||||
512, # t.dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
softmax_scale=self.attn_softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
final_res = insert_decoder_result_back(
|
||||
decoder_res,
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
token_num,
|
||||
)
|
||||
|
||||
return final_res
|
||||
|
||||
Reference in New Issue
Block a user