mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
@@ -38,315 +39,331 @@
|
||||
|
||||
// Config
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && \
|
||||
(__CUDACC_VER_MAJOR__ >= 10))
|
||||
#define CUTE_ARCH_RED_F16_SM70_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \
|
||||
(__CUDACC_VER_MAJOR__ >= 12))
|
||||
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
|
||||
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
namespace cute {
|
||||
|
||||
//////////////////////////////////
|
||||
// Wrapper around CUDA's atomicAdd
|
||||
//////////////////////////////////
|
||||
|
||||
template <class T>
|
||||
struct TypedAtomicAdd
|
||||
{
|
||||
using SRegisters = T[1];
|
||||
using DRegisters = T[1];
|
||||
struct TypedAtomicAdd {
|
||||
using SRegisters = T[1];
|
||||
using DRegisters = T[1];
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst)
|
||||
{
|
||||
atomicAdd(&dst, src);
|
||||
}
|
||||
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) {
|
||||
atomicAdd(&dst, src);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct Copy_Traits<TypedAtomicAdd<T>>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<TypedAtomicAdd<T>> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
// F16 ADD PTX
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM70_RED_ADD_NOFTZ_F16
|
||||
{
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
struct SM70_RED_ADD_NOFTZ_F16 {
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
|
||||
{
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) {
|
||||
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
|
||||
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
|
||||
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst),
|
||||
"h"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM70_RED_ADD_NOFTZ_F16x2
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
struct SM70_RED_ADD_NOFTZ_F16x2 {
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
|
||||
{
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) {
|
||||
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
|
||||
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
|
||||
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst),
|
||||
"r"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V2
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V2 {
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
|
||||
{
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0,
|
||||
uint32_t const& src1,
|
||||
uint64_t& gmem_dst) {
|
||||
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
|
||||
asm volatile(
|
||||
"red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst),
|
||||
"r"(src0),
|
||||
"r"(src1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V4
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V4 {
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(
|
||||
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
|
||||
{
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0,
|
||||
uint32_t const& src1,
|
||||
uint32_t const& src2,
|
||||
uint32_t const& src3,
|
||||
uint128_t& gmem_dst) {
|
||||
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
|
||||
"r"(src2), "r"(src3));
|
||||
asm volatile(
|
||||
"red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(
|
||||
&gmem_dst),
|
||||
"r"(src0),
|
||||
"r"(src1),
|
||||
"r"(src2),
|
||||
"r"(src3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
// BF16 ADD PTX
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16
|
||||
{
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
struct SM90_RED_ADD_NOFTZ_BF16 {
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
|
||||
{
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) {
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
|
||||
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst),
|
||||
"h"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Trying to use red.global.bf16 without "
|
||||
"CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2 {
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
|
||||
{
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) {
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
|
||||
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst),
|
||||
"r"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Trying to use red.global.bf16 without "
|
||||
"CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V2
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V2 {
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
|
||||
{
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0,
|
||||
uint32_t const& src1,
|
||||
uint64_t& gmem_dst) {
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
|
||||
asm volatile(
|
||||
"red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst),
|
||||
"r"(src0),
|
||||
"r"(src1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Trying to use red.global.bf16 without "
|
||||
"CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V4
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V4 {
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(
|
||||
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
|
||||
{
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0,
|
||||
uint32_t const& src1,
|
||||
uint32_t const& src2,
|
||||
uint32_t const& src3,
|
||||
uint128_t& gmem_dst) {
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
|
||||
"r"(src2), "r"(src3));
|
||||
asm volatile(
|
||||
"red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(
|
||||
&gmem_dst),
|
||||
"r"(src0),
|
||||
"r"(src1),
|
||||
"r"(src2),
|
||||
"r"(src3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"Trying to use red.global.bf16 without "
|
||||
"CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4> {
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
} // end namespace cute
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
@@ -59,8 +60,9 @@ template <
|
||||
bool GlobalToShared = true>
|
||||
struct copy;
|
||||
|
||||
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
|
||||
/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
|
||||
/// Initiates an asynchronous copy from global memory to shared memory. Rather
|
||||
/// than predicate the entire transfer, zeros are written to SMEM if the guard
|
||||
/// predicate is false.
|
||||
///
|
||||
/// cp.async
|
||||
///
|
||||
@@ -72,7 +74,8 @@ template <
|
||||
bool GlobalToShared = true>
|
||||
struct copy_zfill;
|
||||
|
||||
/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
|
||||
/// Blocks until all but <N> previous cp.async.commit_group operations have
|
||||
/// committed.
|
||||
///
|
||||
/// cp.async
|
||||
///
|
||||
@@ -86,11 +89,11 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Always, true> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
|
||||
cp_async<SizeInBytes, CacheOperation::Always>(
|
||||
smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -99,15 +102,15 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Always, false> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) =
|
||||
*static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -116,11 +119,11 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Always, true> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
cp_async_zfill<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
|
||||
cp_async_zfill<SizeInBytes, CacheOperation::Always>(
|
||||
smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -129,20 +132,19 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Always, false> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) =
|
||||
*static_cast<AccessType const *>(global_ptr);
|
||||
} else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -153,11 +155,11 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Global, true> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
|
||||
cp_async<SizeInBytes, CacheOperation::Global>(
|
||||
smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -166,15 +168,15 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Global, false> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) =
|
||||
*static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -183,11 +185,11 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Global, true> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async_zfill<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
|
||||
cp_async_zfill<SizeInBytes, CacheOperation::Global>(
|
||||
smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -196,31 +198,29 @@ template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Global, false> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) =
|
||||
*static_cast<AccessType const *>(global_ptr);
|
||||
} else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does
|
||||
/// not block.
|
||||
template <bool GlobalToShared>
|
||||
CUTLASS_DEVICE
|
||||
void copy_fence() {}
|
||||
CUTLASS_DEVICE void copy_fence() {}
|
||||
|
||||
template <>
|
||||
CUTLASS_DEVICE
|
||||
void copy_fence<true>() {
|
||||
CUTLASS_DEVICE void copy_fence<true>() {
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
@@ -229,7 +229,6 @@ void copy_fence<true>() {
|
||||
/// Partial specialization
|
||||
template <int N>
|
||||
struct copy_wait<N, false> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
copy_wait() {}
|
||||
};
|
||||
@@ -237,7 +236,6 @@ struct copy_wait<N, false> {
|
||||
/// Partial specialization
|
||||
template <int N>
|
||||
struct copy_wait<N, true> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
copy_wait() { cp_async_wait<N>(); }
|
||||
};
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,14 +18,15 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
@@ -37,10 +38,8 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace arch
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
// Tag which triggers MMA which will trigger
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA;
|
||||
@@ -52,8 +51,8 @@ struct OpMultiplyAddDequantizeInterleavedBToA;
|
||||
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
|
||||
with the quantization op before instantiating the GEMM pieces.
|
||||
|
||||
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
|
||||
code we need to duplicate.
|
||||
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount
|
||||
of code we need to duplicate.
|
||||
*/
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
@@ -61,60 +60,59 @@ struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
|
||||
// The default just forwards the original operator
|
||||
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
|
||||
struct TagOperator
|
||||
{
|
||||
using TaggedOperator = MmaOp;
|
||||
struct TagOperator {
|
||||
using TaggedOperator = MmaOp;
|
||||
};
|
||||
|
||||
// Specializations below attach more information to the operator
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA,
|
||||
WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY> {
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA,
|
||||
WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY> {
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA,
|
||||
WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS> {
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
};
|
||||
|
||||
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
|
||||
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
|
||||
// as a default.
|
||||
// Here we instantiate some structs to "detag" the tagged operator. It splits it
|
||||
// back to the original operator + the extra information. If no extra info was
|
||||
// tagged, the dequant op per column scaling as a default.
|
||||
template <typename TaggedMmaOp>
|
||||
struct DetagOperator
|
||||
{
|
||||
using Operator = TaggedMmaOp;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
struct DetagOperator {
|
||||
using Operator = TaggedMmaOp;
|
||||
static constexpr WeightOnlyQuantOp QuantOp =
|
||||
WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale> {
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp =
|
||||
WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale> {
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp =
|
||||
WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias> {
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp =
|
||||
WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
|
||||
};
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
Reference in New Issue
Block a user