mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[Optimize]support machete weight only gemm (#3561)
* support machete weight only gemm * add generate * update * fix * change file location * add sm_version limit * fix * fix * fix ci * fix coverage * fix xpu
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/cute_utils.cuh
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// layout utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Permute layout based on indices, example:
|
||||
// permute_layout<1, 0>(layout) will swap the two dimensions
|
||||
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
|
||||
template <size_t... I, typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
|
||||
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
|
||||
return cute::make_layout(cute::get<I>(l)...);
|
||||
}
|
||||
|
||||
// is the layout f(x) = x
|
||||
template <typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
|
||||
if constexpr (std::is_same_v<Layout, void>) {
|
||||
return true;
|
||||
} else {
|
||||
constexpr auto coalesced_layout = coalesce(Layout{});
|
||||
if constexpr (rank(coalesced_layout) == 1 &&
|
||||
stride<0>(coalesced_layout) == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// Pointer utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class PointerType>
|
||||
static constexpr auto get_logical_ptr(PointerType* ptr) {
|
||||
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
|
||||
return cute::subbyte_iterator<PointerType>(ptr);
|
||||
} else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// Misc utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename Elements>
|
||||
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
|
||||
constexpr auto bits = sizeof_bits_v<T> * Elements{};
|
||||
if constexpr (bits % 128 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<128>{};
|
||||
} else if constexpr (bits % 64 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<64>{};
|
||||
} else if constexpr (bits % 32 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<32>{};
|
||||
} else if constexpr (bits % 16 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<16>{};
|
||||
} else {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<8>{};
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace cute
|
||||
@@ -0,0 +1,44 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_collective_builder.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
//
|
||||
// MacheteCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
|
||||
// for custom kernel tags, allowing you to build custom collectives. Without
|
||||
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
|
||||
// will resort to using the standard cutlass collective builder.
|
||||
//
|
||||
|
||||
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
|
||||
// collective
|
||||
struct CutlassKernelTag {};
|
||||
|
||||
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
|
||||
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
|
||||
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
|
||||
class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType, class Enable = void>
|
||||
struct MacheteCollectiveBuilder {
|
||||
static_assert(sizeof(ElementA) == 0,
|
||||
"Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
|
||||
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType>
|
||||
struct MacheteCollectiveBuilder<
|
||||
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
|
||||
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType> {
|
||||
using CollectiveOp = typename CollectiveBuilder<
|
||||
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
||||
@@ -0,0 +1,51 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_custom_types.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, int Bias, bool Signed = false>
|
||||
struct machete_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
|
||||
using Base = integer_subbyte<Bits, Signed>;
|
||||
|
||||
using Storage = typename Base::Storage;
|
||||
using xint_t = typename Base::xint_t;
|
||||
|
||||
using Base::bits_mask_;
|
||||
using Base::sign_mask_;
|
||||
using Base::storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// No operation
|
||||
machete_biased_integer_subbyte() = default;
|
||||
|
||||
/// Conversion from integer type
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(int value)
|
||||
: Base(value) {}
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(unsigned value)
|
||||
: Base(value) {}
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(double value)
|
||||
: Base(value) {}
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// "GPTQ" types, i.e. symmetric quantization
|
||||
using machete_uint4b8_t = machete_biased_integer_subbyte<4, 8>; // u4b8
|
||||
using machete_uint8b128_t = machete_biased_integer_subbyte<8, 128>; // u8b128
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, int Bias, bool Signed>
|
||||
struct sizeof_bits<machete_biased_integer_subbyte<Bits, Bias, Signed>> {
|
||||
static constexpr int value = Bits;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,993 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "machete_custom_types.cuh"
|
||||
#include "cute_utils.cuh"
|
||||
#include "machete_type_utils.cuh"
|
||||
|
||||
// this file extends:
|
||||
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
|
||||
// with vllm specific type conversions, namely: machete_uint4b8_t, machete_uint8b128_t
|
||||
// as well as adds interleaved numeric array converters for specific types.
|
||||
// (interleaved numeric array converters can be more efficient for subbyte
|
||||
// types)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
|
||||
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
|
||||
// make subbyte converts more efficient by allowing for efficient extraction
|
||||
// of subbyte elements from a 32bit register.
|
||||
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
||||
class Enable = void>
|
||||
struct InterleavedNumericArrayConverter {
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
using result_type = typename Converter::result_type;
|
||||
using source_type = typename Converter::source_type;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
|
||||
printf(
|
||||
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
|
||||
nameof_v<T>, nameof_v<S>, N);
|
||||
} else {
|
||||
printf(
|
||||
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
|
||||
"implemented\n",
|
||||
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
|
||||
}
|
||||
__brkpt();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||
FloatRoundStyle Round>
|
||||
struct InterleavedNumericArrayConverter<
|
||||
IlvBlkLayout, T, S, N, Round,
|
||||
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
using result_type = typename Converter::result_type;
|
||||
using source_type = typename Converter::source_type;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return Converter::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
template <typename RegConvert32bit, typename T, typename S, int N>
|
||||
struct ArrayConverterPacked32Bit {
|
||||
using result_type = Array<T, N>;
|
||||
using source_type = Array<S, N>;
|
||||
|
||||
using result_packed_8_t = Array<T, 8>;
|
||||
using result_packed_4_t = Array<T, 4>;
|
||||
using result_packed_2_t = Array<T, 2>;
|
||||
using src_packed_8_t = Array<S, 8>;
|
||||
using src_packed_4_t = Array<S, 4>;
|
||||
using src_packed_2_t = Array<S, 2>;
|
||||
|
||||
static_assert(N % 2 == 0, "N must be a multiple of 2");
|
||||
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
|
||||
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
|
||||
static constexpr auto src_elems_per_32bit_reg =
|
||||
32 / cutlass::sizeof_bits_v<S>;
|
||||
|
||||
// Maybe not Valid. ScalarConverter will not actually work unless
|
||||
// NumericConverter<T, S, Round> is implemented. However it won't be used
|
||||
// anyways since we assert N % 2 == 0, just here for compliance with
|
||||
// VectorizedConverter.
|
||||
using ScalarConverter = NumericConverter<T, S>;
|
||||
|
||||
template <typename PackedSrc>
|
||||
CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
|
||||
if constexpr (sizeof(PackedSrc) == 1) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
|
||||
} else if constexpr (sizeof(PackedSrc) == 2) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
|
||||
} else if constexpr (sizeof(PackedSrc) == 4) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
|
||||
} else {
|
||||
static_assert(sizeof(PackedSrc) == 8);
|
||||
return reinterpret_cast<Array<uint32_t, 2> const&>(src);
|
||||
}
|
||||
}
|
||||
|
||||
// The core converter uses bit tricks to construct a known FP16 number, then
|
||||
// does a subtraction in FP16 for the final result.
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||
PackedSrcType const& source) {
|
||||
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
|
||||
static_assert(PackedResultType::kElements == 2 ||
|
||||
PackedResultType::kElements == 4 ||
|
||||
PackedResultType::kElements == 8,
|
||||
"Invalid PackedResultType must be 2, 4 or 8.");
|
||||
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
|
||||
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
|
||||
|
||||
return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE static result_type convert(source_type const& source) {
|
||||
result_type result;
|
||||
using ConverterType =
|
||||
ArrayConverterPacked32Bit<RegConvert32bit,
|
||||
typename result_type::Element,
|
||||
typename source_type::Element, N>;
|
||||
|
||||
if constexpr (src_elems_per_32bit_reg >= 8) {
|
||||
detail::VectorizedConverter::convert<
|
||||
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
|
||||
} else if constexpr (src_elems_per_32bit_reg >= 4) {
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
} else {
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
|
||||
// into 2 32bit register.
|
||||
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
|
||||
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
|
||||
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
|
||||
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
|
||||
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
|
||||
uint32_t src) {
|
||||
cutlass::AlignedArray<uint32_t, 2> r;
|
||||
// Determines if the value is in the top half of the LUT if set or
|
||||
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
|
||||
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
|
||||
// selects the correct candidate. When elements in final_prmt_base
|
||||
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
|
||||
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
|
||||
uint32_t high_bit = (src & 0x88888888) >> 1;
|
||||
|
||||
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
|
||||
// (selects correct high or low candidate)
|
||||
const uint32_t final_prmt_base = 0x32103210;
|
||||
|
||||
// Ignore the high bit when indexing into LUT, for each 4bit value
|
||||
// we index into both the high and low candidates then use
|
||||
// high_bit | final_prmt_base to select the correct candidate
|
||||
uint32_t lut_idx = (src & 0x77777777);
|
||||
|
||||
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
|
||||
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
|
||||
(uint32_t(d) << 24);
|
||||
};
|
||||
|
||||
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
|
||||
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
|
||||
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
|
||||
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
|
||||
uint32_t final_prmt_idx = final_prmt_base | high_bit;
|
||||
|
||||
// This uses a look up table to convert packed int4s to packed int8s,
|
||||
// using the int4 value as the index to prmt. It first select both the
|
||||
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
|
||||
// select the correct candidate.
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .b32 low, high;\n"
|
||||
" prmt.b32 low, %1, %2, %5;\n"
|
||||
" prmt.b32 high, %3, %4, %5;\n"
|
||||
" prmt.b32 %0, low, high, %6;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
|
||||
"r"(final_prmt_idx));
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
|
||||
// for Array<int8_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<int8_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
|
||||
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
|
||||
0xFC, 0xFD, 0xFE, 0xFF, //
|
||||
0x00, 0x01, 0x02, 0x03, //
|
||||
0x04, 0x05, 0x06, 0x07>(src_[0]);
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::float_e4m3_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::float_e4m3_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::float_e4m3_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
|
||||
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
|
||||
0xC8, 0xC4, 0xC0, 0xB8, //
|
||||
0x00, 0x38, 0x40, 0x44, //
|
||||
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
|
||||
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
|
||||
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
|
||||
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
|
||||
// We use inline asm instead of __byte_perm intrinsic since we don't want
|
||||
// the documented (& 0x7) on the index. NVCC might be able to optimize it
|
||||
// out since the index is a constexpr, but we choose to be safe about it
|
||||
// here.
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
static_assert(RegArray::kElements <= 4,
|
||||
"Too many inputs for F16 -> I4 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||
// we are trying to construct x and a fp16 value
|
||||
// The below XOR does the following:
|
||||
// 1) Sets the exponent bits of the FP16 to the correct value for the
|
||||
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
|
||||
// where x1 in the high nibble and x0 is the low nibble then using hfma
|
||||
// to subtract 1032 from that
|
||||
// The AND does the following:
|
||||
// 1) Clear the set bits for the int4 we will ignore.
|
||||
// We use lop3 so that we can use 1 instruction for AND and XOR.
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
static constexpr uint32_t and_mask = 0xFFF0FF0F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 hfmas that do the following:
|
||||
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
|
||||
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
|
||||
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
|
||||
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
|
||||
|
||||
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
|
||||
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::half_t, machete_uint4b8_t, N,
|
||||
Round, void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
auto src_ = src >> (4 * (ii));
|
||||
r[ii + 0] = src_;
|
||||
r[ii + 1] = src_;
|
||||
|
||||
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 1])
|
||||
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
|
||||
// For high nibble:
|
||||
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
|
||||
// - {72, 72}
|
||||
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
|
||||
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||
}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(high_nib_scale),
|
||||
reinterpret_cast<const half2&>(high_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::half_t, uint4_t, N, Round,
|
||||
void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<uint4_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
auto src_ = src >> (4 * (ii));
|
||||
r[ii + 0] = src_;
|
||||
r[ii + 1] = src_;
|
||||
|
||||
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 1])
|
||||
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
|
||||
// For high nibble:
|
||||
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
|
||||
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
|
||||
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||
}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(high_nib_scale),
|
||||
reinterpret_cast<const half2&>(high_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(start_byte_for_fp16),
|
||||
"r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
|
||||
static constexpr uint32_t bias_rep = 0x64806480;
|
||||
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hsub2(fp16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::float, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<float, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<float, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
PackedResultType r;
|
||||
|
||||
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
|
||||
// u8x4 source and stores the result in r (without introducing extra
|
||||
// cvt.u32.u8 instruction)
|
||||
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
|
||||
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
|
||||
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
|
||||
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
|
||||
// Subtract the magic number 0x4B000000 from tmp in floating-point
|
||||
// arithmetic to obtain final result
|
||||
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src_reg = src_[0];
|
||||
// Hold output BF16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
uint32_t src_reg_shifted = src_reg >> 4;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
|
||||
static_assert(RegArray::kElements <= 4,
|
||||
"Too many inputs for uint4b8_t -> BF16 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||
// we are trying to construct x and a BF16 value
|
||||
// The below XOR does the following:
|
||||
// 1) Sets the exponent bits of the BF16 to the correct value for the
|
||||
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
|
||||
// and subtracting 136 to get {x1, x0}
|
||||
static constexpr uint32_t xor_mask = 0x43004300;
|
||||
static constexpr uint32_t and_mask = 0x000F000F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 bfmas that do the following:
|
||||
// high BF16:
|
||||
// hi_bf16 - 136, lo_bf16 - 136
|
||||
|
||||
// This is the BF16 {136, 136} represented as an integer.
|
||||
static constexpr uint32_t bias_rep = 0x43084308;
|
||||
const __nv_bfloat162& bias =
|
||||
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::bfloat16_t, machete_uint4b8_t, N,
|
||||
Round, void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t or_mask = 0x43004300;
|
||||
|
||||
// Unlike float16 where the mantissa is large enough to contain 2
|
||||
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||
// nibble at a time
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
r[ii] = src >> (4 * ii);
|
||||
|
||||
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
|
||||
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
|
||||
|
||||
{
|
||||
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val,
|
||||
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::bfloat16_t, uint4_t, N, Round,
|
||||
void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<uint4_t, N>;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t or_mask = 0x43004300;
|
||||
|
||||
// Unlike float16 where the mantissa is large enough to contain 2
|
||||
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||
// nibble at a time
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
r[ii] = src >> (4 * ii);
|
||||
|
||||
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
|
||||
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
|
||||
|
||||
{
|
||||
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val,
|
||||
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
|
||||
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
|
||||
using src_packed_4_t = Array<machete_uint8b128_t, 4>;
|
||||
using src_packed_2_t = Array<machete_uint8b128_t, 2>;
|
||||
|
||||
// Not Valid, not supported, only here to satisfy the interface and to avoid
|
||||
// a compile error. ScalarConverter will not actually work until
|
||||
// NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round> is
|
||||
// implemented
|
||||
using ScalarConverter =
|
||||
NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round>;
|
||||
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||
PackedSrcType const& source) {
|
||||
static_assert(
|
||||
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
|
||||
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
|
||||
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
|
||||
platform::is_same<PackedResultType, result_packed_4_t>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
|
||||
"convert dispatch.");
|
||||
|
||||
NumericArrayConverter<float, machete_uint8b128_t, PackedResultType::kElements,
|
||||
Round>
|
||||
convert_uint8_to_f32;
|
||||
Array<float, PackedResultType::kElements> tmp =
|
||||
convert_uint8_to_f32(source);
|
||||
NumericArrayConverter<cutlass::bfloat16_t, float,
|
||||
PackedResultType::kElements, Round>
|
||||
convert_f32_to_bf16_;
|
||||
return convert_f32_to_bf16_(tmp);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
result_type result;
|
||||
using ConverterType =
|
||||
NumericArrayConverter<typename result_type::Element,
|
||||
typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
|
||||
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<cutlass::half_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||
template <typename PackedResultType, int src_regs>
|
||||
CUTLASS_DEVICE static PackedResultType convert(
|
||||
Array<uint32_t, src_regs> src) {
|
||||
// Hold output int8s in reg. We need 1 reg for every 4 elements
|
||||
using RegArray = cutlass::AlignedArray<
|
||||
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
|
||||
RegArray r;
|
||||
|
||||
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
|
||||
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
|
||||
|
||||
*reinterpret_cast<half2*>(&src[0]) =
|
||||
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
|
||||
|
||||
if constexpr (src_regs > 1) {
|
||||
*reinterpret_cast<half2*>(&src[1]) =
|
||||
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
|
||||
}
|
||||
|
||||
static_assert(PackedResultType::kElements <= 4);
|
||||
uint32_t uint8s;
|
||||
static constexpr uint32_t MASK_0246 = 0x6420;
|
||||
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||
: "=r"(uint8s)
|
||||
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
|
||||
"n"(MASK_0246));
|
||||
|
||||
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(int8s);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,43 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cuda_bf16.h"
|
||||
|
||||
#include "machete_custom_types.cuh"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <typename T>
|
||||
struct nameof {
|
||||
static constexpr char const* value = "unknown";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline constexpr auto nameof_v = nameof<T>::value;
|
||||
|
||||
#define NAMEOF_TYPE(T) \
|
||||
template <> \
|
||||
struct nameof<T> { \
|
||||
static constexpr char const* value = #T; \
|
||||
};
|
||||
|
||||
NAMEOF_TYPE(float_e4m3_t)
|
||||
NAMEOF_TYPE(float_e5m2_t)
|
||||
NAMEOF_TYPE(half_t)
|
||||
NAMEOF_TYPE(nv_bfloat16)
|
||||
NAMEOF_TYPE(bfloat16_t)
|
||||
NAMEOF_TYPE(float)
|
||||
|
||||
NAMEOF_TYPE(int4b_t)
|
||||
NAMEOF_TYPE(int8_t)
|
||||
NAMEOF_TYPE(int32_t)
|
||||
NAMEOF_TYPE(int64_t)
|
||||
|
||||
NAMEOF_TYPE(machete_uint4b8_t)
|
||||
NAMEOF_TYPE(uint4b_t)
|
||||
NAMEOF_TYPE(uint8_t)
|
||||
NAMEOF_TYPE(machete_uint8b128_t)
|
||||
NAMEOF_TYPE(uint32_t)
|
||||
NAMEOF_TYPE(uint64_t)
|
||||
|
||||
}; // namespace cutlass
|
||||
@@ -0,0 +1,161 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/half.h"
|
||||
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
|
||||
namespace cute {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
|
||||
seq<I...>) {
|
||||
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
|
||||
}
|
||||
|
||||
template <class F, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
|
||||
return make_shape(f(I)...);
|
||||
}
|
||||
|
||||
}; // namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
|
||||
if constexpr (cute::is_tuple<T>::value) {
|
||||
return detail::tapply_with_idx(
|
||||
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
|
||||
tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// calls: make_shape(f(0), f(1), ..., f(N-1))
|
||||
template <int N, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
||||
return detail::make_shape_from_idx(f, make_seq<N>{});
|
||||
}
|
||||
|
||||
}; // namespace cute
|
||||
|
||||
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
|
||||
// shape of the passed in tensor and the strides are of type `Stride` and
|
||||
// contain the strides of the passed in tensor, checking that any static strides
|
||||
// in `Stride{}` match the strides of the passed in tensor.
|
||||
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||
// strides are set to be 0 or 1.
|
||||
template <typename Stride>
|
||||
static inline auto make_cute_layout(paddle::Tensor const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
PD_CHECK(tensor.shape().size() <= rank(Stride{}));
|
||||
auto stride = cute::transform_with_idx(
|
||||
Stride{}, [&](auto const& stride_ele, auto const& idx) {
|
||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||
|
||||
if (idx < tensor.shape().size()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
PD_CHECK(StrideEle::value == tensor.strides()[idx], "Expected ",
|
||||
name, ".strides()[", idx, "] to be ", StrideEle::value, ", but got ", tensor.strides()[idx], ". ");
|
||||
return StrideEle{};
|
||||
} else {
|
||||
if (tensor.shape()[idx] == 1) {
|
||||
// use 0 stride for dims with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
return tensor.strides()[idx];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||
}
|
||||
return StrideEle{};
|
||||
}
|
||||
});
|
||||
|
||||
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
||||
if (idx < tensor.shape().size())
|
||||
return tensor.shape()[idx];
|
||||
else
|
||||
return int64_t(1);
|
||||
});
|
||||
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
|
||||
template <typename Stride>
|
||||
static inline auto maybe_make_cute_layout(
|
||||
std::optional<paddle::Tensor> const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||
|
||||
if (tensor) {
|
||||
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
|
||||
} else {
|
||||
return std::optional<Layout>{};
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Paddle dtype to Cutlass Type (equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
template <typename T>
|
||||
struct equivalent_cutlass_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<phi::dtype::float16> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<phi::dtype::bfloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
//
|
||||
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
|
||||
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
|
||||
template <typename T>
|
||||
struct equivalent_scalar_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::half_t> {
|
||||
using type = phi::dtype::float16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
||||
using type = phi::dtype::bfloat16;
|
||||
};
|
||||
|
||||
// get equivalent c10::ScalarType tag from compile time type
|
||||
template <typename T>
|
||||
static inline constexpr paddle::DataType equivalent_scalar_type_v =
|
||||
phi::CppTypeToDataType<equivalent_scalar_type_t<T>>::Type();
|
||||
@@ -0,0 +1,372 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/phi/common/data_type.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
namespace machete {
|
||||
|
||||
//
|
||||
// ScalarType can represent a wide range of floating point and integer types,
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
class ScalarType {
|
||||
public:
|
||||
enum NanRepr : uint8_t {
|
||||
NAN_NONE = 0, // nans are not supported
|
||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
||||
int32_t bias, bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr) {};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
}
|
||||
|
||||
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits, false, bias);
|
||||
}
|
||||
|
||||
// IEEE 754 compliant floating point type
|
||||
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
||||
uint8_t mantissa) {
|
||||
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
|
||||
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
||||
bool finite_values_only,
|
||||
NanRepr nan_repr) {
|
||||
// PADDLE_ENFORCE(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
|
||||
// PADDLE_ENFORCE(nan_repr != NAN_IEEE_754,
|
||||
// "use `float_IEEE754` constructor for floating point types that "
|
||||
// "follow IEEE 754 conventions");
|
||||
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
||||
nan_repr);
|
||||
}
|
||||
|
||||
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
||||
// excluding the sign bit for integer types)
|
||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||
// sign bit)
|
||||
int32_t const bias; // stored values equal value + bias,
|
||||
// used for quantized type
|
||||
|
||||
// Extra Floating point info
|
||||
bool const finite_values_only; // i.e. no +/-inf if true
|
||||
NanRepr const nan_repr; // how NaNs are represented
|
||||
// (not applicable for integer types)
|
||||
|
||||
using Id = int64_t;
|
||||
|
||||
private:
|
||||
// Field size in id
|
||||
template <typename T_>
|
||||
static constexpr size_t member_id_field_width() {
|
||||
using T = std::decay_t<T_>;
|
||||
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
||||
Rest... rest) {
|
||||
auto new_val = f(val, member);
|
||||
if constexpr (sizeof...(rest) > 0) {
|
||||
return reduce_members_helper(f, new_val, rest...);
|
||||
} else {
|
||||
return new_val;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
constexpr auto reduce_members(Fn f, Init init) const {
|
||||
// Should be in constructor order for `from_id`
|
||||
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
||||
finite_values_only, nan_repr);
|
||||
};
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
static constexpr auto reduce_member_types(Fn f, Init init) {
|
||||
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
||||
return dummy_type.reduce_members(f, init);
|
||||
};
|
||||
|
||||
static constexpr auto id_size_bits() {
|
||||
return reduce_member_types(
|
||||
[](int acc, auto member) -> int {
|
||||
return acc + member_id_field_width<decltype(member)>();
|
||||
},
|
||||
0);
|
||||
}
|
||||
|
||||
public:
|
||||
// unique id for this scalar type that can be computed at compile time for
|
||||
// c++17 template specialization this is not needed once we migrate to
|
||||
// c++20 and can pass literal classes as template parameters
|
||||
constexpr Id id() const {
|
||||
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
||||
"ScalarType id is too large to be stored");
|
||||
|
||||
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
||||
auto member) -> std::pair<Id, uint32_t> {
|
||||
auto [id, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<decltype(member)>();
|
||||
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
||||
<< bit_offset,
|
||||
bit_offset + bits};
|
||||
};
|
||||
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
||||
}
|
||||
|
||||
// create a ScalarType from an id, for c++17 template specialization,
|
||||
// this is not needed once we migrate to c++20 and can pass literal
|
||||
// classes as template parameters
|
||||
static constexpr ScalarType from_id(Id id) {
|
||||
auto extract_and_advance = [id](auto result, auto member) {
|
||||
using T = decltype(member);
|
||||
auto [tuple, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<T>();
|
||||
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
||||
((uint64_t(1) << bits) - 1));
|
||||
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
||||
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
||||
};
|
||||
|
||||
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
||||
std::pair<std::tuple<>, int>{});
|
||||
return std::apply([](auto... args) { return ScalarType(args...); },
|
||||
tuple_args);
|
||||
}
|
||||
|
||||
constexpr int64_t size_bits() const {
|
||||
return mantissa + exponent + is_signed();
|
||||
}
|
||||
constexpr bool is_signed() const { return signed_; }
|
||||
constexpr bool is_integer() const { return exponent == 0; }
|
||||
constexpr bool is_floating_point() const { return exponent > 0; }
|
||||
constexpr bool is_ieee_754() const {
|
||||
return is_floating_point() && finite_values_only == false &&
|
||||
nan_repr == NAN_IEEE_754;
|
||||
}
|
||||
constexpr bool has_nans() const {
|
||||
return is_floating_point() && nan_repr != NAN_NONE;
|
||||
}
|
||||
constexpr bool has_infs() const {
|
||||
return is_floating_point() && finite_values_only == false;
|
||||
}
|
||||
constexpr bool has_bias() const { return bias != 0; }
|
||||
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
PADDLE_ENFORCE(mantissa <= 52 && exponent <= 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
max_mantissa -= 1;
|
||||
}
|
||||
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
PADDLE_ENFORCE(exponent < 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
// adjust the exponent to match that of a double
|
||||
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
||||
// is the exponent bits), there is some precedent for non-standard biases,
|
||||
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
||||
// but to avoid premature over complication we are just assuming the
|
||||
// standard exponent bias until there is a need to support non-standard
|
||||
// biases
|
||||
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
||||
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
||||
|
||||
uint64_t max_exponent_double =
|
||||
max_exponent - exponent_bias + exponent_bias_double;
|
||||
|
||||
// shift the mantissa into the position for a double and
|
||||
// the exponent
|
||||
uint64_t double_raw =
|
||||
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
||||
|
||||
return *reinterpret_cast<double*>(&double_raw);
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_max() const {
|
||||
if (is_floating_point()) {
|
||||
return {_floating_point_max()};
|
||||
} else {
|
||||
// PADDLE_ENFORCE(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
||||
// "Cannot represent max as a int64_t");
|
||||
return {(int64_t(1) << mantissa) - 1};
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_min() const {
|
||||
if (is_floating_point()) {
|
||||
// PADDLE_ENFORCE(is_signed(),
|
||||
// "We currently assume all floating point types are signed");
|
||||
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
||||
|
||||
double max = _floating_point_max();
|
||||
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
||||
uint64_t min_raw = max_raw | sign_bit_double;
|
||||
return {*reinterpret_cast<double*>(&min_raw)};
|
||||
} else {
|
||||
// PADDLE_ENFORCE(!is_signed() || size_bits() <= 64,
|
||||
// "Cannot represent min as a int64_t");
|
||||
if (is_signed()) {
|
||||
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
||||
// then perform an arithmetic shift right to set all the bits above
|
||||
// (size_bits() - 1) to 1
|
||||
return {INT64_MIN >> (64 - size_bits())};
|
||||
} else {
|
||||
return {int64_t(0)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Max representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> max() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_max());
|
||||
}
|
||||
|
||||
// Min representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> min() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_min());
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
* for floating point types (leading f) the scheme is:
|
||||
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
* flags:
|
||||
* - no-flags: means it follows IEEE 754 conventions
|
||||
* - f: means finite values only (no infinities)
|
||||
* - n: means nans are supported (non-standard encoding)
|
||||
* for integer types the scheme is:
|
||||
* `[u]int<size_bits>[b<bias>]`
|
||||
* - if bias is not present it means its zero
|
||||
*/
|
||||
if (is_floating_point()) {
|
||||
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
||||
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
||||
if (!is_ieee_754()) {
|
||||
if (finite_values_only) {
|
||||
ret += "f";
|
||||
}
|
||||
if (nan_repr != NAN_NONE) {
|
||||
ret += "n";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
||||
if (has_bias()) {
|
||||
ret += "b" + std::to_string(bias);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool operator==(ScalarType const& other) const {
|
||||
return mantissa == other.mantissa && exponent == other.exponent &&
|
||||
bias == other.bias && signed_ == other.signed_ &&
|
||||
finite_values_only == other.finite_values_only &&
|
||||
nan_repr == other.nan_repr;
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeId = machete::ScalarType::Id;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = machete::ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = machete::ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = machete::ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = machete::ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = machete::ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = machete::ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE2M1f =
|
||||
machete::ScalarType::float_(2, 1, true, machete::ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE3M2f =
|
||||
machete::ScalarType::float_(3, 2, true, machete::ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
machete::ScalarType::float_(4, 3, true, machete::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = machete::ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = machete::ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = machete::ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// // Fixed width style names, generally following:
|
||||
// // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
constexpr auto kInt4 = kS4;
|
||||
constexpr auto kUint4 = kU4;
|
||||
constexpr auto kUint4b8 = kU4B8;
|
||||
constexpr auto kInt8 = kS8;
|
||||
constexpr auto kUint8 = kU8;
|
||||
constexpr auto kUint8b128 = kU8B128;
|
||||
constexpr auto kFloat4_e2m1f = kFE2M1f;
|
||||
constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
constexpr auto kFloat16_e8m7 = kFE8M7;
|
||||
constexpr auto kFloat16_e5m10 = kFE5M10;
|
||||
|
||||
// colloquial names
|
||||
constexpr auto kHalf = kFE5M10;
|
||||
constexpr auto kFloat16 = kHalf;
|
||||
constexpr auto kFloat16Id = kFloat16.id();
|
||||
|
||||
constexpr auto kInt32 = phi::DataType::INT32;
|
||||
constexpr auto kInt64 = phi::DataType::INT64;
|
||||
constexpr auto kBool = phi::DataType::BOOL;
|
||||
constexpr auto kFloat8_e4m3fn = phi::DataType::FLOAT8_E4M3FN;
|
||||
constexpr auto kBFloat16 = phi::DataType::BFLOAT16;
|
||||
constexpr auto kFloat32 = phi::DataType::FLOAT32;
|
||||
constexpr auto kByte = phi::DataType::INT8;
|
||||
|
||||
}; // namespace machete
|
||||
Reference in New Issue
Block a user