Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
@@ -38,315 +39,331 @@
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && \
(__CUDACC_VER_MAJOR__ >= 10))
#define CUTE_ARCH_RED_F16_SM70_ENABLED
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \
(__CUDACC_VER_MAJOR__ >= 12))
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
#endif
namespace cute
{
namespace cute {
//////////////////////////////////
// Wrapper around CUDA's atomicAdd
//////////////////////////////////
template <class T>
struct TypedAtomicAdd
{
using SRegisters = T[1];
using DRegisters = T[1];
struct TypedAtomicAdd {
using SRegisters = T[1];
using DRegisters = T[1];
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst)
{
atomicAdd(&dst, src);
}
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) {
atomicAdd(&dst, src);
}
};
template <class T>
struct Copy_Traits<TypedAtomicAdd<T>>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<TypedAtomicAdd<T>> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
// F16 ADD PTX
//////////////////////////////////
struct SM70_RED_ADD_NOFTZ_F16
{
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
struct SM70_RED_ADD_NOFTZ_F16 {
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
{
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) {
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst),
"h"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
CUTE_INVALID_CONTROL_PATH(
"Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
#endif
}
}
};
template <>
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM70_RED_ADD_NOFTZ_F16x2
{
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
struct SM70_RED_ADD_NOFTZ_F16x2 {
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
{
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) {
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst),
"r"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
CUTE_INVALID_CONTROL_PATH(
"Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
#endif
}
}
};
template <>
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM90_RED_ADD_NOFTZ_F16x2_V2
{
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
struct SM90_RED_ADD_NOFTZ_F16x2_V2 {
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
{
CUTE_HOST_DEVICE static void copy(uint32_t const& src0,
uint32_t const& src1,
uint64_t& gmem_dst) {
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
asm volatile(
"red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst),
"r"(src0),
"r"(src1));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH(
"Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
#endif
}
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
struct SM90_RED_ADD_NOFTZ_F16x2_V4
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
struct SM90_RED_ADD_NOFTZ_F16x2_V4 {
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void copy(
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
{
CUTE_HOST_DEVICE static void copy(uint32_t const& src0,
uint32_t const& src1,
uint32_t const& src2,
uint32_t const& src3,
uint128_t& gmem_dst) {
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
"r"(src2), "r"(src3));
asm volatile(
"red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(
&gmem_dst),
"r"(src0),
"r"(src1),
"r"(src2),
"r"(src3));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH(
"Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
#endif
}
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
// BF16 ADD PTX
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16
{
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
struct SM90_RED_ADD_NOFTZ_BF16 {
using SRegisters = uint16_t[1];
using DRegisters = uint16_t[1];
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
{
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) {
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst),
"h"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH(
"Trying to use red.global.bf16 without "
"CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _16>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2
{
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
struct SM90_RED_ADD_NOFTZ_BF16x2 {
using SRegisters = uint32_t[1];
using DRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
{
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) {
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst),
"r"(src0));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH(
"Trying to use red.global.bf16 without "
"CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _32>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2_V2
{
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
struct SM90_RED_ADD_NOFTZ_BF16x2_V2 {
using SRegisters = uint32_t[2];
using DRegisters = uint64_t[1];
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
{
CUTE_HOST_DEVICE static void copy(uint32_t const& src0,
uint32_t const& src1,
uint64_t& gmem_dst) {
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
asm volatile(
"red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst),
"r"(src0),
"r"(src1));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH(
"Trying to use red.global.bf16 without "
"CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _64>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
struct SM90_RED_ADD_NOFTZ_BF16x2_V4
{
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
struct SM90_RED_ADD_NOFTZ_BF16x2_V4 {
using SRegisters = uint32_t[4];
using DRegisters = uint128_t[1];
CUTE_HOST_DEVICE static void copy(
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
{
CUTE_HOST_DEVICE static void copy(uint32_t const& src0,
uint32_t const& src1,
uint32_t const& src2,
uint32_t const& src3,
uint128_t& gmem_dst) {
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
"r"(src2), "r"(src3));
asm volatile(
"red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(
&gmem_dst),
"r"(src0),
"r"(src1),
"r"(src2),
"r"(src3));
#else
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
CUTE_INVALID_CONTROL_PATH(
"Trying to use red.global.bf16 without "
"CUTE_ARCH_RED_BF16_SM90_ENABLED.");
#endif
}
}
};
template <>
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4> {
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, _128>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
};
//////////////////////////////////
} // end namespace cute
} // end namespace cute
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
@@ -59,8 +60,9 @@ template <
bool GlobalToShared = true>
struct copy;
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
/// Initiates an asynchronous copy from global memory to shared memory. Rather
/// than predicate the entire transfer, zeros are written to SMEM if the guard
/// predicate is false.
///
/// cp.async
///
@@ -72,7 +74,8 @@ template <
bool GlobalToShared = true>
struct copy_zfill;
/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
/// Blocks until all but <N> previous cp.async.commit_group operations have
/// committed.
///
/// cp.async
///
@@ -86,11 +89,11 @@ template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Always, true> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
cp_async<SizeInBytes, CacheOperation::Always>(
smem_ptr, global_ptr, pred_guard);
}
};
@@ -99,15 +102,15 @@ template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Always, false> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) =
*static_cast<AccessType const *>(global_ptr);
}
}
};
@@ -116,11 +119,11 @@ template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Always, true> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
cp_async_zfill<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
cp_async_zfill<SizeInBytes, CacheOperation::Always>(
smem_ptr, global_ptr, pred_guard);
}
};
@@ -129,20 +132,19 @@ template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Always, false> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
using AccessType = Array<uint8_t, SizeInBytes>;
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) =
*static_cast<AccessType const *>(global_ptr);
} else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
}
};
@@ -153,11 +155,11 @@ template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Global, true> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
cp_async<SizeInBytes, CacheOperation::Global>(
smem_ptr, global_ptr, pred_guard);
}
};
@@ -166,15 +168,15 @@ template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Global, false> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) =
*static_cast<AccessType const *>(global_ptr);
}
}
};
@@ -183,11 +185,11 @@ template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Global, true> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async_zfill<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
cp_async_zfill<SizeInBytes, CacheOperation::Global>(
smem_ptr, global_ptr, pred_guard);
}
};
@@ -196,31 +198,29 @@ template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Global, false> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) =
*static_cast<AccessType const *>(global_ptr);
} else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
}
};
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does
/// not block.
template <bool GlobalToShared>
CUTLASS_DEVICE
void copy_fence() {}
CUTLASS_DEVICE void copy_fence() {}
template <>
CUTLASS_DEVICE
void copy_fence<true>() {
CUTLASS_DEVICE void copy_fence<true>() {
cp_async_fence();
}
@@ -229,7 +229,6 @@ void copy_fence<true>() {
/// Partial specialization
template <int N>
struct copy_wait<N, false> {
CUTLASS_DEVICE
copy_wait() {}
};
@@ -237,7 +236,6 @@ struct copy_wait<N, false> {
/// Partial specialization
template <int N>
struct copy_wait<N, true> {
CUTLASS_DEVICE
copy_wait() { cp_async_wait<N>(); }
};
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
@@ -37,10 +38,8 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace arch
{
namespace cutlass {
namespace arch {
// Tag which triggers MMA which will trigger
struct OpMultiplyAddDequantizeInterleavedBToA;
@@ -52,8 +51,8 @@ struct OpMultiplyAddDequantizeInterleavedBToA;
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
with the quantization op before instantiating the GEMM pieces.
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
code we need to duplicate.
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount
of code we need to duplicate.
*/
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
@@ -61,60 +60,59 @@ struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
// The default just forwards the original operator
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
struct TagOperator
{
using TaggedOperator = MmaOp;
struct TagOperator {
using TaggedOperator = MmaOp;
};
// Specializations below attach more information to the operator
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA,
WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY> {
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
};
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA,
WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY> {
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
};
template <>
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
{
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA,
WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS> {
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
};
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
// as a default.
// Here we instantiate some structs to "detag" the tagged operator. It splits it
// back to the original operator + the extra information. If no extra info was
// tagged, the dequant op per column scaling as a default.
template <typename TaggedMmaOp>
struct DetagOperator
{
using Operator = TaggedMmaOp;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
struct DetagOperator {
using Operator = TaggedMmaOp;
static constexpr WeightOnlyQuantOp QuantOp =
WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale> {
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp =
WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale> {
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp =
WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
};
template <>
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
{
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias> {
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
static constexpr WeightOnlyQuantOp QuantOp =
WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
};
} // namespace arch
} // namespace cutlass
} // namespace arch
} // namespace cutlass