[XPU] Split the block_attn operator into smaller operators (#6798)

* spliced block_attn

* adapt to latest vllm

* fix unit tests

* delete mtp+cudagraph 4 cards test

* fix vl model

* fix mtp

* fix slot mapping
This commit is contained in:
RuohengMa
2026-04-16 14:28:40 +08:00
committed by GitHub
parent 6b891da02b
commit de0c5e68fb
12 changed files with 2891 additions and 131 deletions
+9 -4
View File
@@ -159,6 +159,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
if (use_neox_rotary_style) {
pos_emb_type = "NEOX";
} else if (rope_head_dim == head_dim / 2) {
// vl model use this
pos_emb_type = "HALF_HEAD_DIM";
} else {
pos_emb_type = "NORMAL";
@@ -984,7 +985,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
return {block_attn_out};
}
std::vector<paddle::Tensor> BlockAttn(
std::vector<paddle::Tensor> BlockAttnFused(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
@@ -1008,6 +1009,8 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::Tensor& decoder_context_len_cache,
const paddle::Tensor& decoder_batch_map,
const paddle::Tensor& prefix_len,
const paddle::Tensor& slot_mapping_enc,
const paddle::Tensor& slot_mapping_dec,
const paddle::optional<paddle::Tensor>& k_scales,
const paddle::optional<paddle::Tensor>& v_scales,
const paddle::optional<paddle::Tensor>& k_scales_inv,
@@ -1067,7 +1070,7 @@ std::vector<paddle::Tensor> BlockAttn(
} else if (cache_dtype == paddle::DataType::INT8) {
APPLY_KERNEL(paddle::bfloat16, int8_t, paddle::bfloat16);
} else {
PD_THROW("block_attn not support cache_dtype==%d",
PD_THROW("block_attn_fused not support cache_dtype==%d",
static_cast<int>(cache_dtype));
return {};
}
@@ -1097,7 +1100,7 @@ std::vector<paddle::DataType> BlockAttnInferDtype(
return {qkv_dtype};
}
PD_BUILD_STATIC_OP(block_attn)
PD_BUILD_STATIC_OP(block_attn_fused)
.Inputs({"qkv",
"key_cache",
"value_cache",
@@ -1121,6 +1124,8 @@ PD_BUILD_STATIC_OP(block_attn)
"decoder_context_len_cache",
"decoder_batch_map",
"prefix_len",
"slot_mapping_enc",
"slot_mapping_dec",
paddle::Optional("k_scales"),
paddle::Optional("v_scales"),
paddle::Optional("k_scales_inv"),
@@ -1135,6 +1140,6 @@ PD_BUILD_STATIC_OP(block_attn)
paddle::Optional("cachekv_signal_thread_cpu")})
.Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"})
.Outputs({"block_attn_out"})
.SetKernelFn(PD_KERNEL(BlockAttn))
.SetKernelFn(PD_KERNEL(BlockAttnFused))
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(BlockAttnInferDtype));
File diff suppressed because it is too large Load Diff
+109 -4
View File
@@ -16,14 +16,60 @@
#include "paddle/extension.h"
#include "xpu/internal/infra_op.h"
#include "xpu/plugin.h"
#include "ops/utility/env.h"
XPU_DECLARE_BOOL(encoder_splice, false);
XPU_DECLARE_BOOL(decoder_splice, false);
namespace api = baidu::xpu::api;
void lod_to_slot_mapping(api::Context* xpu_ctx,
paddle::Place place,
const std::vector<int32_t>& block_table,
const std::vector<int32_t>& kv_seq_lod,
const std::vector<int32_t>& start_tokens,
const std::vector<int32_t>& real_batch,
int32_t* slot_mapping,
int32_t token_num,
int32_t block_size,
int32_t batch_size,
int32_t max_num_blocks_per_seq,
int32_t num_speculative_tokens) {
if (token_num <= 0) {
return;
}
std::vector<int32_t> slot_mapping_vec(token_num, -1);
int32_t idx = 0;
// For each Batch
for (auto batch_ = 0; batch_ < batch_size; batch_++) {
int32_t seq_len = kv_seq_lod[batch_ + 1] - kv_seq_lod[batch_];
int32_t seq_start = start_tokens[batch_];
int32_t dst_batch_id = real_batch[batch_];
// for each token
for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) {
int32_t table_id = seq_ / block_size;
int32_t block_id =
block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
int32_t seq_offset = seq_ % block_size;
int32_t dst_token_offset = block_id * block_size + seq_offset;
slot_mapping_vec[idx] = dst_token_offset;
idx++;
}
}
int ret = api::do_host2device(xpu_ctx,
slot_mapping_vec.data(),
slot_mapping,
token_num * sizeof(int32_t));
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
}
std::vector<paddle::Tensor> GetInferParam(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& block_tables,
int block_size) {
int block_size,
int num_speculative_tokens) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
@@ -109,6 +155,15 @@ std::vector<paddle::Tensor> GetInferParam(
batch_offset++;
}
}
// for vsl_rotary_embedding_gptj of cudagraph mode
int prev_val = 0;
for (int i = 0; i < bsz; i++) {
if (decoder_seq_lod_vec[i] > prev_val) {
prev_val = decoder_seq_lod_vec[i];
} else if (decoder_seq_lod_vec[i] < prev_val) {
decoder_seq_lod_vec[i] = prev_val;
}
}
int prefix_block_num_per_seq = (max_kv_len + block_size - 1) / block_size;
std::vector<int32_t> prefix_block_tables_vec(
enc_batch * prefix_block_num_per_seq, -1);
@@ -167,6 +222,52 @@ std::vector<paddle::Tensor> GetInferParam(
seq_lens_encoder.type(),
seq_lens_encoder.place());
// for store_paged_kv_cache of cudagraph mode
// if slot_mapping is -1, store_paged_kv_cache will not write to kv cache
paddle::Tensor slot_mapping_enc = paddle::full(
{total_enc_len}, -1, paddle::DataType::INT32, seq_lens_encoder.place());
// TODO: mtp mode not verified yet, need further adaption
paddle::Tensor slot_mapping_dec =
paddle::full({bsz * (1 + num_speculative_tokens)},
-1,
paddle::DataType::INT32,
seq_lens_decoder.place());
if (FLAGS_encoder_splice || FLAGS_decoder_splice) {
std::vector<int32_t> block_tables_vec(block_bs * block_num_per_seq);
r = xpu_memcpy(block_tables_vec.data(),
block_tables.data<int32_t>(),
sizeof(int32_t) * block_bs * block_num_per_seq,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
if (FLAGS_encoder_splice) {
lod_to_slot_mapping(xpu_ctx->x_context(),
seq_lens_encoder.place(),
block_tables_vec,
encoder_seq_lod_vec,
prefix_len_vec,
encoder_batch_map_vec,
slot_mapping_enc.data<int32_t>(),
total_enc_len,
block_size,
enc_batch,
block_num_per_seq,
0);
}
if (FLAGS_decoder_splice) {
lod_to_slot_mapping(xpu_ctx->x_context(),
seq_lens_decoder.place(),
block_tables_vec,
decoder_seq_lod_vec,
decoder_context_len_cache_vec,
decoder_batch_map_vec,
slot_mapping_dec.data<int32_t>(),
bsz * (1 + num_speculative_tokens),
block_size,
dec_batch,
block_num_per_seq,
num_speculative_tokens);
}
}
auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
@@ -326,7 +427,9 @@ std::vector<paddle::Tensor> GetInferParam(
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu};
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec};
}
std::vector<std::vector<int64_t>> GetInferParamInferShape(
@@ -400,8 +503,10 @@ PD_BUILD_OP(get_infer_param)
"prefix_len_cpu",
"decoder_context_len_cpu",
"decoder_context_len_cache_cpu",
"len_info_cpu"})
"len_info_cpu",
"slot_mapping_enc",
"slot_mapping_dec"})
.SetKernelFn(PD_KERNEL(GetInferParam))
.Attrs({"block_size: int"})
.Attrs({"block_size: int", "num_speculative_tokens: int"})
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetInferParamInferDtype));
+94 -3
View File
@@ -57,7 +57,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag);
std::vector<paddle::Tensor> BlockAttn(
std::vector<paddle::Tensor> SplitEmbeddingKVCacheBlockAttn(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
@@ -81,6 +81,50 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::Tensor& decoder_context_len_cache_xpu,
const paddle::Tensor& decoder_batch_map_xpu,
const paddle::Tensor& prefix_len_xpu,
const paddle::Tensor& slot_mapping_enc,
const paddle::Tensor& slot_mapping_dec,
const paddle::optional<paddle::Tensor>& k_scales,
const paddle::optional<paddle::Tensor>& v_scales,
const paddle::optional<paddle::Tensor>& k_scales_inv,
const paddle::optional<paddle::Tensor>& v_scales_inv,
const paddle::optional<paddle::Tensor>& k_zeros,
const paddle::optional<paddle::Tensor>& v_zeros,
const paddle::optional<paddle::Tensor>& shift,
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const bool use_neox_rotary_style,
const bool rope_3d = false);
// deprecated, keep for unit test, will be removed in the future
std::vector<paddle::Tensor> BlockAttnFused(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& block_tables,
const paddle::Tensor& prefix_block_tables,
const paddle::Tensor& len_info_cpu,
const paddle::Tensor& encoder_seq_lod_cpu,
const paddle::Tensor& decoder_seq_lod_cpu,
const paddle::Tensor& encoder_kv_lod_cpu,
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_context_len_cpu,
const paddle::Tensor& decoder_context_len_cache_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& prefix_len_cpu,
const paddle::Tensor& encoder_seq_lod_xpu,
const paddle::Tensor& decoder_seq_lod_xpu,
const paddle::Tensor& encoder_kv_lod_xpu,
const paddle::Tensor& encoder_batch_map_xpu,
const paddle::Tensor& decoder_context_len_xpu,
const paddle::Tensor& decoder_context_len_cache_xpu,
const paddle::Tensor& decoder_batch_map_xpu,
const paddle::Tensor& prefix_len_xpu,
const paddle::Tensor& slot_mapping_enc,
const paddle::Tensor& slot_mapping_dec,
const paddle::optional<paddle::Tensor>& k_scales,
const paddle::optional<paddle::Tensor>& v_scales,
const paddle::optional<paddle::Tensor>& k_scales_inv,
@@ -434,7 +478,8 @@ std::vector<paddle::Tensor> GetInferParam(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& block_tables,
int block_size);
int block_size,
int num_speculative_tokens);
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);
@@ -749,7 +794,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
"adjust batch in XPU");
m.def("block_attn",
&BlockAttn,
&SplitEmbeddingKVCacheBlockAttn,
py::arg("qkv"),
py::arg("key_cache"),
py::arg("value_cache"),
@@ -773,6 +818,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("decoder_context_len_cache_xpu"),
py::arg("decoder_batch_map_xpu"),
py::arg("prefix_len_xpu"),
py::arg("slot_mapping_enc"),
py::arg("slot_mapping_dec"),
py::arg("k_scales"),
py::arg("v_scales"),
py::arg("k_scales_inv"),
@@ -789,6 +836,49 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("rope_3d") = false,
"block attention in XPU");
m.def("block_attn_fused",
&BlockAttnFused,
py::arg("qkv"),
py::arg("key_cache"),
py::arg("value_cache"),
py::arg("rotary_embs"),
py::arg("block_tables"),
py::arg("prefix_block_tables"),
py::arg("len_info_cpu"),
py::arg("encoder_seq_lod_cpu"),
py::arg("decoder_seq_lod_cpu"),
py::arg("encoder_kv_lod_cpu"),
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_context_len_cpu"),
py::arg("decoder_context_len_cache_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("prefix_len_cpu"),
py::arg("encoder_seq_lod_xpu"),
py::arg("decoder_seq_lod_xpu"),
py::arg("encoder_kv_lod_xpu"),
py::arg("encoder_batch_map_xpu"),
py::arg("decoder_context_len_xpu"),
py::arg("decoder_context_len_cache_xpu"),
py::arg("decoder_batch_map_xpu"),
py::arg("prefix_len_xpu"),
py::arg("slot_mapping_enc"),
py::arg("slot_mapping_dec"),
py::arg("k_scales"),
py::arg("v_scales"),
py::arg("k_scales_inv"),
py::arg("v_scales_inv"),
py::arg("k_zeros"),
py::arg("v_zeros"),
py::arg("shift"),
py::arg("smooth"),
py::arg("q_norm_weight"),
py::arg("k_norm_weight"),
py::arg("kv_signal_data_cpu"),
py::arg("cachekv_signal_thread_cpu"),
py::arg("use_neox_rotary_style"),
py::arg("rope_3d") = false,
"block attention fused in XPU");
m.def("create_kv_signal_sender",
&create_cachekv_signal_thread,
"init write cache kv signal thread");
@@ -963,6 +1053,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("seq_lens_this_time"),
py::arg("block_tables"),
py::arg("block_size"),
py::arg("num_speculative_tokens"),
"Get infer parameters for block attention in XPU");
m.def("get_peer_mem_addr",
+653
View File
@@ -0,0 +1,653 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import paddle
# block_attn_fused is deprecated and should be removed in the future
from fastdeploy.model_executor.ops.xpu import (
block_attn,
block_attn_fused,
get_infer_param,
)
def print_all_not_equal_elements_info(k, x, y):
x_flatten = x.flatten()
y_flatten = y.flatten()
index = paddle.nonzero(x_flatten != y_flatten)
x_not_equal = x_flatten[index]
y_not_equal = y_flatten[index]
print(f"reference not equal element of {k}: {x_not_equal}")
print(f"calculated result not equal element of {k}: {y_not_equal}")
xy_diff = x - y
xy_mean_diff = paddle.mean(xy_diff)
xy_max_abs_diff = paddle.max(paddle.abs(xy_diff))
xy_min_abs_diff = paddle.min(paddle.abs(xy_diff))
print(f"{k} mean diff: {xy_mean_diff}, max abs diff: {xy_max_abs_diff}, min abs diff: {xy_min_abs_diff}")
def run_prefix_cache_block_attn(
block_attn_func,
qkv,
seq_len,
seq_lens_this_time,
hit_prefix_len,
key_cache,
value_cache,
rotary_embs,
block_tables,
attn_out,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
k_zp,
v_zp,
shift,
smooth,
q_norm_weight,
k_norm_weight,
kv_signal_data_cpu,
cachekv_signal_thread_cpu,
use_neox_rotary_style,
rope_3d,
num_speculative_tokens,
):
if key_cache.dtype == paddle.int8:
rtol = 1e-1
atol = 1e-2
else:
rtol = 1e-2
atol = 1e-3
# prefix cache block attn
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
(
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
) # block_size
qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn_func(
qkv_prefix,
key_cache,
value_cache,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
k_zp,
v_zp,
shift,
smooth,
q_norm_weight,
k_norm_weight,
kv_signal_data_cpu,
cachekv_signal_thread_cpu,
use_neox_rotary_style,
rope_3d,
)
attn_out_np = attn_out[hit_prefix_len:].astype("float32").numpy()
attn_out_prefix_cache_np = attn_out_prefix_cache.astype("float32").numpy()
is_passed = np.allclose(attn_out_np, attn_out_prefix_cache_np, rtol=rtol, atol=atol)
if not is_passed:
print(f"block_attn_func: {block_attn_func}")
print("prefix_cache block_attn check failed!")
print(f"origin block_attn_out: {attn_out[hit_prefix_len:]}")
print(f"prefix_cache block_attn_out: {attn_out_prefix_cache}")
print("not equal elements are listed below:")
print_all_not_equal_elements_info("block_attn_out", attn_out[hit_prefix_len:], attn_out_prefix_cache)
else:
print(f"prefix_cache check of {block_attn_func} PASSED!")
assert is_passed
return attn_out_prefix_cache
def run_block_attn(
seed,
is_fused,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
mode, # 1 for split kvcache encoder only, 2 for split kvcache decoder only, 3 for mixed
hit_prefix_len,
kvcache_dtype,
has_zp,
use_neox_rotary_style,
rotary_embs_shape,
num_speculative_tokens,
):
assert mode == 0 or mode == 1, "mixed mode not supported yet!"
if mode == 0:
encoder_seq_len = seq_len
decoder_seq_len = 0
elif mode == 1:
encoder_seq_len = 0
decoder_seq_len = seq_len
else:
pass
seq_lens_encoder = paddle.to_tensor([encoder_seq_len, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([decoder_seq_len, 0, 0, 0, 0], dtype="int32")
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
(
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
)
qkv = paddle.uniform(
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
)
rotary_embs = paddle.uniform(shape=rotary_embs_shape, dtype="float32", min=-1.0, max=1.0, seed=seed)
key_cache = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
dtype=kvcache_dtype,
)
value_cache = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
dtype=kvcache_dtype,
)
scale_tensor_k = None
scale_tensor_v = None
k_quant_scale = None
v_quant_scale = None
k_dequant_scale = None
v_dequant_scale = None
k_zp = None
v_zp = None
if kvcache_dtype == "int8":
scale_tensor_k = paddle.uniform(
shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed
) # max
scale_tensor_v = paddle.uniform(
shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed
) # max
k_quant_scale = 127.0 / scale_tensor_k # for C8 per channel means 127 / max
v_quant_scale = 127.0 / scale_tensor_v # for C8 per channel means 127 / max
if has_zp:
k_dequant_scale = 1 / k_quant_scale # for C8 per channel zp means max
v_dequant_scale = 1 / v_quant_scale # for C8 per channel zp means max
k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
else:
k_dequant_scale = paddle.cast(scale_tensor_k, dtype="float32") # for C8 per channel means max
v_dequant_scale = paddle.cast(scale_tensor_v, dtype="float32") # for C8 per channel means max
# variable below are not yet used
shift = None
smooth = None
q_norm_weight = None
k_norm_weight = None
kv_signal_data_cpu = None
cachekv_signal_thread_cpu = None
rope_3d = False
if is_fused:
block_attn_func = block_attn_fused
else:
block_attn_func = block_attn
attn_out = block_attn_func(
qkv,
key_cache,
value_cache,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
k_zp,
v_zp,
shift,
smooth,
q_norm_weight,
k_norm_weight,
kv_signal_data_cpu,
cachekv_signal_thread_cpu,
use_neox_rotary_style,
rope_3d,
)
result = {
"block_attn_out": attn_out,
"key_cache": key_cache,
"value_cache": value_cache,
}
# prefix cache
if mode == 0 and hit_prefix_len > 0:
assert hit_prefix_len < seq_len
attn_out_prefix_cache = run_prefix_cache_block_attn(
block_attn_func,
qkv,
seq_len,
seq_lens_this_time,
hit_prefix_len,
key_cache,
value_cache,
rotary_embs,
block_tables,
attn_out,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
k_zp,
v_zp,
shift,
smooth,
q_norm_weight,
k_norm_weight,
kv_signal_data_cpu,
cachekv_signal_thread_cpu,
use_neox_rotary_style,
rope_3d,
num_speculative_tokens,
)
result["prefix_cache_block_attn_out"] = attn_out_prefix_cache
return result
def run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len=0,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=False,
only_run_spliced=False,
num_speculative_tokens=0,
):
rtol = 1e-3
atol = 1e-2
# 0 for prefill only, 1 for decode only, 2 for mixed
# TODO: mixed mode not supported yet, get_infer_param should be modified first
mode_name = ["prefill only", "decode only", "mixed"]
if use_neox_rotary_style:
embedding_type = "neox"
else:
embedding_type = "rope"
for mode in [0, 1]:
if mode == 0:
seq_len_list = [seq_len]
elif mode == 1:
# seq_len > 1 goes into mtp branch, which only supports seq_len <= 31
# TODO: mtp mode need further adaption
# seq_len_list = [1, random.randint(2, 31)]
seq_len_list = [1]
for idx, seqlen in enumerate(seq_len_list):
if idx == 0:
branch_name = "non mtp branch"
elif idx == 1:
branch_name = "mtp branch"
print(
f"runnning block attention of mode {mode_name[mode]} ({branch_name}), is_prefix_cache: {hit_prefix_len > 0}, kvcache type: {kvcache_dtype}, has_zp: {has_zp}, rotary_style: {embedding_type}"
)
if not only_run_spliced:
fused_result = run_block_attn(
seed,
True, # is_fused
head_num,
kv_head_num,
head_dim,
seqlen,
block_batch,
max_block_per_seq,
block_size,
mode,
hit_prefix_len,
kvcache_dtype,
has_zp,
use_neox_rotary_style,
rotary_embs_shape,
num_speculative_tokens,
)
spliced_result = run_block_attn(
seed,
False, # is_fused
head_num,
kv_head_num,
head_dim,
seqlen,
block_batch,
max_block_per_seq,
block_size,
mode,
hit_prefix_len,
kvcache_dtype,
has_zp,
use_neox_rotary_style,
rotary_embs_shape,
num_speculative_tokens,
)
if "fused_result" in locals() and "spliced_result" in locals():
for k in fused_result.keys():
if paddle.is_integer(fused_result[k]):
fused_v = fused_result[k].astype("int32")
spliced_v = spliced_result[k].astype("int32")
fused_v_np = fused_v.numpy()
splice_v_np = spliced_v.numpy()
# is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-1, atol=1e-1)
is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-2, atol=rtol)
else:
fused_v = fused_result[k].astype("float32")
spliced_v = spliced_result[k].astype("float32")
fused_v_np = fused_v.numpy()
splice_v_np = spliced_v.numpy()
is_passed = np.allclose(fused_v_np, splice_v_np, rtol=rtol, atol=atol, equal_nan=True)
if not is_passed:
print(f"{k} in mode {mode_name[mode]} check FAILED!")
print(f"fused {k}: {fused_v}")
print(f"spliced {k}: {spliced_v}")
print("not equal elements are listed below:")
print_all_not_equal_elements_info(k, fused_v, spliced_v)
else:
print(f"{k} in mode {mode_name[mode]} check PASSED!")
assert is_passed
print("")
else:
if "fused_result" not in locals():
print("fused_result not found.")
if "spliced_result" not in locals():
print("spliced_result not found.")
print("skip comparison.")
seed = random.randint(0, 2026)
paddle.seed(seed)
head_num = 64
kv_head_num = 8
head_dim = 128
rotary_embs_shape = [2, 1, 8192, 1, head_dim]
seq_len = 128
block_batch = 5
max_block_per_seq = 128
block_size = 64
# TODO: if hit_prefix_len has a small value, e.g. hit_prefix_len == 2, block_attn_out and prefix_cache_block_attn_out will have greater diff
hit_prefix_len = 71
# no prefix cache
# block_attn fused vs spliced
use_neox_rotary_style = False
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
)
# c8 quantization block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="int8",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
)
# c8 zp quantization block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="int8",
has_zp=True,
use_neox_rotary_style=use_neox_rotary_style,
)
# prefix cache
# block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
)
# c8 quantization block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len,
kvcache_dtype="int8",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
)
# c8 zp quantization block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len,
kvcache_dtype="int8",
has_zp=True,
use_neox_rotary_style=use_neox_rotary_style,
)
# # neox
# # block_attn fused vs spliced
# # no prefix cache
use_neox_rotary_style = True
only_run_spliced = False
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
only_run_spliced=only_run_spliced,
)
# prefix cache
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
only_run_spliced=only_run_spliced,
)
# neox glm 4.5 air debug
head_num = 24
kv_head_num = 2
head_dim = 128
seq_len = 128
block_batch = 64
max_block_per_seq = 2050
block_size = 64
rotary_embs_shape = [2, 1, 131072, 1, head_dim // 2]
use_neox_rotary_style = True
only_run_spliced = False
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
only_run_spliced=only_run_spliced,
)
print("\nALL PASSED!")
@@ -15,7 +15,7 @@
import numpy as np
import paddle
from fastdeploy.model_executor.ops.xpu import block_attn, get_infer_param
from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param
head_num = 64
kv_head_num = 8
@@ -53,8 +53,10 @@ block_tables = block_tables.reshape((block_batch, max_block_per_seq))
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
) # block_size
qkv = paddle.uniform(
@@ -64,7 +66,6 @@ qkv = paddle.uniform(
max=1.0,
)
cum_offsets = paddle.zeros(shape=[block_batch], dtype="bfloat16")
rotary_embs = paddle.uniform(shape=[2, 1, 8192, 1, head_dim], dtype="float32", min=-1.0, max=1.0)
key_cache = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
@@ -94,11 +95,10 @@ v_dequant_scale_zp = 1 / v_quant_scale # for C8 per channel zp means max
k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
attn_out = block_attn(
attn_out = block_attn_fused(
qkv,
key_cache,
value_cache,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
@@ -111,6 +111,16 @@ attn_out = block_attn(
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
None,
None,
None,
@@ -121,12 +131,15 @@ attn_out = block_attn(
None,
None,
None,
None,
None,
False,
False,
)
attn_out_C8 = block_attn(
attn_out_C8 = block_attn_fused(
qkv,
key_cache_int8,
value_cache_int8,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
@@ -139,6 +152,16 @@ attn_out_C8 = block_attn(
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
@@ -149,12 +172,15 @@ attn_out_C8 = block_attn(
None,
None,
None,
None,
None,
False,
False,
)
attn_out_C8_zp = block_attn(
attn_out_C8_zp = block_attn_fused(
qkv,
key_cache_int8,
value_cache_int8,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
@@ -167,6 +193,16 @@ attn_out_C8_zp = block_attn(
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
k_quant_scale,
v_quant_scale,
k_dequant_scale_zp,
@@ -177,6 +213,10 @@ attn_out_C8_zp = block_attn(
None,
None,
None,
None,
None,
False,
False,
)
# prefix cache : hit 71 tokens
@@ -207,16 +247,17 @@ seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
) # block_size
qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn(
attn_out_prefix_cache = block_attn_fused(
qkv_prefix,
key_cache,
value_cache,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
@@ -229,6 +270,16 @@ attn_out_prefix_cache = block_attn(
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
None,
None,
None,
@@ -239,13 +290,16 @@ attn_out_prefix_cache = block_attn(
None,
None,
None,
None,
None,
False,
False,
)
attn_out_C8_prefix_cache = block_attn(
attn_out_C8_prefix_cache = block_attn_fused(
qkv_prefix,
key_cache_int8,
value_cache_int8,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
@@ -258,6 +312,16 @@ attn_out_C8_prefix_cache = block_attn(
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
@@ -268,13 +332,16 @@ attn_out_C8_prefix_cache = block_attn(
None,
None,
None,
None,
None,
False,
False,
)
attn_out_C8_zp_prefix_cache = block_attn(
attn_out_C8_zp_prefix_cache = block_attn_fused(
qkv_prefix,
key_cache_int8,
value_cache_int8,
cum_offsets,
rotary_embs,
block_tables,
prefix_block_tables,
@@ -287,6 +354,16 @@ attn_out_C8_zp_prefix_cache = block_attn(
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
k_quant_scale,
v_quant_scale,
k_dequant_scale_zp,
@@ -297,6 +374,10 @@ attn_out_C8_zp_prefix_cache = block_attn(
None,
None,
None,
None,
None,
False,
False,
)
print("-- C16 prefix cache test --")
print("attn_out[hit_prefix_len:]'s mean:", attn_out[hit_prefix_len:].mean().item())
@@ -278,6 +278,11 @@ class XPUForwardMeta(ForwardMeta):
# max bs
max_num_seqs: int = 0
# for spliced block_attn
slot_mapping_enc: Optional[paddle.Tensor] = None
#
slot_mapping_dec: Optional[paddle.Tensor] = None
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
"""
Synchronize attributes from another XPUForwardMeta object
@@ -214,6 +214,8 @@ class XPUAttentionBackend(AttentionBackend):
forward_meta.decoder_context_len_cache,
forward_meta.decoder_batch_map,
forward_meta.prefix_len,
forward_meta.slot_mapping_enc,
forward_meta.slot_mapping_dec,
cache_k_scale,
cache_v_scale,
cache_k_out_scale,
@@ -127,6 +127,7 @@ def xpu_pre_process(
is_profiling: bool = False,
forward_meta=None,
use_cudagraph=False,
num_speculative_tokens=0,
) -> XPUForwardMeta:
""" """
@@ -196,8 +197,15 @@ def xpu_pre_process(
xpu_forward_meta.decoder_context_len_cpu,
xpu_forward_meta.decoder_context_len_cache_cpu,
xpu_forward_meta.len_info_cpu,
xpu_forward_meta.slot_mapping_enc,
xpu_forward_meta.slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
xpu_forward_meta.block_tables,
block_size,
num_speculative_tokens,
)
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
+1
View File
@@ -1078,6 +1078,7 @@ class MTPProposer(Proposer):
self.model_inputs["draft_tokens"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
num_speculative_tokens=self.speculative_config.num_speculative_tokens,
)
if self.enable_mm:
+1
View File
@@ -1129,6 +1129,7 @@ class XPUModelRunner(ModelRunnerBase):
is_profiling=is_dummy_run,
forward_meta=self.forward_meta,
use_cudagraph=self.use_cudagraph,
num_speculative_tokens=self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
)
if self.use_cudagraph:
@@ -1,103 +0,0 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MTP模式测试 - ERNIE-4.5-21B-A3B-Paddle 模型
测试配置:
- 模型: ERNIE-4.5-21B-A3B-Paddle
- 量化: wint4
- Tensor Parallel: 4
"""
import json
import openai
import pytest
from conftest import get_model_path, get_port_num, print_logs_on_failure, start_server
def test_mtp_mode(xpu_env):
"""mtp模式测试"""
print("\n============================开始mtp + CudaGraph 模式测试!============================")
# 获取配置
port_num = get_port_num()
model_path = get_model_path()
spec_config = {"method": "mtp", "num_speculative_tokens": 1, "model": f"{model_path}/ERNIE-4.5-21B-A3B-Paddle/mtp"}
# 构建服务器启动参数
server_args = [
"--model",
f"{model_path}/ERNIE-4.5-21B-A3B-Paddle",
"--port",
str(port_num),
"--engine-worker-queue-port",
str(port_num + 1),
"--metrics-port",
str(port_num + 2),
"--tensor-parallel-size",
"4",
"--num-gpu-blocks-override",
"16384",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--quantization",
"wint4",
"--speculative-config",
f"{json.dumps(spec_config)}",
"--graph-optimization-config",
'{"use_cudagraph":true}',
]
# 启动服务器
if not start_server(server_args):
pytest.fail("mtp模式服务启动失败")
# 执行测试
try:
ip = "0.0.0.0"
client = openai.Client(base_url=f"http://{ip}:{port_num}/v1", api_key="EMPTY_API_KEY")
# 非流式对话
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": "你好,你是谁?"},
],
temperature=1,
top_p=0,
max_tokens=64,
stream=False,
)
print(f"\n模型回复: {response.choices[0].message.content}")
# 验证响应
assert any(
keyword in response.choices[0].message.content for keyword in ["人工智能", "文心一言", "百度", "智能助手"]
), f"响应内容不符合预期: {response.choices[0].message.content}"
print("\nmtp + CudaGraph模式测试通过!")
except Exception as e:
print(f"\nmtp + CudaGraph模式测试失败: {str(e)}")
print_logs_on_failure()
pytest.fail(f"mtp + CudaGraph模式测试失败: {str(e)}")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])