mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -77,16 +77,7 @@ void DisPatchW4AFp8Gemm(
|
||||
max_tokens,
|
||||
stream)
|
||||
} else {
|
||||
GEMM_SWITCH_FP16(
|
||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||
weight,
|
||||
input,
|
||||
out,
|
||||
weight_scale,
|
||||
input_row_sum,
|
||||
tokens,
|
||||
max_tokens,
|
||||
stream)
|
||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,22 +119,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||
input.stream());
|
||||
return {out};
|
||||
} else {
|
||||
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
|
||||
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||
}
|
||||
} else {
|
||||
if (is_bflot16) {
|
||||
@@ -164,23 +140,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||
input.stream());
|
||||
return {out};
|
||||
} else {
|
||||
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
|
||||
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
|
||||
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,14 +83,9 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
||||
}}
|
||||
"""
|
||||
|
||||
gemm_case = [
|
||||
[8192, 3584, 8, 0], # eb45T ffn1
|
||||
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||
[7168, 8192, 8, 0], # eb45T ffn2
|
||||
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||
]
|
||||
gemm_case = [[256, 256, 1, 0]]
|
||||
|
||||
dtype = ["BF16", "FP16"]
|
||||
dtype = ["BF16"]
|
||||
|
||||
|
||||
def get_cutlass_type(type):
|
||||
|
||||
@@ -44,10 +44,10 @@ def peruate_scale(weight_scale):
|
||||
|
||||
|
||||
paddle.seed(0)
|
||||
tokens_per_group = 32
|
||||
N = 8192
|
||||
K = 3584
|
||||
BATCH = 8
|
||||
tokens_per_group = 256
|
||||
N = 256
|
||||
K = 256
|
||||
BATCH = 1
|
||||
TokenPadding = 0
|
||||
|
||||
tokens = [tokens_per_group] * BATCH
|
||||
|
||||
Reference in New Issue
Block a user