[XPU] Support W4A8C8-TP4-300B Model (#4068)

* support w4a8

* delete ep block attn

* delete moe_topk_select

* update note

* update

* delte useless info

* update

* add some note

* fix some format

* update scale info

* add ans baseline

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
yinwei
2025-10-10 15:41:32 +08:00
committed by GitHub
parent c46d5e48f8
commit 20c7b741f4
21 changed files with 2029 additions and 714 deletions
+12
View File
@@ -124,6 +124,16 @@ def xpu_setup_ops():
XDNN_INC_PATH = os.path.join(XDNN_PATH, "include")
XDNN_LIB_DIR = os.path.join(XDNN_PATH, "so")
XFA_PATH = os.getenv("XFA_PATH")
if XFA_PATH is None:
XFA_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xhpc/xfa")
XFA_LIB_DIR = PADDLE_LIB_PATH
XFA_LIB_PATH = os.path.join(XFA_LIB_DIR, "libxpu_flash_attention.so")
else:
XFA_INC_PATH = os.path.join(XFA_PATH, "include")
XFA_LIB_DIR = os.path.join(XFA_PATH, "so")
XFA_LIB_PATH = os.path.join(XFA_LIB_DIR, "libxpu_flash_attention.so")
XVLLM_PATH = os.getenv("XVLLM_PATH")
assert XVLLM_PATH is not None, "XVLLM_PATH is not set."
XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include")
@@ -149,6 +159,7 @@ def xpu_setup_ops():
XRE_INC_PATH,
XVLLM_KERNEL_INC_PATH,
XVLLM_OP_INC_PATH,
XFA_INC_PATH,
]
extra_objects = [
os.path.join(base_dir, "./plugin/build/libxpuplugin.a"),
@@ -156,6 +167,7 @@ def xpu_setup_ops():
XRE_LIB_PATH,
XVLLM_KERNEL_LIB_PATH,
XVLLM_OP_LIB_PATH,
XFA_LIB_PATH,
]
setup(
File diff suppressed because it is too large Load Diff
+372 -212
View File
@@ -1,244 +1,404 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/internal/infra_op.h"
#include "xpu/plugin.h"
namespace api = baidu::xpu::api;
std::vector<paddle::Tensor>
GetInferParam(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const int bsz = seq_lens_encoder.dims()[0];
// 判断逻辑
std::vector<int32_t> seq_lens_encoder_vec(bsz, 0); // input
std::vector<int32_t> seq_lens_decoder_vec(bsz, 0); // input
std::vector<paddle::Tensor> GetInferParam(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& block_tables,
int block_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int bsz = seq_lens_encoder.dims()[0];
const int block_bs = block_tables.dims()[0];
const int block_num_per_seq = block_tables.dims()[1];
auto all_param = paddle::empty(
{bsz * 3}, seq_lens_encoder.type(), seq_lens_encoder.place());
int ret = api::copy<int32_t>(xpu_ctx->x_context(),
seq_lens_encoder.data<int32_t>(),
reinterpret_cast<int32_t*>(all_param.data()),
bsz);
ret = api::copy<int32_t>(xpu_ctx->x_context(),
seq_lens_decoder.data<int32_t>(),
reinterpret_cast<int32_t*>(all_param.data()) + bsz,
bsz);
ret =
api::copy<int32_t>(xpu_ctx->x_context(),
seq_lens_this_time.data<int32_t>(),
reinterpret_cast<int32_t*>(all_param.data()) + 2 * bsz,
bsz);
std::unique_ptr<int32_t[]> all_param_cpu(new int32_t[bsz * 3]);
// input ex: [100, 0, 0, 0, 300]
int32_t* seq_lens_encoder_vec = all_param_cpu.get();
// input ex: [0, 5, 0, 25, 64] (64 means prefix len)
int32_t* seq_lens_decoder_vec = all_param_cpu.get() + bsz;
int32_t* seq_lens_this_time_vec = all_param_cpu.get() + 2 * bsz;
std::vector<int32_t> encoder_batch_map_vec(bsz, 0); //
std::vector<int32_t> decoder_batch_map_vec(bsz, 0); //
std::vector<int32_t> encoder_batch_idx_vec(bsz, 0); // 去除空隙的batch map
std::vector<int32_t> decoder_batch_idx_vec(bsz, 0); // 去除空隙的batch map
std::vector<int32_t> encoder_seq_lod_vec(bsz + 1, 0);
std::vector<int32_t> decoder_context_len_vec(bsz, 0);
std::vector<int32_t> decoder_context_len_cache_vec(bsz, 0);
xpu_wait(xpu_ctx->x_context()->xpu_stream); // 是否需要!!!!TODO
int r = xpu_memcpy(seq_lens_encoder_vec.data(),
seq_lens_encoder.data<int32_t>(), sizeof(int32_t) * bsz,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
r = xpu_memcpy(seq_lens_decoder_vec.data(),
seq_lens_decoder.data<int32_t>(), sizeof(int32_t) * bsz,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
std::vector<int32_t> encoder_batch_map_vec(bsz, 0);
std::vector<int32_t> decoder_batch_map_vec(
bsz, 0); // ex : [1, 3]
// 去除空隙的batch map ex : [0, 3]
std::vector<int32_t> encoder_batch_idx_vec(bsz, 0);
// 去除空隙的batch map ex : [1, 2]
std::vector<int32_t> decoder_batch_idx_vec(bsz, 0);
std::vector<int32_t> encoder_seq_lod_vec(bsz + 1, 0); // ex : [0, 100, 400]
std::vector<int32_t> decoder_seq_lod_vec(bsz + 1, 0);
std::vector<int32_t> encoder_kv_lod_vec(bsz + 1, 0); // ex : [0, 100, 464]
std::vector<int32_t> prefix_len_vec(bsz, 0); // ex : [0, 64]
std::vector<int32_t> decoder_context_len_vec(bsz, 0); // ex : [6, 26]
std::vector<int32_t> decoder_context_len_cache_vec(bsz, 0); // ex : [5, 25]
xpu_wait(xpu_ctx->x_context()->xpu_stream);
int r = xpu_memcpy(all_param_cpu.get(),
all_param.data<int32_t>(),
sizeof(int32_t) * 3 * bsz,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
int enc_batch = 0, dec_batch = 0;
int total_enc_len = 0;
int batch_offset = 0;
for (int i = 0; i < bsz; ++i) {
if (seq_lens_encoder_vec[i] > 0) {
enc_batch++;
total_enc_len += seq_lens_encoder_vec[i];
encoder_batch_map_vec[enc_batch - 1] = i;
encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset;
encoder_seq_lod_vec[enc_batch] =
seq_lens_encoder_vec[i] + encoder_seq_lod_vec[enc_batch - 1];
} else if (seq_lens_decoder_vec[i] > 0) {
dec_batch++;
decoder_batch_map_vec[dec_batch - 1] = i;
decoder_batch_idx_vec[dec_batch - 1] = i - batch_offset;
decoder_context_len_vec[dec_batch - 1] =
seq_lens_decoder_vec[i] + 1;
decoder_context_len_cache_vec[dec_batch - 1] =
seq_lens_decoder_vec[i];
} else {
batch_offset++;
}
int enc_batch = 0, dec_batch = 0;
int total_enc_len = 0;
int batch_offset = 0;
int max_seq_len = 0;
int max_prefix_len = 0;
int max_kv_len = 0;
for (int i = 0; i < bsz; ++i) {
if (seq_lens_encoder_vec[i] > 0) {
enc_batch++;
int seq_len = seq_lens_encoder_vec[i];
int prefix_len = seq_lens_decoder_vec[i];
total_enc_len += seq_len;
max_seq_len = std::max(max_seq_len, seq_len);
max_prefix_len = std::max(max_prefix_len, prefix_len);
max_kv_len = std::max(max_kv_len, seq_len + prefix_len);
encoder_batch_map_vec[enc_batch - 1] = i;
encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset;
encoder_seq_lod_vec[enc_batch] =
seq_len + encoder_seq_lod_vec[enc_batch - 1];
encoder_kv_lod_vec[enc_batch] =
seq_len + prefix_len + encoder_kv_lod_vec[enc_batch - 1];
prefix_len_vec[enc_batch - 1] = prefix_len;
} else if (seq_lens_decoder_vec[i] > 0) {
dec_batch++;
decoder_batch_map_vec[dec_batch - 1] = i;
decoder_batch_idx_vec[dec_batch - 1] = i - batch_offset;
decoder_context_len_vec[dec_batch - 1] =
seq_lens_decoder_vec[i] + seq_lens_this_time_vec[i];
decoder_context_len_cache_vec[dec_batch - 1] = seq_lens_decoder_vec[i];
decoder_seq_lod_vec[dec_batch] =
seq_lens_this_time_vec[i] +
decoder_seq_lod_vec[dec_batch - 1]; // use for mtp
} else {
batch_offset++;
}
}
int prefix_block_num_per_seq = (max_kv_len + block_size - 1) / block_size;
std::vector<int32_t> prefix_block_tables_vec(
enc_batch * prefix_block_num_per_seq, -1);
if (max_prefix_len > 0) {
std::vector<int> block_tables_vec(block_bs * block_num_per_seq, -1);
r = xpu_memcpy(block_tables_vec.data(),
block_tables.data<int32_t>(),
sizeof(int32_t) * block_bs * block_num_per_seq,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
for (int i = 0; i < enc_batch; i++) {
int src_bs = encoder_batch_map_vec[i];
int copy_len =
(encoder_kv_lod_vec[i + 1] - encoder_kv_lod_vec[i] + block_size - 1) /
block_size;
std::memcpy(prefix_block_tables_vec.data() + i * prefix_block_num_per_seq,
block_tables_vec.data() + src_bs * block_num_per_seq,
copy_len * sizeof(int32_t));
}
} else {
prefix_block_num_per_seq = -1;
}
auto encoder_batch_map_xpu =
paddle::full({encoder_batch_map_vec.size()}, 0, seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_batch_map_xpu =
paddle::full({decoder_batch_map_vec.size()}, 0, seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_batch_idx_xpu =
paddle::full({encoder_batch_idx_vec.size()}, 0, seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_batch_idx_xpu =
paddle::full({decoder_batch_idx_vec.size()}, 0, seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_seq_lod_xpu =
paddle::full({encoder_seq_lod_vec.size()}, 0, seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_context_len_xpu =
paddle::full({decoder_context_len_vec.size()}, 0,
seq_lens_encoder.type(), seq_lens_encoder.place());
auto decoder_context_len_cache_xpu =
paddle::full({decoder_context_len_cache_vec.size()}, 0,
seq_lens_encoder.type(), seq_lens_encoder.place());
auto encoder_batch_map = paddle::empty({encoder_batch_map_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_batch_map = paddle::empty({decoder_batch_map_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_batch_idx = paddle::empty({encoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_batch_idx = paddle::empty({decoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_seq_lod = paddle::empty({encoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_seq_lod = paddle::empty({decoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_kv_lod = paddle::empty({encoder_kv_lod_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto prefix_len = paddle::empty({prefix_len_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_context_len = paddle::empty({decoder_context_len_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto decoder_context_len_cache =
paddle::empty({decoder_context_len_cache_vec.size()},
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto prefix_block_tables =
paddle::empty({block_bs, block_num_per_seq}, // full size
seq_lens_encoder.type(),
seq_lens_encoder.place());
auto encoder_batch_map_cpu =
paddle::full({encoder_batch_map_vec.size()}, 0, seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_batch_map_cpu =
paddle::full({decoder_batch_map_vec.size()}, 0, seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_batch_idx_cpu =
paddle::full({encoder_batch_idx_vec.size()}, 0, seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_batch_idx_cpu =
paddle::full({decoder_batch_idx_vec.size()}, 0, seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_seq_lod_cpu =
paddle::full({encoder_seq_lod_vec.size()}, 0, seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_context_len_cpu =
paddle::full({decoder_context_len_vec.size()}, 0,
seq_lens_encoder.type(), paddle::CPUPlace());
auto decoder_context_len_cache_cpu =
paddle::full({decoder_context_len_cache_vec.size()}, 0,
seq_lens_encoder.type(), paddle::CPUPlace());
auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_batch_map_cpu = paddle::empty({decoder_batch_map_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_batch_idx_cpu = paddle::empty({encoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_batch_idx_cpu = paddle::empty({decoder_batch_idx_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto encoder_seq_lod_cpu = paddle::empty({encoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_seq_lod_cpu = paddle::empty({decoder_seq_lod_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
int ret = 0;
ret = xpu_memcpy(reinterpret_cast<int32_t *>(const_cast<int32_t *>(
encoder_batch_map_xpu.data<int32_t>())),
encoder_batch_map_vec.data(),
sizeof(int32_t) * encoder_batch_map_vec.size(),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
ret = xpu_memcpy(reinterpret_cast<int32_t *>(const_cast<int32_t *>(
decoder_batch_map_xpu.data<int32_t>())),
decoder_batch_map_vec.data(),
sizeof(int32_t) * decoder_batch_map_vec.size(),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
ret = xpu_memcpy(reinterpret_cast<int32_t *>(const_cast<int32_t *>(
encoder_batch_idx_xpu.data<int32_t>())),
encoder_batch_idx_vec.data(),
sizeof(int32_t) * encoder_batch_idx_vec.size(),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
ret = xpu_memcpy(reinterpret_cast<int32_t *>(const_cast<int32_t *>(
decoder_batch_idx_xpu.data<int32_t>())),
decoder_batch_idx_vec.data(),
sizeof(int32_t) * decoder_batch_idx_vec.size(),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
ret = xpu_memcpy(reinterpret_cast<int32_t *>(const_cast<int32_t *>(
encoder_seq_lod_xpu.data<int32_t>())),
encoder_seq_lod_vec.data(),
sizeof(int32_t) * encoder_seq_lod_vec.size(),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
ret = xpu_memcpy(reinterpret_cast<int32_t *>(const_cast<int32_t *>(
decoder_context_len_xpu.data<int32_t>())),
decoder_context_len_vec.data(),
sizeof(int32_t) * decoder_context_len_vec.size(),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
ret = xpu_memcpy(reinterpret_cast<int32_t *>(const_cast<int32_t *>(
decoder_context_len_cache_xpu.data<int32_t>())),
decoder_context_len_cache_vec.data(),
sizeof(int32_t) * decoder_context_len_cache_vec.size(),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
auto encoder_kv_lod_cpu = paddle::empty(
{encoder_kv_lod_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
auto prefix_len_cpu = paddle::empty(
{prefix_len_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace());
auto decoder_context_len_cpu = paddle::empty({decoder_context_len_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
auto decoder_context_len_cache_cpu =
paddle::empty({decoder_context_len_cache_vec.size()},
seq_lens_encoder.type(),
paddle::CPUPlace());
std::memcpy(encoder_batch_map_cpu.data<int32_t>(),
encoder_batch_map_vec.data(),
sizeof(int32_t) * encoder_batch_map_vec.size());
std::memcpy(decoder_batch_map_cpu.data<int32_t>(),
decoder_batch_map_vec.data(),
sizeof(int32_t) * decoder_batch_map_vec.size());
std::memcpy(encoder_batch_idx_cpu.data<int32_t>(),
encoder_batch_idx_vec.data(),
sizeof(int32_t) * encoder_batch_idx_vec.size());
std::memcpy(decoder_batch_idx_cpu.data<int32_t>(),
decoder_batch_idx_vec.data(),
sizeof(int32_t) * decoder_batch_idx_vec.size());
std::memcpy(encoder_seq_lod_cpu.data<int32_t>(), encoder_seq_lod_vec.data(),
sizeof(int32_t) * encoder_seq_lod_vec.size());
std::memcpy(decoder_context_len_cpu.data<int32_t>(),
decoder_context_len_vec.data(),
sizeof(int32_t) * decoder_context_len_vec.size());
std::memcpy(decoder_context_len_cache_cpu.data<int32_t>(),
decoder_context_len_cache_vec.data(),
sizeof(int32_t) * decoder_context_len_cache_vec.size());
ret = api::do_host2device(
xpu_ctx->x_context(),
reinterpret_cast<void*>(encoder_batch_map_vec.data()),
reinterpret_cast<void*>(
const_cast<int32_t*>(encoder_batch_map.data<int32_t>())),
sizeof(int32_t) * encoder_batch_map_vec.size());
ret = api::do_host2device(
xpu_ctx->x_context(),
reinterpret_cast<void*>(decoder_batch_map_vec.data()),
reinterpret_cast<void*>(
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())),
sizeof(int32_t) * decoder_batch_map_vec.size());
ret = api::do_host2device(
xpu_ctx->x_context(),
reinterpret_cast<void*>(encoder_batch_idx_vec.data()),
reinterpret_cast<void*>(
const_cast<int32_t*>(encoder_batch_idx.data<int32_t>())),
sizeof(int32_t) * encoder_batch_idx_vec.size());
ret = api::do_host2device(
xpu_ctx->x_context(),
reinterpret_cast<void*>(decoder_batch_idx_vec.data()),
reinterpret_cast<void*>(
const_cast<int32_t*>(decoder_batch_idx.data<int32_t>())),
sizeof(int32_t) * decoder_batch_idx_vec.size());
ret = api::do_host2device(xpu_ctx->x_context(),
reinterpret_cast<void*>(encoder_seq_lod_vec.data()),
reinterpret_cast<void*>(const_cast<int32_t*>(
encoder_seq_lod.data<int32_t>())),
sizeof(int32_t) * encoder_seq_lod_vec.size());
ret = api::do_host2device(xpu_ctx->x_context(),
reinterpret_cast<void*>(decoder_seq_lod_vec.data()),
reinterpret_cast<void*>(const_cast<int32_t*>(
decoder_seq_lod.data<int32_t>())),
sizeof(int32_t) * decoder_seq_lod_vec.size());
ret = api::do_host2device(xpu_ctx->x_context(),
reinterpret_cast<void*>(encoder_kv_lod_vec.data()),
reinterpret_cast<void*>(const_cast<int32_t*>(
encoder_kv_lod.data<int32_t>())),
sizeof(int32_t) * encoder_kv_lod_vec.size());
ret = api::do_host2device(
xpu_ctx->x_context(),
reinterpret_cast<void*>(prefix_len_vec.data()),
reinterpret_cast<void*>(const_cast<int32_t*>(prefix_len.data<int32_t>())),
sizeof(int32_t) * prefix_len_vec.size());
ret = api::do_host2device(
xpu_ctx->x_context(),
reinterpret_cast<void*>(decoder_context_len_vec.data()),
reinterpret_cast<void*>(
const_cast<int32_t*>(decoder_context_len.data<int32_t>())),
sizeof(int32_t) * decoder_context_len_vec.size());
ret = api::do_host2device(
xpu_ctx->x_context(),
reinterpret_cast<void*>(decoder_context_len_cache_vec.data()),
reinterpret_cast<void*>(
const_cast<int32_t*>(decoder_context_len_cache.data<int32_t>())),
sizeof(int32_t) * decoder_context_len_cache_vec.size());
ret = api::do_host2device(
xpu_ctx->x_context(),
reinterpret_cast<void*>(prefix_block_tables_vec.data()),
reinterpret_cast<void*>(
const_cast<int32_t*>(prefix_block_tables.data<int32_t>())),
sizeof(int32_t) * prefix_block_tables_vec.size());
auto enc_batch_tensor = paddle::full(
{1}, enc_batch, seq_lens_encoder.type(), paddle::CPUPlace());
auto dec_batch_tensor = paddle::full(
{1}, dec_batch, seq_lens_encoder.type(), paddle::CPUPlace());
auto total_enc_len_tensor = paddle::full(
{1}, total_enc_len, seq_lens_encoder.type(), paddle::CPUPlace());
std::memcpy(encoder_batch_map_cpu.data<int32_t>(),
encoder_batch_map_vec.data(),
sizeof(int32_t) * encoder_batch_map_vec.size());
std::memcpy(decoder_batch_map_cpu.data<int32_t>(),
decoder_batch_map_vec.data(),
sizeof(int32_t) * decoder_batch_map_vec.size());
std::memcpy(encoder_batch_idx_cpu.data<int32_t>(),
encoder_batch_idx_vec.data(),
sizeof(int32_t) * encoder_batch_idx_vec.size());
std::memcpy(decoder_batch_idx_cpu.data<int32_t>(),
decoder_batch_idx_vec.data(),
sizeof(int32_t) * decoder_batch_idx_vec.size());
std::memcpy(encoder_seq_lod_cpu.data<int32_t>(),
encoder_seq_lod_vec.data(),
sizeof(int32_t) * encoder_seq_lod_vec.size());
std::memcpy(decoder_seq_lod_cpu.data<int32_t>(),
decoder_seq_lod_vec.data(),
sizeof(int32_t) * decoder_seq_lod_vec.size());
std::memcpy(encoder_kv_lod_cpu.data<int32_t>(),
encoder_kv_lod_vec.data(),
sizeof(int32_t) * encoder_kv_lod_vec.size());
std::memcpy(prefix_len_cpu.data<int32_t>(),
prefix_len_vec.data(),
sizeof(int32_t) * prefix_len_vec.size());
std::memcpy(decoder_context_len_cpu.data<int32_t>(),
decoder_context_len_vec.data(),
sizeof(int32_t) * decoder_context_len_vec.size());
std::memcpy(decoder_context_len_cache_cpu.data<int32_t>(),
decoder_context_len_cache_vec.data(),
sizeof(int32_t) * decoder_context_len_cache_vec.size());
return {encoder_batch_map_xpu,
decoder_batch_map_xpu,
encoder_batch_idx_xpu,
decoder_batch_idx_xpu,
encoder_seq_lod_xpu,
decoder_context_len_xpu,
decoder_context_len_cache_xpu,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
enc_batch_tensor,
dec_batch_tensor,
total_enc_len_tensor};
std::vector<int> len_info_vec = {enc_batch,
dec_batch,
total_enc_len,
max_seq_len,
max_kv_len,
prefix_block_num_per_seq};
auto len_info_cpu =
paddle::empty({6}, seq_lens_encoder.type(), paddle::CPUPlace());
std::memcpy(len_info_cpu.data<int32_t>(),
len_info_vec.data(),
sizeof(int32_t) * len_info_vec.size());
return {encoder_batch_map,
decoder_batch_map,
encoder_batch_idx,
decoder_batch_idx,
encoder_seq_lod,
decoder_seq_lod,
encoder_kv_lod,
prefix_len,
decoder_context_len,
decoder_context_len_cache,
prefix_block_tables,
encoder_batch_map_cpu,
decoder_batch_map_cpu,
encoder_batch_idx_cpu,
decoder_batch_idx_cpu,
encoder_seq_lod_cpu,
decoder_seq_lod_cpu,
encoder_kv_lod_cpu,
prefix_len_cpu,
decoder_context_len_cpu,
decoder_context_len_cache_cpu,
len_info_cpu};
}
std::vector<std::vector<int64_t>>
GetInferParamInferShape(const std::vector<int64_t> &seq_lens_encoder_shape,
const std::vector<int64_t> &seq_lens_decoder_shape) {
return {seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{seq_lens_encoder_shape[0] + 1},
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{1},
{1},
{1}};
std::vector<std::vector<int64_t>> GetInferParamInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& block_tables_shape) {
return {seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
block_tables_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
{seq_lens_encoder_shape[0] + 1},
seq_lens_encoder_shape,
seq_lens_encoder_shape,
seq_lens_encoder_shape,
{6}};
}
std::vector<paddle::DataType>
GetInferParamInferDtype(const paddle::DataType &seq_lens_encoder_dtype,
const paddle::DataType &seq_lens_decoder_dtype) {
return {
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype};
std::vector<paddle::DataType> GetInferParamInferDtype(
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& block_tables_dtype) {
return {
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, block_tables_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype,
seq_lens_encoder_dtype};
}
PD_BUILD_OP(get_infer_param)
.Inputs({"seq_lens_encoder", "seq_lens_decoder"})
.Outputs({"encoder_batch_map_xpu", "decoder_batch_map_xpu",
"encoder_batch_idx_xpu", "decoder_batch_idx_xpu",
"encoder_seq_lod_xpu", "decoder_context_len_xpu",
"decoder_context_len_cache_xpu", "encoder_batch_map_cpu",
"decoder_batch_map_cpu", "encoder_batch_idx_cpu",
"decoder_batch_idx_cpu", "encoder_seq_lod_cpu",
"decoder_context_len_cpu", "decoder_context_len_cache_cpu",
"enc_batch_tensor", "dec_batch_tensor", "total_enc_len_tensor"})
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"block_tables"})
.Outputs({"encoder_batch_map",
"decoder_batch_map",
"encoder_batch_idx",
"decoder_batch_idx",
"encoder_seq_lod",
"decoder_seq_lod",
"encoder_kv_lod",
"prefix_len",
"decoder_context_len",
"decoder_context_len_cache",
"prefix_block_tables",
"encoder_batch_map_cpu",
"decoder_batch_map_cpu",
"encoder_batch_idx_cpu",
"decoder_batch_idx_cpu",
"encoder_seq_lod_cpu",
"decoder_seq_lod_cpu",
"encoder_kv_lod_cpu",
"prefix_len_cpu",
"decoder_context_len_cpu",
"decoder_context_len_cache_cpu",
"len_info_cpu"})
.SetKernelFn(PD_KERNEL(GetInferParam))
.Attrs({"block_size: int"})
.SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetInferParamInferDtype));
@@ -86,7 +86,7 @@ std::vector<paddle::Tensor> WeightQuantize(const paddle::Tensor &x,
APPLY_WEIGHT_QUANTIZE_KERNEL(float);
} else {
PD_THROW("WeightQuantize not support x_type==%d",
static_cast<int>(x_type));
static_cast<int>(x_type));
return {};
}
}
+9 -10
View File
@@ -15,9 +15,9 @@ The scheduling flow is shown below - users randomly request IP and port, obtain
```python
prompts = [
"Hello, my name is",
"你好,请问今天是星期",
"请写6个以数字开头的成语",
"写一个300字的小说大纲,内容是李白穿越到现代,最后成为公司文职人员的故事",
"你好,请问今天是星期",
"请写6个以数字开头的成语",
"写一个300字的小说大纲,内容是李白穿越到现代,最后成为公司文职人员的故事",
"我要采访一位科幻作家,创建一个包含5个问题的列表"
]
@@ -83,9 +83,9 @@ python -m fastdeploy.entrypoints.openai.multi_api_server \
```
### Parameter Description
- num-servers: Number of API servers to launch
- ports: Ports for API servers
- args: Arguments for API servers
- num-servers: Number of API servers to launch
- ports: Ports for API servers
- args: Arguments for API servers
### Data Parallelism + Disaggregated Deployment
Refer to [Disaggregated Deployment](disaggregated.md#multi-machine-disaggregated-deployment)
@@ -94,9 +94,8 @@ Refer to [Disaggregated Deployment](disaggregated.md#multi-machine-disaggregated
For multi-machine deployment, ensure network cards support RDMA and all cluster nodes are interconnected.
**Note**:
* `KVCACHE_RDMA_NICS` specifies RDMA network cards for the current machine, multiple cards should be separated by commas.
* The repository provides an automatic RDMA network card detection script `bash scripts/get_rdma_nics.sh <device>`, where <device> can be `cpu` or `gpu`.
- `KVCACHE_RDMA_NICS` specifies RDMA network cards for the current machine, multiple cards should be separated by commas.
- The repository provides an automatic RDMA network card detection script `bash scripts/get_rdma_nics.sh <device>`, where <device> can be `cpu` or `gpu`.
**Prefill Instance**
```bash
@@ -148,4 +147,4 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-ttl 9000
--scheduler-topic "test" \
--splitwise-role "decode"
```
```
+3 -3
View File
@@ -73,10 +73,10 @@ Refer to the example code `offline_disaggregated_demo.py` in the `fastdeploy/dem
#### Prerequisite: Redis
> **⚠️ NOTE**
> **Redis requirement: version 6.2.0 or higher**
> **⚠️ NOTE**
> **Redis requirement: version 6.2.0 or higher**
> Versions below this may not support the required commands.
>
>
* Installation via `conda`
```bash
+57 -57
View File
@@ -1,71 +1,71 @@
# Multi-Node Deployment
## Overview
## Overview
Multi-node deployment addresses scenarios where a single machine's GPU memory is insufficient to support deployment of large models by enabling tensor parallelism across multiple machines.
## Environment Preparation
#### Network Requirements
1. All nodes must be within the same local network
2. Ensure bidirectional connectivity between all nodes (test using `ping` and `nc -zv`)
## Environment Preparation
### Network Requirements
1. All nodes must be within the same local network
2. Ensure bidirectional connectivity between all nodes (test using `ping` and `nc -zv`)
#### Software Requirements
1. Install the same version of FastDeploy on all nodes
2. [Recommended] Install and configure MPI (OpenMPI or MPICH)
#### Software Requirements
1. Install the same version of FastDeploy on all nodes
2. [Recommended] Install and configure MPI (OpenMPI or MPICH)
## Tensor Parallel Deployment
## Tensor Parallel Deployment
### Recommended Launch Method
We recommend using mpirun for one-command startup without manually starting each node.
### Recommended Launch Method
We recommend using mpirun for one-command startup without manually starting each node.
### Usage Instructions
1. Execute the same command on all machines
2. The IP order in the `ips` parameter determines the node startup sequence
3. The first IP will be designated as the master node
4. Ensure all nodes can resolve each other's hostnames
### Usage Instructions
1. Execute the same command on all machines
2. The IP order in the `ips` parameter determines the node startup sequence
3. The first IP will be designated as the master node
4. Ensure all nodes can resolve each other's hostnames
* Online inference startup example:
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--max-model-len 32768 \
--max-num-seqs 32 \
--tensor-parallel-size 16 \
--ips 192.168.1.101,192.168.1.102
```
* Online inference startup example:
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--max-model-len 32768 \
--max-num-seqs 32 \
--tensor-parallel-size 16 \
--ips 192.168.1.101,192.168.1.102
```
* Offline startup example:
```python
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
model_name_or_path = "baidu/ERNIE-4.5-300B-A47B-Paddle"
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
llm = LLM(model=model_name_or_path, tensor_parallel_size=16, ips="192.168.1.101,192.168.1.102")
if llm._check_master():
output = llm.generate(prompts="Who are you?", use_tqdm=True, sampling_params=sampling_params)
print(output)
```
* Offline startup example:
```python
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
* Notes:
- Only the master node can receive completion requests
- Always send requests to the master node (the first IP in the ips list)
- The master node will distribute workloads across all nodes
model_name_or_path = "baidu/ERNIE-4.5-300B-A47B-Paddle"
### Parameter Description
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
llm = LLM(model=model_name_or_path, tensor_parallel_size=16, ips="192.168.1.101,192.168.1.102")
if llm._check_master():
output = llm.generate(prompts="Who are you?", use_tqdm=True, sampling_params=sampling_params)
print(output)
```
#### `ips` Parameter
- **Type**: `string`
- **Format**: Comma-separated IPv4 addresses
- **Description**: Specifies the IP addresses of all nodes in the deployment group
- **Required**: Only for multi-node deployments
- **Example**: `"192.168.1.101,192.168.1.102,192.168.1.103"`
* Notes:
* Only the master node can receive completion requests
* Always send requests to the master node (the first IP in the ips list)
* The master node will distribute workloads across all nodes
#### `tensor_parallel_size` Parameter
- **Type**: `integer`
- **Description**: Total number of GPUs across all nodes
- **Required**: Yes
- **Example**: For 2 nodes with 8 GPUs each, set to 16
### Parameter Description
#### `ips` Parameter
* **Type**: `string`
* **Format**: Comma-separated IPv4 addresses
* **Description**: Specifies the IP addresses of all nodes in the deployment group
* **Required**: Only for multi-node deployments
* **Example**: `"192.168.1.101,192.168.1.102,192.168.1.103"`
#### `tensor_parallel_size` Parameter
* **Type**: `integer`
* **Description**: Total number of GPUs across all nodes
* **Required**: Yes
* **Example**: For 2 nodes with 8 GPUs each, set to 16
+5 -11
View File
@@ -12,15 +12,14 @@ FastDeploy 提供了splitwise scheduler,可以感知各个DP的负载状态,
具体调度流程如下图,用户随机请求ip 与端口,通过redis获取负载状态,将数据分发到负载较低的DP进行推理。
![数据调度架构图](./images/scheduler_img.png)
#### 离线推理
```python
prompts = [
"Hello, my name is",
"你好,请问今天是星期",
"请写6个以数字开头的成语",
"写一个300字的小说大纲,内容是李白穿越到现代,最后成为公司文职人员的故事",
"你好,请问今天是星期",
"请写6个以数字开头的成语",
"写一个300字的小说大纲,内容是李白穿越到现代,最后成为公司文职人员的故事",
"我要采访一位科幻作家,创建一个包含5个问题的列表"
]
@@ -65,11 +64,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-ttl 9000
```
### 用户自行调度
FastDeploy 提供了multi_api_server,用户可以拉起多个api server,用户自行选择dp 进行请求,在该种情况下用户可以自行添加负载均衡模型进行调度。(目前该种方式只支持在线推理)
#### 在线推理
![数据调度架构图](./images/no_scheduler_img.png)
@@ -95,8 +92,6 @@ python -m fastdeploy.entrypoints.openai.multi_api_server \
- ports: 指定拉起的api server 的端口
- args: 指定拉起的api server 的参数
### 数据并行 + 分离式部署
具体可以参考[分离式部署](disaggregated.md#多机分离式部署)
@@ -106,8 +101,8 @@ python -m fastdeploy.entrypoints.openai.multi_api_server \
多机部署时需要确认当前网卡是否支持RDMA,并且需要集群中所有节点网络互通。
**注意**
* `KVCACHE_RDMA_NICS` 指定当前机器的RDMA网卡,多个网卡用逗号隔开。
* 仓库中提供了自动检测RDMA网卡的脚本 `bash scripts/get_rdma_nics.sh <device>`, 其中 <device> 可以是 `cpu``gpu`
- `KVCACHE_RDMA_NICS` 指定当前机器的RDMA网卡,多个网卡用逗号隔开。
- 仓库中提供了自动检测RDMA网卡的脚本 `bash scripts/get_rdma_nics.sh <device>`, 其中 <device> 可以是 `cpu``gpu`
**prefill 实例**
@@ -163,4 +158,3 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-topic "test" \
--splitwise-role "decode"
```
+2 -2
View File
@@ -75,8 +75,8 @@ python -m fastdeploy.entrypoints.openai.api_server \
#### 前置依赖 Redis
* 使用`conda`安装
> **⚠️ 注意**
> **Redis 版本要求:6.2.0 及以上**
> **⚠️ 注意**
> **Redis 版本要求:6.2.0 及以上**
> 低于此版本可能不支持所需的命令。
```bash
+13 -15
View File
@@ -4,11 +4,10 @@
多节点部署旨在解决单个机器GPU显存不足时,支持跨多台机器的张量并行执行。
## 环境准备
#### 网络要求
### 网络要求
1. 所有节点必须在同一本地网络中
2. 确保所有节点之间双向连通(可使用`ping``nc -zv`测试)
#### 软件要求
1. 所有节点安装相同版本的FastDeploy
2. [建议安装]安装并配置MPIOpenMPI或MPICH
@@ -52,22 +51,21 @@
```
* 注意:
- 只有主节点可以接收完成请求
- 请始终将请求发送到主节点(ips列表中的第一个IP)
- 主节点将在所有节点间分配工作负载
* 只有主节点可以接收完成请求
* 请始终将请求发送到主节点(ips列表中的第一个IP)
* 主节点将在所有节点间分配工作负载
### 参数说明
#### `ips`参数
- **类型**: `字符串`
- **格式**: 逗号分隔的IPv4地址
- **描述**: 指定部署组中所有节点的IP地址
- **必填**: 仅多节点部署时需要
- **示例**: `"192.168.1.101,192.168.1.102,192.168.1.103"`
* **类型**: `字符串`
* **格式**: 逗号分隔的IPv4地址
* **描述**: 指定部署组中所有节点的IP地址
* **必填**: 仅多节点部署时需要
* **示例**: `"192.168.1.101,192.168.1.102,192.168.1.103"`
#### `tensor_parallel_size`参数
- **类型**: `整数`
- **描述**: 所有节点上的GPU总数
- **必填**: 是
- **示例**: 对于2个节点各8个GPU,设置为16
* **类型**: `整数`
* **描述**: 所有节点上的GPU总数
* **必填**: 是
* **示例**: 对于2个节点各8个GPU,设置为16
@@ -192,10 +192,12 @@ class EngineWorkerQueue:
"get_finish_request_barrier",
callable=lambda idx: self.finish_request_barrier[idx],
)
QueueManager.register(
"get_finish_add_cache_task_barrier",
callable=lambda idx: self.finish_add_cache_task_barrier[idx],
)
QueueManager.register(
"get_worker_process_tp_barrier",
callable=lambda idx: self.worker_process_tp_barrier[idx],
+17 -3
View File
@@ -193,7 +193,7 @@ class XPUForwardMeta(ForwardMeta):
# Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None
# TODO(wanghaitao): Supplementary notes
# TODO(yinwei): Supplementary notes
#
encoder_batch_map: Optional[paddle.Tensor] = None
#
@@ -205,10 +205,17 @@ class XPUForwardMeta(ForwardMeta):
#
encoder_seq_lod: Optional[paddle.Tensor] = None
#
decoder_seq_lod: Optional[paddle.Tensor] = None
#
encoder_kv_lod: Optional[paddle.Tensor] = None
#
prefix_len: Optional[paddle.Tensor] = None
#
decoder_context_len: Optional[paddle.Tensor] = None
#
decoder_context_len_cache: Optional[paddle.Tensor] = None
#
prefix_block_tables: Optional[paddle.Tensor] = None
#
encoder_batch_map_cpu: Optional[paddle.Tensor] = None
#
@@ -220,10 +227,17 @@ class XPUForwardMeta(ForwardMeta):
#
encoder_seq_lod_cpu: Optional[paddle.Tensor] = None
#
decoder_seq_lod_cpu: Optional[paddle.Tensor] = None
#
encoder_kv_lod_cpu: Optional[paddle.Tensor] = None
#
prefix_len_cpu: Optional[paddle.Tensor] = None
#
decoder_context_len_cpu: Optional[paddle.Tensor] = None
#
decoder_context_len_cache_cpu: Optional[paddle.Tensor] = None
#
len_info_cpu: Optional[paddle.Tensor] = None
#
batch_tensor: Optional[paddle.Tensor] = None
#
@@ -127,7 +127,6 @@ class XPUAttentionBackend(AttentionBackend):
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
) -> Tuple[int, int, int, int]:
"""
Calculate kv cache shape
@@ -164,6 +163,12 @@ class XPUAttentionBackend(AttentionBackend):
k_quant_scale = getattr(layer, "cache_k_scale", None)
v_quant_scale = getattr(layer, "cache_v_scale", None)
cache_k_out_scale = getattr(layer, "cache_k_out_scale", None)
cache_v_out_scale = getattr(layer, "cache_v_out_scale", None)
k_zp = getattr(self, "cache_k_zp", None)
v_zp = getattr(self, "cache_v_zp", None)
from fastdeploy.model_executor.ops.xpu import block_attn
res = block_attn(
@@ -173,17 +178,25 @@ class XPUAttentionBackend(AttentionBackend):
forward_meta.cum_offsets,
metadata.rotary_embs,
metadata.block_tables,
None,
k_quant_scale,
v_quant_scale,
forward_meta.enc_batch,
forward_meta.dec_batch,
forward_meta.total_enc_len,
forward_meta.prefix_block_tables,
forward_meta.len_info_cpu,
forward_meta.encoder_seq_lod_cpu,
forward_meta.decoder_seq_lod_cpu,
forward_meta.encoder_kv_lod_cpu,
forward_meta.encoder_batch_map_cpu,
forward_meta.decoder_context_len_cpu,
forward_meta.decoder_context_len_cache_cpu,
forward_meta.decoder_batch_map_cpu,
forward_meta.pos_emb_type,
self.rope_3d,
forward_meta.prefix_len_cpu,
k_quant_scale,
v_quant_scale,
cache_k_out_scale,
cache_v_out_scale,
k_zp, # zero_point_quant_scale
v_zp, # zero_point_quant_scale
None, # shift
None, # smooth
None, # kv_signal_data
None, # kv_signal_sender
)
return res
@@ -22,10 +22,12 @@ from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
)
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.xpu import (
ep_moe_expert_combine,
ep_moe_expert_dispatch,
moe_expert_ffn,
moe_topk_select,
weight_quantize_xpu,
)
@@ -196,6 +198,9 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
"""
Paddle xpu load weight process.
"""
# for k, v in state_dict.items():
# print(f"k : {k}, value.shape {v.shape}, value.dtype : {v.dtype}")
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
@@ -215,9 +220,14 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
weight_list = []
weight_scale_list = []
for i in range(layer.num_local_experts):
# print(f"=======================第{i}层=======================")
# print(f" wint4 未量化前权重: {weight_tensor[i]}")
quant_weight, scale = weight_quantize_xpu(
weight_tensor[i], self.moe_quant_type, -1, -1
) # weight is [k,n]
# print(f" wint4 量化后权重: {quant_weight}")
# print(f" wint4 量化后scale: {scale}")
weight_list.append(quant_weight.transpose([1, 0])) # transpose weight to [n,k]
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
@@ -235,31 +245,81 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
"""
XPU compute Fused MoE.
"""
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
# from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
fused_moe_out = xpu_moe_layer(
x,
gate.weight.transpose([1, 0]),
layer.gate_correction_bias,
# fused_moe_out = xpu_moe_layer(
# x,
# gate.weight.transpose([1, 0]),
# layer.gate_correction_bias,
# layer.up_gate_proj_weight,
# layer.down_proj_weight,
# None, # up_gate_proj bias
# None, # down_proj bias
# (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
# (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
# (layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
# self.moe_quant_type,
# layer.top_k,
# False, # moe group, used in deepseek
# )
# if layer.tp_size > 1:
# from fastdeploy.distributed.communication import (
# tensor_model_parallel_all_reduce,
# )
# tensor_model_parallel_all_reduce(fused_moe_out)
# return fused_moe_out
gate_out = paddle.matmul(x.cast("float32"), gate.weight.transpose([1, 0]), transpose_y=True)
topk_idx, topk_weights = moe_topk_select(gate_out, layer.gate_correction_bias, layer.top_k, True)
token_nums_per_expert_list = list(range(64)) # 填充做占位符
permute_input, permute_indices_per_token, token_num_lod, dst_weights, ffn1_act_scale_per_token = (
ep_moe_expert_dispatch(
x,
topk_idx,
topk_weights,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
token_nums_per_expert_list,
x.shape[0] * layer.top_k,
self.moe_quant_type,
)
)
ffn_out = moe_expert_ffn(
permute_input,
token_num_lod,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None, # up_gate_proj bias
None, # down_proj bias
None, # moe_ffn1_bias
None, # moe_ffn2_bias
None, # ffn1 in scale
None, # ffn2 in scale
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
None, # moe_ffn2_shift
None, # moe_ffn2_smooth
self.moe_quant_type,
layer.top_k,
False, # moe group, used in deepseek
-1,
x.shape[0] * layer.top_k, # token_all_num
)
topk_weights_bf16 = topk_weights.astype("bfloat16")
tmp_ffn_out = ep_moe_expert_combine(
ffn_out,
permute_indices_per_token,
topk_weights_bf16,
permute_indices_per_token.shape[0],
ffn_out.shape[0],
ffn_out.shape[1],
permute_indices_per_token.shape[1],
)
if layer.reduce_results and layer.tp_size > 1:
from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce,
)
tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out
tensor_model_parallel_all_reduce(tmp_ffn_out)
return tmp_ffn_out
class XPUWeightOnlyMoeEpMethod(XPUMoEMethod):
@@ -548,3 +608,260 @@ class XPUWeightOnlyMoeEpMethod(XPUMoEMethod):
# 4. EP combine
return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle)
class XPUW4A8MoEMethod(XPUMoEMethod):
"""
XPU w4a8 MoE Method
"""
def __init__(
self,
quant_config: WeightOnlyConfig,
) -> None:
super().__init__(quant_config)
self.quant_config = quant_config
self.moe_quant_type = "w4a8"
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"up_gate_proj_weight_scale",
"down_proj_weight_scale",
]
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Paddle cutlass create weight process.
"""
self.weight_dtype = "int8"
self.scale_dtype = "float32"
# get weight shape
if self.moe_quant_type in ["weight_only_int4", "w4a8"]:
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size // 2,
]
else:
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size * 2,
layer.hidden_size,
]
if self.moe_quant_type in ["weight_only_int4", "w4a8"]:
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size // 2,
]
else:
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size,
]
# set weight param
setattr(
layer,
self.added_weight_attrs[0],
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_weight_attrs[1],
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scales
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# in_scale
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
setattr(
layer,
in_scale_name,
layer.create_parameter(
shape=[layer.num_local_experts],
dtype=self.scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def paddle_swap_int4_pack_int4_0123_to_int8_1032in_int8(self, weight_tensor: paddle.Tensor) -> paddle.Tensor:
"""
Pack the last dimension of a tensor into int8 format by combining adjacent int4 values.
"""
mask = paddle.full_like(weight_tensor, 0x0F, dtype="int8")
high_4bit = (weight_tensor >> 4) & mask
low_4bit = weight_tensor & mask
swapped = (low_4bit << 4) | high_4bit
return swapped
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
load weight and process.
"""
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert up_gate_proj_weights[0].shape == [
layer.hidden_size // 2,
layer.moe_intermediate_size * 2,
]
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size // 2,
layer.hidden_size,
]
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
weight_list = []
for i in range(layer.num_local_experts):
weight_list.append(weight_tensor[i].transpose([1, 0])) # transpone to [n, k]
quanted_weight = paddle.stack(weight_list, axis=0)
getattr(layer, weight_name).set_value(quanted_weight)
self.load_w4a8_scale_weights(
layer,
layer.weight_key_map,
state_dict,
logical_expert_ids,
)
def load_w4a8_scale_weights(
self,
layer: nn.Layer,
weight_key_map: dict,
state_dict: dict,
logical_expert_ids: paddle.Tensor,
):
"""
Get w4a8 weights from state dict and process them.
Args:
layer (nn.Layer): The layer to add parameters to.
weight_key_map (dict): The weight key map.
state_dict (dict): The state dict.
"""
def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx):
return get_tensor(
(
state_dict.pop(key_template.format(expert_idx))
if key_template.format(expert_idx) in state_dict
else key_template.format(expert_idx)
),
layer.fd_config.model_config.model,
)
# 1. Init scale containers and maps
up_gate_proj_weight_scales = []
down_proj_weight_scales = []
up_gate_proj_in_scales = []
down_proj_in_scales = []
scale_weight_map = {
"up_gate_proj_weight_scale": up_gate_proj_weight_scales,
"down_proj_weight_scale": down_proj_weight_scales,
"up_gate_proj_in_scale": up_gate_proj_in_scales,
"down_proj_in_scale": down_proj_in_scales,
}
scale_key_map = {
"up_gate_proj_weight_scale": weight_key_map.get("up_gate_proj_expert_weight_scale_key", None),
"down_proj_weight_scale": weight_key_map.get("down_proj_expert_weight_scale_key", None),
"up_gate_proj_in_scale": weight_key_map.get("up_gate_proj_expert_in_scale_key", None),
"down_proj_in_scale": weight_key_map.get("down_proj_expert_in_scale_key", None),
}
for name, value in scale_key_map.items():
if value is None:
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
for expert_idx in logical_expert_ids:
for name, scale_key_template in scale_key_map.items():
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
scale_weight_map[name].append(scale_tensor)
# 2. Process scale tensor and set to layer
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
getattr(layer, in_scale_name).set_value(paddle.concat(scale_weight_map[in_scale_name]))
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
getattr(layer, weight_scale_name).set_value(paddle.stack(scale_weight_map[weight_scale_name], axis=0))
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
gate_out = paddle.matmul(x.cast("float32"), gate.weight.transpose([1, 0]), transpose_y=True)
topk_idx, topk_weights = moe_topk_select(gate_out, layer.gate_correction_bias, layer.top_k, True)
token_nums_per_expert_list = list(range(64)) # 填充做占位符
permute_input, permute_indices_per_token, token_num_lod, dst_weights, ffn1_act_scale_per_token = (
ep_moe_expert_dispatch(
x,
topk_idx,
topk_weights,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
token_nums_per_expert_list,
x.shape[0] * layer.top_k,
self.moe_quant_type,
)
)
ffn_out = moe_expert_ffn(
permute_input,
token_num_lod,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None, # moe_ffn1_bias
None, # moe_ffn2_bias
(ffn1_act_scale_per_token if hasattr(layer, "up_gate_proj_in_scale") else None),
(layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
None, # moe_ffn2_shift
None, # moe_ffn2_smooth
self.moe_quant_type,
getattr(layer.moe_quant_config, "hadamard_block_size", 128), # hadamard_blocksize defalue 128
x.shape[0] * layer.top_k, # token_all_num
)
topk_weights_bf16 = topk_weights.astype("bfloat16")
tmp_ffn_out = ep_moe_expert_combine(
ffn_out,
permute_indices_per_token,
topk_weights_bf16,
permute_indices_per_token.shape[0],
ffn_out.shape[0],
ffn_out.shape[1],
permute_indices_per_token.shape[1],
)
if layer.tp_size > 1:
from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce,
)
tensor_model_parallel_all_reduce(tmp_ffn_out)
return tmp_ffn_out
@@ -0,0 +1,231 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from paddle import nn
from fastdeploy.model_executor.layers.quantization.kv_cache import (
KvCacheQuantzationTypes,
)
from fastdeploy.model_executor.layers.quantization.quant_base import (
QuantConfigBase,
QuantMethodBase,
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs
class XPUKvCacheQuantConfig(QuantConfigBase):
"""
quantization config for weight 4bits and activation fp8
"""
def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool) -> None:
"""
__init__
"""
super().__init__()
self.kv_cache_quant_type = kv_cache_quant_type
self.is_channel_wise = is_channel_wise
try:
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
except ValueError:
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
if self.quant_type == KvCacheQuantzationTypes.INT8:
self.max_bound = 127.0
self.is_channel_wise = True
else:
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
def name(self) -> str:
"""
get_name
"""
return "kvcache"
@classmethod
def from_config(
cls, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool
) -> "XPUKvCacheQuantConfig":
"""
from_config
"""
return cls(kv_cache_quant_type, is_channel_wise, has_zero_point)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
"""
get_quant_method
"""
return XPUKVCacheMethodBase(self)
class XPUKVCacheMethodBase(QuantMethodBase):
"""
XPUKVCacheMethodBase: XPU need scale in fp32 format but GPU define all scale in bf16 format
"""
def __init__(
self,
quant_config: XPUKvCacheQuantConfig,
) -> None:
"""
XPUKVCacheMethodBase __init__
"""
super().__init__()
self.cache_quant_config = quant_config
def load_zp(self, layer: nn.Layer, state_dict):
"""
load_zp
"""
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast("float32")
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast("float32")
layer.cache_k_zp.set_value(cache_k_zeropoint)
layer.cache_v_zp.set_value(cache_v_zeropoint)
def load_scale(self, layer: nn.Layer, state_dict):
"""
load_scale
"""
cache_k_scale_tensor = get_tensor(state_dict.pop(self.cache_k_scale_name)).cast("float32").reshape_([-1])
cache_v_scale_tensor = get_tensor(state_dict.pop(self.cache_v_scale_name)).cast("float32").reshape_([-1])
if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT8:
# cache_k_scale and cache_v_scale are used to quantize the KV Cache, while cache_k_out_scale and cache_v_out_scale are used for inverse quantization
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
else:
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
# W4A8 model need kv_scale in bf16 format
layer.cache_k_scale.set_value(paddle.cast(cache_k_scale, paddle.get_default_dtype()))
layer.cache_v_scale.set_value(paddle.cast(cache_v_scale, paddle.get_default_dtype()))
layer.cache_k_out_scale.set_value(cache_k_out_scale)
layer.cache_v_out_scale.set_value(cache_v_out_scale)
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
create_weights
"""
if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT8:
layer.cache_quant_type_str = "cache_int8"
layer.quant_max_bound = 127.0
layer.quant_min_bound = -127.0
else:
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
scale_shape = [layer.fd_config.model_config.num_key_value_heads]
if self.cache_quant_config.is_channel_wise:
scale_shape = [layer.kv_num_heads * layer.head_dim]
layer.cache_k_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.cache_k_scale,
{
**extra_weight_attrs,
},
)
set_weight_attrs(
layer.cache_v_scale,
{
**extra_weight_attrs,
},
)
layer.cache_k_out_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_out_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.cache_quant_config.has_zero_point:
layer.cache_k_zp = layer.create_parameter(
shape=scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_zp = layer.create_parameter(
shape=scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.cache_k_zp,
{
**extra_weight_attrs,
},
)
set_weight_attrs(
layer.cache_v_zp,
{
**extra_weight_attrs,
},
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
use for loader v0
"""
self.prefix = layer.prefix
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
self.cache_k_zp_name = layer.prefix + ".cachek_matmul.activation_zero_point"
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"
if "block_wise" not in layer.cache_quant_type_str:
self.load_scale(layer, state_dict)
if self.cache_quant_config.has_zero_point:
self.load_zp(layer, state_dict)
def process_weights_after_loading(self, layer: nn.Layer):
"""
use for loader v1
"""
# cache_k_out_scale is the reciprocal of cache_k_scale
if layer.cache_k_scale._is_initialized():
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale) # cache_k_out_scale
if layer.cache_v_scale._is_initialized():
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
def apply(self, layer):
"""
apply
"""
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
@@ -22,6 +22,7 @@ from paddle import nn
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.platforms import current_platform
from .quant_base import QuantConfigBase, QuantMethodBase
@@ -94,7 +95,14 @@ class KvCacheQuantConfig(QuantConfigBase):
"""
get_quant_method
"""
return KVCacheMethodBase(self)
if current_platform.is_xpu():
from fastdeploy.model_executor.layers.backends.xpu.quantization.kv_cache import (
XPUKVCacheMethodBase,
)
return XPUKVCacheMethodBase(self)
else:
return KVCacheMethodBase(self)
class KVCacheMethodBase(QuantMethodBase):
@@ -118,6 +126,7 @@ class KVCacheMethodBase(QuantMethodBase):
"""
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype())
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype())
layer.cache_k_zp.set_value(cache_k_zeropoint)
layer.cache_v_zp.set_value(cache_v_zeropoint)
@@ -125,7 +134,6 @@ class KVCacheMethodBase(QuantMethodBase):
"""
load_scale
"""
cache_k_scale_tensor = (
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
@@ -186,6 +194,7 @@ class KVCacheMethodBase(QuantMethodBase):
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.cache_k_scale,
{
@@ -198,6 +207,7 @@ class KVCacheMethodBase(QuantMethodBase):
**extra_weight_attrs,
},
)
layer.cache_k_out_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
@@ -16,7 +16,8 @@
from typing import Optional
from ..moe import FusedMoE
from fastdeploy.platforms import current_platform
from .quant_base import QuantConfigBase, QuantMethodBase
@@ -40,11 +41,17 @@ class W4A8Config(QuantConfigBase):
return cls(is_permuted, hadamard_block_size)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE):
if current_platform.is_cuda():
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
CutlassW4A8MoEMethod,
)
return CutlassW4A8MoEMethod(self)
elif current_platform.is_xpu():
from fastdeploy.model_executor.layers.backends.xpu.moe.fused_moe import (
XPUW4A8MoEMethod,
)
return XPUW4A8MoEMethod(self)
else:
raise ValueError(f"Unsupported layer type {type(layer)} for w4a8")
+1 -1
View File
@@ -19,7 +19,6 @@ import io
import ipaddress
import mimetypes
import os
import random
import socket
import subprocess
import tempfile
@@ -103,6 +102,7 @@ def http_to_pil_image(url):
return pil_image
def base64_to_pil_image(base64_string):
"""base64_to_pil_image"""
image_bytes = base64.b64decode(base64_string)
+18 -65
View File
@@ -59,6 +59,7 @@ def xpu_pre_process(
seq_lens_this_time: int,
share_inputs: Dict,
use_speculate_method: bool,
block_size: int,
draft_tokens: Optional[paddle.Tensor] = None,
seq_lens_encoder: Optional[paddle.Tensor] = None,
seq_lens_decoder: Optional[paddle.Tensor] = None,
@@ -98,39 +99,35 @@ def xpu_pre_process(
caches=share_inputs["caches"],
)
# Get xpu extra param
(
xpu_forward_meta.encoder_batch_map,
xpu_forward_meta.decoder_batch_map,
xpu_forward_meta.encoder_batch_idx,
xpu_forward_meta.decoder_batch_idx,
xpu_forward_meta.encoder_seq_lod,
xpu_forward_meta.decoder_seq_lod,
xpu_forward_meta.encoder_kv_lod,
xpu_forward_meta.prefix_len,
xpu_forward_meta.decoder_context_len,
xpu_forward_meta.decoder_context_len_cache,
xpu_forward_meta.prefix_block_tables,
xpu_forward_meta.encoder_batch_map_cpu,
xpu_forward_meta.decoder_batch_map_cpu,
xpu_forward_meta.encoder_batch_idx_cpu,
xpu_forward_meta.decoder_batch_idx_cpu,
xpu_forward_meta.encoder_seq_lod_cpu,
xpu_forward_meta.decoder_seq_lod_cpu,
xpu_forward_meta.encoder_kv_lod_cpu,
xpu_forward_meta.prefix_len_cpu,
xpu_forward_meta.decoder_context_len_cpu,
xpu_forward_meta.decoder_context_len_cache_cpu,
xpu_forward_meta.enc_batch,
xpu_forward_meta.dec_batch,
xpu_forward_meta.total_enc_len,
) = get_infer_param(seq_lens_encoder, seq_lens_decoder)
# Adjust batch
# print(f"=========================adjust_batch 更新前=========================")
# print(f"ids_remove_padding : {ids_remove_padding}")
# print(f"cum_offsets : {cum_offsets}")
# print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}")
# print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}")
# print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}")
# print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}")
# print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}")
# print(f"xpu_forward_meta.dec_batch : {xpu_forward_meta.decoder_batch_map}")
xpu_forward_meta.len_info_cpu,
) = get_infer_param(
seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size
)
xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0]
xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1]
xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2]
adjusted_input = adjust_batch(
ids_remove_padding.reshape([-1, 1]),
@@ -146,16 +143,6 @@ def xpu_pre_process(
None, # output_padding_offset
-1, # max_input_length
)
# print(f"=========================adjust_batch 更新后=========================")
# print(f"ids_remove_padding : {ids_remove_padding}")
# print(f"cum_offsets : {cum_offsets}")
# print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}")
# print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}")
# print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}")
# print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}")
# print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}")
adjusted_input = adjusted_input.squeeze(1)
@@ -268,22 +255,6 @@ def xpu_post_process(
# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output:
# print(f"============================================update_inputs_v1 更新前=========================================")
# print(f"model_output.stop_flags : {model_output.stop_flags}")
# print(f"model_output.not_need_stop : {model_output.not_need_stop}")
# print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}")
# print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}")
# print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}")
# print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}")
# print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}")
# print(f"sampled_token_ids : {sampled_token_ids}")
# print(f"model_output.input_ids : {model_output.input_ids}")
# print(f"model_output.stop_nums : {model_output.stop_nums}")
# print(f"model_output.next_tokens : {model_output.next_tokens}")
# print(f"model_output.is_block_step : {model_output.is_block_step}")
# print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}")
# print(f"block_size : {block_size}")
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
@@ -300,21 +271,6 @@ def xpu_post_process(
model_output.is_block_step,
block_size,
)
# print(f"============================================update_inputs_v1 更新后=========================================")
# print(f"model_output.stop_flags : {model_output.stop_flags}")
# print(f"model_output.not_need_stop : {model_output.not_need_stop}")
# print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}")
# print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}")
# print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}")
# print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}")
# print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}")
# print(f"sampled_token_ids : {sampled_token_ids}")
# print(f"model_output.input_ids : {model_output.input_ids}")
# print(f"model_output.stop_nums : {model_output.stop_nums}")
# print(f"model_output.next_tokens : {model_output.next_tokens}")
# print(f"model_output.is_block_step : {model_output.is_block_step}")
# print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}")
# print(f"block_size : {block_size}")
else:
update_inputs(
model_output.stop_flags,
@@ -879,6 +835,7 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_this_time"],
self.share_inputs,
use_speculate_method=False,
block_size=self.parallel_config.block_size,
draft_tokens=None,
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
@@ -948,19 +905,15 @@ class XPUModelRunner(ModelRunnerBase):
# Get kv cache dtype
cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
if (
self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None
):
cache_type = "uint8"
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
cache_type = "int8"
# Get kv cache shape
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
)
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
for i in range(self.model_config.num_hidden_layers):
cache_kvs[f"key_caches_{i}"] = paddle.full(
+18 -26
View File
@@ -13,6 +13,7 @@ from core import TEMPLATE, URL, build_request_payload, send_request
COMPLETIONS_URL = URL.replace("/v1/chat/completions", "/v1/completions")
def test_completion_total_tokens():
data = {
"prompt": "你是谁",
@@ -48,7 +49,7 @@ def test_completion_echo_stream_one_prompt_rti():
"max_tokens": 2,
"return_token_ids": True,
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(COMPLETIONS_URL, payload, stream=True)
last_data = None
@@ -60,7 +61,7 @@ def test_completion_echo_stream_one_prompt_rti():
break
if line.strip() == "" or not line.startswith("data: "):
continue
line = line[len("data: "):]
line = line[len("data: ") :]
stream_data = json.loads(line)
counter += 1
if counter == 2: # 当计数器为2时,保存第二包数据
@@ -81,9 +82,9 @@ def test_completion_echo_stream_one_prompt():
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"echo": True,
"max_tokens": 2
"max_tokens": 2,
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(COMPLETIONS_URL, payload, stream=True)
last_data = None
@@ -95,7 +96,7 @@ def test_completion_echo_stream_one_prompt():
break
if line.strip() == "" or not line.startswith("data: "):
continue
line = line[len("data: "):]
line = line[len("data: ") :]
stream_data = json.loads(line)
counter += 1
if counter == 1: # 当计数器为1时,保存第一包数据
@@ -112,14 +113,14 @@ def test_completion_echo_stream_more_prompt():
测试echo参数在流式回复中,且设置为回复多个prompt
"""
data = {
"prompt": ["水果的营养价值是如何的?","水的化学式是什么?"],
"prompt": ["水果的营养价值是如何的?", "水的化学式是什么?"],
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"echo": True,
"max_tokens": 2,
"return_token_ids": True
"return_token_ids": True,
}
payload = build_request_payload(TEMPLATE, data)
resp = send_request(COMPLETIONS_URL, payload, stream=True)
last_data = None
@@ -136,9 +137,9 @@ def test_completion_echo_stream_more_prompt():
break
if line.strip() == "" or not line.startswith("data: "):
continue
line = line[len("data: "):]
line = line[len("data: ") :]
stream_data = json.loads(line)
for choice in stream_data.get("choices", []):
index = choice.get("index")
if index in packet_count_by_index:
@@ -183,13 +184,13 @@ def test_completion_echo_more_prompt():
"""
data = {
"stream": False,
"prompt": ["水果的营养价值是如何的?","水的化学式是什么?"],
"prompt": ["水果的营养价值是如何的?", "水的化学式是什么?"],
"echo": True,
"max_tokens": 100
"max_tokens": 100,
}
payload = build_request_payload(TEMPLATE, data)
response = send_request(COMPLETIONS_URL, payload).json()
text_0 = response["choices"][0]["text"]
text_1 = response["choices"][1]["text"]
assert data["prompt"][0] in text_0, "echo回显不正确"
@@ -204,12 +205,8 @@ def test_completion_finish_length():
"""
非流式回复中,因达到max_token截断检查finish_reasoning参数
"""
data = {
"stream": False,
"prompt": "水果的营养价值是如何的?",
"max_tokens": 10
}
data = {"stream": False, "prompt": "水果的营养价值是如何的?", "max_tokens": 10}
payload = build_request_payload(TEMPLATE, data)
response = send_request(COMPLETIONS_URL, payload).json()
@@ -221,15 +218,10 @@ def test_completion_finish_stop():
"""
非流式回复中,模型自然回复完成,检查finish_reasoning参数
"""
data = {
"stream": False,
"prompt": "简短的回答我:苹果是水果吗?"
}
data = {"stream": False, "prompt": "简短的回答我:苹果是水果吗?"}
payload = build_request_payload(TEMPLATE, data)
response = send_request(COMPLETIONS_URL, payload).json()
finish_reason = response["choices"][0]["finish_reason"]
assert finish_reason == "stop", "无任何中介,finish_reason不为stop"
+1 -1
View File
@@ -19,7 +19,7 @@ def test_45t():
ip = "0.0.0.0"
service_http_port = "8188" # 服务配置的
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
base_response = "你好!我是一个基于人工智能技术构建的助手,可以帮你解答问题、提供建议、辅助创作,或者陪你聊天解闷~😊 无论是学习、工作还是生活中的疑问,都可以随时告诉我哦!你今天有什么想聊的吗?"
base_response = "你好!我是一个基于人工智能技术开发的助手,可以帮你解答问题、提供建议、聊天交流或者完成一些任务。无论是学习、工作还是生活中的疑问,都可以随时告诉我哦~😊 你有什么想聊的吗?"
# 非流式对话
response = client.chat.completions.create(
model="default",