From 609f649dd7831fc54b24af1610eea7222c249048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:37:52 +0800 Subject: [PATCH] [OP] Add flashmla baseline implementation and precision test (#7477) --- .../layers/attention/mla_attention_backend.py | 30 +++++++ tests/operators/test_flashmla_precision.py | 78 +++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 tests/operators/test_flashmla_precision.py diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index c609397a9d..7e2e1066f0 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -713,3 +713,33 @@ class MLAAttentionBackend(AttentionBackend): ) return final_res + + @staticmethod + def flashmla_baseline(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale): + page_size = 64 + q_num_heads = decoder_q.shape[2] + assert decoder_q.shape[1:] == [1, q_num_heads, 576] + assert latent_cache.shape[1:] == [1, page_size, 576] + + res_baseline = paddle.zeros([decoder_q.shape[0], 1, q_num_heads, 512]) + for batch_id in range(decoder_q.shape[0]): + kv_len = cache_seqlens[batch_id].item() + extract_k = paddle.zeros([kv_len, 576], dtype=decoder_q.dtype) + extract_v = paddle.zeros([kv_len, 512], dtype=decoder_q.dtype) + + for local_seq_id in range(0, kv_len, page_size): + start = local_seq_id + end = min(local_seq_id + page_size, kv_len) + physical_id = block_table[batch_id, local_seq_id // page_size].item() + + page_end = page_size if end % page_size == 0 else end % page_size + extract_k[start:end, :] = latent_cache[physical_id, 0, :page_end, :] + extract_v[start:end, :] = latent_cache[physical_id, 0, :page_end, :512] + + this_batch_q = decoder_q[batch_id, 0, :, :] + p = paddle.matmul(this_batch_q, extract_k.transpose([1, 0]).contiguous()) + p = p * attn_softmax_scale + p = paddle.nn.functional.softmax(p, -1) + res_baseline[batch_id, 0, :, :] = paddle.matmul(p, extract_v).contiguous() + + return res_baseline diff --git a/tests/operators/test_flashmla_precision.py b/tests/operators/test_flashmla_precision.py new file mode 100644 index 0000000000..e1e3a9a242 --- /dev/null +++ b/tests/operators/test_flashmla_precision.py @@ -0,0 +1,78 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +import paddle + +paddle.set_default_dtype("bfloat16") + +from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( + MLAAttentionBackend, +) + + +class TestFlashMLA(unittest.TestCase): + def setUp(self): + pass + + def test_flashmla(self): + bsz = 128 + kv_len = 1000 + decoder_q = paddle.randn([bsz, 1, 128, 576], dtype="bfloat16") + cache_seqlens = paddle.zeros([bsz], dtype="int32") + kv_len + block_tables = paddle.arange((kv_len // 64 + 1) * bsz, dtype="int32").reshape([bsz, -1]) + latent_cache = paddle.randn([10000, 1, 64, 576], dtype="bfloat16") + # copy from dsv3 + attn_softmax_scale = 0.1352337788608801 + + baseline_out = MLAAttentionBackend.flashmla_baseline( + decoder_q, latent_cache, block_tables, cache_seqlens, attn_softmax_scale + ) + + paddle.enable_compat(scope={"flash_mla"}) # Enable paddle.enable_compat before importing flash_mla + try: + import flash_mla + except ImportError: + print(100 * "Please install flash_mla first") + return + + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() + + new_cache_shape = latent_cache.shape + assert new_cache_shape[1] == 1 + new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1] + + decoder_res, _ = flash_mla.flash_mla_with_kvcache( + decoder_q, + # 外面的开源仓库的kv cache存储格式和FD的不同 + # 幸好这里缓存的头是1,直接view即可,否则上上下下要改很多! + latent_cache.view(new_cache_shape), + block_tables, + cache_seqlens, + 512, # t.dv, + tile_scheduler_metadata, + num_splits, + softmax_scale=attn_softmax_scale, + causal=True, + ) + + max_diff = (decoder_res - baseline_out).abs().max().item() + self.assertLessEqual(max_diff, 0.1) + + +if __name__ == "__main__": + unittest.main()