mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Split the block_attn operator into smaller operators (#6798)
* spliced block_attn * adapt to latest vllm * fix unit tests * delete mtp+cudagraph 4 cards test * fix vl model * fix mtp * fix slot mapping
This commit is contained in:
@@ -159,6 +159,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
if (use_neox_rotary_style) {
|
||||
pos_emb_type = "NEOX";
|
||||
} else if (rope_head_dim == head_dim / 2) {
|
||||
// vl model use this
|
||||
pos_emb_type = "HALF_HEAD_DIM";
|
||||
} else {
|
||||
pos_emb_type = "NORMAL";
|
||||
@@ -984,7 +985,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
return {block_attn_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> BlockAttn(
|
||||
std::vector<paddle::Tensor> BlockAttnFused(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
@@ -1008,6 +1009,8 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::Tensor& decoder_context_len_cache,
|
||||
const paddle::Tensor& decoder_batch_map,
|
||||
const paddle::Tensor& prefix_len,
|
||||
const paddle::Tensor& slot_mapping_enc,
|
||||
const paddle::Tensor& slot_mapping_dec,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
@@ -1067,7 +1070,7 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
} else if (cache_dtype == paddle::DataType::INT8) {
|
||||
APPLY_KERNEL(paddle::bfloat16, int8_t, paddle::bfloat16);
|
||||
} else {
|
||||
PD_THROW("block_attn not support cache_dtype==%d",
|
||||
PD_THROW("block_attn_fused not support cache_dtype==%d",
|
||||
static_cast<int>(cache_dtype));
|
||||
return {};
|
||||
}
|
||||
@@ -1097,7 +1100,7 @@ std::vector<paddle::DataType> BlockAttnInferDtype(
|
||||
return {qkv_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(block_attn)
|
||||
PD_BUILD_STATIC_OP(block_attn_fused)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
@@ -1121,6 +1124,8 @@ PD_BUILD_STATIC_OP(block_attn)
|
||||
"decoder_context_len_cache",
|
||||
"decoder_batch_map",
|
||||
"prefix_len",
|
||||
"slot_mapping_enc",
|
||||
"slot_mapping_dec",
|
||||
paddle::Optional("k_scales"),
|
||||
paddle::Optional("v_scales"),
|
||||
paddle::Optional("k_scales_inv"),
|
||||
@@ -1135,6 +1140,6 @@ PD_BUILD_STATIC_OP(block_attn)
|
||||
paddle::Optional("cachekv_signal_thread_cpu")})
|
||||
.Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"})
|
||||
.Outputs({"block_attn_out"})
|
||||
.SetKernelFn(PD_KERNEL(BlockAttn))
|
||||
.SetKernelFn(PD_KERNEL(BlockAttnFused))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(BlockAttnInferDtype));
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,14 +16,60 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/internal/infra_op.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "ops/utility/env.h"
|
||||
|
||||
XPU_DECLARE_BOOL(encoder_splice, false);
|
||||
XPU_DECLARE_BOOL(decoder_splice, false);
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
void lod_to_slot_mapping(api::Context* xpu_ctx,
|
||||
paddle::Place place,
|
||||
const std::vector<int32_t>& block_table,
|
||||
const std::vector<int32_t>& kv_seq_lod,
|
||||
const std::vector<int32_t>& start_tokens,
|
||||
const std::vector<int32_t>& real_batch,
|
||||
int32_t* slot_mapping,
|
||||
int32_t token_num,
|
||||
int32_t block_size,
|
||||
int32_t batch_size,
|
||||
int32_t max_num_blocks_per_seq,
|
||||
int32_t num_speculative_tokens) {
|
||||
if (token_num <= 0) {
|
||||
return;
|
||||
}
|
||||
std::vector<int32_t> slot_mapping_vec(token_num, -1);
|
||||
int32_t idx = 0;
|
||||
// For each Batch
|
||||
for (auto batch_ = 0; batch_ < batch_size; batch_++) {
|
||||
int32_t seq_len = kv_seq_lod[batch_ + 1] - kv_seq_lod[batch_];
|
||||
int32_t seq_start = start_tokens[batch_];
|
||||
int32_t dst_batch_id = real_batch[batch_];
|
||||
// for each token
|
||||
for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) {
|
||||
int32_t table_id = seq_ / block_size;
|
||||
int32_t block_id =
|
||||
block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
|
||||
int32_t seq_offset = seq_ % block_size;
|
||||
int32_t dst_token_offset = block_id * block_size + seq_offset;
|
||||
slot_mapping_vec[idx] = dst_token_offset;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
int ret = api::do_host2device(xpu_ctx,
|
||||
slot_mapping_vec.data(),
|
||||
slot_mapping,
|
||||
token_num * sizeof(int32_t));
|
||||
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetInferParam(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& block_tables,
|
||||
int block_size) {
|
||||
int block_size,
|
||||
int num_speculative_tokens) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
@@ -109,6 +155,15 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
batch_offset++;
|
||||
}
|
||||
}
|
||||
// for vsl_rotary_embedding_gptj of cudagraph mode
|
||||
int prev_val = 0;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
if (decoder_seq_lod_vec[i] > prev_val) {
|
||||
prev_val = decoder_seq_lod_vec[i];
|
||||
} else if (decoder_seq_lod_vec[i] < prev_val) {
|
||||
decoder_seq_lod_vec[i] = prev_val;
|
||||
}
|
||||
}
|
||||
int prefix_block_num_per_seq = (max_kv_len + block_size - 1) / block_size;
|
||||
std::vector<int32_t> prefix_block_tables_vec(
|
||||
enc_batch * prefix_block_num_per_seq, -1);
|
||||
@@ -167,6 +222,52 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
|
||||
// for store_paged_kv_cache of cudagraph mode
|
||||
// if slot_mapping is -1, store_paged_kv_cache will not write to kv cache
|
||||
paddle::Tensor slot_mapping_enc = paddle::full(
|
||||
{total_enc_len}, -1, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
// TODO: mtp mode not verified yet, need further adaption
|
||||
paddle::Tensor slot_mapping_dec =
|
||||
paddle::full({bsz * (1 + num_speculative_tokens)},
|
||||
-1,
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_decoder.place());
|
||||
if (FLAGS_encoder_splice || FLAGS_decoder_splice) {
|
||||
std::vector<int32_t> block_tables_vec(block_bs * block_num_per_seq);
|
||||
r = xpu_memcpy(block_tables_vec.data(),
|
||||
block_tables.data<int32_t>(),
|
||||
sizeof(int32_t) * block_bs * block_num_per_seq,
|
||||
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
|
||||
if (FLAGS_encoder_splice) {
|
||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||
seq_lens_encoder.place(),
|
||||
block_tables_vec,
|
||||
encoder_seq_lod_vec,
|
||||
prefix_len_vec,
|
||||
encoder_batch_map_vec,
|
||||
slot_mapping_enc.data<int32_t>(),
|
||||
total_enc_len,
|
||||
block_size,
|
||||
enc_batch,
|
||||
block_num_per_seq,
|
||||
0);
|
||||
}
|
||||
if (FLAGS_decoder_splice) {
|
||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||
seq_lens_decoder.place(),
|
||||
block_tables_vec,
|
||||
decoder_seq_lod_vec,
|
||||
decoder_context_len_cache_vec,
|
||||
decoder_batch_map_vec,
|
||||
slot_mapping_dec.data<int32_t>(),
|
||||
bsz * (1 + num_speculative_tokens),
|
||||
block_size,
|
||||
dec_batch,
|
||||
block_num_per_seq,
|
||||
num_speculative_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
@@ -326,7 +427,9 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu};
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetInferParamInferShape(
|
||||
@@ -400,8 +503,10 @@ PD_BUILD_OP(get_infer_param)
|
||||
"prefix_len_cpu",
|
||||
"decoder_context_len_cpu",
|
||||
"decoder_context_len_cache_cpu",
|
||||
"len_info_cpu"})
|
||||
"len_info_cpu",
|
||||
"slot_mapping_enc",
|
||||
"slot_mapping_dec"})
|
||||
.SetKernelFn(PD_KERNEL(GetInferParam))
|
||||
.Attrs({"block_size: int"})
|
||||
.Attrs({"block_size: int", "num_speculative_tokens: int"})
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetInferParamInferDtype));
|
||||
|
||||
@@ -57,7 +57,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag);
|
||||
|
||||
std::vector<paddle::Tensor> BlockAttn(
|
||||
std::vector<paddle::Tensor> SplitEmbeddingKVCacheBlockAttn(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
@@ -81,6 +81,50 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::Tensor& decoder_context_len_cache_xpu,
|
||||
const paddle::Tensor& decoder_batch_map_xpu,
|
||||
const paddle::Tensor& prefix_len_xpu,
|
||||
const paddle::Tensor& slot_mapping_enc,
|
||||
const paddle::Tensor& slot_mapping_dec,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
const paddle::optional<paddle::Tensor>& v_scales_inv,
|
||||
const paddle::optional<paddle::Tensor>& k_zeros,
|
||||
const paddle::optional<paddle::Tensor>& v_zeros,
|
||||
const paddle::optional<paddle::Tensor>& shift,
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d = false);
|
||||
|
||||
// deprecated, keep for unit test, will be removed in the future
|
||||
std::vector<paddle::Tensor> BlockAttnFused(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& rotary_embs,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& prefix_block_tables,
|
||||
const paddle::Tensor& len_info_cpu,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& decoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_kv_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_map_cpu,
|
||||
const paddle::Tensor& decoder_context_len_cpu,
|
||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& prefix_len_cpu,
|
||||
const paddle::Tensor& encoder_seq_lod_xpu,
|
||||
const paddle::Tensor& decoder_seq_lod_xpu,
|
||||
const paddle::Tensor& encoder_kv_lod_xpu,
|
||||
const paddle::Tensor& encoder_batch_map_xpu,
|
||||
const paddle::Tensor& decoder_context_len_xpu,
|
||||
const paddle::Tensor& decoder_context_len_cache_xpu,
|
||||
const paddle::Tensor& decoder_batch_map_xpu,
|
||||
const paddle::Tensor& prefix_len_xpu,
|
||||
const paddle::Tensor& slot_mapping_enc,
|
||||
const paddle::Tensor& slot_mapping_dec,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
@@ -434,7 +478,8 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& block_tables,
|
||||
int block_size);
|
||||
int block_size,
|
||||
int num_speculative_tokens);
|
||||
|
||||
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);
|
||||
|
||||
@@ -749,7 +794,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
"adjust batch in XPU");
|
||||
|
||||
m.def("block_attn",
|
||||
&BlockAttn,
|
||||
&SplitEmbeddingKVCacheBlockAttn,
|
||||
py::arg("qkv"),
|
||||
py::arg("key_cache"),
|
||||
py::arg("value_cache"),
|
||||
@@ -773,6 +818,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("decoder_context_len_cache_xpu"),
|
||||
py::arg("decoder_batch_map_xpu"),
|
||||
py::arg("prefix_len_xpu"),
|
||||
py::arg("slot_mapping_enc"),
|
||||
py::arg("slot_mapping_dec"),
|
||||
py::arg("k_scales"),
|
||||
py::arg("v_scales"),
|
||||
py::arg("k_scales_inv"),
|
||||
@@ -789,6 +836,49 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("rope_3d") = false,
|
||||
"block attention in XPU");
|
||||
|
||||
m.def("block_attn_fused",
|
||||
&BlockAttnFused,
|
||||
py::arg("qkv"),
|
||||
py::arg("key_cache"),
|
||||
py::arg("value_cache"),
|
||||
py::arg("rotary_embs"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("prefix_block_tables"),
|
||||
py::arg("len_info_cpu"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("decoder_seq_lod_cpu"),
|
||||
py::arg("encoder_kv_lod_cpu"),
|
||||
py::arg("encoder_batch_map_cpu"),
|
||||
py::arg("decoder_context_len_cpu"),
|
||||
py::arg("decoder_context_len_cache_cpu"),
|
||||
py::arg("decoder_batch_map_cpu"),
|
||||
py::arg("prefix_len_cpu"),
|
||||
py::arg("encoder_seq_lod_xpu"),
|
||||
py::arg("decoder_seq_lod_xpu"),
|
||||
py::arg("encoder_kv_lod_xpu"),
|
||||
py::arg("encoder_batch_map_xpu"),
|
||||
py::arg("decoder_context_len_xpu"),
|
||||
py::arg("decoder_context_len_cache_xpu"),
|
||||
py::arg("decoder_batch_map_xpu"),
|
||||
py::arg("prefix_len_xpu"),
|
||||
py::arg("slot_mapping_enc"),
|
||||
py::arg("slot_mapping_dec"),
|
||||
py::arg("k_scales"),
|
||||
py::arg("v_scales"),
|
||||
py::arg("k_scales_inv"),
|
||||
py::arg("v_scales_inv"),
|
||||
py::arg("k_zeros"),
|
||||
py::arg("v_zeros"),
|
||||
py::arg("shift"),
|
||||
py::arg("smooth"),
|
||||
py::arg("q_norm_weight"),
|
||||
py::arg("k_norm_weight"),
|
||||
py::arg("kv_signal_data_cpu"),
|
||||
py::arg("cachekv_signal_thread_cpu"),
|
||||
py::arg("use_neox_rotary_style"),
|
||||
py::arg("rope_3d") = false,
|
||||
"block attention fused in XPU");
|
||||
|
||||
m.def("create_kv_signal_sender",
|
||||
&create_cachekv_signal_thread,
|
||||
"init write cache kv signal thread");
|
||||
@@ -963,6 +1053,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("block_size"),
|
||||
py::arg("num_speculative_tokens"),
|
||||
"Get infer parameters for block attention in XPU");
|
||||
|
||||
m.def("get_peer_mem_addr",
|
||||
|
||||
@@ -0,0 +1,653 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
# block_attn_fused is deprecated and should be removed in the future
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
block_attn,
|
||||
block_attn_fused,
|
||||
get_infer_param,
|
||||
)
|
||||
|
||||
|
||||
def print_all_not_equal_elements_info(k, x, y):
|
||||
x_flatten = x.flatten()
|
||||
y_flatten = y.flatten()
|
||||
index = paddle.nonzero(x_flatten != y_flatten)
|
||||
x_not_equal = x_flatten[index]
|
||||
y_not_equal = y_flatten[index]
|
||||
print(f"reference not equal element of {k}: {x_not_equal}")
|
||||
print(f"calculated result not equal element of {k}: {y_not_equal}")
|
||||
xy_diff = x - y
|
||||
xy_mean_diff = paddle.mean(xy_diff)
|
||||
xy_max_abs_diff = paddle.max(paddle.abs(xy_diff))
|
||||
xy_min_abs_diff = paddle.min(paddle.abs(xy_diff))
|
||||
print(f"{k} mean diff: {xy_mean_diff}, max abs diff: {xy_max_abs_diff}, min abs diff: {xy_min_abs_diff}")
|
||||
|
||||
|
||||
def run_prefix_cache_block_attn(
|
||||
block_attn_func,
|
||||
qkv,
|
||||
seq_len,
|
||||
seq_lens_this_time,
|
||||
hit_prefix_len,
|
||||
key_cache,
|
||||
value_cache,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
attn_out,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
k_zp,
|
||||
v_zp,
|
||||
shift,
|
||||
smooth,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
kv_signal_data_cpu,
|
||||
cachekv_signal_thread_cpu,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
num_speculative_tokens,
|
||||
):
|
||||
if key_cache.dtype == paddle.int8:
|
||||
rtol = 1e-1
|
||||
atol = 1e-2
|
||||
else:
|
||||
rtol = 1e-2
|
||||
atol = 1e-3
|
||||
# prefix cache block attn
|
||||
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
||||
) # block_size
|
||||
qkv_prefix = qkv[hit_prefix_len:]
|
||||
attn_out_prefix_cache = block_attn_func(
|
||||
qkv_prefix,
|
||||
key_cache,
|
||||
value_cache,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
len_info_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
k_zp,
|
||||
v_zp,
|
||||
shift,
|
||||
smooth,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
kv_signal_data_cpu,
|
||||
cachekv_signal_thread_cpu,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
)
|
||||
attn_out_np = attn_out[hit_prefix_len:].astype("float32").numpy()
|
||||
attn_out_prefix_cache_np = attn_out_prefix_cache.astype("float32").numpy()
|
||||
is_passed = np.allclose(attn_out_np, attn_out_prefix_cache_np, rtol=rtol, atol=atol)
|
||||
if not is_passed:
|
||||
print(f"block_attn_func: {block_attn_func}")
|
||||
print("prefix_cache block_attn check failed!")
|
||||
print(f"origin block_attn_out: {attn_out[hit_prefix_len:]}")
|
||||
print(f"prefix_cache block_attn_out: {attn_out_prefix_cache}")
|
||||
print("not equal elements are listed below:")
|
||||
print_all_not_equal_elements_info("block_attn_out", attn_out[hit_prefix_len:], attn_out_prefix_cache)
|
||||
else:
|
||||
print(f"prefix_cache check of {block_attn_func} PASSED!")
|
||||
assert is_passed
|
||||
return attn_out_prefix_cache
|
||||
|
||||
|
||||
def run_block_attn(
|
||||
seed,
|
||||
is_fused,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
mode, # 1 for split kvcache encoder only, 2 for split kvcache decoder only, 3 for mixed
|
||||
hit_prefix_len,
|
||||
kvcache_dtype,
|
||||
has_zp,
|
||||
use_neox_rotary_style,
|
||||
rotary_embs_shape,
|
||||
num_speculative_tokens,
|
||||
):
|
||||
assert mode == 0 or mode == 1, "mixed mode not supported yet!"
|
||||
if mode == 0:
|
||||
encoder_seq_len = seq_len
|
||||
decoder_seq_len = 0
|
||||
elif mode == 1:
|
||||
encoder_seq_len = 0
|
||||
decoder_seq_len = seq_len
|
||||
else:
|
||||
pass
|
||||
seq_lens_encoder = paddle.to_tensor([encoder_seq_len, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_decoder = paddle.to_tensor([decoder_seq_len, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
|
||||
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
|
||||
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
||||
)
|
||||
qkv = paddle.uniform(
|
||||
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
|
||||
)
|
||||
|
||||
rotary_embs = paddle.uniform(shape=rotary_embs_shape, dtype="float32", min=-1.0, max=1.0, seed=seed)
|
||||
key_cache = paddle.zeros(
|
||||
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
|
||||
dtype=kvcache_dtype,
|
||||
)
|
||||
value_cache = paddle.zeros(
|
||||
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
|
||||
dtype=kvcache_dtype,
|
||||
)
|
||||
|
||||
scale_tensor_k = None
|
||||
scale_tensor_v = None
|
||||
k_quant_scale = None
|
||||
v_quant_scale = None
|
||||
k_dequant_scale = None
|
||||
v_dequant_scale = None
|
||||
k_zp = None
|
||||
v_zp = None
|
||||
if kvcache_dtype == "int8":
|
||||
scale_tensor_k = paddle.uniform(
|
||||
shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed
|
||||
) # max
|
||||
scale_tensor_v = paddle.uniform(
|
||||
shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed
|
||||
) # max
|
||||
k_quant_scale = 127.0 / scale_tensor_k # for C8 per channel means 127 / max
|
||||
v_quant_scale = 127.0 / scale_tensor_v # for C8 per channel means 127 / max
|
||||
if has_zp:
|
||||
k_dequant_scale = 1 / k_quant_scale # for C8 per channel zp means max
|
||||
v_dequant_scale = 1 / v_quant_scale # for C8 per channel zp means max
|
||||
k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
|
||||
v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
|
||||
else:
|
||||
k_dequant_scale = paddle.cast(scale_tensor_k, dtype="float32") # for C8 per channel means max
|
||||
v_dequant_scale = paddle.cast(scale_tensor_v, dtype="float32") # for C8 per channel means max
|
||||
# variable below are not yet used
|
||||
shift = None
|
||||
smooth = None
|
||||
q_norm_weight = None
|
||||
k_norm_weight = None
|
||||
kv_signal_data_cpu = None
|
||||
cachekv_signal_thread_cpu = None
|
||||
rope_3d = False
|
||||
|
||||
if is_fused:
|
||||
block_attn_func = block_attn_fused
|
||||
else:
|
||||
block_attn_func = block_attn
|
||||
attn_out = block_attn_func(
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
len_info_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
k_zp,
|
||||
v_zp,
|
||||
shift,
|
||||
smooth,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
kv_signal_data_cpu,
|
||||
cachekv_signal_thread_cpu,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
)
|
||||
result = {
|
||||
"block_attn_out": attn_out,
|
||||
"key_cache": key_cache,
|
||||
"value_cache": value_cache,
|
||||
}
|
||||
|
||||
# prefix cache
|
||||
if mode == 0 and hit_prefix_len > 0:
|
||||
assert hit_prefix_len < seq_len
|
||||
attn_out_prefix_cache = run_prefix_cache_block_attn(
|
||||
block_attn_func,
|
||||
qkv,
|
||||
seq_len,
|
||||
seq_lens_this_time,
|
||||
hit_prefix_len,
|
||||
key_cache,
|
||||
value_cache,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
attn_out,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
k_zp,
|
||||
v_zp,
|
||||
shift,
|
||||
smooth,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
kv_signal_data_cpu,
|
||||
cachekv_signal_thread_cpu,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
result["prefix_cache_block_attn_out"] = attn_out_prefix_cache
|
||||
return result
|
||||
|
||||
|
||||
def run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len=0,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=False,
|
||||
only_run_spliced=False,
|
||||
num_speculative_tokens=0,
|
||||
):
|
||||
rtol = 1e-3
|
||||
atol = 1e-2
|
||||
# 0 for prefill only, 1 for decode only, 2 for mixed
|
||||
# TODO: mixed mode not supported yet, get_infer_param should be modified first
|
||||
mode_name = ["prefill only", "decode only", "mixed"]
|
||||
|
||||
if use_neox_rotary_style:
|
||||
embedding_type = "neox"
|
||||
else:
|
||||
embedding_type = "rope"
|
||||
for mode in [0, 1]:
|
||||
if mode == 0:
|
||||
seq_len_list = [seq_len]
|
||||
elif mode == 1:
|
||||
# seq_len > 1 goes into mtp branch, which only supports seq_len <= 31
|
||||
# TODO: mtp mode need further adaption
|
||||
# seq_len_list = [1, random.randint(2, 31)]
|
||||
seq_len_list = [1]
|
||||
for idx, seqlen in enumerate(seq_len_list):
|
||||
if idx == 0:
|
||||
branch_name = "non mtp branch"
|
||||
elif idx == 1:
|
||||
branch_name = "mtp branch"
|
||||
print(
|
||||
f"runnning block attention of mode {mode_name[mode]} ({branch_name}), is_prefix_cache: {hit_prefix_len > 0}, kvcache type: {kvcache_dtype}, has_zp: {has_zp}, rotary_style: {embedding_type}"
|
||||
)
|
||||
if not only_run_spliced:
|
||||
fused_result = run_block_attn(
|
||||
seed,
|
||||
True, # is_fused
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seqlen,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
mode,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype,
|
||||
has_zp,
|
||||
use_neox_rotary_style,
|
||||
rotary_embs_shape,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
spliced_result = run_block_attn(
|
||||
seed,
|
||||
False, # is_fused
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seqlen,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
mode,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype,
|
||||
has_zp,
|
||||
use_neox_rotary_style,
|
||||
rotary_embs_shape,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
if "fused_result" in locals() and "spliced_result" in locals():
|
||||
for k in fused_result.keys():
|
||||
if paddle.is_integer(fused_result[k]):
|
||||
fused_v = fused_result[k].astype("int32")
|
||||
spliced_v = spliced_result[k].astype("int32")
|
||||
fused_v_np = fused_v.numpy()
|
||||
splice_v_np = spliced_v.numpy()
|
||||
# is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-1, atol=1e-1)
|
||||
is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-2, atol=rtol)
|
||||
else:
|
||||
fused_v = fused_result[k].astype("float32")
|
||||
spliced_v = spliced_result[k].astype("float32")
|
||||
fused_v_np = fused_v.numpy()
|
||||
splice_v_np = spliced_v.numpy()
|
||||
is_passed = np.allclose(fused_v_np, splice_v_np, rtol=rtol, atol=atol, equal_nan=True)
|
||||
if not is_passed:
|
||||
print(f"{k} in mode {mode_name[mode]} check FAILED!")
|
||||
print(f"fused {k}: {fused_v}")
|
||||
print(f"spliced {k}: {spliced_v}")
|
||||
print("not equal elements are listed below:")
|
||||
print_all_not_equal_elements_info(k, fused_v, spliced_v)
|
||||
else:
|
||||
print(f"{k} in mode {mode_name[mode]} check PASSED!")
|
||||
assert is_passed
|
||||
print("")
|
||||
else:
|
||||
if "fused_result" not in locals():
|
||||
print("fused_result not found.")
|
||||
if "spliced_result" not in locals():
|
||||
print("spliced_result not found.")
|
||||
print("skip comparison.")
|
||||
|
||||
|
||||
seed = random.randint(0, 2026)
|
||||
paddle.seed(seed)
|
||||
head_num = 64
|
||||
kv_head_num = 8
|
||||
head_dim = 128
|
||||
rotary_embs_shape = [2, 1, 8192, 1, head_dim]
|
||||
seq_len = 128
|
||||
block_batch = 5
|
||||
max_block_per_seq = 128
|
||||
block_size = 64
|
||||
# TODO: if hit_prefix_len has a small value, e.g. hit_prefix_len == 2, block_attn_out and prefix_cache_block_attn_out will have greater diff
|
||||
hit_prefix_len = 71
|
||||
|
||||
# no prefix cache
|
||||
# block_attn fused vs spliced
|
||||
use_neox_rotary_style = False
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
# c8 quantization block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="int8",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
# c8 zp quantization block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="int8",
|
||||
has_zp=True,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
|
||||
# prefix cache
|
||||
# block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
# c8 quantization block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype="int8",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
# c8 zp quantization block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype="int8",
|
||||
has_zp=True,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
|
||||
# # neox
|
||||
# # block_attn fused vs spliced
|
||||
# # no prefix cache
|
||||
use_neox_rotary_style = True
|
||||
only_run_spliced = False
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
only_run_spliced=only_run_spliced,
|
||||
)
|
||||
# prefix cache
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
only_run_spliced=only_run_spliced,
|
||||
)
|
||||
|
||||
# neox glm 4.5 air debug
|
||||
head_num = 24
|
||||
kv_head_num = 2
|
||||
head_dim = 128
|
||||
seq_len = 128
|
||||
block_batch = 64
|
||||
max_block_per_seq = 2050
|
||||
block_size = 64
|
||||
rotary_embs_shape = [2, 1, 131072, 1, head_dim // 2]
|
||||
|
||||
use_neox_rotary_style = True
|
||||
only_run_spliced = False
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
only_run_spliced=only_run_spliced,
|
||||
)
|
||||
|
||||
print("\nALL PASSED!")
|
||||
@@ -15,7 +15,7 @@
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import block_attn, get_infer_param
|
||||
from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param
|
||||
|
||||
head_num = 64
|
||||
kv_head_num = 8
|
||||
@@ -53,8 +53,10 @@ block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
|
||||
) # block_size
|
||||
|
||||
qkv = paddle.uniform(
|
||||
@@ -64,7 +66,6 @@ qkv = paddle.uniform(
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
cum_offsets = paddle.zeros(shape=[block_batch], dtype="bfloat16")
|
||||
rotary_embs = paddle.uniform(shape=[2, 1, 8192, 1, head_dim], dtype="float32", min=-1.0, max=1.0)
|
||||
key_cache = paddle.zeros(
|
||||
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
|
||||
@@ -94,11 +95,10 @@ v_dequant_scale_zp = 1 / v_quant_scale # for C8 per channel zp means max
|
||||
|
||||
k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
|
||||
v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
|
||||
attn_out = block_attn(
|
||||
attn_out = block_attn_fused(
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -111,6 +111,16 @@ attn_out = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -121,12 +131,15 @@ attn_out = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
attn_out_C8 = block_attn(
|
||||
attn_out_C8 = block_attn_fused(
|
||||
qkv,
|
||||
key_cache_int8,
|
||||
value_cache_int8,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -139,6 +152,16 @@ attn_out_C8 = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
@@ -149,12 +172,15 @@ attn_out_C8 = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
attn_out_C8_zp = block_attn(
|
||||
attn_out_C8_zp = block_attn_fused(
|
||||
qkv,
|
||||
key_cache_int8,
|
||||
value_cache_int8,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -167,6 +193,16 @@ attn_out_C8_zp = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale_zp,
|
||||
@@ -177,6 +213,10 @@ attn_out_C8_zp = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
# prefix cache : hit 71 tokens
|
||||
@@ -207,16 +247,17 @@ seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
|
||||
) # block_size
|
||||
qkv_prefix = qkv[hit_prefix_len:]
|
||||
|
||||
attn_out_prefix_cache = block_attn(
|
||||
attn_out_prefix_cache = block_attn_fused(
|
||||
qkv_prefix,
|
||||
key_cache,
|
||||
value_cache,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -229,6 +270,16 @@ attn_out_prefix_cache = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -239,13 +290,16 @@ attn_out_prefix_cache = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
attn_out_C8_prefix_cache = block_attn(
|
||||
attn_out_C8_prefix_cache = block_attn_fused(
|
||||
qkv_prefix,
|
||||
key_cache_int8,
|
||||
value_cache_int8,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -258,6 +312,16 @@ attn_out_C8_prefix_cache = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
@@ -268,13 +332,16 @@ attn_out_C8_prefix_cache = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
attn_out_C8_zp_prefix_cache = block_attn(
|
||||
attn_out_C8_zp_prefix_cache = block_attn_fused(
|
||||
qkv_prefix,
|
||||
key_cache_int8,
|
||||
value_cache_int8,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -287,6 +354,16 @@ attn_out_C8_zp_prefix_cache = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale_zp,
|
||||
@@ -297,6 +374,10 @@ attn_out_C8_zp_prefix_cache = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
print("-- C16 prefix cache test --")
|
||||
print("attn_out[hit_prefix_len:]'s mean:", attn_out[hit_prefix_len:].mean().item())
|
||||
|
||||
@@ -278,6 +278,11 @@ class XPUForwardMeta(ForwardMeta):
|
||||
# max bs
|
||||
max_num_seqs: int = 0
|
||||
|
||||
# for spliced block_attn
|
||||
slot_mapping_enc: Optional[paddle.Tensor] = None
|
||||
#
|
||||
slot_mapping_dec: Optional[paddle.Tensor] = None
|
||||
|
||||
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
|
||||
"""
|
||||
Synchronize attributes from another XPUForwardMeta object
|
||||
|
||||
@@ -214,6 +214,8 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_context_len_cache,
|
||||
forward_meta.decoder_batch_map,
|
||||
forward_meta.prefix_len,
|
||||
forward_meta.slot_mapping_enc,
|
||||
forward_meta.slot_mapping_dec,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_out_scale,
|
||||
|
||||
@@ -127,6 +127,7 @@ def xpu_pre_process(
|
||||
is_profiling: bool = False,
|
||||
forward_meta=None,
|
||||
use_cudagraph=False,
|
||||
num_speculative_tokens=0,
|
||||
) -> XPUForwardMeta:
|
||||
""" """
|
||||
|
||||
@@ -196,8 +197,15 @@ def xpu_pre_process(
|
||||
xpu_forward_meta.decoder_context_len_cpu,
|
||||
xpu_forward_meta.decoder_context_len_cache_cpu,
|
||||
xpu_forward_meta.len_info_cpu,
|
||||
xpu_forward_meta.slot_mapping_enc,
|
||||
xpu_forward_meta.slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
xpu_forward_meta.block_tables,
|
||||
block_size,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
|
||||
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
|
||||
|
||||
@@ -1078,6 +1078,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["draft_tokens"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
num_speculative_tokens=self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
if self.enable_mm:
|
||||
|
||||
@@ -1129,6 +1129,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
is_profiling=is_dummy_run,
|
||||
forward_meta=self.forward_meta,
|
||||
use_cudagraph=self.use_cudagraph,
|
||||
num_speculative_tokens=self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
|
||||
)
|
||||
|
||||
if self.use_cudagraph:
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
MTP模式测试 - ERNIE-4.5-21B-A3B-Paddle 模型
|
||||
|
||||
测试配置:
|
||||
- 模型: ERNIE-4.5-21B-A3B-Paddle
|
||||
- 量化: wint4
|
||||
- Tensor Parallel: 4
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from conftest import get_model_path, get_port_num, print_logs_on_failure, start_server
|
||||
|
||||
|
||||
def test_mtp_mode(xpu_env):
|
||||
"""mtp模式测试"""
|
||||
|
||||
print("\n============================开始mtp + CudaGraph 模式测试!============================")
|
||||
|
||||
# 获取配置
|
||||
port_num = get_port_num()
|
||||
model_path = get_model_path()
|
||||
spec_config = {"method": "mtp", "num_speculative_tokens": 1, "model": f"{model_path}/ERNIE-4.5-21B-A3B-Paddle/mtp"}
|
||||
# 构建服务器启动参数
|
||||
server_args = [
|
||||
"--model",
|
||||
f"{model_path}/ERNIE-4.5-21B-A3B-Paddle",
|
||||
"--port",
|
||||
str(port_num),
|
||||
"--engine-worker-queue-port",
|
||||
str(port_num + 1),
|
||||
"--metrics-port",
|
||||
str(port_num + 2),
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
"--num-gpu-blocks-override",
|
||||
"16384",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--quantization",
|
||||
"wint4",
|
||||
"--speculative-config",
|
||||
f"{json.dumps(spec_config)}",
|
||||
"--graph-optimization-config",
|
||||
'{"use_cudagraph":true}',
|
||||
]
|
||||
|
||||
# 启动服务器
|
||||
if not start_server(server_args):
|
||||
pytest.fail("mtp模式服务启动失败")
|
||||
|
||||
# 执行测试
|
||||
try:
|
||||
ip = "0.0.0.0"
|
||||
client = openai.Client(base_url=f"http://{ip}:{port_num}/v1", api_key="EMPTY_API_KEY")
|
||||
|
||||
# 非流式对话
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "user", "content": "你好,你是谁?"},
|
||||
],
|
||||
temperature=1,
|
||||
top_p=0,
|
||||
max_tokens=64,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
print(f"\n模型回复: {response.choices[0].message.content}")
|
||||
|
||||
# 验证响应
|
||||
assert any(
|
||||
keyword in response.choices[0].message.content for keyword in ["人工智能", "文心一言", "百度", "智能助手"]
|
||||
), f"响应内容不符合预期: {response.choices[0].message.content}"
|
||||
|
||||
print("\nmtp + CudaGraph模式测试通过!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nmtp + CudaGraph模式测试失败: {str(e)}")
|
||||
print_logs_on_failure()
|
||||
pytest.fail(f"mtp + CudaGraph模式测试失败: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user