diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 1240ff86a2..f94e8493f7 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -313,7 +313,6 @@ void GetBlockShapeAndSplitKVBlock( // decoder if (max_dec_len_this_time > 0) { if (mla_backend) { - PADDLE_ENFORCE(group_size <= 64, "now only group_size <= 64"); const int set_chunk_size = get_mla_dec_chunk_size(bsz); CUDA_CHECK(cudaMemsetAsync( diff --git a/docs/zh/best_practices/DeepSeek-V3.md b/docs/zh/best_practices/DeepSeek-V3.md index ad0dc38b07..0ca012d9a2 100644 --- a/docs/zh/best_practices/DeepSeek-V3.md +++ b/docs/zh/best_practices/DeepSeek-V3.md @@ -69,3 +69,34 @@ python -m fastdeploy.entrypoints.openai.multi_api_server \ --graph-optimization-config '{"use_cudagraph":true}' \ ``` + +**示例3:** H800上16卡部署 blockwise_fp8 模型16K上下文的服务 + +这个例子中支持使用FlashMLA算子做MLA的计算 + +```shell +MODEL_PATH=/models/DeepSeek-V3.2-Exp-BF16 + +export FD_DISABLE_CHUNKED_PREFILL=1 +export FD_ATTENTION_BACKEND="MLA_ATTN" +export FLAGS_flash_attn_version=3 +export USE_FLASH_MLA=1 + +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FD_ENABLE_MULTI_API_SERVER=1 +python -m fastdeploy.entrypoints.openai.multi_api_server \ + --ports "9811,9812,9813,9814,9815,9816,9817,9818" \ + --num-servers 8 \ + --args --model "$model_path" \ + --ips "10.95.246.220,10.95.230.91" \ + --no-enable-prefix-caching \ + --quantization block_wise_fp8 \ + --disable-sequence-parallel-moe \ + --tensor-parallel-size 1 \ + --num-gpu-blocks-override 1024 \ + --data-parallel-size 16 \ + --max-model-len 16384 \ + --enable-expert-parallel \ + --max-num-seqs 20 \ + --graph-optimization-config '{"use_cudagraph":true}' \ +``` diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 413a4fe36b..08ed2d8c06 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -16,6 +16,9 @@ from __future__ import annotations +import paddle + +paddle.enable_compat(scope={"flash_mla"}) # Enable torch proxy before importing flash_mla import math import os from dataclasses import dataclass, field @@ -47,6 +50,9 @@ if current_platform.is_cuda(): if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta +import triton +import triton.language as tl + from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -56,6 +62,139 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id +@triton.jit() +def extract_kernel( + q, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + output, + cache_seqlens, + HIDDEN_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + + batch_id = tl.program_id(axis=0) + cache_kv_len = tl.load(seq_lens_decoder + batch_id) + + # 这个batch不是decoder,所以不需要动弹 + if cache_kv_len <= 0: + return + + cu_len_this_batch = tl.load(cu_seqlens_q + batch_id) + + read_offsets = tl.arange(0, BLOCK_SIZE) + q += cu_len_this_batch * HIDDEN_DIM + + row_data = tl.load(q + read_offsets, mask=read_offsets < HIDDEN_DIM) + + output += batch_id * HIDDEN_DIM + + tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM) + + tl.store(cache_seqlens + batch_id, cache_kv_len + 1) + + +def extract_decoder_token_from_q( + q: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, +): + assert len(q.shape) == 2 + assert len(cu_seqlens_q.shape) == 1 + assert len(seq_lens_encoder.shape) == 1 + assert len(seq_lens_decoder.shape) == 1 + + max_bsz = seq_lens_decoder.shape[0] + + hidden_dim = q.shape[-1] + out = paddle.empty([max_bsz, hidden_dim], dtype=q.dtype) + + cache_seqlens = paddle.zeros_like(seq_lens_decoder) + + BLOCK_SIZE = triton.next_power_of_2(hidden_dim) + + grid = (max_bsz,) + + extract_kernel[grid]( + q, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + out, + cache_seqlens, + hidden_dim, + BLOCK_SIZE, + ) + + return out, cache_seqlens + + +@triton.jit() +def insert_kernel( + decoder_res, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + output, + HIDDEN_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + + batch_id = tl.program_id(axis=0) + cache_kv_len = tl.load(seq_lens_decoder + batch_id) + + # 这个batch不是decoder,所以不需要动弹 + if cache_kv_len <= 0: + return + + cu_len_this_batch = tl.load(cu_seqlens_q + batch_id) + + read_offsets = tl.arange(0, BLOCK_SIZE) + + decoder_res += batch_id * HIDDEN_DIM + + row_data = tl.load(decoder_res + read_offsets, mask=read_offsets < HIDDEN_DIM) + + output += cu_len_this_batch * HIDDEN_DIM + + tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM) + + +def insert_decoder_result_back( + decoder_result: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + mixed_token_num, +): + assert len(decoder_result.shape) == 4 + assert len(cu_seqlens_q.shape) == 1 + assert len(seq_lens_encoder.shape) == 1 + + max_bsz = seq_lens_encoder.shape[0] + + hidden_dim = decoder_result.shape[-2] * decoder_result.shape[-1] + out = paddle.zeros([mixed_token_num, hidden_dim], dtype=decoder_result.dtype) + + BLOCK_SIZE = triton.next_power_of_2(hidden_dim) + + grid = (max_bsz,) + + insert_kernel[grid]( + decoder_result, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + out, + hidden_dim, + BLOCK_SIZE, + ) + + return out + + def yarn_get_mscale(scale=1, mscale=1): """ """ if scale <= 1: @@ -455,47 +594,90 @@ class MLAAttentionBackend(AttentionBackend): speculate_decoder, ) - # 多头潜在注意力计算 - fmha_out = multi_head_latent_attention( - q, - latent_cache, - latent_cache, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.cu_seqlens_q, - forward_meta.batch_id_per_token, - metadata.block_tables, - forward_meta.kv_batch_ids, - forward_meta.kv_tile_ids_per_batch, - forward_meta.kv_num_blocks_x_cpu, - forward_meta.decoder_batch_ids, - forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_device, - forward_meta.decoder_chunk_size_device, - metadata.max_dec_len_this_time, - metadata.max_kv_len_this_time, - None, # attn_mask - None, # qkv_bias - None, # qkv_out_scales - None, # cache_k_quant_scales - None, # cache_v_quant_scales - None, # cache_k_dequant_scales - None, # cache_v_dequant_scales - None, # cache_k_zp - None, # cache_v_zp - None, # out_shifts - None, # out_smooths - metadata._fuse_kernel_compute_dtype, - "none", # cache_quant_type - self.kv_lora_rank, - self.max_seq_len, - self.attn_softmax_scale, - 0.0, # quant_max_bound - 0.0, # quant_min_bound - 0.0, # out_linear_in_scale - speculate_max_tokens, - True, # causal - speculate_decoder, - ) + if int(os.getenv("USE_FLASH_MLA", "0")) == 0: + assert self.num_heads <= 64, "paddle mla attention support failed" + # 多头潜在注意力计算 + fmha_out = multi_head_latent_attention( + q, + latent_cache, + latent_cache, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + forward_meta.batch_id_per_token, + metadata.block_tables, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, + metadata.max_dec_len_this_time, + metadata.max_kv_len_this_time, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # out_shifts + None, # out_smooths + metadata._fuse_kernel_compute_dtype, + "none", # cache_quant_type + self.kv_lora_rank, + self.max_seq_len, + self.attn_softmax_scale, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + 0.0, # out_linear_in_scale + speculate_max_tokens, + True, # causal + speculate_decoder, + ) - return fmha_out + return fmha_out + else: + import flash_mla + + decoder_q, cache_seqlens = extract_decoder_token_from_q( + q, + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + ) + + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() + token_num = q.shape[0] + decoder_q.reshape_([-1, 1, self.num_heads, 576]) + + new_cache_shape = latent_cache.shape + assert new_cache_shape[1] == 1 + new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1] + + decoder_res, _ = flash_mla.flash_mla_with_kvcache( + decoder_q, + # 外面的开源仓库的kv cache存储格式和FD的不同 + # 幸好这里缓存的头是1,直接view即可,否则上上下下要改很多! + latent_cache.view(new_cache_shape), + metadata.block_tables, + cache_seqlens, + 512, # t.dv, + tile_scheduler_metadata, + num_splits, + softmax_scale=self.attn_softmax_scale, + causal=True, + ) + + final_res = insert_decoder_result_back( + decoder_res, + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + token_num, + ) + + return final_res diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 94c7086f9f..e3b1281988 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -362,7 +362,10 @@ class DeepseekV3MLAAttention(nn.Layer): compressed_kv = self.kv_a_layernorm(compressed_kv)[0] - if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time + need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0 + need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0 + + if need_do_prefill: # max_enc_len_this_time key_value = self.kv_b_proj(compressed_kv) key_value.reshape_( [ @@ -393,10 +396,9 @@ class DeepseekV3MLAAttention(nn.Layer): fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) - fmha_out = fmha_out_prefill - if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time + if need_do_decode: # max_dec_len_this_time q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) q_input = paddle.concat([q_nope_out, query_pe], axis=-1) @@ -427,10 +429,10 @@ class DeepseekV3MLAAttention(nn.Layer): .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) ) - if fmha_out is None: - fmha_out = fmha_out_decode + if need_do_prefill: + fmha_out += fmha_out_decode else: - fmha_out = fmha_out + fmha_out_decode + fmha_out = fmha_out_decode output = self.o_proj(fmha_out) return output diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index d3ae33b1aa..03d5862f05 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -184,8 +184,8 @@ class InputBatch: dtype="int64", ) self.batch_id_per_token = paddle.full([max_num_seqs * self.max_chunk_tokens, 1], 0, dtype="int32") - self.cu_seqlens_q = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32") - self.cu_seqlens_k = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32") + self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32") + self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32") # Declare AttentionBackend buffers self.decoder_batch_ids = None