mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
@@ -1,4 +1,5 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
|
||||
// adapted from:
|
||||
// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,14 +12,16 @@
|
||||
namespace fastdeploy {
|
||||
|
||||
// Vectorization containers
|
||||
template <typename scalar_t> struct __align__(8) vec4_t {
|
||||
template <typename scalar_t>
|
||||
struct __align__(8) vec4_t {
|
||||
scalar_t x;
|
||||
scalar_t y;
|
||||
scalar_t z;
|
||||
scalar_t w;
|
||||
};
|
||||
|
||||
template <typename quant_type_t> struct __align__(4) q8x4_t {
|
||||
template <typename quant_type_t>
|
||||
struct __align__(4) q8x4_t {
|
||||
static_assert(std::is_same_v<quant_type_t, int8_t> ||
|
||||
std::is_same_v<quant_type_t, phi::dtype::float8_e4m3fn>);
|
||||
quant_type_t x;
|
||||
@@ -94,7 +97,8 @@ __global__ void segmented_max_reduction(float *__restrict__ scale,
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ float thread_max_vec(scalar_t const *__restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int64_t const num_elems,
|
||||
int const tid,
|
||||
int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const *vectorized_in =
|
||||
@@ -125,7 +129,8 @@ __device__ void scaled_fp8_conversion_vec(fp8_type *__restrict__ out,
|
||||
scalar_t const *__restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
int const tid,
|
||||
int const step) {
|
||||
using float8x4_t = q8x4_t<fp8_type>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const *vectorized_in = reinterpret_cast<vec4_t<scalar_t> const *>(input);
|
||||
@@ -156,4 +161,4 @@ __device__ void scaled_fp8_conversion_vec(fp8_type *__restrict__ out,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace fastdeploy
|
||||
|
||||
Reference in New Issue
Block a user