mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
@@ -77,16 +77,7 @@ void DisPatchW4AFp8Gemm(
|
|||||||
max_tokens,
|
max_tokens,
|
||||||
stream)
|
stream)
|
||||||
} else {
|
} else {
|
||||||
GEMM_SWITCH_FP16(
|
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
|
||||||
weight,
|
|
||||||
input,
|
|
||||||
out,
|
|
||||||
weight_scale,
|
|
||||||
input_row_sum,
|
|
||||||
tokens,
|
|
||||||
max_tokens,
|
|
||||||
stream)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,22 +119,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
|||||||
input.stream());
|
input.stream());
|
||||||
return {out};
|
return {out};
|
||||||
} else {
|
} else {
|
||||||
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
|
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||||
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
|
|
||||||
DisPatchW4AFp8Gemm(
|
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
|
||||||
tokens.data<int>(),
|
|
||||||
input_row_sum.data<float>(),
|
|
||||||
weight_scale.data<float>(),
|
|
||||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
|
||||||
token_padding_size,
|
|
||||||
max_tokens,
|
|
||||||
batch_size,
|
|
||||||
M,
|
|
||||||
K,
|
|
||||||
input.stream());
|
|
||||||
return {out};
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (is_bflot16) {
|
if (is_bflot16) {
|
||||||
@@ -164,23 +140,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
|||||||
input.stream());
|
input.stream());
|
||||||
return {out};
|
return {out};
|
||||||
} else {
|
} else {
|
||||||
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
|
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||||
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
|
|
||||||
|
|
||||||
DisPatchW4AFp8Gemm(
|
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
|
||||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
|
||||||
tokens.data<int>(),
|
|
||||||
input_row_sum.data<float>(),
|
|
||||||
weight_scale.data<float>(),
|
|
||||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
|
||||||
token_padding_size,
|
|
||||||
max_tokens,
|
|
||||||
batch_size,
|
|
||||||
M,
|
|
||||||
K,
|
|
||||||
input.stream());
|
|
||||||
return {out};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,14 +83,9 @@ void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gemm_case = [
|
gemm_case = [[256, 256, 1, 0]]
|
||||||
[8192, 3584, 8, 0], # eb45T ffn1
|
|
||||||
[8192, 3584, 8, 2048], # eb45T ffn1
|
|
||||||
[7168, 8192, 8, 0], # eb45T ffn2
|
|
||||||
[7168, 8192, 8, 2048], # eb45T ffn2
|
|
||||||
]
|
|
||||||
|
|
||||||
dtype = ["BF16", "FP16"]
|
dtype = ["BF16"]
|
||||||
|
|
||||||
|
|
||||||
def get_cutlass_type(type):
|
def get_cutlass_type(type):
|
||||||
|
|||||||
@@ -44,10 +44,10 @@ def peruate_scale(weight_scale):
|
|||||||
|
|
||||||
|
|
||||||
paddle.seed(0)
|
paddle.seed(0)
|
||||||
tokens_per_group = 32
|
tokens_per_group = 256
|
||||||
N = 8192
|
N = 256
|
||||||
K = 3584
|
K = 256
|
||||||
BATCH = 8
|
BATCH = 1
|
||||||
TokenPadding = 0
|
TokenPadding = 0
|
||||||
|
|
||||||
tokens = [tokens_per_group] * BATCH
|
tokens = [tokens_per_group] * BATCH
|
||||||
|
|||||||
Reference in New Issue
Block a user