Files
FastDeploy/fastdeploy/model_executor/layers/attention/dsa_helper.py
T

90 lines
2.8 KiB
Python

import paddle
def convert_float32_uint8(tensor):
assert tensor.dtype == paddle.float32
# last_dim = tensor.shape[-1]
return tensor.view("uint8")
def convert_bfloat16_uint8(tensor):
assert tensor.dtype == paddle.bfloat16
return tensor.view("uint8")
def convert_uint8_float32(tensor):
assert tensor.dtype == paddle.uint8
last_dim = tensor.shape[-1]
assert last_dim % 4 == 0
tmp0 = tensor[:, 0::4].contiguous().numpy().astype("int32")
tmp1 = tensor[:, 1::4].contiguous().numpy().astype("int32")
tmp2 = tensor[:, 2::4].contiguous().numpy().astype("int32")
tmp3 = tensor[:, 3::4].contiguous().numpy().astype("int32")
tmp = (tmp3 << 24) | (tmp2 << 16) | (tmp1 << 8) | tmp0
tmp = paddle.to_tensor(tmp).view(paddle.float32)
return tmp
def convert_uint8_bfloat16(tensor):
assert tensor.dtype == paddle.uint8
last_dim = tensor.shape[-1]
assert last_dim % 2 == 0
tmp0 = tensor[:, 0::2].contiguous().numpy().astype("uint16")
tmp1 = tensor[:, 1::2].contiguous().numpy().astype("uint16")
tmp = (tmp1 << 8) | tmp0
tmp = paddle.to_tensor(tmp).view(paddle.bfloat16)
return tmp
def dsk_attn_write_cache(
compressed_kv,
kv_pe,
kv_cache,
slot_mapping,
scale,
cache_quant_type_str,
):
token_num = slot_mapping.shape[0]
page_size = 64
assert compressed_kv.shape == [token_num, 512]
assert kv_pe.shape == [token_num, 1, 64]
zkk_kv_pe = kv_pe.reshape([token_num, 64])
assert len(kv_cache.shape) == 4
assert kv_cache.shape[1:] == [1, page_size, 656]
assert kv_cache.dtype == paddle.uint8
compressed_kv = compressed_kv.cast("float32").reshape([token_num, 4, 128])
zkk_scale_max = compressed_kv.abs().max(axis=-1) + 0.00001
assert zkk_scale_max.shape == [token_num, 4]
zkk_quant_compressed_kv = compressed_kv / zkk_scale_max[:, :, None] * 448
zkk_quant_compressed_kv = zkk_quant_compressed_kv.cast(paddle.float8_e4m3fn)
zkk_quant_compressed_kv.reshape_([0, -1])
zkk_scale_max = zkk_scale_max / 448.0
for token_id in range(token_num):
dst_physical_pos = slot_mapping[token_id].item()
dst_block = dst_physical_pos // page_size
dst_offset = dst_physical_pos % page_size
# write quant uint8
baseline = zkk_quant_compressed_kv[token_id, :].contiguous().view(paddle.uint8)
kv_cache[dst_block, 0, dst_offset, :512] = baseline
baseline = zkk_scale_max[token_id, :]
baseline = baseline + 0
baseline = convert_float32_uint8(baseline).cast("uint8")
kv_cache[dst_block, 0, dst_offset, 512:528] = baseline
baseline = zkk_kv_pe[token_id, :]
baseline = baseline + 0
baseline = convert_bfloat16_uint8(baseline).contiguous()
kv_cache[dst_block, 0, dst_offset, 528:656] = baseline