mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Split the block_attn operator into smaller operators (#6798)
* spliced block_attn * adapt to latest vllm * fix unit tests * delete mtp+cudagraph 4 cards test * fix vl model * fix mtp * fix slot mapping
This commit is contained in:
@@ -159,6 +159,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
if (use_neox_rotary_style) {
|
||||
pos_emb_type = "NEOX";
|
||||
} else if (rope_head_dim == head_dim / 2) {
|
||||
// vl model use this
|
||||
pos_emb_type = "HALF_HEAD_DIM";
|
||||
} else {
|
||||
pos_emb_type = "NORMAL";
|
||||
@@ -984,7 +985,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
return {block_attn_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> BlockAttn(
|
||||
std::vector<paddle::Tensor> BlockAttnFused(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
@@ -1008,6 +1009,8 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::Tensor& decoder_context_len_cache,
|
||||
const paddle::Tensor& decoder_batch_map,
|
||||
const paddle::Tensor& prefix_len,
|
||||
const paddle::Tensor& slot_mapping_enc,
|
||||
const paddle::Tensor& slot_mapping_dec,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
@@ -1067,7 +1070,7 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
} else if (cache_dtype == paddle::DataType::INT8) {
|
||||
APPLY_KERNEL(paddle::bfloat16, int8_t, paddle::bfloat16);
|
||||
} else {
|
||||
PD_THROW("block_attn not support cache_dtype==%d",
|
||||
PD_THROW("block_attn_fused not support cache_dtype==%d",
|
||||
static_cast<int>(cache_dtype));
|
||||
return {};
|
||||
}
|
||||
@@ -1097,7 +1100,7 @@ std::vector<paddle::DataType> BlockAttnInferDtype(
|
||||
return {qkv_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(block_attn)
|
||||
PD_BUILD_STATIC_OP(block_attn_fused)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
@@ -1121,6 +1124,8 @@ PD_BUILD_STATIC_OP(block_attn)
|
||||
"decoder_context_len_cache",
|
||||
"decoder_batch_map",
|
||||
"prefix_len",
|
||||
"slot_mapping_enc",
|
||||
"slot_mapping_dec",
|
||||
paddle::Optional("k_scales"),
|
||||
paddle::Optional("v_scales"),
|
||||
paddle::Optional("k_scales_inv"),
|
||||
@@ -1135,6 +1140,6 @@ PD_BUILD_STATIC_OP(block_attn)
|
||||
paddle::Optional("cachekv_signal_thread_cpu")})
|
||||
.Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"})
|
||||
.Outputs({"block_attn_out"})
|
||||
.SetKernelFn(PD_KERNEL(BlockAttn))
|
||||
.SetKernelFn(PD_KERNEL(BlockAttnFused))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(BlockAttnInferDtype));
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,14 +16,60 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "xpu/internal/infra_op.h"
|
||||
#include "xpu/plugin.h"
|
||||
#include "ops/utility/env.h"
|
||||
|
||||
XPU_DECLARE_BOOL(encoder_splice, false);
|
||||
XPU_DECLARE_BOOL(decoder_splice, false);
|
||||
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
void lod_to_slot_mapping(api::Context* xpu_ctx,
|
||||
paddle::Place place,
|
||||
const std::vector<int32_t>& block_table,
|
||||
const std::vector<int32_t>& kv_seq_lod,
|
||||
const std::vector<int32_t>& start_tokens,
|
||||
const std::vector<int32_t>& real_batch,
|
||||
int32_t* slot_mapping,
|
||||
int32_t token_num,
|
||||
int32_t block_size,
|
||||
int32_t batch_size,
|
||||
int32_t max_num_blocks_per_seq,
|
||||
int32_t num_speculative_tokens) {
|
||||
if (token_num <= 0) {
|
||||
return;
|
||||
}
|
||||
std::vector<int32_t> slot_mapping_vec(token_num, -1);
|
||||
int32_t idx = 0;
|
||||
// For each Batch
|
||||
for (auto batch_ = 0; batch_ < batch_size; batch_++) {
|
||||
int32_t seq_len = kv_seq_lod[batch_ + 1] - kv_seq_lod[batch_];
|
||||
int32_t seq_start = start_tokens[batch_];
|
||||
int32_t dst_batch_id = real_batch[batch_];
|
||||
// for each token
|
||||
for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) {
|
||||
int32_t table_id = seq_ / block_size;
|
||||
int32_t block_id =
|
||||
block_table[dst_batch_id * max_num_blocks_per_seq + table_id];
|
||||
int32_t seq_offset = seq_ % block_size;
|
||||
int32_t dst_token_offset = block_id * block_size + seq_offset;
|
||||
slot_mapping_vec[idx] = dst_token_offset;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
int ret = api::do_host2device(xpu_ctx,
|
||||
slot_mapping_vec.data(),
|
||||
slot_mapping,
|
||||
token_num * sizeof(int32_t));
|
||||
PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed.");
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetInferParam(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& block_tables,
|
||||
int block_size) {
|
||||
int block_size,
|
||||
int num_speculative_tokens) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
@@ -109,6 +155,15 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
batch_offset++;
|
||||
}
|
||||
}
|
||||
// for vsl_rotary_embedding_gptj of cudagraph mode
|
||||
int prev_val = 0;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
if (decoder_seq_lod_vec[i] > prev_val) {
|
||||
prev_val = decoder_seq_lod_vec[i];
|
||||
} else if (decoder_seq_lod_vec[i] < prev_val) {
|
||||
decoder_seq_lod_vec[i] = prev_val;
|
||||
}
|
||||
}
|
||||
int prefix_block_num_per_seq = (max_kv_len + block_size - 1) / block_size;
|
||||
std::vector<int32_t> prefix_block_tables_vec(
|
||||
enc_batch * prefix_block_num_per_seq, -1);
|
||||
@@ -167,6 +222,52 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
seq_lens_encoder.type(),
|
||||
seq_lens_encoder.place());
|
||||
|
||||
// for store_paged_kv_cache of cudagraph mode
|
||||
// if slot_mapping is -1, store_paged_kv_cache will not write to kv cache
|
||||
paddle::Tensor slot_mapping_enc = paddle::full(
|
||||
{total_enc_len}, -1, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
// TODO: mtp mode not verified yet, need further adaption
|
||||
paddle::Tensor slot_mapping_dec =
|
||||
paddle::full({bsz * (1 + num_speculative_tokens)},
|
||||
-1,
|
||||
paddle::DataType::INT32,
|
||||
seq_lens_decoder.place());
|
||||
if (FLAGS_encoder_splice || FLAGS_decoder_splice) {
|
||||
std::vector<int32_t> block_tables_vec(block_bs * block_num_per_seq);
|
||||
r = xpu_memcpy(block_tables_vec.data(),
|
||||
block_tables.data<int32_t>(),
|
||||
sizeof(int32_t) * block_bs * block_num_per_seq,
|
||||
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
|
||||
if (FLAGS_encoder_splice) {
|
||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||
seq_lens_encoder.place(),
|
||||
block_tables_vec,
|
||||
encoder_seq_lod_vec,
|
||||
prefix_len_vec,
|
||||
encoder_batch_map_vec,
|
||||
slot_mapping_enc.data<int32_t>(),
|
||||
total_enc_len,
|
||||
block_size,
|
||||
enc_batch,
|
||||
block_num_per_seq,
|
||||
0);
|
||||
}
|
||||
if (FLAGS_decoder_splice) {
|
||||
lod_to_slot_mapping(xpu_ctx->x_context(),
|
||||
seq_lens_decoder.place(),
|
||||
block_tables_vec,
|
||||
decoder_seq_lod_vec,
|
||||
decoder_context_len_cache_vec,
|
||||
decoder_batch_map_vec,
|
||||
slot_mapping_dec.data<int32_t>(),
|
||||
bsz * (1 + num_speculative_tokens),
|
||||
block_size,
|
||||
dec_batch,
|
||||
block_num_per_seq,
|
||||
num_speculative_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()},
|
||||
seq_lens_encoder.type(),
|
||||
paddle::CPUPlace());
|
||||
@@ -326,7 +427,9 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu};
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetInferParamInferShape(
|
||||
@@ -400,8 +503,10 @@ PD_BUILD_OP(get_infer_param)
|
||||
"prefix_len_cpu",
|
||||
"decoder_context_len_cpu",
|
||||
"decoder_context_len_cache_cpu",
|
||||
"len_info_cpu"})
|
||||
"len_info_cpu",
|
||||
"slot_mapping_enc",
|
||||
"slot_mapping_dec"})
|
||||
.SetKernelFn(PD_KERNEL(GetInferParam))
|
||||
.Attrs({"block_size: int"})
|
||||
.Attrs({"block_size: int", "num_speculative_tokens: int"})
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetInferParamInferDtype));
|
||||
|
||||
@@ -57,7 +57,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag);
|
||||
|
||||
std::vector<paddle::Tensor> BlockAttn(
|
||||
std::vector<paddle::Tensor> SplitEmbeddingKVCacheBlockAttn(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
@@ -81,6 +81,50 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::Tensor& decoder_context_len_cache_xpu,
|
||||
const paddle::Tensor& decoder_batch_map_xpu,
|
||||
const paddle::Tensor& prefix_len_xpu,
|
||||
const paddle::Tensor& slot_mapping_enc,
|
||||
const paddle::Tensor& slot_mapping_dec,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
const paddle::optional<paddle::Tensor>& v_scales_inv,
|
||||
const paddle::optional<paddle::Tensor>& k_zeros,
|
||||
const paddle::optional<paddle::Tensor>& v_zeros,
|
||||
const paddle::optional<paddle::Tensor>& shift,
|
||||
const paddle::optional<paddle::Tensor>& smooth,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
|
||||
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d = false);
|
||||
|
||||
// deprecated, keep for unit test, will be removed in the future
|
||||
std::vector<paddle::Tensor> BlockAttnFused(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& rotary_embs,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& prefix_block_tables,
|
||||
const paddle::Tensor& len_info_cpu,
|
||||
const paddle::Tensor& encoder_seq_lod_cpu,
|
||||
const paddle::Tensor& decoder_seq_lod_cpu,
|
||||
const paddle::Tensor& encoder_kv_lod_cpu,
|
||||
const paddle::Tensor& encoder_batch_map_cpu,
|
||||
const paddle::Tensor& decoder_context_len_cpu,
|
||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& prefix_len_cpu,
|
||||
const paddle::Tensor& encoder_seq_lod_xpu,
|
||||
const paddle::Tensor& decoder_seq_lod_xpu,
|
||||
const paddle::Tensor& encoder_kv_lod_xpu,
|
||||
const paddle::Tensor& encoder_batch_map_xpu,
|
||||
const paddle::Tensor& decoder_context_len_xpu,
|
||||
const paddle::Tensor& decoder_context_len_cache_xpu,
|
||||
const paddle::Tensor& decoder_batch_map_xpu,
|
||||
const paddle::Tensor& prefix_len_xpu,
|
||||
const paddle::Tensor& slot_mapping_enc,
|
||||
const paddle::Tensor& slot_mapping_dec,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
@@ -434,7 +478,8 @@ std::vector<paddle::Tensor> GetInferParam(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& block_tables,
|
||||
int block_size);
|
||||
int block_size,
|
||||
int num_speculative_tokens);
|
||||
|
||||
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag);
|
||||
|
||||
@@ -749,7 +794,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
"adjust batch in XPU");
|
||||
|
||||
m.def("block_attn",
|
||||
&BlockAttn,
|
||||
&SplitEmbeddingKVCacheBlockAttn,
|
||||
py::arg("qkv"),
|
||||
py::arg("key_cache"),
|
||||
py::arg("value_cache"),
|
||||
@@ -773,6 +818,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("decoder_context_len_cache_xpu"),
|
||||
py::arg("decoder_batch_map_xpu"),
|
||||
py::arg("prefix_len_xpu"),
|
||||
py::arg("slot_mapping_enc"),
|
||||
py::arg("slot_mapping_dec"),
|
||||
py::arg("k_scales"),
|
||||
py::arg("v_scales"),
|
||||
py::arg("k_scales_inv"),
|
||||
@@ -789,6 +836,49 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("rope_3d") = false,
|
||||
"block attention in XPU");
|
||||
|
||||
m.def("block_attn_fused",
|
||||
&BlockAttnFused,
|
||||
py::arg("qkv"),
|
||||
py::arg("key_cache"),
|
||||
py::arg("value_cache"),
|
||||
py::arg("rotary_embs"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("prefix_block_tables"),
|
||||
py::arg("len_info_cpu"),
|
||||
py::arg("encoder_seq_lod_cpu"),
|
||||
py::arg("decoder_seq_lod_cpu"),
|
||||
py::arg("encoder_kv_lod_cpu"),
|
||||
py::arg("encoder_batch_map_cpu"),
|
||||
py::arg("decoder_context_len_cpu"),
|
||||
py::arg("decoder_context_len_cache_cpu"),
|
||||
py::arg("decoder_batch_map_cpu"),
|
||||
py::arg("prefix_len_cpu"),
|
||||
py::arg("encoder_seq_lod_xpu"),
|
||||
py::arg("decoder_seq_lod_xpu"),
|
||||
py::arg("encoder_kv_lod_xpu"),
|
||||
py::arg("encoder_batch_map_xpu"),
|
||||
py::arg("decoder_context_len_xpu"),
|
||||
py::arg("decoder_context_len_cache_xpu"),
|
||||
py::arg("decoder_batch_map_xpu"),
|
||||
py::arg("prefix_len_xpu"),
|
||||
py::arg("slot_mapping_enc"),
|
||||
py::arg("slot_mapping_dec"),
|
||||
py::arg("k_scales"),
|
||||
py::arg("v_scales"),
|
||||
py::arg("k_scales_inv"),
|
||||
py::arg("v_scales_inv"),
|
||||
py::arg("k_zeros"),
|
||||
py::arg("v_zeros"),
|
||||
py::arg("shift"),
|
||||
py::arg("smooth"),
|
||||
py::arg("q_norm_weight"),
|
||||
py::arg("k_norm_weight"),
|
||||
py::arg("kv_signal_data_cpu"),
|
||||
py::arg("cachekv_signal_thread_cpu"),
|
||||
py::arg("use_neox_rotary_style"),
|
||||
py::arg("rope_3d") = false,
|
||||
"block attention fused in XPU");
|
||||
|
||||
m.def("create_kv_signal_sender",
|
||||
&create_cachekv_signal_thread,
|
||||
"init write cache kv signal thread");
|
||||
@@ -963,6 +1053,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("seq_lens_this_time"),
|
||||
py::arg("block_tables"),
|
||||
py::arg("block_size"),
|
||||
py::arg("num_speculative_tokens"),
|
||||
"Get infer parameters for block attention in XPU");
|
||||
|
||||
m.def("get_peer_mem_addr",
|
||||
|
||||
@@ -0,0 +1,653 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
# block_attn_fused is deprecated and should be removed in the future
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
block_attn,
|
||||
block_attn_fused,
|
||||
get_infer_param,
|
||||
)
|
||||
|
||||
|
||||
def print_all_not_equal_elements_info(k, x, y):
|
||||
x_flatten = x.flatten()
|
||||
y_flatten = y.flatten()
|
||||
index = paddle.nonzero(x_flatten != y_flatten)
|
||||
x_not_equal = x_flatten[index]
|
||||
y_not_equal = y_flatten[index]
|
||||
print(f"reference not equal element of {k}: {x_not_equal}")
|
||||
print(f"calculated result not equal element of {k}: {y_not_equal}")
|
||||
xy_diff = x - y
|
||||
xy_mean_diff = paddle.mean(xy_diff)
|
||||
xy_max_abs_diff = paddle.max(paddle.abs(xy_diff))
|
||||
xy_min_abs_diff = paddle.min(paddle.abs(xy_diff))
|
||||
print(f"{k} mean diff: {xy_mean_diff}, max abs diff: {xy_max_abs_diff}, min abs diff: {xy_min_abs_diff}")
|
||||
|
||||
|
||||
def run_prefix_cache_block_attn(
|
||||
block_attn_func,
|
||||
qkv,
|
||||
seq_len,
|
||||
seq_lens_this_time,
|
||||
hit_prefix_len,
|
||||
key_cache,
|
||||
value_cache,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
attn_out,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
k_zp,
|
||||
v_zp,
|
||||
shift,
|
||||
smooth,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
kv_signal_data_cpu,
|
||||
cachekv_signal_thread_cpu,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
num_speculative_tokens,
|
||||
):
|
||||
if key_cache.dtype == paddle.int8:
|
||||
rtol = 1e-1
|
||||
atol = 1e-2
|
||||
else:
|
||||
rtol = 1e-2
|
||||
atol = 1e-3
|
||||
# prefix cache block attn
|
||||
seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
||||
) # block_size
|
||||
qkv_prefix = qkv[hit_prefix_len:]
|
||||
attn_out_prefix_cache = block_attn_func(
|
||||
qkv_prefix,
|
||||
key_cache,
|
||||
value_cache,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
len_info_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
k_zp,
|
||||
v_zp,
|
||||
shift,
|
||||
smooth,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
kv_signal_data_cpu,
|
||||
cachekv_signal_thread_cpu,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
)
|
||||
attn_out_np = attn_out[hit_prefix_len:].astype("float32").numpy()
|
||||
attn_out_prefix_cache_np = attn_out_prefix_cache.astype("float32").numpy()
|
||||
is_passed = np.allclose(attn_out_np, attn_out_prefix_cache_np, rtol=rtol, atol=atol)
|
||||
if not is_passed:
|
||||
print(f"block_attn_func: {block_attn_func}")
|
||||
print("prefix_cache block_attn check failed!")
|
||||
print(f"origin block_attn_out: {attn_out[hit_prefix_len:]}")
|
||||
print(f"prefix_cache block_attn_out: {attn_out_prefix_cache}")
|
||||
print("not equal elements are listed below:")
|
||||
print_all_not_equal_elements_info("block_attn_out", attn_out[hit_prefix_len:], attn_out_prefix_cache)
|
||||
else:
|
||||
print(f"prefix_cache check of {block_attn_func} PASSED!")
|
||||
assert is_passed
|
||||
return attn_out_prefix_cache
|
||||
|
||||
|
||||
def run_block_attn(
|
||||
seed,
|
||||
is_fused,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
mode, # 1 for split kvcache encoder only, 2 for split kvcache decoder only, 3 for mixed
|
||||
hit_prefix_len,
|
||||
kvcache_dtype,
|
||||
has_zp,
|
||||
use_neox_rotary_style,
|
||||
rotary_embs_shape,
|
||||
num_speculative_tokens,
|
||||
):
|
||||
assert mode == 0 or mode == 1, "mixed mode not supported yet!"
|
||||
if mode == 0:
|
||||
encoder_seq_len = seq_len
|
||||
decoder_seq_len = 0
|
||||
elif mode == 1:
|
||||
encoder_seq_len = 0
|
||||
decoder_seq_len = seq_len
|
||||
else:
|
||||
pass
|
||||
seq_lens_encoder = paddle.to_tensor([encoder_seq_len, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_decoder = paddle.to_tensor([decoder_seq_len, 0, 0, 0, 0], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32")
|
||||
block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32")
|
||||
block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
||||
(
|
||||
encoder_batch_map,
|
||||
decoder_batch_map,
|
||||
encoder_batch_idx,
|
||||
decoder_batch_idx,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
prefix_len,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
prefix_block_tables,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
encoder_batch_idx_cpu,
|
||||
decoder_batch_idx_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
prefix_len_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens
|
||||
)
|
||||
qkv = paddle.uniform(
|
||||
shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed
|
||||
)
|
||||
|
||||
rotary_embs = paddle.uniform(shape=rotary_embs_shape, dtype="float32", min=-1.0, max=1.0, seed=seed)
|
||||
key_cache = paddle.zeros(
|
||||
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
|
||||
dtype=kvcache_dtype,
|
||||
)
|
||||
value_cache = paddle.zeros(
|
||||
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
|
||||
dtype=kvcache_dtype,
|
||||
)
|
||||
|
||||
scale_tensor_k = None
|
||||
scale_tensor_v = None
|
||||
k_quant_scale = None
|
||||
v_quant_scale = None
|
||||
k_dequant_scale = None
|
||||
v_dequant_scale = None
|
||||
k_zp = None
|
||||
v_zp = None
|
||||
if kvcache_dtype == "int8":
|
||||
scale_tensor_k = paddle.uniform(
|
||||
shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed
|
||||
) # max
|
||||
scale_tensor_v = paddle.uniform(
|
||||
shape=[kv_head_num * head_dim], dtype="bfloat16", min=1.0, max=1.0, seed=seed
|
||||
) # max
|
||||
k_quant_scale = 127.0 / scale_tensor_k # for C8 per channel means 127 / max
|
||||
v_quant_scale = 127.0 / scale_tensor_v # for C8 per channel means 127 / max
|
||||
if has_zp:
|
||||
k_dequant_scale = 1 / k_quant_scale # for C8 per channel zp means max
|
||||
v_dequant_scale = 1 / v_quant_scale # for C8 per channel zp means max
|
||||
k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
|
||||
v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
|
||||
else:
|
||||
k_dequant_scale = paddle.cast(scale_tensor_k, dtype="float32") # for C8 per channel means max
|
||||
v_dequant_scale = paddle.cast(scale_tensor_v, dtype="float32") # for C8 per channel means max
|
||||
# variable below are not yet used
|
||||
shift = None
|
||||
smooth = None
|
||||
q_norm_weight = None
|
||||
k_norm_weight = None
|
||||
kv_signal_data_cpu = None
|
||||
cachekv_signal_thread_cpu = None
|
||||
rope_3d = False
|
||||
|
||||
if is_fused:
|
||||
block_attn_func = block_attn_fused
|
||||
else:
|
||||
block_attn_func = block_attn
|
||||
attn_out = block_attn_func(
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
len_info_cpu,
|
||||
encoder_seq_lod_cpu,
|
||||
decoder_seq_lod_cpu,
|
||||
encoder_kv_lod_cpu,
|
||||
encoder_batch_map_cpu,
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
k_zp,
|
||||
v_zp,
|
||||
shift,
|
||||
smooth,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
kv_signal_data_cpu,
|
||||
cachekv_signal_thread_cpu,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
)
|
||||
result = {
|
||||
"block_attn_out": attn_out,
|
||||
"key_cache": key_cache,
|
||||
"value_cache": value_cache,
|
||||
}
|
||||
|
||||
# prefix cache
|
||||
if mode == 0 and hit_prefix_len > 0:
|
||||
assert hit_prefix_len < seq_len
|
||||
attn_out_prefix_cache = run_prefix_cache_block_attn(
|
||||
block_attn_func,
|
||||
qkv,
|
||||
seq_len,
|
||||
seq_lens_this_time,
|
||||
hit_prefix_len,
|
||||
key_cache,
|
||||
value_cache,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
attn_out,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
k_zp,
|
||||
v_zp,
|
||||
shift,
|
||||
smooth,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
kv_signal_data_cpu,
|
||||
cachekv_signal_thread_cpu,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
result["prefix_cache_block_attn_out"] = attn_out_prefix_cache
|
||||
return result
|
||||
|
||||
|
||||
def run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len=0,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=False,
|
||||
only_run_spliced=False,
|
||||
num_speculative_tokens=0,
|
||||
):
|
||||
rtol = 1e-3
|
||||
atol = 1e-2
|
||||
# 0 for prefill only, 1 for decode only, 2 for mixed
|
||||
# TODO: mixed mode not supported yet, get_infer_param should be modified first
|
||||
mode_name = ["prefill only", "decode only", "mixed"]
|
||||
|
||||
if use_neox_rotary_style:
|
||||
embedding_type = "neox"
|
||||
else:
|
||||
embedding_type = "rope"
|
||||
for mode in [0, 1]:
|
||||
if mode == 0:
|
||||
seq_len_list = [seq_len]
|
||||
elif mode == 1:
|
||||
# seq_len > 1 goes into mtp branch, which only supports seq_len <= 31
|
||||
# TODO: mtp mode need further adaption
|
||||
# seq_len_list = [1, random.randint(2, 31)]
|
||||
seq_len_list = [1]
|
||||
for idx, seqlen in enumerate(seq_len_list):
|
||||
if idx == 0:
|
||||
branch_name = "non mtp branch"
|
||||
elif idx == 1:
|
||||
branch_name = "mtp branch"
|
||||
print(
|
||||
f"runnning block attention of mode {mode_name[mode]} ({branch_name}), is_prefix_cache: {hit_prefix_len > 0}, kvcache type: {kvcache_dtype}, has_zp: {has_zp}, rotary_style: {embedding_type}"
|
||||
)
|
||||
if not only_run_spliced:
|
||||
fused_result = run_block_attn(
|
||||
seed,
|
||||
True, # is_fused
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seqlen,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
mode,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype,
|
||||
has_zp,
|
||||
use_neox_rotary_style,
|
||||
rotary_embs_shape,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
spliced_result = run_block_attn(
|
||||
seed,
|
||||
False, # is_fused
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seqlen,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
mode,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype,
|
||||
has_zp,
|
||||
use_neox_rotary_style,
|
||||
rotary_embs_shape,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
if "fused_result" in locals() and "spliced_result" in locals():
|
||||
for k in fused_result.keys():
|
||||
if paddle.is_integer(fused_result[k]):
|
||||
fused_v = fused_result[k].astype("int32")
|
||||
spliced_v = spliced_result[k].astype("int32")
|
||||
fused_v_np = fused_v.numpy()
|
||||
splice_v_np = spliced_v.numpy()
|
||||
# is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-1, atol=1e-1)
|
||||
is_passed = np.allclose(fused_v_np, splice_v_np, rtol=1e-2, atol=rtol)
|
||||
else:
|
||||
fused_v = fused_result[k].astype("float32")
|
||||
spliced_v = spliced_result[k].astype("float32")
|
||||
fused_v_np = fused_v.numpy()
|
||||
splice_v_np = spliced_v.numpy()
|
||||
is_passed = np.allclose(fused_v_np, splice_v_np, rtol=rtol, atol=atol, equal_nan=True)
|
||||
if not is_passed:
|
||||
print(f"{k} in mode {mode_name[mode]} check FAILED!")
|
||||
print(f"fused {k}: {fused_v}")
|
||||
print(f"spliced {k}: {spliced_v}")
|
||||
print("not equal elements are listed below:")
|
||||
print_all_not_equal_elements_info(k, fused_v, spliced_v)
|
||||
else:
|
||||
print(f"{k} in mode {mode_name[mode]} check PASSED!")
|
||||
assert is_passed
|
||||
print("")
|
||||
else:
|
||||
if "fused_result" not in locals():
|
||||
print("fused_result not found.")
|
||||
if "spliced_result" not in locals():
|
||||
print("spliced_result not found.")
|
||||
print("skip comparison.")
|
||||
|
||||
|
||||
seed = random.randint(0, 2026)
|
||||
paddle.seed(seed)
|
||||
head_num = 64
|
||||
kv_head_num = 8
|
||||
head_dim = 128
|
||||
rotary_embs_shape = [2, 1, 8192, 1, head_dim]
|
||||
seq_len = 128
|
||||
block_batch = 5
|
||||
max_block_per_seq = 128
|
||||
block_size = 64
|
||||
# TODO: if hit_prefix_len has a small value, e.g. hit_prefix_len == 2, block_attn_out and prefix_cache_block_attn_out will have greater diff
|
||||
hit_prefix_len = 71
|
||||
|
||||
# no prefix cache
|
||||
# block_attn fused vs spliced
|
||||
use_neox_rotary_style = False
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
# c8 quantization block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="int8",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
# c8 zp quantization block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="int8",
|
||||
has_zp=True,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
|
||||
# prefix cache
|
||||
# block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
# c8 quantization block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype="int8",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
# c8 zp quantization block_attn fused vs spliced
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype="int8",
|
||||
has_zp=True,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
)
|
||||
|
||||
# # neox
|
||||
# # block_attn fused vs spliced
|
||||
# # no prefix cache
|
||||
use_neox_rotary_style = True
|
||||
only_run_spliced = False
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
only_run_spliced=only_run_spliced,
|
||||
)
|
||||
# prefix cache
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
hit_prefix_len,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
only_run_spliced=only_run_spliced,
|
||||
)
|
||||
|
||||
# neox glm 4.5 air debug
|
||||
head_num = 24
|
||||
kv_head_num = 2
|
||||
head_dim = 128
|
||||
seq_len = 128
|
||||
block_batch = 64
|
||||
max_block_per_seq = 2050
|
||||
block_size = 64
|
||||
rotary_embs_shape = [2, 1, 131072, 1, head_dim // 2]
|
||||
|
||||
use_neox_rotary_style = True
|
||||
only_run_spliced = False
|
||||
run_compare_block_attn(
|
||||
seed,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
seq_len,
|
||||
block_batch,
|
||||
max_block_per_seq,
|
||||
block_size,
|
||||
rotary_embs_shape,
|
||||
0,
|
||||
kvcache_dtype="bfloat16",
|
||||
has_zp=False,
|
||||
use_neox_rotary_style=use_neox_rotary_style,
|
||||
only_run_spliced=only_run_spliced,
|
||||
)
|
||||
|
||||
print("\nALL PASSED!")
|
||||
@@ -15,7 +15,7 @@
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.xpu import block_attn, get_infer_param
|
||||
from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param
|
||||
|
||||
head_num = 64
|
||||
kv_head_num = 8
|
||||
@@ -53,8 +53,10 @@ block_tables = block_tables.reshape((block_batch, max_block_per_seq))
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
|
||||
) # block_size
|
||||
|
||||
qkv = paddle.uniform(
|
||||
@@ -64,7 +66,6 @@ qkv = paddle.uniform(
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
cum_offsets = paddle.zeros(shape=[block_batch], dtype="bfloat16")
|
||||
rotary_embs = paddle.uniform(shape=[2, 1, 8192, 1, head_dim], dtype="float32", min=-1.0, max=1.0)
|
||||
key_cache = paddle.zeros(
|
||||
shape=[block_batch * max_block_per_seq, kv_head_num, block_size, head_dim],
|
||||
@@ -94,11 +95,10 @@ v_dequant_scale_zp = 1 / v_quant_scale # for C8 per channel zp means max
|
||||
|
||||
k_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
|
||||
v_zp = paddle.zeros(shape=[kv_head_num * head_dim], dtype="bfloat16")
|
||||
attn_out = block_attn(
|
||||
attn_out = block_attn_fused(
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -111,6 +111,16 @@ attn_out = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -121,12 +131,15 @@ attn_out = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
attn_out_C8 = block_attn(
|
||||
attn_out_C8 = block_attn_fused(
|
||||
qkv,
|
||||
key_cache_int8,
|
||||
value_cache_int8,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -139,6 +152,16 @@ attn_out_C8 = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
@@ -149,12 +172,15 @@ attn_out_C8 = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
attn_out_C8_zp = block_attn(
|
||||
attn_out_C8_zp = block_attn_fused(
|
||||
qkv,
|
||||
key_cache_int8,
|
||||
value_cache_int8,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -167,6 +193,16 @@ attn_out_C8_zp = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale_zp,
|
||||
@@ -177,6 +213,10 @@ attn_out_C8_zp = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
# prefix cache : hit 71 tokens
|
||||
@@ -207,16 +247,17 @@ seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32")
|
||||
decoder_context_len_cpu,
|
||||
decoder_context_len_cache_cpu,
|
||||
len_info_cpu,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
) = get_infer_param(
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64
|
||||
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0
|
||||
) # block_size
|
||||
qkv_prefix = qkv[hit_prefix_len:]
|
||||
|
||||
attn_out_prefix_cache = block_attn(
|
||||
attn_out_prefix_cache = block_attn_fused(
|
||||
qkv_prefix,
|
||||
key_cache,
|
||||
value_cache,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -229,6 +270,16 @@ attn_out_prefix_cache = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -239,13 +290,16 @@ attn_out_prefix_cache = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
attn_out_C8_prefix_cache = block_attn(
|
||||
attn_out_C8_prefix_cache = block_attn_fused(
|
||||
qkv_prefix,
|
||||
key_cache_int8,
|
||||
value_cache_int8,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -258,6 +312,16 @@ attn_out_C8_prefix_cache = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
@@ -268,13 +332,16 @@ attn_out_C8_prefix_cache = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
attn_out_C8_zp_prefix_cache = block_attn(
|
||||
attn_out_C8_zp_prefix_cache = block_attn_fused(
|
||||
qkv_prefix,
|
||||
key_cache_int8,
|
||||
value_cache_int8,
|
||||
cum_offsets,
|
||||
rotary_embs,
|
||||
block_tables,
|
||||
prefix_block_tables,
|
||||
@@ -287,6 +354,16 @@ attn_out_C8_zp_prefix_cache = block_attn(
|
||||
decoder_context_len_cache_cpu,
|
||||
decoder_batch_map_cpu,
|
||||
prefix_len_cpu,
|
||||
encoder_seq_lod,
|
||||
decoder_seq_lod,
|
||||
encoder_kv_lod,
|
||||
encoder_batch_map,
|
||||
decoder_context_len,
|
||||
decoder_context_len_cache,
|
||||
decoder_batch_map,
|
||||
prefix_len,
|
||||
slot_mapping_enc,
|
||||
slot_mapping_dec,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale_zp,
|
||||
@@ -297,6 +374,10 @@ attn_out_C8_zp_prefix_cache = block_attn(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
print("-- C16 prefix cache test --")
|
||||
print("attn_out[hit_prefix_len:]'s mean:", attn_out[hit_prefix_len:].mean().item())
|
||||
|
||||
Reference in New Issue
Block a user