Files
FastDeploy/custom_ops/xpu_ops/test/test_block_attn.py
T
RuohengMa de0c5e68fb [XPU] Split the block_attn operator into smaller operators (#6798)
* spliced block_attn

* adapt to latest vllm

* fix unit tests

* delete mtp+cudagraph 4 cards test

* fix vl model

* fix mtp

* fix slot mapping
2026-04-16 14:28:40 +08:00

654 lines
19 KiB
Python

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import paddle
# block_attn_fused is deprecated and should be removed in the future
from fastdeploy.model_executor.ops.xpu import (
block_attn,
block_attn_fused,
get_infer_param,
)
def print_all_not_equal_elements_info(k, x, y):
x_flatten = x.flatten()
y_flatten = y.flatten()
index = paddle.nonzero(x_flatten != y_flatten)
x_not_equal = x_flatten[index]
y_not_equal = y_flatten[index]
print(f"reference not equal element of {k}: {x_not_equal}")
print(f"calculated result not equal element of {k}: {y_not_equal}")
xy_diff = x - y
xy_mean_diff = paddle.mean(xy_diff)
xy_max_abs_diff = paddle.max(paddle.abs(xy_diff))
xy_min_abs_diff = paddle.min(paddle.abs(xy_diff))
print(f"{k} mean diff: {xy_mean_diff}, max abs diff: {xy_max_abs_diff}, min abs diff: {xy_min_abs_diff}")
def run_prefix_cache_block_attn(
block_attn_func,
qkv,
seq_len,
seq_lens_this_time,
hit_prefix_len,
key_cache,
value_cache,
rotary_embs,
block_tables,
attn_out,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
k_zp,
v_zp,
shift,
smooth,
q_norm_weight,
k_norm_weight,
kv_signal_data_cpu,
cachekv_signal_thread_cpu,
use_neox_rotary_style,
rope_3d,
num_speculative_tokens,
):
if key_cache.dtype == paddle.int8:
rtol = 1e-1
atol = 1e-2
else:
rtol = 1e-2
atol = 1e-3
# prefix cache block attn
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
(
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
) # block_size
qkv_prefix = qkv[hit_prefix_len:]
attn_out_prefix_cache = block_attn_func(
qkv_prefix,
key_cache,
value_cache,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
k_zp,
v_zp,
shift,
smooth,
q_norm_weight,
k_norm_weight,
kv_signal_data_cpu,
cachekv_signal_thread_cpu,
use_neox_rotary_style,
rope_3d,
)
attn_out_np = attn_out[hit_prefix_len:].astype("float32").numpy()
attn_out_prefix_cache_np = attn_out_prefix_cache.astype("float32").numpy()
is_passed = np.allclose(attn_out_np, attn_out_prefix_cache_np, rtol=rtol, atol=atol)
if not is_passed:
print(f"block_attn_func: {block_attn_func}")
print("prefix_cache block_attn check failed!")
print(f"origin block_attn_out: {attn_out[hit_prefix_len:]}")
print(f"prefix_cache block_attn_out: {attn_out_prefix_cache}")
print("not equal elements are listed below:")
print_all_not_equal_elements_info("block_attn_out", attn_out[hit_prefix_len:], attn_out_prefix_cache)
else:
print(f"prefix_cache check of {block_attn_func} PASSED!")
assert is_passed
return attn_out_prefix_cache
def run_block_attn(
seed,
is_fused,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
mode, # 1 for split kvcache encoder only, 2 for split kvcache decoder only, 3 for mixed
hit_prefix_len,
kvcache_dtype,
has_zp,
use_neox_rotary_style,
rotary_embs_shape,
num_speculative_tokens,
):
assert mode == 0 or mode == 1, "mixed mode not supported yet!"
if mode == 0:
encoder_seq_len = seq_len
decoder_seq_len = 0
elif mode == 1:
encoder_seq_len = 0
decoder_seq_len = seq_len
else:
pass
seq_lens_encoder = paddle.to_tensor([encoder_seq_len, 0, 0, 0, 0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([decoder_seq_len, 0, 0, 0, 0], dtype="int32")
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
(
encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu,
slot_mapping_enc,
slot_mapping_dec,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
)
qkv = paddle.uniform(
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
)
rotary_embs = paddle.uniform(shape=rotary_embs_shape, dtype="float32", min=-1.0, max=1.0, seed=seed)
key_cache = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
dtype=kvcache_dtype,
)
value_cache = paddle.zeros(
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
dtype=kvcache_dtype,
)
scale_tensor_k = None
scale_tensor_v = None
k_quant_scale = None
v_quant_scale = None
k_dequant_scale = None
v_dequant_scale = None
k_zp = None
v_zp = None
if kvcache_dtype == "int8":
scale_tensor_k = paddle.uniform(
shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed
) # max
scale_tensor_v = paddle.uniform(
shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed
) # max
k_quant_scale = 127.0 / scale_tensor_k # for C8 per channel means 127 / max
v_quant_scale = 127.0 / scale_tensor_v # for C8 per channel means 127 / max
if has_zp:
k_dequant_scale = 1 / k_quant_scale # for C8 per channel zp means max
v_dequant_scale = 1 / v_quant_scale # for C8 per channel zp means max
k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
else:
k_dequant_scale = paddle.cast(scale_tensor_k, dtype="float32") # for C8 per channel means max
v_dequant_scale = paddle.cast(scale_tensor_v, dtype="float32") # for C8 per channel means max
# variable below are not yet used
shift = None
smooth = None
q_norm_weight = None
k_norm_weight = None
kv_signal_data_cpu = None
cachekv_signal_thread_cpu = None
rope_3d = False
if is_fused:
block_attn_func = block_attn_fused
else:
block_attn_func = block_attn
attn_out = block_attn_func(
qkv,
key_cache,
value_cache,
rotary_embs,
block_tables,
prefix_block_tables,
len_info_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
encoder_batch_map_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
decoder_batch_map_cpu,
prefix_len_cpu,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
encoder_batch_map,
decoder_context_len,
decoder_context_len_cache,
decoder_batch_map,
prefix_len,
slot_mapping_enc,
slot_mapping_dec,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
k_zp,
v_zp,
shift,
smooth,
q_norm_weight,
k_norm_weight,
kv_signal_data_cpu,
cachekv_signal_thread_cpu,
use_neox_rotary_style,
rope_3d,
)
result = {
"block_attn_out": attn_out,
"key_cache": key_cache,
"value_cache": value_cache,
}
# prefix cache
if mode == 0 and hit_prefix_len > 0:
assert hit_prefix_len < seq_len
attn_out_prefix_cache = run_prefix_cache_block_attn(
block_attn_func,
qkv,
seq_len,
seq_lens_this_time,
hit_prefix_len,
key_cache,
value_cache,
rotary_embs,
block_tables,
attn_out,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
k_zp,
v_zp,
shift,
smooth,
q_norm_weight,
k_norm_weight,
kv_signal_data_cpu,
cachekv_signal_thread_cpu,
use_neox_rotary_style,
rope_3d,
num_speculative_tokens,
)
result["prefix_cache_block_attn_out"] = attn_out_prefix_cache
return result
def run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len=0,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=False,
only_run_spliced=False,
num_speculative_tokens=0,
):
rtol = 1e-3
atol = 1e-2
# 0 for prefill only, 1 for decode only, 2 for mixed
# TODO: mixed mode not supported yet, get_infer_param should be modified first
mode_name = ["prefill only", "decode only", "mixed"]
if use_neox_rotary_style:
embedding_type = "neox"
else:
embedding_type = "rope"
for mode in [0, 1]:
if mode == 0:
seq_len_list = [seq_len]
elif mode == 1:
# seq_len > 1 goes into mtp branch, which only supports seq_len <= 31
# TODO: mtp mode need further adaption
# seq_len_list = [1, random.randint(2, 31)]
seq_len_list = [1]
for idx, seqlen in enumerate(seq_len_list):
if idx == 0:
branch_name = "non mtp branch"
elif idx == 1:
branch_name = "mtp branch"
print(
f"runnning block attention of mode {mode_name[mode]} ({branch_name}), is_prefix_cache: {hit_prefix_len > 0}, kvcache type: {kvcache_dtype}, has_zp: {has_zp}, rotary_style: {embedding_type}"
)
if not only_run_spliced:
fused_result = run_block_attn(
seed,
True, # is_fused
head_num,
kv_head_num,
head_dim,
seqlen,
block_batch,
max_block_per_seq,
block_size,
mode,
hit_prefix_len,
kvcache_dtype,
has_zp,
use_neox_rotary_style,
rotary_embs_shape,
num_speculative_tokens,
)
spliced_result = run_block_attn(
seed,
False, # is_fused
head_num,
kv_head_num,
head_dim,
seqlen,
block_batch,
max_block_per_seq,
block_size,
mode,
hit_prefix_len,
kvcache_dtype,
has_zp,
use_neox_rotary_style,
rotary_embs_shape,
num_speculative_tokens,
)
if "fused_result" in locals() and "spliced_result" in locals():
for k in fused_result.keys():
if paddle.is_integer(fused_result[k]):
fused_v = fused_result[k].astype("int32")
spliced_v = spliced_result[k].astype("int32")
fused_v_np = fused_v.numpy()
splice_v_np = spliced_v.numpy()
# is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-1, atol=1e-1)
is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-2, atol=rtol)
else:
fused_v = fused_result[k].astype("float32")
spliced_v = spliced_result[k].astype("float32")
fused_v_np = fused_v.numpy()
splice_v_np = spliced_v.numpy()
is_passed = np.allclose(fused_v_np, splice_v_np, rtol=rtol, atol=atol, equal_nan=True)
if not is_passed:
print(f"{k} in mode {mode_name[mode]} check FAILED!")
print(f"fused {k}: {fused_v}")
print(f"spliced {k}: {spliced_v}")
print("not equal elements are listed below:")
print_all_not_equal_elements_info(k, fused_v, spliced_v)
else:
print(f"{k} in mode {mode_name[mode]} check PASSED!")
assert is_passed
print("")
else:
if "fused_result" not in locals():
print("fused_result not found.")
if "spliced_result" not in locals():
print("spliced_result not found.")
print("skip comparison.")
seed = random.randint(0, 2026)
paddle.seed(seed)
head_num = 64
kv_head_num = 8
head_dim = 128
rotary_embs_shape = [2, 1, 8192, 1, head_dim]
seq_len = 128
block_batch = 5
max_block_per_seq = 128
block_size = 64
# TODO: if hit_prefix_len has a small value, e.g. hit_prefix_len == 2, block_attn_out and prefix_cache_block_attn_out will have greater diff
hit_prefix_len = 71
# no prefix cache
# block_attn fused vs spliced
use_neox_rotary_style = False
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
)
# c8 quantization block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="int8",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
)
# c8 zp quantization block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="int8",
has_zp=True,
use_neox_rotary_style=use_neox_rotary_style,
)
# prefix cache
# block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
)
# c8 quantization block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len,
kvcache_dtype="int8",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
)
# c8 zp quantization block_attn fused vs spliced
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len,
kvcache_dtype="int8",
has_zp=True,
use_neox_rotary_style=use_neox_rotary_style,
)
# # neox
# # block_attn fused vs spliced
# # no prefix cache
use_neox_rotary_style = True
only_run_spliced = False
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
only_run_spliced=only_run_spliced,
)
# prefix cache
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
hit_prefix_len,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
only_run_spliced=only_run_spliced,
)
# neox glm 4.5 air debug
head_num = 24
kv_head_num = 2
head_dim = 128
seq_len = 128
block_batch = 64
max_block_per_seq = 2050
block_size = 64
rotary_embs_shape = [2, 1, 131072, 1, head_dim // 2]
use_neox_rotary_style = True
only_run_spliced = False
run_compare_block_attn(
seed,
head_num,
kv_head_num,
head_dim,
seq_len,
block_batch,
max_block_per_seq,
block_size,
rotary_embs_shape,
0,
kvcache_dtype="bfloat16",
has_zp=False,
use_neox_rotary_style=use_neox_rotary_style,
only_run_spliced=only_run_spliced,
)
print("\nALL PASSED!")