From de0c5e68fbf57ce40653a5e2440fdba9f2e42e5f Mon Sep 17 00:00:00 2001 From: RuohengMa <120699764+RuohengMa@users.noreply.github.com> Date: Thu, 16 Apr 2026 14:28:40 +0800 Subject: [PATCH] [XPU] Split the block_attn operator into smaller operators (#6798) * spliced block_attn * adapt to latest vllm * fix unit tests * delete mtp+cudagraph 4 cards test * fix vl model * fix mtp * fix slot mapping --- custom_ops/xpu_ops/src/ops/block_attn.cc | 13 +- .../xpu_ops/src/ops/block_attn_spliced.cc | 1911 +++++++++++++++++ custom_ops/xpu_ops/src/ops/get_infer_param.cc | 113 +- custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 97 +- custom_ops/xpu_ops/test/test_block_attn.py | 653 ++++++ .../test/test_block_attn_prefix_cache.py | 113 +- fastdeploy/model_executor/forward_meta.py | 5 + .../layers/backends/xpu/attention.py | 2 + .../xpu_pre_and_post_process.py | 10 +- fastdeploy/spec_decode/mtp.py | 1 + fastdeploy/worker/xpu_model_runner.py | 1 + .../xpu_ci/4cards_cases/test_mtp_cudagraph.py | 103 - 12 files changed, 2891 insertions(+), 131 deletions(-) create mode 100644 custom_ops/xpu_ops/src/ops/block_attn_spliced.cc create mode 100644 custom_ops/xpu_ops/test/test_block_attn.py delete mode 100644 tests/xpu_ci/4cards_cases/test_mtp_cudagraph.py diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index a9e23c0834..328ab06555 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -159,6 +159,7 @@ std::vector BlockAttnKernel( if (use_neox_rotary_style) { pos_emb_type = "NEOX"; } else if (rope_head_dim == head_dim / 2) { + // vl model use this pos_emb_type = "HALF_HEAD_DIM"; } else { pos_emb_type = "NORMAL"; @@ -984,7 +985,7 @@ std::vector BlockAttnKernel( return {block_attn_out}; } -std::vector BlockAttn( +std::vector BlockAttnFused( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, @@ -1008,6 +1009,8 @@ std::vector BlockAttn( const paddle::Tensor& decoder_context_len_cache, const paddle::Tensor& decoder_batch_map, const paddle::Tensor& prefix_len, + const paddle::Tensor& slot_mapping_enc, + const paddle::Tensor& slot_mapping_dec, const paddle::optional& k_scales, const paddle::optional& v_scales, const paddle::optional& k_scales_inv, @@ -1067,7 +1070,7 @@ std::vector BlockAttn( } else if (cache_dtype == paddle::DataType::INT8) { APPLY_KERNEL(paddle::bfloat16, int8_t, paddle::bfloat16); } else { - PD_THROW("block_attn not support cache_dtype==%d", + PD_THROW("block_attn_fused not support cache_dtype==%d", static_cast(cache_dtype)); return {}; } @@ -1097,7 +1100,7 @@ std::vector BlockAttnInferDtype( return {qkv_dtype}; } -PD_BUILD_STATIC_OP(block_attn) +PD_BUILD_STATIC_OP(block_attn_fused) .Inputs({"qkv", "key_cache", "value_cache", @@ -1121,6 +1124,8 @@ PD_BUILD_STATIC_OP(block_attn) "decoder_context_len_cache", "decoder_batch_map", "prefix_len", + "slot_mapping_enc", + "slot_mapping_dec", paddle::Optional("k_scales"), paddle::Optional("v_scales"), paddle::Optional("k_scales_inv"), @@ -1135,6 +1140,6 @@ PD_BUILD_STATIC_OP(block_attn) paddle::Optional("cachekv_signal_thread_cpu")}) .Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"}) .Outputs({"block_attn_out"}) - .SetKernelFn(PD_KERNEL(BlockAttn)) + .SetKernelFn(PD_KERNEL(BlockAttnFused)) .SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(BlockAttnInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/block_attn_spliced.cc b/custom_ops/xpu_ops/src/ops/block_attn_spliced.cc new file mode 100644 index 0000000000..af1bdec937 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/block_attn_spliced.cc @@ -0,0 +1,1911 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ops/pybind/cachekv_signal_thread_worker.h" +#include "ops/remote_cache_kv_ipc.h" +#include "ops/utility/env.h" +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false); +XPU_DECLARE_BOOL(use_pd_disaggregation_per_chunk, false); +XPU_DECLARE_BOOL(encoder_splice, false); +XPU_DECLARE_BOOL(decoder_splice, false); +XPU_DECLARE_BOOL(use_sdnn_rmsnorm, false); + +namespace xftblock = baidu::xpu::xftblock; +namespace api = baidu::xpu::api; + +template +struct SplitRopeTypeTrait { + using E_Scale = TS; + using D_Scale = TS; +}; +template <> +struct SplitRopeTypeTrait { + using E_Scale = bfloat16; + using D_Scale = float; +}; +template <> +struct SplitRopeTypeTrait { + using E_Scale = bfloat16; + using D_Scale = bfloat16; +}; + +void do_add_zero(api::Context* xpu_ctx, + paddle::Place place, + bfloat16* x, + int64_t token_num, + int64_t kv_head_num, + int64_t head_dim, + const float* cache_zero) { + if (cache_zero == nullptr) { + return; + } + std::vector x_shape = {token_num, kv_head_num * head_dim}; + std::vector cache_zero_shape = {1, kv_head_num * head_dim}; + int64_t x_numel = token_num * kv_head_num * head_dim; + + int ret; + auto x_fp32 = paddle::empty(x_shape, paddle::DataType::FLOAT32, place); + auto x_fp32_ptr = const_cast(x_fp32.data()); + + ret = api::cast(xpu_ctx, x, x_fp32_ptr, x_numel); + PD_CHECK(ret == api::SUCCESS, "api::cast failed."); + ret = api::broadcast_add( + xpu_ctx, x_fp32_ptr, cache_zero, x_fp32_ptr, x_shape, cache_zero_shape); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed."); + ret = api::cast(xpu_ctx, x_fp32_ptr, x, x_numel); + PD_CHECK(ret == api::SUCCESS, "api::cast failed."); +} + +template +void store_paged_kv_cache_wrapper(api::Context* xpu_ctx, + paddle::Place place, + bfloat16* k, + bfloat16* v, + TKV_CACHE* key_cache, + TKV_CACHE* value_cache, + TID* slot_mapping, + int64_t num_blocks, + int64_t token_num, + int64_t kv_head_num, + int64_t head_dim, + int64_t block_size, + const TSCALE* k_cache_scale, + const TSCALE* v_cache_scale, + const TZERO* k_cache_zero, + const TZERO* v_cache_zero) { + std::vector cache_zero_shape = {1, kv_head_num * head_dim}; + int64_t cache_zero_numel = kv_head_num * head_dim; + + int ret; + paddle::Tensor k_cache_scale_fp32, v_cache_scale_fp32, k_cache_zero_fp32, + v_cache_zero_fp32; + float* k_cache_scale_fp32_ptr = nullptr; + float* v_cache_scale_fp32_ptr = nullptr; + float* k_cache_zero_fp32_ptr = nullptr; + float* v_cache_zero_fp32_ptr = nullptr; + + if (k_cache_scale != nullptr) { + if (!std::is_same::value) { + k_cache_scale_fp32 = + paddle::empty(cache_zero_shape, paddle::DataType::FLOAT32, place); + v_cache_scale_fp32 = + paddle::empty(cache_zero_shape, paddle::DataType::FLOAT32, place); + k_cache_scale_fp32_ptr = + const_cast(k_cache_scale_fp32.data()); + v_cache_scale_fp32_ptr = + const_cast(v_cache_scale_fp32.data()); + ret = api::cast( + xpu_ctx, k_cache_scale, k_cache_scale_fp32_ptr, cache_zero_numel); + PD_CHECK(ret == api::SUCCESS, "api::cast failed."); + ret = api::cast( + xpu_ctx, v_cache_scale, v_cache_scale_fp32_ptr, cache_zero_numel); + PD_CHECK(ret == api::SUCCESS, "api::cast failed."); + } else { + k_cache_scale_fp32_ptr = + const_cast(reinterpret_cast(k_cache_scale)); + v_cache_scale_fp32_ptr = + const_cast(reinterpret_cast(v_cache_scale)); + } + } + if (k_cache_zero != nullptr) { + if (!std::is_same::value) { + k_cache_zero_fp32 = + paddle::empty(cache_zero_shape, paddle::DataType::FLOAT32, place); + v_cache_zero_fp32 = + paddle::empty(cache_zero_shape, paddle::DataType::FLOAT32, place); + k_cache_zero_fp32_ptr = + const_cast(k_cache_zero_fp32.data()); + v_cache_zero_fp32_ptr = + const_cast(v_cache_zero_fp32.data()); + ret = api::cast( + xpu_ctx, k_cache_zero, k_cache_zero_fp32_ptr, cache_zero_numel); + PD_CHECK(ret == api::SUCCESS, "api::cast failed."); + ret = api::cast( + xpu_ctx, v_cache_zero, v_cache_zero_fp32_ptr, cache_zero_numel); + PD_CHECK(ret == api::SUCCESS, "api::cast failed."); + } else { + k_cache_zero_fp32_ptr = + const_cast(reinterpret_cast(k_cache_zero)); + v_cache_zero_fp32_ptr = + const_cast(reinterpret_cast(v_cache_zero)); + } + } + + if (k_cache_zero != nullptr) { + do_add_zero(xpu_ctx, + place, + k, + token_num, + kv_head_num, + head_dim, + k_cache_zero_fp32_ptr); + do_add_zero(xpu_ctx, + place, + v, + token_num, + kv_head_num, + head_dim, + v_cache_zero_fp32_ptr); + } + + ret = infer_ops::store_paged_kv_cache( + xpu_ctx, + k, + v, + key_cache, + value_cache, + slot_mapping, + k_cache_scale_fp32_ptr, + v_cache_scale_fp32_ptr, + token_num, + kv_head_num, + head_dim, + num_blocks, + block_size); + PD_CHECK(ret == api::SUCCESS, "store_paged_kv_cache failed."); +} + +template +void split_kvcache_encoder(api::Context* xpu_ctx, + xftblock::XFTContext& xctx, + const paddle::Tensor& qkv, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& block_tables, + const paddle::Tensor& slot_mapping, + int64_t batch_size, + int64_t token_num, + int64_t q_num_heads, + int64_t kv_num_heads, + int64_t head_dim, + int64_t rope_head_dim, + int64_t hidden_dim, + int64_t rope_max_seqlen, + int64_t block_size, + int64_t num_blocks, + int64_t block_batch, + int64_t max_block_per_seq, + const api::VectorParam& seq_lod, + const api::VectorParam& start_tokens, + const api::VectorParam& real_batch, + int64_t qkv_offset, + const float* k_cache_scale_inv, + const float* v_cache_scale_inv, + const TSCALE* intx_k_pc_scale, + const TSCALE* intx_v_pc_scale, + const TSCALE* intx_k_pc_zero, + const TSCALE* intx_v_pc_zero, + const float* q_norm_weight, + const float* k_norm_weight, + std::string pos_emb_type, + bool rope_3d, + bool use_neox_rotary_style) { + int ret; + int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads; + if (FLAGS_encoder_splice) { + if (rope_3d) { + PD_THROW("split_kvcache_encoder does not support rope_3d == true!"); + } + paddle::Place place = qkv.place(); + xftblock::DataType KV_BUF_TYPE = std::is_same::value + ? xftblock::DataType::DT_BFLOAT16 + : xftblock::DataType::DT_FLOAT16; + auto q_split = paddle::empty({token_num, hidden_dim}, qkv.type(), place); + auto k_split = paddle::empty( + {token_num, real_kv_num_heads * head_dim}, qkv.type(), place); + xftblock::Tensor qkv_xft_tensor( + const_cast(qkv.data() + qkv_offset * sizeof(TQKV)), + KV_BUF_TYPE, + {token_num, (q_num_heads + 2 * real_kv_num_heads) * head_dim}); + xftblock::Tensor q_xft_tensor( + q_split.data(), KV_BUF_TYPE, {token_num, hidden_dim}); + xftblock::Tensor k_xft_tensor( + k_split.data(), KV_BUF_TYPE, {token_num, real_kv_num_heads * head_dim}); + xftblock::Tensor v_xft_tensor(const_cast(v.data()), + KV_BUF_TYPE, + {token_num, real_kv_num_heads * head_dim}); + + ret = xftblock::split_qkv_block(&xctx, + &qkv_xft_tensor, + &q_xft_tensor, + &k_xft_tensor, + &v_xft_tensor, + token_num, + q_num_heads, + real_kv_num_heads, + head_dim); + PD_CHECK(ret == api::SUCCESS, "split_qkv_block failed."); + + if (!use_neox_rotary_style) { + ret = infer_ops::vsl_rotary_embedding_gptj( + xpu_ctx, + reinterpret_cast(q_split.data()), + reinterpret_cast(k_split.data()), + reinterpret_cast(rotary_embs.data()), + const_cast(reinterpret_cast(q.data())), + const_cast(reinterpret_cast(k.data())), + seq_lod, + 1, + rope_max_seqlen, + q_num_heads, + head_dim, + "BLHD", + start_tokens, + "NORMAL", + real_kv_num_heads, + false); + PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_gptj failed."); + } else { + ret = infer_ops::vsl_rotary_embedding_neox( + xpu_ctx, + reinterpret_cast(q_split.data()), + reinterpret_cast(k_split.data()), + reinterpret_cast(rotary_embs.data()), + const_cast(reinterpret_cast(q.data())), + const_cast(reinterpret_cast(k.data())), + seq_lod, + 1, + rope_max_seqlen, + q_num_heads, + head_dim, + rope_head_dim, + "BLHD", + start_tokens, + "NORMAL", + real_kv_num_heads, + false); + PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_neox failed."); + } + + if (q_norm_weight) { + ret = infer_ops::qkrmsnorm( + xpu_ctx, + reinterpret_cast(q.data()), + q_num_heads * head_dim, + head_dim, + const_cast(reinterpret_cast(q.data())), + q_num_heads * head_dim, + head_dim, + head_dim, + token_num, + q_num_heads, + 1e-5, + q_norm_weight, + nullptr, // not supported yet + false, // not supported yet + FLAGS_use_sdnn_rmsnorm, + false); + } + if (k_norm_weight) { + ret = infer_ops::qkrmsnorm( + xpu_ctx, + reinterpret_cast(k.data()), + real_kv_num_heads * head_dim, + head_dim, + const_cast(reinterpret_cast(k.data())), + real_kv_num_heads * head_dim, + head_dim, + head_dim, + token_num, + real_kv_num_heads, + 1e-5, + k_norm_weight, + nullptr, // not supported yet + false, // not supported yet + FLAGS_use_sdnn_rmsnorm, + false); + } + + // write to cache + if (std::is_same::value && intx_k_pc_scale && + intx_v_pc_scale) { + store_paged_kv_cache_wrapper( + xpu_ctx, + place, + const_cast(reinterpret_cast(k.data())), + const_cast(reinterpret_cast(v.data())), + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + const_cast(slot_mapping.data()), + num_blocks, + token_num, + real_kv_num_heads, + head_dim, + block_size, + intx_k_pc_scale, + intx_v_pc_scale, + intx_k_pc_zero, + intx_v_pc_zero); + } else { + float* k_scale_cache_ptr = nullptr; + float* v_scale_cache_ptr = nullptr; + paddle::Tensor k_scale_cache, v_scale_cache; + if (k_cache_scale_inv) { + k_scale_cache = paddle::empty( + {real_kv_num_heads}, paddle::DataType::FLOAT32, place); + k_scale_cache_ptr = const_cast(k_scale_cache.data()); + ret = api::reciprocal( + xpu_ctx, k_cache_scale_inv, k_scale_cache_ptr, real_kv_num_heads); + PD_CHECK(ret == api::SUCCESS, "api::reciprocal failed."); + } + if (v_cache_scale_inv) { + v_scale_cache = paddle::empty( + {real_kv_num_heads}, paddle::DataType::FLOAT32, place); + v_scale_cache_ptr = const_cast(v_scale_cache.data()); + ret = api::reciprocal( + xpu_ctx, v_cache_scale_inv, v_scale_cache_ptr, real_kv_num_heads); + PD_CHECK(ret == api::SUCCESS, "api::reciprocal failed."); + } + store_paged_kv_cache_wrapper( + xpu_ctx, + place, + const_cast(reinterpret_cast(k.data())), + const_cast(reinterpret_cast(v.data())), + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + const_cast(slot_mapping.data()), + num_blocks, + token_num, + real_kv_num_heads, + head_dim, + block_size, + k_scale_cache_ptr, + v_scale_cache_ptr, + nullptr, + nullptr); + } + } else { + if (use_neox_rotary_style) { + ret = infer_ops::split_neox_cache_kv_encoder( + xpu_ctx, + reinterpret_cast(qkv.data()) + qkv_offset, // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + const_cast(reinterpret_cast(q.data())), + const_cast(reinterpret_cast(k.data())), + const_cast(reinterpret_cast(v.data())), + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + seq_lod, // seq_lod + real_batch, // real_batch + start_tokens, // start_tokens + batch_size, // batch_size + 1, // emb_batch_size + rope_max_seqlen, // max_seqlen + q_num_heads, + real_kv_num_heads, + head_dim, + rope_head_dim, + block_batch, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + nullptr, // k_cache_scale_inv - use for per head + nullptr, // v_cache_scale_inv - use for per head + nullptr, // intx_k_pc_scale + nullptr, // intx_v_pc_scale + nullptr, // intx_k_pc_zero + nullptr, // intx_v_pc_zero + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed."); + } else { + ret = infer_ops:: + split_rope_cache_kv_encoder( + xpu_ctx, + reinterpret_cast(qkv.data()) + qkv_offset, // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + const_cast(reinterpret_cast(q.data())), + const_cast(reinterpret_cast(k.data())), + const_cast(reinterpret_cast(v.data())), + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + seq_lod, // seq_lod + real_batch, // real_batch + start_tokens, // start_tokens + batch_size, // batch_size + 1, // emb_batch_size + rope_max_seqlen, // max_seqlen + q_num_heads, + real_kv_num_heads, + head_dim, + block_batch, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + k_cache_scale_inv, // k_cache_scale_inv - use for per head + v_cache_scale_inv, // v_cache_scale_inv - use for per head + intx_k_pc_scale, // intx_k_pc_scale + intx_v_pc_scale, // intx_v_pc_scale + intx_k_pc_zero, // intx_k_pc_zero + intx_v_pc_zero, // intx_v_pc_zero + q_norm_weight, + k_norm_weight, + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_encoder failed."); + } + } +} + +template +void split_kvcache_decoder(api::Context* xpu_ctx, + xftblock::XFTContext& xctx, + const paddle::Tensor& qkv, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& block_tables, + const paddle::Tensor& slot_mapping, + int64_t batch_size, + int64_t token_num, + int64_t q_num_heads, + int64_t kv_num_heads, + int64_t head_dim, + int64_t rope_head_dim, + int64_t hidden_dim, + int64_t rope_max_seqlen, + int64_t block_size, + int64_t num_blocks, + int64_t block_batch, + int64_t max_block_per_seq, + const api::VectorParam& seq_lod, + const api::VectorParam& seq_lod_for_fused, + const api::VectorParam& start_tokens, + const api::VectorParam& real_batch, + int64_t qkv_offset, + const TSCALE* k_cache_scale_inv, + const TSCALE* v_cache_scale_inv, + const TSCALE* k_pc_zero, + const TSCALE* v_pc_zero, + const float* q_norm_weight, + const float* k_norm_weight, + std::string pos_emb_type, + bool rope_3d, + bool b_c8_pc, + bool use_neox_rotary_style) { + int64_t real_kv_num_heads = (kv_num_heads == -1) ? q_num_heads : kv_num_heads; + int ret; + if (FLAGS_decoder_splice) { + // not yet supported + if (rope_3d) { + PD_THROW("split_kvcache_decoder does not support rope_3d == true!"); + } + if (std::is_same::value && + (k_cache_scale_inv == nullptr || v_cache_scale_inv == nullptr)) { + PD_THROW( + "split_kvcache_decoder of kv_cache type int8_t does not " + "support nullptr for k_cache_scale_inv or v_cache_scale_inv!"); + } + + xftblock::DataType KV_BUF_TYPE = std::is_same::value + ? xftblock::DataType::DT_BFLOAT16 + : xftblock::DataType::DT_FLOAT16; + + paddle::Place place = qkv.place(); + + auto q_split = paddle::empty({token_num, hidden_dim}, qkv.type(), place); + auto k_split = paddle::empty( + {token_num, real_kv_num_heads * head_dim}, qkv.type(), place); + xftblock::Tensor qkv_xft_tensor( + const_cast(qkv.data() + qkv_offset * sizeof(TQKV)), + KV_BUF_TYPE, + {token_num, (q_num_heads + 2 * real_kv_num_heads) * head_dim}); + xftblock::Tensor q_xft_tensor( + q_split.data(), KV_BUF_TYPE, {token_num, hidden_dim}); + xftblock::Tensor k_xft_tensor( + k_split.data(), KV_BUF_TYPE, {token_num, real_kv_num_heads * head_dim}); + xftblock::Tensor v_xft_tensor(const_cast(v.data()), + KV_BUF_TYPE, + {token_num, real_kv_num_heads * head_dim}); + + ret = xftblock::split_qkv_block(&xctx, + &qkv_xft_tensor, + &q_xft_tensor, + &k_xft_tensor, + &v_xft_tensor, + token_num, + q_num_heads, + real_kv_num_heads, + head_dim); + PD_CHECK(ret == api::SUCCESS, "split_qkv_block failed."); + + if (!use_neox_rotary_style) { + ret = infer_ops::vsl_rotary_embedding_gptj( + xpu_ctx, + reinterpret_cast(q_split.data()), + reinterpret_cast(k_split.data()), + reinterpret_cast(rotary_embs.data()), + const_cast(reinterpret_cast(q.data())), + const_cast(reinterpret_cast(k.data())), + seq_lod, + 1, + rope_max_seqlen, + q_num_heads, + head_dim, + "BLHD", + start_tokens, + "NORMAL", + real_kv_num_heads, + false); + PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_gptj failed."); + } else { + ret = infer_ops::vsl_rotary_embedding_neox( + xpu_ctx, + reinterpret_cast(q_split.data()), + reinterpret_cast(k_split.data()), + reinterpret_cast(rotary_embs.data()), + const_cast(reinterpret_cast(q.data())), + const_cast(reinterpret_cast(k.data())), + seq_lod, + 1, + rope_max_seqlen, + q_num_heads, + head_dim, + rope_head_dim, + "BLHD", + start_tokens, + "NORMAL", + real_kv_num_heads, + false); + PD_CHECK(ret == api::SUCCESS, "vsl_rotary_embedding_neox failed."); + } + + if (q_norm_weight) { + ret = infer_ops::qkrmsnorm( + xpu_ctx, + reinterpret_cast(q.data()), + q_num_heads * head_dim, + head_dim, + const_cast(reinterpret_cast(q.data())), + q_num_heads * head_dim, + head_dim, + head_dim, + token_num, + q_num_heads, + 1e-5, + q_norm_weight, + nullptr, // not supported yet + false, // not supported yet + FLAGS_use_sdnn_rmsnorm, + false); + } + if (k_norm_weight) { + ret = infer_ops::qkrmsnorm( + xpu_ctx, + reinterpret_cast(k.data()), + real_kv_num_heads * head_dim, + head_dim, + const_cast(reinterpret_cast(k.data())), + real_kv_num_heads * head_dim, + head_dim, + head_dim, + token_num, + real_kv_num_heads, + 1e-5, + k_norm_weight, + nullptr, // not supported yet + false, // not supported yet + FLAGS_use_sdnn_rmsnorm, + false); + } + + // write to cache + float* k_cache_scale_fp32_ptr = nullptr; + float* v_cache_scale_fp32_ptr = nullptr; + paddle::Tensor k_scale_cache, v_scale_cache; + int64_t cache_scale_zero_len = + b_c8_pc ? real_kv_num_heads * head_dim : real_kv_num_heads; + if (k_cache_scale_inv) { + k_scale_cache = paddle::empty( + {cache_scale_zero_len}, paddle::DataType::FLOAT32, place); + k_cache_scale_fp32_ptr = const_cast(k_scale_cache.data()); + ret = api::cast(xpu_ctx, + k_cache_scale_inv, + k_cache_scale_fp32_ptr, + cache_scale_zero_len); + if (!b_c8_pc) { + ret = api::reciprocal(xpu_ctx, + k_cache_scale_fp32_ptr, + k_cache_scale_fp32_ptr, + cache_scale_zero_len); + PD_CHECK(ret == api::SUCCESS, "api::reciprocal failed."); + } + } + if (v_cache_scale_inv) { + v_scale_cache = paddle::empty( + {cache_scale_zero_len}, paddle::DataType::FLOAT32, place); + v_cache_scale_fp32_ptr = const_cast(v_scale_cache.data()); + ret = api::cast(xpu_ctx, + v_cache_scale_inv, + v_cache_scale_fp32_ptr, + cache_scale_zero_len); + if (!b_c8_pc) { + ret = api::reciprocal(xpu_ctx, + v_cache_scale_fp32_ptr, + v_cache_scale_fp32_ptr, + cache_scale_zero_len); + PD_CHECK(ret == api::SUCCESS, "api::reciprocal failed."); + } + } + if (std::is_same::value && b_c8_pc) { + store_paged_kv_cache_wrapper( + xpu_ctx, + place, + const_cast(reinterpret_cast(k.data())), + const_cast(reinterpret_cast(v.data())), + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + const_cast(slot_mapping.data()), + num_blocks, + token_num, + real_kv_num_heads, + head_dim, + block_size, + k_cache_scale_fp32_ptr, + v_cache_scale_fp32_ptr, + k_pc_zero, + v_pc_zero); + } else { + store_paged_kv_cache_wrapper( + xpu_ctx, + place, + const_cast(reinterpret_cast(k.data())), + const_cast(reinterpret_cast(v.data())), + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + const_cast(slot_mapping.data()), + num_blocks, + token_num, + real_kv_num_heads, + head_dim, + block_size, + k_cache_scale_fp32_ptr, + v_cache_scale_fp32_ptr, + nullptr, + nullptr); + } + } else { + if (use_neox_rotary_style) { + ret = infer_ops:: + split_neox_cache_kv_decoder( + xpu_ctx, + reinterpret_cast(qkv.data()) + qkv_offset, // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + const_cast(reinterpret_cast(q.data())), + nullptr, + nullptr, + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + seq_lod_for_fused, // seq_lod + real_batch, // real_batch + batch_size, // batch_size + 1, // emb_batch_size = rotary_embs.dims()[1] = 1 + rope_max_seqlen, // max_seqlen + q_num_heads, + real_kv_num_heads, + head_dim, + rope_head_dim, + block_batch, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + k_cache_scale_inv, // k_cache_scale_inv + v_cache_scale_inv, // v_cache_scale_inv + k_pc_zero, // k_cache_zp + v_pc_zero, // v_cache_zp + rope_3d); + } else { + ret = infer_ops:: + split_rope_cache_kv_decoder( + xpu_ctx, + reinterpret_cast(qkv.data()) + qkv_offset, // qkv + reinterpret_cast( + rotary_embs.data()), // rotary_pos_emb + reinterpret_cast( + block_tables.data()), // block_table + const_cast(reinterpret_cast(q.data())), + nullptr, + nullptr, + const_cast( + reinterpret_cast(key_cache.data())), + const_cast( + reinterpret_cast(value_cache.data())), + seq_lod_for_fused, // seq_lod + real_batch, // real_batch + batch_size, // batch_size + 1, // emb_batch_size = rotary_embs.dims()[1] = 1 + rope_max_seqlen, // max_seqlen + q_num_heads, + real_kv_num_heads, + head_dim, + block_batch, + block_size, + max_block_per_seq, + "BLHD", + "HLD", + pos_emb_type, + k_cache_scale_inv, // k_cache_scale_inv + v_cache_scale_inv, // v_cache_scale_inv + k_pc_zero, // k_cache_zp + v_pc_zero, // v_cache_zp + q_norm_weight, + k_norm_weight, + b_c8_pc, // bool b_c8_pc + rope_3d); + PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed."); + } + } +} + +/** + * qkv shape: [token_num, (num_heads + 2 * kv_num_heads) * head_dim] + * k_scales/v_scales value: 127 / max (type = TS) + * k_scales_inv/v_scales_inv value: + * 1. perchannel with zp: max / 127 (type = TS) + * 2. perchannel without zp: max (type = float) + **/ +template +std::vector SplitEmbeddingKVCache( + api::Context* xpu_ctx, + xftblock::XFTContext& xctx, + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& block_tables, + const paddle::Tensor& len_info_cpu, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, + const paddle::Tensor& encoder_kv_lod_cpu, + const paddle::Tensor& encoder_batch_map_cpu, + const paddle::Tensor& decoder_context_len_cpu, + const paddle::Tensor& decoder_context_len_cache_cpu, + const paddle::Tensor& decoder_batch_map_cpu, + const paddle::Tensor& prefix_len_cpu, + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, + const paddle::Tensor& encoder_kv_lod, + const paddle::Tensor& encoder_batch_map, + const paddle::Tensor& decoder_context_len, + const paddle::Tensor& decoder_context_len_cache, + const paddle::Tensor& decoder_batch_map, + const paddle::Tensor& prefix_len, + const paddle::Tensor& slot_mapping_enc, + const paddle::Tensor& slot_mapping_dec, + const paddle::optional& k_scales, + const paddle::optional& v_scales, + const paddle::optional& k_scales_inv, + const paddle::optional& k_zeros, + const paddle::optional& v_zeros, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& kv_signal_data_cpu, + const paddle::optional& cachekv_signal_thread_cpu, + const bool use_neox_rotary_style, + const bool rope_3d) { + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + using XPU_XType = typename XPUTypeTrait::Type; + using XPU_CType = typename XPUTypeTrait::Type; + using XPU_SType = typename XPUTypeTrait::Type; + using E_Scale = typename SplitRopeTypeTrait::E_Scale; + using D_Scale = typename SplitRopeTypeTrait::D_Scale; + typedef TX data_t; + typedef TC cdata_t; + typedef TS sdata_t; + xftblock::DataType KV_BUF_TYPE = std::is_same::value + ? xftblock::DataType::DT_BFLOAT16 + : xftblock::DataType::DT_FLOAT16; + auto qkv_shape = qkv.dims(); + auto cache_shape = key_cache.dims(); + auto block_table_shape = block_tables.dims(); + const int block_batch = block_table_shape[0]; + const int max_block_per_seq = block_table_shape[1]; + const int num_blocks = cache_shape[0]; + const int kv_num_heads = cache_shape[1]; + const int block_size = cache_shape[2]; + const int head_dim = cache_shape[3]; + const int max_seq_len = block_size * max_block_per_seq; + + const int token_num = qkv_shape[0]; + const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim; + const int num_heads = total_num_head - 2 * kv_num_heads; + const int hidden_dim = num_heads * head_dim; + + int enc_batch = len_info_cpu.data()[0]; + int dec_batch = len_info_cpu.data()[1]; + int total_enc_len = len_info_cpu.data()[2]; + int total_dec_len = token_num - total_enc_len; + int max_enc_len = len_info_cpu.data()[3]; + int max_kv_len = len_info_cpu.data()[4]; + int prefix_block_num_per_seq = len_info_cpu.data()[5]; + + int rope_max_seqlen = 0; + int rope_head_dim = 0; + if (rope_3d) { + PD_CHECK(rotary_embs.dims().size() == 6, + "rotary_embs dim size should be 6 in multi-modal model"); + rope_max_seqlen = rotary_embs.dims()[3]; + rope_head_dim = rotary_embs.dims()[5]; + } else { + PD_CHECK(rotary_embs.dims().size() == 5, + "rotary_embs dim size should be 5 in language model"); + rope_max_seqlen = rotary_embs.dims()[2]; + rope_head_dim = rotary_embs.dims()[4]; + } + std::string pos_emb_type; + if (use_neox_rotary_style) { + pos_emb_type = "NEOX"; + } else if (rope_head_dim == head_dim / 2) { + // vl model use this + pos_emb_type = "HALF_HEAD_DIM"; + } else { + pos_emb_type = "NORMAL"; + } + + // TODO(lizanz03): only support c8 zp per channel + bool is_cache_int8 = std::is_same::value; + bool has_zp = k_zeros && v_zeros; + XPU_SType *quant_k_scale{nullptr}, *quant_v_scale{nullptr}, + *quant_k_scale_inv_zp{nullptr}, *quant_k_zp{nullptr}, + *quant_v_zp{nullptr}; + // maxptr for xfa + float* quant_v_scale_inv{nullptr}; + if (is_cache_int8) { + // only support c8 per channel + quant_k_scale = reinterpret_cast( + const_cast(k_scales.get().data())); + quant_v_scale = reinterpret_cast( + const_cast(v_scales.get().data())); + if (has_zp) { + quant_k_scale_inv_zp = reinterpret_cast( + const_cast(k_scales_inv.get().data())); + quant_k_zp = reinterpret_cast( + const_cast(k_zeros.get().data())); + quant_v_zp = reinterpret_cast( + const_cast(v_zeros.get().data())); + } + } + const float *q_norm_weight_data{nullptr}, *k_norm_weight_data{nullptr}; + if (q_norm_weight) { + q_norm_weight_data = q_norm_weight.get().data(); + } + if (k_norm_weight) { + k_norm_weight_data = k_norm_weight.get().data(); + } + PD_CHECK(!(pos_emb_type == "NEOX" && q_norm_weight_data != nullptr), + "split_neox_cache_kv_encoder not support q/k norm weight"); + + int ret; + auto q_enc_tensor = + paddle::empty({total_enc_len, hidden_dim}, qkv.type(), qkv.place()); + auto k_enc_tensor = paddle::empty( + {total_enc_len, kv_num_heads * head_dim}, qkv.type(), qkv.place()); + auto v_enc_tensor = paddle::empty( + {total_enc_len, kv_num_heads * head_dim}, qkv.type(), qkv.place()); + auto q_dec_tensor = + paddle::empty({total_dec_len, hidden_dim}, qkv.type(), qkv.place()); + auto k_dec_tensor = paddle::empty( + {total_dec_len, kv_num_heads * head_dim}, qkv.type(), qkv.place()); + auto v_dec_tensor = paddle::empty( + {total_dec_len, kv_num_heads * head_dim}, qkv.type(), qkv.place()); + + if (enc_batch > 0) { + xftblock::Tensor q_enc_xft_tensor( + q_enc_tensor.data(), KV_BUF_TYPE, {total_enc_len, hidden_dim}); + xftblock::Tensor k_enc_xft_tensor(k_enc_tensor.data(), + KV_BUF_TYPE, + {total_enc_len, kv_num_heads * head_dim}); + xftblock::Tensor v_enc_xft_tensor(v_enc_tensor.data(), + KV_BUF_TYPE, + {total_enc_len, kv_num_heads * head_dim}); + + api::VectorParam seqlod_vp = { + const_cast(encoder_seq_lod_cpu.data()), + enc_batch + 1, + const_cast(encoder_seq_lod.data())}; + api::VectorParam real_batch_vp = { + const_cast(encoder_batch_map_cpu.data()), + enc_batch, + const_cast(encoder_batch_map.data())}; // real batch + baidu::xpu::api::VectorParam prefix_lens_vp{ + const_cast(prefix_len_cpu.data()), + enc_batch, + const_cast(prefix_len.data())}; + + // split, rotary embedding and write to kv cache + split_kvcache_encoder( + xpu_ctx, + xctx, + qkv, + rotary_embs, + q_enc_tensor, + k_enc_tensor, + v_enc_tensor, + key_cache, + value_cache, + block_tables, + slot_mapping_enc, + enc_batch, + total_enc_len, + num_heads, + kv_num_heads, + head_dim, + rope_head_dim, + hidden_dim, + rope_max_seqlen, + block_size, + num_blocks, + block_batch, + max_block_per_seq, + seqlod_vp, + prefix_lens_vp, + real_batch_vp, + 0, + nullptr, // k_cache_scale_inv - use for per head + nullptr, // v_cache_scale_inv - use for per head + quant_k_scale, // intx_k_pc_scale + quant_v_scale, // intx_v_pc_scale + quant_k_zp, // intx_k_pc_zero + quant_v_zp, // intx_v_pc_zero + q_norm_weight_data, + k_norm_weight_data, + pos_emb_type, + rope_3d, + use_neox_rotary_style); + + // pd split + if (FLAGS_fmt_write_cache_completed_signal) { + XPUEvent write_event = nullptr; + ret = xpu_event_create(&write_event); + PD_CHECK(ret == 0, "xpu_event_create write_event failed."); + + ret = xpu_event_record(write_event, xctx.get_main_stream()); + PD_CHECK(ret == 0, "xpu_event_record failed."); + + PD_CHECK(cachekv_signal_thread_cpu, + "cachekv_signal_thread should not be nullptr"); + auto worker = reinterpret_cast( + cachekv_signal_thread_cpu.get().data()[0]); + PD_CHECK(worker != nullptr, + "cachekv_signal_thread should not be nullptr"); + + if (FLAGS_use_pd_disaggregation_per_chunk) { + worker->push_signal_task_per_query(write_event, nullptr); + } else { + // If use micro batch: + // micro_batch_0 do nothing. + // micro_batch_1 write kv signal. + if (kv_signal_data_cpu) { + worker->push_signal_task( + write_event, + reinterpret_cast((const_cast( + kv_signal_data_cpu.get().data())))); + } + } + } + + bool is_prefix_cache = prefix_block_num_per_seq > 0; + if (is_cache_int8 && has_zp && is_prefix_cache) { + // assume q_layout is BLHD, q = q * k_scales_inv + ret = api::broadcast_mul( + xpu_ctx, + q_enc_xft_tensor.data(), + quant_k_scale_inv_zp, + q_enc_xft_tensor.data(), + {total_enc_len, kv_num_heads, num_heads / kv_num_heads, head_dim}, + {1, kv_num_heads, 1, head_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed."); + } + } + + if (dec_batch > 0) { + xftblock::Tensor q_dec_xft_tensor( + q_dec_tensor.data(), KV_BUF_TYPE, {total_dec_len, hidden_dim}); + xftblock::Tensor k_dec_xft_tensor(k_dec_tensor.data(), + KV_BUF_TYPE, + {total_dec_len, kv_num_heads * head_dim}); + xftblock::Tensor v_dec_xft_tensor(v_dec_tensor.data(), + KV_BUF_TYPE, + {total_dec_len, kv_num_heads * head_dim}); + + api::VectorParam decoder_context_len_cache_vp = { + const_cast(decoder_context_len_cache_cpu.data()), + dec_batch, + const_cast(decoder_context_len_cache.data())}; + api::VectorParam real_batch_vp = { + const_cast(decoder_batch_map_cpu.data()), + dec_batch, + const_cast(decoder_batch_map.data())}; + api::VectorParam seqlod_vp = { + const_cast(decoder_seq_lod_cpu.data()), + dec_batch + 1, + const_cast(decoder_seq_lod.data())}; + api::VectorParam seqlod_for_fused_vp = { + const_cast(decoder_context_len_cpu.data()), + dec_batch, + const_cast(decoder_context_len.data())}; + + // split, rotary embedding and write to kv cache + if (total_dec_len != dec_batch) { + // mtp branch + split_kvcache_encoder( + xpu_ctx, + xctx, + qkv, + rotary_embs, + q_dec_tensor, + k_dec_tensor, + v_dec_tensor, + key_cache, + value_cache, + block_tables, + slot_mapping_dec, + dec_batch, + total_dec_len, + num_heads, + kv_num_heads, + head_dim, + rope_head_dim, + hidden_dim, + rope_max_seqlen, + block_size, + num_blocks, + block_batch, + max_block_per_seq, + seqlod_vp, // seq_lod + decoder_context_len_cache_vp, // start_tokens (prefix len) + real_batch_vp, // real_batch + total_enc_len * qkv_shape[qkv_shape.size() - 1], + nullptr, // k_cache_scale_inv - use for per head + nullptr, // v_cache_scale_inv - use for per head + quant_k_scale, // intx_k_pc_scale + quant_v_scale, // intx_v_pc_scale + quant_k_zp, // intx_k_pc_zero + quant_v_zp, // intx_v_pc_zero + q_norm_weight_data, + k_norm_weight_data, + pos_emb_type, + rope_3d, + use_neox_rotary_style); + } else { + // non mtp branch + split_kvcache_decoder( + xpu_ctx, + xctx, + qkv, + rotary_embs, + q_dec_tensor, + k_dec_tensor, + v_dec_tensor, + key_cache, + value_cache, + block_tables, + slot_mapping_dec, + dec_batch, + total_dec_len, + num_heads, + kv_num_heads, + head_dim, + rope_head_dim, + hidden_dim, + rope_max_seqlen, + block_size, + num_blocks, + block_batch, + max_block_per_seq, + seqlod_vp, + seqlod_for_fused_vp, + decoder_context_len_cache_vp, + real_batch_vp, + total_enc_len * qkv_shape[qkv_shape.size() - 1], + reinterpret_cast(quant_k_scale), // k_cache_scale_inv + reinterpret_cast(quant_v_scale), // v_cache_scale_inv + reinterpret_cast(quant_k_zp), // k_cache_zp + reinterpret_cast(quant_v_zp), // v_cache_zp + q_norm_weight_data, + k_norm_weight_data, + pos_emb_type, + rope_3d, + is_cache_int8, // bool b_c8_pc + use_neox_rotary_style); + } + + if (is_cache_int8 && has_zp) { + // q = q * k_scales_inv + ret = api::broadcast_mul( + xpu_ctx, + q_dec_xft_tensor.data(), + quant_k_scale_inv_zp, + q_dec_xft_tensor.data(), + {total_dec_len, kv_num_heads, num_heads / kv_num_heads, head_dim}, + {1, kv_num_heads, 1, head_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed."); + } + } + return {q_enc_tensor, k_enc_tensor, v_enc_tensor, q_dec_tensor}; +} + +template +std::vector BlockAttn( + api::Context* xpu_ctx, + xftblock::XFTContext& xctx, + const paddle::Tensor& q_enc_tensor, + const paddle::Tensor& k_enc_tensor, + const paddle::Tensor& v_enc_tensor, + const paddle::Tensor& q_dec_tensor, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& block_tables, + const paddle::Tensor& prefix_block_tables, + const paddle::Tensor& len_info_cpu, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, + const paddle::Tensor& encoder_kv_lod_cpu, + const paddle::Tensor& encoder_batch_map_cpu, + const paddle::Tensor& decoder_context_len_cpu, + const paddle::Tensor& decoder_context_len_cache_cpu, + const paddle::Tensor& decoder_batch_map_cpu, + const paddle::Tensor& prefix_len_cpu, + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, + const paddle::Tensor& encoder_kv_lod, + const paddle::Tensor& encoder_batch_map, + const paddle::Tensor& decoder_context_len, + const paddle::Tensor& decoder_context_len_cache, + const paddle::Tensor& decoder_batch_map, + const paddle::Tensor& prefix_len, + const paddle::optional& k_scales_inv, + const paddle::optional& v_scales_inv, + const paddle::optional& k_zeros, + const paddle::optional& v_zeros, + const paddle::optional& shift, + const paddle::optional& smooth) { + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + using XPU_XType = typename XPUTypeTrait::Type; + using XPU_CType = typename XPUTypeTrait::Type; + using XPU_SType = typename XPUTypeTrait::Type; + using E_Scale = typename SplitRopeTypeTrait::E_Scale; + using D_Scale = typename SplitRopeTypeTrait::D_Scale; + typedef TX data_t; + typedef TC cdata_t; + typedef TS sdata_t; + xftblock::DataType KV_BUF_TYPE = std::is_same::value + ? xftblock::DataType::DT_BFLOAT16 + : xftblock::DataType::DT_FLOAT16; + + auto cache_shape = key_cache.dims(); + auto block_table_shape = block_tables.dims(); + const int block_batch = block_table_shape[0]; + const int max_block_per_seq = block_table_shape[1]; + const int kv_num_heads = cache_shape[1]; + const int block_size = cache_shape[2]; + const int head_dim = cache_shape[3]; + const int max_seq_len = block_size * max_block_per_seq; + + const int token_num = q_enc_tensor.dims()[0] + q_dec_tensor.dims()[0]; + const int hidden_dim = q_enc_tensor.dims()[q_enc_tensor.dims().size() - 1]; + const int num_heads = hidden_dim / head_dim; + const int total_num_head = num_heads + 2 * kv_num_heads; + + int enc_batch = len_info_cpu.data()[0]; + int dec_batch = len_info_cpu.data()[1]; + int total_enc_len = len_info_cpu.data()[2]; + int total_dec_len = token_num - total_enc_len; + int max_enc_len = len_info_cpu.data()[3]; + int max_kv_len = len_info_cpu.data()[4]; + int prefix_block_num_per_seq = len_info_cpu.data()[5]; + int max_dec_len = len_info_cpu.data()[6]; + + auto block_attn_out = paddle::empty( + {token_num, hidden_dim}, q_enc_tensor.type(), q_enc_tensor.place()); + + // TODO(lizanz03): only support c8 zp per channel + bool is_cache_int8 = std::is_same::value; + bool has_zp = k_zeros && v_zeros; + XPU_SType *quant_v_scale_inv_zp{nullptr}, *quant_v_zp{nullptr}; + // maxptr for xfa + float *quant_k_scale_inv{nullptr}, *quant_v_scale_inv{nullptr}; + XPU_XType *p_shift{nullptr}, *p_smooth{nullptr}; + if (is_cache_int8) { + if (shift) { + p_shift = reinterpret_cast( + const_cast(shift.get().data())); + } + if (smooth) { + p_smooth = reinterpret_cast( + const_cast(smooth.get().data())); + } + if (has_zp) { + quant_v_scale_inv_zp = reinterpret_cast( + const_cast(v_scales_inv.get().data())); + quant_v_zp = reinterpret_cast( + const_cast(v_zeros.get().data())); + } else { + quant_k_scale_inv = reinterpret_cast( + const_cast(k_scales_inv.get().data())); + quant_v_scale_inv = reinterpret_cast( + const_cast(v_scales_inv.get().data())); + } + } + + int ret; + + if (enc_batch > 0) { + xftblock::TransformerParam param; + xftblock::TransformerVsl vsl; + param.batch_size = enc_batch; + param.head_num = num_heads; + param.kv_head_num = kv_num_heads; + param.head_dim = head_dim; + param.max_batch_size = block_batch; + param.max_seq_len = max_seq_len; + param.use_cache_per_channel = + is_cache_int8 && !has_zp; // only support c8 per channel + + vsl.usual_lod_vp = { + const_cast(encoder_seq_lod_cpu.data()), + enc_batch + 1, + const_cast(encoder_seq_lod.data())}; + vsl.kv_lod_vp = {const_cast(encoder_kv_lod_cpu.data()), + enc_batch + 1, + const_cast(encoder_kv_lod.data())}; + vsl.slot_mapping_vp = { + const_cast(encoder_batch_map_cpu.data()), + enc_batch, + const_cast(encoder_batch_map.data())}; // real batch + param.max_valid_seqlen = max_enc_len; + param.max_kv_valid_seqlen = max_kv_len; + // setting for prefix cache + bool is_prefix_cache = prefix_block_num_per_seq > 0; + param.prefill_len = is_prefix_cache ? param.max_valid_seqlen : -1; + param.page_attn.block_size = block_size; + param.page_attn.max_num_blocks_per_seq = prefix_block_num_per_seq; + // prefix_block_tables is a subset of block_tables, which is used for + // prefix cache + xftblock::Tensor prefix_block_tables_tensor( + is_prefix_cache ? reinterpret_cast(const_cast( + prefix_block_tables.data())) + : nullptr, + xftblock::DataType::DT_INT32, + {prefix_block_tables.dims()[0], prefix_block_num_per_seq}); + param.page_attn.block_table = &prefix_block_tables_tensor; + baidu::xpu::api::VectorParam prefix_lens_vp{ + const_cast(prefix_len_cpu.data()), + enc_batch, + const_cast(prefix_len.data())}; + + float* fake_perhead_scale = nullptr; + if (is_cache_int8 && has_zp && is_prefix_cache) { + fake_perhead_scale = RAII_GUARD.alloc(param.kv_head_num); + // set fake_perhead_scale to ones + ret = api::constant( + xpu_ctx, fake_perhead_scale, param.kv_head_num, 127.f); + PD_CHECK(ret == api::SUCCESS, "api::constant failed."); + } + // buf tensor + xftblock::Tensor q_enc_xft_tensor(const_cast(q_enc_tensor.data()), + KV_BUF_TYPE, + {total_enc_len, hidden_dim}); + xftblock::Tensor k_enc_xft_tensor(const_cast(k_enc_tensor.data()), + KV_BUF_TYPE, + {total_enc_len, kv_num_heads * head_dim}); + xftblock::Tensor v_enc_xft_tensor(const_cast(v_enc_tensor.data()), + KV_BUF_TYPE, + {total_enc_len, kv_num_heads * head_dim}); + + // kv cache tensor + xftblock::Tensor key_cache_tensor( + reinterpret_cast( + const_cast(key_cache.data())), // src_data + nullptr, // max_data + has_zp // pc_scale + ? fake_perhead_scale + : quant_k_scale_inv, + is_cache_int8 // cache type + ? xftblock::DataType::DT_INT8 + : KV_BUF_TYPE, + {cache_shape[0], cache_shape[1], cache_shape[2], cache_shape[3]}); + xftblock::Tensor value_cache_tensor( + reinterpret_cast( + const_cast(value_cache.data())), // src_data + nullptr, // max_data + has_zp // pc_scale + ? fake_perhead_scale + : quant_v_scale_inv, + is_cache_int8 // cache type + ? xftblock::DataType::DT_INT8 + : KV_BUF_TYPE, + {cache_shape[0], cache_shape[1], cache_shape[2], cache_shape[3]}); + + xftblock::Tensor encode_output(reinterpret_cast(const_cast( + block_attn_out.data())), + KV_BUF_TYPE, + {total_enc_len, hidden_dim}); + + // attn encode + if (is_prefix_cache) { + ret = + xftblock::xft_context_core_attenion_block(&xctx, + &q_enc_xft_tensor, + &key_cache_tensor, + &value_cache_tensor, + &encode_output, + param, + vsl); + } else { + ret = xftblock::xft_context_core_attenion_block(&xctx, + &q_enc_xft_tensor, + &k_enc_xft_tensor, + &v_enc_xft_tensor, + &encode_output, + param, + vsl); + } + PD_CHECK(ret == api::SUCCESS, + "xftblock::xft_context_core_attenion_block failed."); + + if (is_cache_int8 && has_zp && is_prefix_cache) { + int64_t q_head_num = param.head_num; + int64_t kv_head_num = param.kv_head_num; + // out = (out - v_zeros) * v_scales_inv + ret = api::broadcast_sub(xpu_ctx, + encode_output.data(), + quant_v_zp, + encode_output.data(), + {total_enc_len, + kv_head_num, + q_head_num / kv_head_num, + param.head_dim}, + {1, kv_head_num, 1, param.head_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_sub failed."); + ret = api::broadcast_mul(xpu_ctx, + encode_output.data(), + quant_v_scale_inv_zp, + encode_output.data(), + {total_enc_len, + kv_head_num, + q_head_num / kv_head_num, + param.head_dim}, + {1, kv_head_num, 1, param.head_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed."); + } + if (p_shift != nullptr) { + ret = api::broadcast_add(xpu_ctx, + p_shift, + encode_output.data(), + encode_output.data(), + {1, hidden_dim}, + {total_enc_len, hidden_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_add for shift failed."); + } + if (p_smooth != nullptr) { + ret = api::broadcast_mul(xpu_ctx, + p_smooth, + encode_output.data(), + encode_output.data(), + {1, hidden_dim}, + {total_enc_len, hidden_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul for smooth failed."); + } + } + + if (dec_batch > 0) { + xftblock::TransformerParam param; + xftblock::TransformerVsl vsl; + param.batch_size = dec_batch; + param.head_num = num_heads; + param.kv_head_num = kv_num_heads; + param.head_dim = head_dim; + param.max_batch_size = block_batch; + param.max_seq_len = max_seq_len; + param.use_page_attn = true; + xftblock::Tensor decode_output( + reinterpret_cast( + const_cast(block_attn_out.data()) + + total_enc_len * hidden_dim), + KV_BUF_TYPE, + {total_dec_len, hidden_dim}); + // buf tensor + xftblock::Tensor q_dec_xft_tensor(const_cast(q_dec_tensor.data()), + KV_BUF_TYPE, + {total_dec_len, hidden_dim}); + + float* fake_perhead_scale = nullptr; + if (is_cache_int8 && has_zp) { + int64_t kv_head_num = param.kv_head_num; + fake_perhead_scale = RAII_GUARD.alloc(kv_head_num); + // set fake_perhead_scale to ones + ret = + api::constant(xpu_ctx, fake_perhead_scale, kv_head_num, 127.f); + PD_CHECK(ret == api::SUCCESS, "api::constant failed."); + } + + if (total_dec_len != dec_batch) { + api::VectorParam decoder_context_len_vp = { + const_cast(decoder_context_len_cpu.data()), + dec_batch, + const_cast( + decoder_context_len + .data())}; // use for speculative_attention_decoder + // seq_len in MTP + api::VectorParam decoder_batch_map_vp = { + const_cast(decoder_batch_map_cpu.data()), + dec_batch, + const_cast( + decoder_batch_map.data())}; // real batch + api::VectorParam decoder_seq_lod_vp = { + const_cast(decoder_seq_lod_cpu.data()), + dec_batch + 1, + const_cast( + decoder_seq_lod + .data())}; // use for split rope enc as lod in MTP + + XPU_XType* q_dec_xft_tensor_ptr = q_dec_xft_tensor.data(); + XPU_XType* decode_output_ptr = decode_output.data(); + using TGEMM = std::conditional_t, + tfloat32, + int8_wo_t>; + constexpr int quant_mode = std::is_same_v ? 3 : 0; + ret = baidu::xpu::xfa::speculative_attention_decoder( + xpu_ctx, + decode_output_ptr, // out + q_dec_xft_tensor_ptr, // q + nullptr, // k + nullptr, // v + reinterpret_cast( + key_cache.data()), // k_cache + reinterpret_cast( + value_cache.data()), // v_cache + reinterpret_cast( + block_tables.data()), // block_tables + decoder_context_len_vp, // seq_lengths + decoder_batch_map_vp, // valid_batch + param.max_batch_size, // batch_num + max_dec_len, // qlen + max_seq_len, // max_seq_len + param.head_num, // head_num + param.head_dim, // head_dim + param.kv_head_num, // kv_head_num + nullptr, // attn_mask + 1.0f / + std::sqrt(static_cast(param.head_dim)), // scale 【check】 + block_size, // block_size + max_block_per_seq, // max_blocks_per_seq + -1, // max_window_size + nullptr, // q_maxptr + has_zp // k_cache_maxptr + ? fake_perhead_scale + : quant_k_scale_inv, + has_zp // v_cache_maxptr + ? fake_perhead_scale + : quant_v_scale_inv, + nullptr, // o_maxptr + param.head_dim, // vo_head_dim + decoder_seq_lod_vp); // qlod + PD_CHECK(ret == api::SUCCESS, + "xfa::speculative_attention_decoder failed."); + } else { + vsl.usual_lod_vp = { + const_cast(decoder_context_len_cpu.data()), + dec_batch, + const_cast(decoder_context_len.data())}; + vsl.slot_mapping_vp = { + const_cast(decoder_batch_map_cpu.data()), + dec_batch, + const_cast( + decoder_batch_map.data())}; // real batch + // can not set to nullptr and 0, which will cause inference interrupt + // vsl.slot_mapping_vp = {nullptr, 0, nullptr}; // real batch + + xftblock::Tensor block_table_tensor( + reinterpret_cast( + const_cast(block_tables.data())), + xftblock::DataType::DT_INT32, + {block_table_shape[0], block_table_shape[1]}); + + // normal setting + param.use_cache_per_channel = + is_cache_int8 && !has_zp; // only support c8 per channel + param.prefill_len = -1; + param.page_attn.block_size = block_size; + param.page_attn.max_context_len = max_seq_len; + param.page_attn.max_num_blocks_per_seq = max_block_per_seq; + param.page_attn.block_table = &block_table_tensor; + + // kv cache tensor + xftblock::Tensor key_cache_tensor( + reinterpret_cast( + const_cast(key_cache.data())), // src_data + nullptr, // max_data + has_zp // pc_scale + ? fake_perhead_scale + : quant_k_scale_inv, + is_cache_int8 // cache type + ? xftblock::DataType::DT_INT8 + : KV_BUF_TYPE, + {cache_shape[0], cache_shape[1], cache_shape[2], cache_shape[3]}); + xftblock::Tensor value_cache_tensor( + reinterpret_cast( + const_cast(value_cache.data())), // src_data + nullptr, // max_data + has_zp // pc_scale + ? fake_perhead_scale + : quant_v_scale_inv, + is_cache_int8 // cache type + ? xftblock::DataType::DT_INT8 + : KV_BUF_TYPE, + {cache_shape[0], cache_shape[1], cache_shape[2], cache_shape[3]}); + + // attn decode + ret = xftblock::xft_decoder_core_attenion_block< + XPU_XType, + XPU_CType, + XPU_XType>( // TGEMM = XPU_XType TODOlizan03: used high + // precision + &xctx, + &q_dec_xft_tensor, + &key_cache_tensor, + &value_cache_tensor, + &decode_output, + param, + vsl); + PD_CHECK(ret == api::SUCCESS, + "xftblock::xft_decoder_core_attenion_block failed."); + } + if (is_cache_int8 && has_zp) { + int64_t q_head_num = param.head_num; + int64_t kv_head_num = param.kv_head_num; + // out = (out - v_zeros) * v_scales_inv + if (quant_v_zp) { + ret = + api::broadcast_sub(xpu_ctx, + decode_output.data(), + quant_v_zp, + decode_output.data(), + {total_dec_len, + kv_head_num, + q_head_num / kv_head_num, + param.head_dim}, + {1, kv_head_num, 1, param.head_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_sub failed."); + } + ret = api::broadcast_mul(xpu_ctx, + decode_output.data(), + quant_v_scale_inv_zp, + decode_output.data(), + {total_dec_len, + kv_head_num, + q_head_num / kv_head_num, + param.head_dim}, + {1, kv_head_num, 1, param.head_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed."); + } + if (p_shift != nullptr) { + ret = api::broadcast_add(xpu_ctx, + p_shift, + decode_output.data(), + decode_output.data(), + {1, hidden_dim}, + {total_dec_len, hidden_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_add for shift failed."); + } + if (p_smooth != nullptr) { + ret = api::broadcast_mul(xpu_ctx, + p_smooth, + decode_output.data(), + decode_output.data(), + {1, hidden_dim}, + {total_dec_len, hidden_dim}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul for smooth failed."); + } + } + return {block_attn_out}; +} + +std::vector SplitEmbeddingKVCacheBlockAttn( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& block_tables, + const paddle::Tensor& prefix_block_tables, + const paddle::Tensor& len_info_cpu, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, + const paddle::Tensor& encoder_kv_lod_cpu, + const paddle::Tensor& encoder_batch_map_cpu, + const paddle::Tensor& decoder_context_len_cpu, + const paddle::Tensor& decoder_context_len_cache_cpu, + const paddle::Tensor& decoder_batch_map_cpu, + const paddle::Tensor& prefix_len_cpu, + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, + const paddle::Tensor& encoder_kv_lod, + const paddle::Tensor& encoder_batch_map, + const paddle::Tensor& decoder_context_len, + const paddle::Tensor& decoder_context_len_cache, + const paddle::Tensor& decoder_batch_map, + const paddle::Tensor& prefix_len, + const paddle::Tensor& slot_mapping_enc, + const paddle::Tensor& slot_mapping_dec, + const paddle::optional& k_scales, + const paddle::optional& v_scales, + const paddle::optional& k_scales_inv, + const paddle::optional& v_scales_inv, + const paddle::optional& k_zeros, + const paddle::optional& v_zeros, + const paddle::optional& shift, + const paddle::optional& smooth, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& kv_signal_data_cpu, + const paddle::optional& cachekv_signal_thread_cpu, + const bool use_neox_rotary_style, + const bool rope_3d = false) { +#define APPLY_SPLITKVCACHE(TX, TC, TS) \ + std::vector split_qkv = \ + SplitEmbeddingKVCache(xpu_ctx, \ + xctx, \ + qkv, \ + key_cache, \ + value_cache, \ + rotary_embs, \ + block_tables, \ + len_info_cpu, \ + encoder_seq_lod_cpu, \ + decoder_seq_lod_cpu, \ + encoder_kv_lod_cpu, \ + encoder_batch_map_cpu, \ + decoder_context_len_cpu, \ + decoder_context_len_cache_cpu, \ + decoder_batch_map_cpu, \ + prefix_len_cpu, \ + encoder_seq_lod, \ + decoder_seq_lod, \ + encoder_kv_lod, \ + encoder_batch_map, \ + decoder_context_len, \ + decoder_context_len_cache, \ + decoder_batch_map, \ + prefix_len, \ + slot_mapping_enc, \ + slot_mapping_dec, \ + k_scales, \ + v_scales, \ + k_scales_inv, \ + k_zeros, \ + v_zeros, \ + q_norm_weight, \ + k_norm_weight, \ + kv_signal_data_cpu, \ + cachekv_signal_thread_cpu, \ + use_neox_rotary_style, \ + rope_3d); +#define APPLY_BLOCKATTN(TX, TC, TS) \ + return BlockAttn(xpu_ctx, \ + xctx, \ + split_qkv[0], \ + split_qkv[1], \ + split_qkv[2], \ + split_qkv[3], \ + key_cache, \ + value_cache, \ + rotary_embs, \ + block_tables, \ + prefix_block_tables, \ + len_info_cpu, \ + encoder_seq_lod_cpu, \ + decoder_seq_lod_cpu, \ + encoder_kv_lod_cpu, \ + encoder_batch_map_cpu, \ + decoder_context_len_cpu, \ + decoder_context_len_cache_cpu, \ + decoder_batch_map_cpu, \ + prefix_len_cpu, \ + encoder_seq_lod, \ + decoder_seq_lod, \ + encoder_kv_lod, \ + encoder_batch_map, \ + decoder_context_len, \ + decoder_context_len_cache, \ + decoder_batch_map, \ + prefix_len, \ + k_scales_inv, \ + v_scales_inv, \ + k_zeros, \ + v_zeros, \ + shift, \ + smooth); + + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx)->x_context(); + xftblock::XFTContext xctx(xpu_ctx, nullptr); + const auto cache_dtype = key_cache.dtype(); + + if (cache_dtype == paddle::DataType::BFLOAT16) { + APPLY_SPLITKVCACHE(paddle::bfloat16, paddle::bfloat16, paddle::bfloat16); + APPLY_BLOCKATTN(paddle::bfloat16, paddle::bfloat16, paddle::bfloat16); + } else if (cache_dtype == paddle::DataType::INT8) { + APPLY_SPLITKVCACHE(paddle::bfloat16, int8_t, paddle::bfloat16); + APPLY_BLOCKATTN(paddle::bfloat16, int8_t, paddle::bfloat16); + } else { + PD_THROW("block_attn not support cache_dtype==%d", + static_cast(cache_dtype)); + return {}; + } + +#undef APPLY_SPLITKVCACHE +#undef APPLY_BLOCKATTN +} + +std::vector> SplitEmbeddingKVCacheBlockAttnInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape) { + const int token_num = qkv_shape[0]; + const int kv_num_heads = key_cache_shape[1]; + int head_dim = key_cache_shape[3]; + // if (cache_quant_type_str == "cache_int4_zp") { + // head_dim *= 2; + // } + const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim; + const int num_heads = total_num_head - 2 * kv_num_heads; + return {{token_num, num_heads * head_dim}}; +} + +std::vector SplitEmbeddingKVCacheBlockAttnInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype) { + return {qkv_dtype}; +} + +PD_BUILD_STATIC_OP(block_attn) + .Inputs({"qkv", + "key_cache", + "value_cache", + "rotary_embs", + "block_tables", + "prefix_block_tables", + "len_info_cpu", + "encoder_seq_lod_cpu", + "decoder_seq_lod_cpu", + "encoder_kv_lod_cpu", + "encoder_batch_map_cpu", + "decoder_context_len_cpu", + "decoder_context_len_cache_cpu", + "decoder_batch_map_cpu", + "prefix_len_cpu", + "encoder_seq_lod", + "decoder_seq_lod", + "encoder_kv_lod", + "encoder_batch_map", + "decoder_context_len", + "decoder_context_len_cache", + "decoder_batch_map", + "prefix_len", + "slot_mapping_enc", + "slot_mapping_dec", + paddle::Optional("k_scales"), + paddle::Optional("v_scales"), + paddle::Optional("k_scales_inv"), + paddle::Optional("v_scales_inv"), + paddle::Optional("k_zeros"), + paddle::Optional("v_zeros"), + paddle::Optional("shift"), + paddle::Optional("smooth"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight"), + paddle::Optional("kv_signal_data_cpu"), + paddle::Optional("cachekv_signal_thread_cpu")}) + .Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"}) + .Outputs({"block_attn_out"}) + .SetKernelFn(PD_KERNEL(SplitEmbeddingKVCacheBlockAttn)) + .SetInferShapeFn(PD_INFER_SHAPE(SplitEmbeddingKVCacheBlockAttnInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SplitEmbeddingKVCacheBlockAttnInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/get_infer_param.cc b/custom_ops/xpu_ops/src/ops/get_infer_param.cc index 5a923f0f7e..28a91f81a1 100644 --- a/custom_ops/xpu_ops/src/ops/get_infer_param.cc +++ b/custom_ops/xpu_ops/src/ops/get_infer_param.cc @@ -16,14 +16,60 @@ #include "paddle/extension.h" #include "xpu/internal/infra_op.h" #include "xpu/plugin.h" +#include "ops/utility/env.h" + +XPU_DECLARE_BOOL(encoder_splice, false); +XPU_DECLARE_BOOL(decoder_splice, false); + namespace api = baidu::xpu::api; +void lod_to_slot_mapping(api::Context* xpu_ctx, + paddle::Place place, + const std::vector& block_table, + const std::vector& kv_seq_lod, + const std::vector& start_tokens, + const std::vector& real_batch, + int32_t* slot_mapping, + int32_t token_num, + int32_t block_size, + int32_t batch_size, + int32_t max_num_blocks_per_seq, + int32_t num_speculative_tokens) { + if (token_num <= 0) { + return; + } + std::vector slot_mapping_vec(token_num, -1); + int32_t idx = 0; + // For each Batch + for (auto batch_ = 0; batch_ < batch_size; batch_++) { + int32_t seq_len = kv_seq_lod[batch_ + 1] - kv_seq_lod[batch_]; + int32_t seq_start = start_tokens[batch_]; + int32_t dst_batch_id = real_batch[batch_]; + // for each token + for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) { + int32_t table_id = seq_ / block_size; + int32_t block_id = + block_table[dst_batch_id * max_num_blocks_per_seq + table_id]; + int32_t seq_offset = seq_ % block_size; + int32_t dst_token_offset = block_id * block_size + seq_offset; + slot_mapping_vec[idx] = dst_token_offset; + idx++; + } + } + int ret = api::do_host2device(xpu_ctx, + slot_mapping_vec.data(), + slot_mapping, + token_num * sizeof(int32_t)); + PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed."); +} + std::vector GetInferParam( const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& block_tables, - int block_size) { + int block_size, + int num_speculative_tokens) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto xpu_ctx = static_cast(dev_ctx); @@ -109,6 +155,15 @@ std::vector GetInferParam( batch_offset++; } } + // for vsl_rotary_embedding_gptj of cudagraph mode + int prev_val = 0; + for (int i = 0; i < bsz; i++) { + if (decoder_seq_lod_vec[i] > prev_val) { + prev_val = decoder_seq_lod_vec[i]; + } else if (decoder_seq_lod_vec[i] < prev_val) { + decoder_seq_lod_vec[i] = prev_val; + } + } int prefix_block_num_per_seq = (max_kv_len + block_size - 1) / block_size; std::vector prefix_block_tables_vec( enc_batch * prefix_block_num_per_seq, -1); @@ -167,6 +222,52 @@ std::vector GetInferParam( seq_lens_encoder.type(), seq_lens_encoder.place()); + // for store_paged_kv_cache of cudagraph mode + // if slot_mapping is -1, store_paged_kv_cache will not write to kv cache + paddle::Tensor slot_mapping_enc = paddle::full( + {total_enc_len}, -1, paddle::DataType::INT32, seq_lens_encoder.place()); + // TODO: mtp mode not verified yet, need further adaption + paddle::Tensor slot_mapping_dec = + paddle::full({bsz * (1 + num_speculative_tokens)}, + -1, + paddle::DataType::INT32, + seq_lens_decoder.place()); + if (FLAGS_encoder_splice || FLAGS_decoder_splice) { + std::vector block_tables_vec(block_bs * block_num_per_seq); + r = xpu_memcpy(block_tables_vec.data(), + block_tables.data(), + sizeof(int32_t) * block_bs * block_num_per_seq, + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + if (FLAGS_encoder_splice) { + lod_to_slot_mapping(xpu_ctx->x_context(), + seq_lens_encoder.place(), + block_tables_vec, + encoder_seq_lod_vec, + prefix_len_vec, + encoder_batch_map_vec, + slot_mapping_enc.data(), + total_enc_len, + block_size, + enc_batch, + block_num_per_seq, + 0); + } + if (FLAGS_decoder_splice) { + lod_to_slot_mapping(xpu_ctx->x_context(), + seq_lens_decoder.place(), + block_tables_vec, + decoder_seq_lod_vec, + decoder_context_len_cache_vec, + decoder_batch_map_vec, + slot_mapping_dec.data(), + bsz * (1 + num_speculative_tokens), + block_size, + dec_batch, + block_num_per_seq, + num_speculative_tokens); + } + } + auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace()); @@ -326,7 +427,9 @@ std::vector GetInferParam( prefix_len_cpu, decoder_context_len_cpu, decoder_context_len_cache_cpu, - len_info_cpu}; + len_info_cpu, + slot_mapping_enc, + slot_mapping_dec}; } std::vector> GetInferParamInferShape( @@ -400,8 +503,10 @@ PD_BUILD_OP(get_infer_param) "prefix_len_cpu", "decoder_context_len_cpu", "decoder_context_len_cache_cpu", - "len_info_cpu"}) + "len_info_cpu", + "slot_mapping_enc", + "slot_mapping_dec"}) .SetKernelFn(PD_KERNEL(GetInferParam)) - .Attrs({"block_size: int"}) + .Attrs({"block_size: int", "num_speculative_tokens: int"}) .SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetInferParamInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index b73a3d8f15..e8a66c750a 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -57,7 +57,7 @@ void GetOutputKVSignal(const paddle::Tensor& x, int64_t rank_id, bool wait_flag); -std::vector BlockAttn( +std::vector SplitEmbeddingKVCacheBlockAttn( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, @@ -81,6 +81,50 @@ std::vector BlockAttn( const paddle::Tensor& decoder_context_len_cache_xpu, const paddle::Tensor& decoder_batch_map_xpu, const paddle::Tensor& prefix_len_xpu, + const paddle::Tensor& slot_mapping_enc, + const paddle::Tensor& slot_mapping_dec, + const paddle::optional& k_scales, + const paddle::optional& v_scales, + const paddle::optional& k_scales_inv, + const paddle::optional& v_scales_inv, + const paddle::optional& k_zeros, + const paddle::optional& v_zeros, + const paddle::optional& shift, + const paddle::optional& smooth, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& kv_signal_data_cpu, + const paddle::optional& cachekv_signal_thread_cpu, + const bool use_neox_rotary_style, + const bool rope_3d = false); + +// deprecated, keep for unit test, will be removed in the future +std::vector BlockAttnFused( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& block_tables, + const paddle::Tensor& prefix_block_tables, + const paddle::Tensor& len_info_cpu, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, + const paddle::Tensor& encoder_kv_lod_cpu, + const paddle::Tensor& encoder_batch_map_cpu, + const paddle::Tensor& decoder_context_len_cpu, + const paddle::Tensor& decoder_context_len_cache_cpu, + const paddle::Tensor& decoder_batch_map_cpu, + const paddle::Tensor& prefix_len_cpu, + const paddle::Tensor& encoder_seq_lod_xpu, + const paddle::Tensor& decoder_seq_lod_xpu, + const paddle::Tensor& encoder_kv_lod_xpu, + const paddle::Tensor& encoder_batch_map_xpu, + const paddle::Tensor& decoder_context_len_xpu, + const paddle::Tensor& decoder_context_len_cache_xpu, + const paddle::Tensor& decoder_batch_map_xpu, + const paddle::Tensor& prefix_len_xpu, + const paddle::Tensor& slot_mapping_enc, + const paddle::Tensor& slot_mapping_dec, const paddle::optional& k_scales, const paddle::optional& v_scales, const paddle::optional& k_scales_inv, @@ -434,7 +478,8 @@ std::vector GetInferParam( const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& block_tables, - int block_size); + int block_size, + int num_speculative_tokens); void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag); @@ -749,7 +794,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { "adjust batch in XPU"); m.def("block_attn", - &BlockAttn, + &SplitEmbeddingKVCacheBlockAttn, py::arg("qkv"), py::arg("key_cache"), py::arg("value_cache"), @@ -773,6 +818,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("decoder_context_len_cache_xpu"), py::arg("decoder_batch_map_xpu"), py::arg("prefix_len_xpu"), + py::arg("slot_mapping_enc"), + py::arg("slot_mapping_dec"), py::arg("k_scales"), py::arg("v_scales"), py::arg("k_scales_inv"), @@ -789,6 +836,49 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("rope_3d") = false, "block attention in XPU"); + m.def("block_attn_fused", + &BlockAttnFused, + py::arg("qkv"), + py::arg("key_cache"), + py::arg("value_cache"), + py::arg("rotary_embs"), + py::arg("block_tables"), + py::arg("prefix_block_tables"), + py::arg("len_info_cpu"), + py::arg("encoder_seq_lod_cpu"), + py::arg("decoder_seq_lod_cpu"), + py::arg("encoder_kv_lod_cpu"), + py::arg("encoder_batch_map_cpu"), + py::arg("decoder_context_len_cpu"), + py::arg("decoder_context_len_cache_cpu"), + py::arg("decoder_batch_map_cpu"), + py::arg("prefix_len_cpu"), + py::arg("encoder_seq_lod_xpu"), + py::arg("decoder_seq_lod_xpu"), + py::arg("encoder_kv_lod_xpu"), + py::arg("encoder_batch_map_xpu"), + py::arg("decoder_context_len_xpu"), + py::arg("decoder_context_len_cache_xpu"), + py::arg("decoder_batch_map_xpu"), + py::arg("prefix_len_xpu"), + py::arg("slot_mapping_enc"), + py::arg("slot_mapping_dec"), + py::arg("k_scales"), + py::arg("v_scales"), + py::arg("k_scales_inv"), + py::arg("v_scales_inv"), + py::arg("k_zeros"), + py::arg("v_zeros"), + py::arg("shift"), + py::arg("smooth"), + py::arg("q_norm_weight"), + py::arg("k_norm_weight"), + py::arg("kv_signal_data_cpu"), + py::arg("cachekv_signal_thread_cpu"), + py::arg("use_neox_rotary_style"), + py::arg("rope_3d") = false, + "block attention fused in XPU"); + m.def("create_kv_signal_sender", &create_cachekv_signal_thread, "init write cache kv signal thread"); @@ -963,6 +1053,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_this_time"), py::arg("block_tables"), py::arg("block_size"), + py::arg("num_speculative_tokens"), "Get infer parameters for block attention in XPU"); m.def("get_peer_mem_addr", diff --git a/custom_ops/xpu_ops/test/test_block_attn.py b/custom_ops/xpu_ops/test/test_block_attn.py new file mode 100644 index 0000000000..51b6a7b9db --- /dev/null +++ b/custom_ops/xpu_ops/test/test_block_attn.py @@ -0,0 +1,653 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import paddle + +# block_attn_fused is deprecated and should be removed in the future +from fastdeploy.model_executor.ops.xpu import ( + block_attn, + block_attn_fused, + get_infer_param, +) + + +def print_all_not_equal_elements_info(k, x, y): + x_flatten = x.flatten() + y_flatten = y.flatten() + index = paddle.nonzero(x_flatten != y_flatten) + x_not_equal = x_flatten[index] + y_not_equal = y_flatten[index] + print(f"reference not equal element of {k}: {x_not_equal}") + print(f"calculated result not equal element of {k}: {y_not_equal}") + xy_diff = x - y + xy_mean_diff = paddle.mean(xy_diff) + xy_max_abs_diff = paddle.max(paddle.abs(xy_diff)) + xy_min_abs_diff = paddle.min(paddle.abs(xy_diff)) + print(f"{k} mean diff: {xy_mean_diff}, max abs diff: {xy_max_abs_diff}, min abs diff: {xy_min_abs_diff}") + + +def run_prefix_cache_block_attn( + block_attn_func, + qkv, + seq_len, + seq_lens_this_time, + hit_prefix_len, + key_cache, + value_cache, + rotary_embs, + block_tables, + attn_out, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + k_zp, + v_zp, + shift, + smooth, + q_norm_weight, + k_norm_weight, + kv_signal_data_cpu, + cachekv_signal_thread_cpu, + use_neox_rotary_style, + rope_3d, + num_speculative_tokens, +): + if key_cache.dtype == paddle.int8: + rtol = 1e-1 + atol = 1e-2 + else: + rtol = 1e-2 + atol = 1e-3 + # prefix cache block attn + seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32") + seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32") + ( + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + slot_mapping_enc, + slot_mapping_dec, + ) = get_infer_param( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens + ) # block_size + qkv_prefix = qkv[hit_prefix_len:] + attn_out_prefix_cache = block_attn_func( + qkv_prefix, + key_cache, + value_cache, + rotary_embs, + block_tables, + prefix_block_tables, + len_info_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + encoder_batch_map_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + decoder_batch_map_cpu, + prefix_len_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + encoder_batch_map, + decoder_context_len, + decoder_context_len_cache, + decoder_batch_map, + prefix_len, + slot_mapping_enc, + slot_mapping_dec, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + k_zp, + v_zp, + shift, + smooth, + q_norm_weight, + k_norm_weight, + kv_signal_data_cpu, + cachekv_signal_thread_cpu, + use_neox_rotary_style, + rope_3d, + ) + attn_out_np = attn_out[hit_prefix_len:].astype("float32").numpy() + attn_out_prefix_cache_np = attn_out_prefix_cache.astype("float32").numpy() + is_passed = np.allclose(attn_out_np, attn_out_prefix_cache_np, rtol=rtol, atol=atol) + if not is_passed: + print(f"block_attn_func: {block_attn_func}") + print("prefix_cache block_attn check failed!") + print(f"origin block_attn_out: {attn_out[hit_prefix_len:]}") + print(f"prefix_cache block_attn_out: {attn_out_prefix_cache}") + print("not equal elements are listed below:") + print_all_not_equal_elements_info("block_attn_out", attn_out[hit_prefix_len:], attn_out_prefix_cache) + else: + print(f"prefix_cache check of {block_attn_func} PASSED!") + assert is_passed + return attn_out_prefix_cache + + +def run_block_attn( + seed, + is_fused, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + mode, # 1 for split kvcache encoder only, 2 for split kvcache decoder only, 3 for mixed + hit_prefix_len, + kvcache_dtype, + has_zp, + use_neox_rotary_style, + rotary_embs_shape, + num_speculative_tokens, +): + assert mode == 0 or mode == 1, "mixed mode not supported yet!" + if mode == 0: + encoder_seq_len = seq_len + decoder_seq_len = 0 + elif mode == 1: + encoder_seq_len = 0 + decoder_seq_len = seq_len + else: + pass + seq_lens_encoder = paddle.to_tensor([encoder_seq_len, 0, 0, 0, 0], dtype="int32") + seq_lens_decoder = paddle.to_tensor([decoder_seq_len, 0, 0, 0, 0], dtype="int32") + seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32") + block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32") + block_tables = block_tables.reshape((block_batch, max_block_per_seq)) + ( + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + slot_mapping_enc, + slot_mapping_dec, + ) = get_infer_param( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens + ) + qkv = paddle.uniform( + shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed + ) + + rotary_embs = paddle.uniform(shape=rotary_embs_shape, dtype="float32", min=-1.0, max=1.0, seed=seed) + key_cache = paddle.zeros( + shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim], + dtype=kvcache_dtype, + ) + value_cache = paddle.zeros( + shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim], + dtype=kvcache_dtype, + ) + + scale_tensor_k = None + scale_tensor_v = None + k_quant_scale = None + v_quant_scale = None + k_dequant_scale = None + v_dequant_scale = None + k_zp = None + v_zp = None + if kvcache_dtype == "int8": + scale_tensor_k = paddle.uniform( + shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed + ) # max + scale_tensor_v = paddle.uniform( + shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed + ) # max + k_quant_scale = 127.0 / scale_tensor_k # for C8 per channel means 127 / max + v_quant_scale = 127.0 / scale_tensor_v # for C8 per channel means 127 / max + if has_zp: + k_dequant_scale = 1 / k_quant_scale # for C8 per channel zp means max + v_dequant_scale = 1 / v_quant_scale # for C8 per channel zp means max + k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16") + v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16") + else: + k_dequant_scale = paddle.cast(scale_tensor_k, dtype="float32") # for C8 per channel means max + v_dequant_scale = paddle.cast(scale_tensor_v, dtype="float32") # for C8 per channel means max + # variable below are not yet used + shift = None + smooth = None + q_norm_weight = None + k_norm_weight = None + kv_signal_data_cpu = None + cachekv_signal_thread_cpu = None + rope_3d = False + + if is_fused: + block_attn_func = block_attn_fused + else: + block_attn_func = block_attn + attn_out = block_attn_func( + qkv, + key_cache, + value_cache, + rotary_embs, + block_tables, + prefix_block_tables, + len_info_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + encoder_batch_map_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + decoder_batch_map_cpu, + prefix_len_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + encoder_batch_map, + decoder_context_len, + decoder_context_len_cache, + decoder_batch_map, + prefix_len, + slot_mapping_enc, + slot_mapping_dec, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + k_zp, + v_zp, + shift, + smooth, + q_norm_weight, + k_norm_weight, + kv_signal_data_cpu, + cachekv_signal_thread_cpu, + use_neox_rotary_style, + rope_3d, + ) + result = { + "block_attn_out": attn_out, + "key_cache": key_cache, + "value_cache": value_cache, + } + + # prefix cache + if mode == 0 and hit_prefix_len > 0: + assert hit_prefix_len < seq_len + attn_out_prefix_cache = run_prefix_cache_block_attn( + block_attn_func, + qkv, + seq_len, + seq_lens_this_time, + hit_prefix_len, + key_cache, + value_cache, + rotary_embs, + block_tables, + attn_out, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + k_zp, + v_zp, + shift, + smooth, + q_norm_weight, + k_norm_weight, + kv_signal_data_cpu, + cachekv_signal_thread_cpu, + use_neox_rotary_style, + rope_3d, + num_speculative_tokens, + ) + result["prefix_cache_block_attn_out"] = attn_out_prefix_cache + return result + + +def run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + hit_prefix_len=0, + kvcache_dtype="bfloat16", + has_zp=False, + use_neox_rotary_style=False, + only_run_spliced=False, + num_speculative_tokens=0, +): + rtol = 1e-3 + atol = 1e-2 + # 0 for prefill only, 1 for decode only, 2 for mixed + # TODO: mixed mode not supported yet, get_infer_param should be modified first + mode_name = ["prefill only", "decode only", "mixed"] + + if use_neox_rotary_style: + embedding_type = "neox" + else: + embedding_type = "rope" + for mode in [0, 1]: + if mode == 0: + seq_len_list = [seq_len] + elif mode == 1: + # seq_len > 1 goes into mtp branch, which only supports seq_len <= 31 + # TODO: mtp mode need further adaption + # seq_len_list = [1, random.randint(2, 31)] + seq_len_list = [1] + for idx, seqlen in enumerate(seq_len_list): + if idx == 0: + branch_name = "non mtp branch" + elif idx == 1: + branch_name = "mtp branch" + print( + f"runnning block attention of mode {mode_name[mode]} ({branch_name}), is_prefix_cache: {hit_prefix_len > 0}, kvcache type: {kvcache_dtype}, has_zp: {has_zp}, rotary_style: {embedding_type}" + ) + if not only_run_spliced: + fused_result = run_block_attn( + seed, + True, # is_fused + head_num, + kv_head_num, + head_dim, + seqlen, + block_batch, + max_block_per_seq, + block_size, + mode, + hit_prefix_len, + kvcache_dtype, + has_zp, + use_neox_rotary_style, + rotary_embs_shape, + num_speculative_tokens, + ) + spliced_result = run_block_attn( + seed, + False, # is_fused + head_num, + kv_head_num, + head_dim, + seqlen, + block_batch, + max_block_per_seq, + block_size, + mode, + hit_prefix_len, + kvcache_dtype, + has_zp, + use_neox_rotary_style, + rotary_embs_shape, + num_speculative_tokens, + ) + if "fused_result" in locals() and "spliced_result" in locals(): + for k in fused_result.keys(): + if paddle.is_integer(fused_result[k]): + fused_v = fused_result[k].astype("int32") + spliced_v = spliced_result[k].astype("int32") + fused_v_np = fused_v.numpy() + splice_v_np = spliced_v.numpy() + # is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-1, atol=1e-1) + is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-2, atol=rtol) + else: + fused_v = fused_result[k].astype("float32") + spliced_v = spliced_result[k].astype("float32") + fused_v_np = fused_v.numpy() + splice_v_np = spliced_v.numpy() + is_passed = np.allclose(fused_v_np, splice_v_np, rtol=rtol, atol=atol, equal_nan=True) + if not is_passed: + print(f"{k} in mode {mode_name[mode]} check FAILED!") + print(f"fused {k}: {fused_v}") + print(f"spliced {k}: {spliced_v}") + print("not equal elements are listed below:") + print_all_not_equal_elements_info(k, fused_v, spliced_v) + else: + print(f"{k} in mode {mode_name[mode]} check PASSED!") + assert is_passed + print("") + else: + if "fused_result" not in locals(): + print("fused_result not found.") + if "spliced_result" not in locals(): + print("spliced_result not found.") + print("skip comparison.") + + +seed = random.randint(0, 2026) +paddle.seed(seed) +head_num = 64 +kv_head_num = 8 +head_dim = 128 +rotary_embs_shape = [2, 1, 8192, 1, head_dim] +seq_len = 128 +block_batch = 5 +max_block_per_seq = 128 +block_size = 64 +# TODO: if hit_prefix_len has a small value, e.g. hit_prefix_len == 2, block_attn_out and prefix_cache_block_attn_out will have greater diff +hit_prefix_len = 71 + +# no prefix cache +# block_attn fused vs spliced +use_neox_rotary_style = False +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + 0, + kvcache_dtype="bfloat16", + has_zp=False, + use_neox_rotary_style=use_neox_rotary_style, +) +# c8 quantization block_attn fused vs spliced +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + 0, + kvcache_dtype="int8", + has_zp=False, + use_neox_rotary_style=use_neox_rotary_style, +) +# c8 zp quantization block_attn fused vs spliced +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + 0, + kvcache_dtype="int8", + has_zp=True, + use_neox_rotary_style=use_neox_rotary_style, +) + +# prefix cache +# block_attn fused vs spliced +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + hit_prefix_len, + kvcache_dtype="bfloat16", + has_zp=False, + use_neox_rotary_style=use_neox_rotary_style, +) +# c8 quantization block_attn fused vs spliced +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + hit_prefix_len, + kvcache_dtype="int8", + has_zp=False, + use_neox_rotary_style=use_neox_rotary_style, +) +# c8 zp quantization block_attn fused vs spliced +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + hit_prefix_len, + kvcache_dtype="int8", + has_zp=True, + use_neox_rotary_style=use_neox_rotary_style, +) + +# # neox +# # block_attn fused vs spliced +# # no prefix cache +use_neox_rotary_style = True +only_run_spliced = False +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + 0, + kvcache_dtype="bfloat16", + has_zp=False, + use_neox_rotary_style=use_neox_rotary_style, + only_run_spliced=only_run_spliced, +) +# prefix cache +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + hit_prefix_len, + kvcache_dtype="bfloat16", + has_zp=False, + use_neox_rotary_style=use_neox_rotary_style, + only_run_spliced=only_run_spliced, +) + +# neox glm 4.5 air debug +head_num = 24 +kv_head_num = 2 +head_dim = 128 +seq_len = 128 +block_batch = 64 +max_block_per_seq = 2050 +block_size = 64 +rotary_embs_shape = [2, 1, 131072, 1, head_dim // 2] + +use_neox_rotary_style = True +only_run_spliced = False +run_compare_block_attn( + seed, + head_num, + kv_head_num, + head_dim, + seq_len, + block_batch, + max_block_per_seq, + block_size, + rotary_embs_shape, + 0, + kvcache_dtype="bfloat16", + has_zp=False, + use_neox_rotary_style=use_neox_rotary_style, + only_run_spliced=only_run_spliced, +) + +print("\nALL PASSED!") diff --git a/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py b/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py index 000c0a3592..c15a5f4e4a 100644 --- a/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py +++ b/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py @@ -15,7 +15,7 @@ import numpy as np import paddle -from fastdeploy.model_executor.ops.xpu import block_attn, get_infer_param +from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param head_num = 64 kv_head_num = 8 @@ -53,8 +53,10 @@ block_tables = block_tables.reshape((block_batch, max_block_per_seq)) decoder_context_len_cpu, decoder_context_len_cache_cpu, len_info_cpu, + slot_mapping_enc, + slot_mapping_dec, ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64 + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0 ) # block_size qkv = paddle.uniform( @@ -64,7 +66,6 @@ qkv = paddle.uniform( max=1.0, ) -cum_offsets = paddle.zeros(shape=[block_batch], dtype="bfloat16") rotary_embs = paddle.uniform(shape=[2, 1, 8192, 1, head_dim], dtype="float32", min=-1.0, max=1.0) key_cache = paddle.zeros( shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim], @@ -94,11 +95,10 @@ v_dequant_scale_zp = 1 / v_quant_scale # for C8 per channel zp means max k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16") v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16") -attn_out = block_attn( +attn_out = block_attn_fused( qkv, key_cache, value_cache, - cum_offsets, rotary_embs, block_tables, prefix_block_tables, @@ -111,6 +111,16 @@ attn_out = block_attn( decoder_context_len_cache_cpu, decoder_batch_map_cpu, prefix_len_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + encoder_batch_map, + decoder_context_len, + decoder_context_len_cache, + decoder_batch_map, + prefix_len, + slot_mapping_enc, + slot_mapping_dec, None, None, None, @@ -121,12 +131,15 @@ attn_out = block_attn( None, None, None, + None, + None, + False, + False, ) -attn_out_C8 = block_attn( +attn_out_C8 = block_attn_fused( qkv, key_cache_int8, value_cache_int8, - cum_offsets, rotary_embs, block_tables, prefix_block_tables, @@ -139,6 +152,16 @@ attn_out_C8 = block_attn( decoder_context_len_cache_cpu, decoder_batch_map_cpu, prefix_len_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + encoder_batch_map, + decoder_context_len, + decoder_context_len_cache, + decoder_batch_map, + prefix_len, + slot_mapping_enc, + slot_mapping_dec, k_quant_scale, v_quant_scale, k_dequant_scale, @@ -149,12 +172,15 @@ attn_out_C8 = block_attn( None, None, None, + None, + None, + False, + False, ) -attn_out_C8_zp = block_attn( +attn_out_C8_zp = block_attn_fused( qkv, key_cache_int8, value_cache_int8, - cum_offsets, rotary_embs, block_tables, prefix_block_tables, @@ -167,6 +193,16 @@ attn_out_C8_zp = block_attn( decoder_context_len_cache_cpu, decoder_batch_map_cpu, prefix_len_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + encoder_batch_map, + decoder_context_len, + decoder_context_len_cache, + decoder_batch_map, + prefix_len, + slot_mapping_enc, + slot_mapping_dec, k_quant_scale, v_quant_scale, k_dequant_scale_zp, @@ -177,6 +213,10 @@ attn_out_C8_zp = block_attn( None, None, None, + None, + None, + False, + False, ) # prefix cache : hit 71 tokens @@ -207,16 +247,17 @@ seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32") decoder_context_len_cpu, decoder_context_len_cache_cpu, len_info_cpu, + slot_mapping_enc, + slot_mapping_dec, ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64 + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0 ) # block_size qkv_prefix = qkv[hit_prefix_len:] -attn_out_prefix_cache = block_attn( +attn_out_prefix_cache = block_attn_fused( qkv_prefix, key_cache, value_cache, - cum_offsets, rotary_embs, block_tables, prefix_block_tables, @@ -229,6 +270,16 @@ attn_out_prefix_cache = block_attn( decoder_context_len_cache_cpu, decoder_batch_map_cpu, prefix_len_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + encoder_batch_map, + decoder_context_len, + decoder_context_len_cache, + decoder_batch_map, + prefix_len, + slot_mapping_enc, + slot_mapping_dec, None, None, None, @@ -239,13 +290,16 @@ attn_out_prefix_cache = block_attn( None, None, None, + None, + None, + False, + False, ) -attn_out_C8_prefix_cache = block_attn( +attn_out_C8_prefix_cache = block_attn_fused( qkv_prefix, key_cache_int8, value_cache_int8, - cum_offsets, rotary_embs, block_tables, prefix_block_tables, @@ -258,6 +312,16 @@ attn_out_C8_prefix_cache = block_attn( decoder_context_len_cache_cpu, decoder_batch_map_cpu, prefix_len_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + encoder_batch_map, + decoder_context_len, + decoder_context_len_cache, + decoder_batch_map, + prefix_len, + slot_mapping_enc, + slot_mapping_dec, k_quant_scale, v_quant_scale, k_dequant_scale, @@ -268,13 +332,16 @@ attn_out_C8_prefix_cache = block_attn( None, None, None, + None, + None, + False, + False, ) -attn_out_C8_zp_prefix_cache = block_attn( +attn_out_C8_zp_prefix_cache = block_attn_fused( qkv_prefix, key_cache_int8, value_cache_int8, - cum_offsets, rotary_embs, block_tables, prefix_block_tables, @@ -287,6 +354,16 @@ attn_out_C8_zp_prefix_cache = block_attn( decoder_context_len_cache_cpu, decoder_batch_map_cpu, prefix_len_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + encoder_batch_map, + decoder_context_len, + decoder_context_len_cache, + decoder_batch_map, + prefix_len, + slot_mapping_enc, + slot_mapping_dec, k_quant_scale, v_quant_scale, k_dequant_scale_zp, @@ -297,6 +374,10 @@ attn_out_C8_zp_prefix_cache = block_attn( None, None, None, + None, + None, + False, + False, ) print("-- C16 prefix cache test --") print("attn_out[hit_prefix_len:]'s mean:", attn_out[hit_prefix_len:].mean().item()) diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 44cf528bed..9b36556a6e 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -278,6 +278,11 @@ class XPUForwardMeta(ForwardMeta): # max bs max_num_seqs: int = 0 + # for spliced block_attn + slot_mapping_enc: Optional[paddle.Tensor] = None + # + slot_mapping_dec: Optional[paddle.Tensor] = None + def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None): """ Synchronize attributes from another XPUForwardMeta object diff --git a/fastdeploy/model_executor/layers/backends/xpu/attention.py b/fastdeploy/model_executor/layers/backends/xpu/attention.py index 31fce9bdf5..dbc41d6dd1 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/attention.py +++ b/fastdeploy/model_executor/layers/backends/xpu/attention.py @@ -214,6 +214,8 @@ class XPUAttentionBackend(AttentionBackend): forward_meta.decoder_context_len_cache, forward_meta.decoder_batch_map, forward_meta.prefix_len, + forward_meta.slot_mapping_enc, + forward_meta.slot_mapping_dec, cache_k_scale, cache_v_scale, cache_k_out_scale, diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 0674f2b6d7..b8c97529dc 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -127,6 +127,7 @@ def xpu_pre_process( is_profiling: bool = False, forward_meta=None, use_cudagraph=False, + num_speculative_tokens=0, ) -> XPUForwardMeta: """ """ @@ -196,8 +197,15 @@ def xpu_pre_process( xpu_forward_meta.decoder_context_len_cpu, xpu_forward_meta.decoder_context_len_cache_cpu, xpu_forward_meta.len_info_cpu, + xpu_forward_meta.slot_mapping_enc, + xpu_forward_meta.slot_mapping_dec, ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + xpu_forward_meta.block_tables, + block_size, + num_speculative_tokens, ) xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0] xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1] diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index a155e5d9ad..b4b2541099 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -1078,6 +1078,7 @@ class MTPProposer(Proposer): self.model_inputs["draft_tokens"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], + num_speculative_tokens=self.speculative_config.num_speculative_tokens, ) if self.enable_mm: diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 5a992b90d3..201d503649 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -1129,6 +1129,7 @@ class XPUModelRunner(ModelRunnerBase): is_profiling=is_dummy_run, forward_meta=self.forward_meta, use_cudagraph=self.use_cudagraph, + num_speculative_tokens=self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0, ) if self.use_cudagraph: diff --git a/tests/xpu_ci/4cards_cases/test_mtp_cudagraph.py b/tests/xpu_ci/4cards_cases/test_mtp_cudagraph.py deleted file mode 100644 index 6fcf14335f..0000000000 --- a/tests/xpu_ci/4cards_cases/test_mtp_cudagraph.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -MTP模式测试 - ERNIE-4.5-21B-A3B-Paddle 模型 - -测试配置: -- 模型: ERNIE-4.5-21B-A3B-Paddle -- 量化: wint4 -- Tensor Parallel: 4 -""" - -import json - -import openai -import pytest -from conftest import get_model_path, get_port_num, print_logs_on_failure, start_server - - -def test_mtp_mode(xpu_env): - """mtp模式测试""" - - print("\n============================开始mtp + CudaGraph 模式测试!============================") - - # 获取配置 - port_num = get_port_num() - model_path = get_model_path() - spec_config = {"method": "mtp", "num_speculative_tokens": 1, "model": f"{model_path}/ERNIE-4.5-21B-A3B-Paddle/mtp"} - # 构建服务器启动参数 - server_args = [ - "--model", - f"{model_path}/ERNIE-4.5-21B-A3B-Paddle", - "--port", - str(port_num), - "--engine-worker-queue-port", - str(port_num + 1), - "--metrics-port", - str(port_num + 2), - "--tensor-parallel-size", - "4", - "--num-gpu-blocks-override", - "16384", - "--max-model-len", - "8192", - "--max-num-seqs", - "128", - "--quantization", - "wint4", - "--speculative-config", - f"{json.dumps(spec_config)}", - "--graph-optimization-config", - '{"use_cudagraph":true}', - ] - - # 启动服务器 - if not start_server(server_args): - pytest.fail("mtp模式服务启动失败") - - # 执行测试 - try: - ip = "0.0.0.0" - client = openai.Client(base_url=f"http://{ip}:{port_num}/v1", api_key="EMPTY_API_KEY") - - # 非流式对话 - response = client.chat.completions.create( - model="default", - messages=[ - {"role": "user", "content": "你好,你是谁?"}, - ], - temperature=1, - top_p=0, - max_tokens=64, - stream=False, - ) - - print(f"\n模型回复: {response.choices[0].message.content}") - - # 验证响应 - assert any( - keyword in response.choices[0].message.content for keyword in ["人工智能", "文心一言", "百度", "智能助手"] - ), f"响应内容不符合预期: {response.choices[0].message.content}" - - print("\nmtp + CudaGraph模式测试通过!") - - except Exception as e: - print(f"\nmtp + CudaGraph模式测试失败: {str(e)}") - print_logs_on_failure() - pytest.fail(f"mtp + CudaGraph模式测试失败: {str(e)}") - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"])